summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlon Levy <alevy@redhat.com>2010-08-18 13:43:29 +0300
committerAlon Levy <alevy@redhat.com>2010-08-18 13:43:29 +0300
commit1ecec7be766f59d2f19e74f03f01a61589c32313 (patch)
treea7a0819acdd576f49b39c41cc2df0821243ab5a6
parent08951c2d430aa1216d186f7b52fb9b2298350291 (diff)
proxy version - working; pcap version - misses some packets due to missing reordering logic; doesn't handle bad headers very nicely
-rw-r--r--client.py26
-rw-r--r--client_proto.py65
-rwxr-xr-xdumpspice.py83
-rwxr-xr-xmypcap.py111
-rw-r--r--pcapspice.py178
-rw-r--r--pcaputil.py238
-rwxr-xr-xproxy.py122
-rwxr-xr-xtest_send.py12
8 files changed, 631 insertions, 204 deletions
diff --git a/client.py b/client.py
index 646c2f9..dc9d796 100644
--- a/client.py
+++ b/client.py
@@ -1,8 +1,12 @@
import struct
import socket
+import logging
+
import client_proto
+logger = logging.getLogger('client')
+
# TODO - parse spice/protocol.h directly (to keep in sync automatically)
SPICE_MAGIC = "REDQ"
SPICE_VERSION_MAJOR = 2
@@ -143,6 +147,24 @@ class SpiceDataHeader(Struct):
self.e.type, self.e.sub_list, self.e.size)
__repr__ = __str__
+ @classmethod
+ def verify(cls, *args, **kw):
+ inst = cls(*args, **kw)
+ if inst.e.sub_list != 0 and inst.e.sub_list >= inst.e.size:
+ logger.error('too large sub_list packet - %s' % inst)
+ import pdb; pdb.set_trace()
+ return None
+ if inst.e.type < 0 or inst.e.type >= 400:
+ logger.error('bad type of packet - %s' % inst)
+ import pdb; pdb.set_trace()
+ return None
+ if inst.e.size > 1000000:
+ logger.error('too large packet - %s' % inst)
+ import pdb; pdb.set_trace()
+ return None
+ return inst
+
+
class SpiceSubMessage(Struct):
__metaclass__ = StructMeta
fields = [(uint16, 'type'), (uint32, 'size')]
@@ -167,7 +189,9 @@ class SpiceLinkHeader(Struct):
@classmethod
def verify(cls, *args, **kw):
inst = cls(*args, **kw)
- assert(inst.e.magic == SPICE_MAGIC)
+ if inst.e.magic != SPICE_MAGIC:
+ logger.error('bad magic in packet: %s' % inst)
+ return None
return inst
link_header_size = SpiceLinkHeader.size
diff --git a/client_proto.py b/client_proto.py
index c29c182..3095243 100644
--- a/client_proto.py
+++ b/client_proto.py
@@ -3,7 +3,13 @@ Decode spice protocol using spice demarshaller.
"""
ANNOTATE = False
+NO_PRINT_PROTECTION = False
+import logging
+
+logger = logging.getLogger('client_proto')
+
+from collections import namedtuple
import struct
import sys
if not 'proto' in locals():
@@ -152,7 +158,7 @@ def parse_member(member, data, i_s, parsed):
result_value = None
if len(set(member.attributes.keys()) - set(
['end', 'ctype', 'as_ptr', 'nomarshal', 'anon', 'chunk'])) > 0:
- print "has attributes %s" % member.attributes
+ logging.debug("has attributes %s" % member.attributes)
#import pdb; pdb.set_trace()
if hasattr(member, 'is_switch') and member.is_switch():
var_name = member.variable
@@ -177,24 +183,28 @@ def parse_member(member, data, i_s, parsed):
i_s_out = i_s + primitive.size
elif member_type.is_struct():
i_s_out, result_name, result_value = parse_complex_member(member, member_type, data, i_s)
- elif member_type.is_enum():
+ elif hasattr(member_type, 'is_enum') and member_type.is_enum():
primitive = primitives[member_type.primitive_type()]
value = primitive.unpack(data[i_s:i_s+primitive.size])[0]
result_value = Enum(member_type.names, member_type.names.get(value, value), value)
# TODO - use enum for nice display
i_s_out = i_s + primitive.size
- elif member_type.is_primitive() and hasattr(member_type, 'names'):
+ elif member_type.is_primitive():
# assume flag
- if not isinstance(member_type, ptypes.FlagsType):
- print "not really a flag.."
- import pdb; pdb.set_trace()
primitive = primitives[member_type.primitive_type()]
value = primitive.unpack(data[i_s:i_s+primitive.size])[0]
- result_value = Flag(member_type.names, value)
+ if hasattr(member_type, 'names'):
+ if not isinstance(member_type, ptypes.FlagsType):
+ print "not really a flag.."
+ import pdb; pdb.set_trace()
+ result_value = Flag(member_type.names, value)
+ else:
+ result_value = value
i_s_out = i_s + primitive.size
else:
import pdb; pdb.set_trace()
elif member_type.is_array():
+ num_elements = None
if member_type.is_remaining_length():
data_len = len(data) - i_s
elif member_type.is_identifier_length():
@@ -223,8 +233,10 @@ def parse_member(member, data, i_s, parsed):
element_size = primitives[element_type.name].size
if data_len is None:
data_len = num_elements * element_size
+ if num_elements is None:
+ num_elements = data_len / element_size
primitive = primitive_arrays[element_type.name]
- contents = primitive(data_len).unpack_from(data[i_s:i_s+data_len])
+ contents = primitive(num_elements).unpack_from(data[i_s:i_s+data_len])
i_s_out = i_s + data_len
else:
contents = []
@@ -232,7 +244,7 @@ def parse_member(member, data, i_s, parsed):
for _ in xrange(num_elements):
i_s_out, sub_name, sub_value = parse_member(element_type, data, i_s_out, [])
if sub_value is None:
- print "list parsing failed"
+ logging.error("list parsing failed")
contents = []
break
contents.append((sub_name, sub_value))
@@ -256,11 +268,12 @@ def parse_member(member, data, i_s, parsed):
else:
print "pointer with unknown size??"
import pdb; pdb.set_trace()
- if result_name in ['data', 'ents', 'glyphs', 'String', 'str']:
- result_value = NoPrint(name='', s=result_value)
- elif len(str(result_value)) > 1000:
- print "noprint?"
- import pdb; pdb.set_trace()
+ if NO_PRINT_PROTECTION:
+ if result_name in ['data', 'ents', 'glyphs', 'String', 'str', 'rects']:
+ result_value = NoPrint(name='', s=result_value)
+ elif len(str(result_value)) > 1000:
+ print "noprint?"
+ import pdb; pdb.set_trace()
return i_s_out, result_name, result_value
@simple_annotate_data
@@ -272,11 +285,13 @@ def parse_complex_member(member, the_type, data, i_s):
if result_value is not None:
parsed.append((result_name, result_value))
else:
- print "don't know how to parse %s (%s) in %s" % (
- sub_member.name, the_type, member)
+ logger.error("don't know how to parse %s (%s) in %s" % (
+ sub_member.name, the_type, member))
break
return i_s_out, member.name, parsed
+ParseResult=namedtuple('ParseResult', ['msg_proto', 'result_name', 'result_value'])
+
def parse(channel_type, is_client, header, data):
if channel_type not in channels:
return NoPrint(name='unknown channel (%s %s)' % (
@@ -286,20 +301,20 @@ def parse(channel_type, is_client, header, data):
channel = channels[channel_type]
collection = channel.client if is_client else channel.server
if header.e.type not in collection:
- print "bad data - no such message in protocol"
+ logger.error("bad data - no such message in protocol")
i_s, result_name, result_value = 0, None, None
else:
msg_proto = collection[header.e.type]
i_s, result_name, result_value = parse_complex_member(
msg_proto, msg_proto.message_type, data, 0)
- i_s = max(data.max_pointer_i_s, i_s)
- left_over = NoPrint(name='%s:%s' % (channel_type, i_s), s=data[i_s:]) if i_s < len(data) else ''
- print_annotation(data)
- if len(left_over) > 0:
- print "WARNING: in message %s %s out %s unaccounted for (%s%%)" % (
- msg_proto.name, len(left_over), len(data), float(len(left_over))/len(data))
- #import pdb; pdb.set_trace()
- return (msg_proto, (result_name, result_value))
+ i_s = max(data.max_pointer_i_s, i_s)
+ left_over = NoPrint(name='%s:%s' % (channel_type, i_s), s=data[i_s:]) if i_s < len(data) else ''
+ #result_value.data = data # let the reference escape, so we can print annotation data
+ if len(left_over) > 0:
+ logger.warning("in message %s %s out %s unaccounted for (%2.1d%%)" % (
+ msg_proto.name, len(left_over), len(data), 100.0*len(left_over)/len(data)))
+ #import pdb; pdb.set_trace()
+ return ParseResult(msg_proto, result_name, result_value)
class NoPrint(object):
objs = {}
diff --git a/dumpspice.py b/dumpspice.py
index 4c69db4..9c643c9 100755
--- a/dumpspice.py
+++ b/dumpspice.py
@@ -1,22 +1,83 @@
#!/usr/bin/env python
import sys
import pcaputil
-import pcapspice
+from proxy import proxy, closeallsockets
+from collections import defaultdict
+from time import time
+from select import select
+from optparse import OptionParser
+import logging
-def main(stdscr=None):
- p = pcaputil.packet_iter('lo')
+dt = 1.0
+
+class Histogram(defaultdict):
+ def __init__(self):
+ super(Histogram, self).__init__(lambda: (0,0))
+ self.last = defaultdict(lambda: (0,0))
+ def show(self):
+ print "----------------------------"
+ print '\n'.join(['%20s: %6d %4d' % (k, self[k][1], self[k][1] - self.last[k][1]) for t,k in sorted((t,k) for k,(t,c) in self.items())])
+ self.last.update(self)
+
+verbose = 0
+
+def dumpspice(p, stdscr=None):
+ import pcapspice
spice = pcapspice.spice_iter(p)
+ hist = Histogram()
+ last_print = start_time = time()
if stdscr:
stdscr.erase()
- for d in spice:
- if verbose:
- print d
+ # replace the "for d in spice:" loop with a select
+ while True:
+ #rds, _ws, _xs = select([p.fileno()],[],[],dt)
+ cur_time = time()
+ do_read = True # = len(rds) > 0:
+ if do_read:
+ d = spice.next()
+ if verbose:
+ print d
+ old_time, old_count = hist[d.msg.data.result_name]
+ hist[d.msg.data.result_name] = (cur_time, old_count + 1)
+ if cur_time - last_print > dt:
+ hist.show()
+ last_print = cur_time
+
+def frompcap(stdscr=None):
+ p = pcaputil.packet_iter('lo')
+ return dumpspice(p, stdscr)
+
+def fromproxy(stdscr, local_port, remote_addr):
+ p = proxy(local_port=local_port, remote_addr=remote_addr)
+ return dumpspice(p, stdscr=stdscr)
if __name__ == '__main__':
- verbose = '-v' in sys.argv
- if '-c' in sys.argv:
- import curses
- curses.wrapper(main)
+ parser = OptionParser()
+ parser.add_option('-p', '--proxy', dest='proxy', help='use proxy',
+ action='store_true')
+ parser.add_option('-l', '--localport', dest='local_port', help='set proxy local port')
+ parser.add_option('-r', '--remoteaddr', dest='remote_addr', help='set proxy remote address')
+ parser.add_option('-v', '--verbose', dest='verbose', action='count', help='verbosity', default=0)
+ parser.add_option('-c', '--curses', dest='curses', action='store_true', help='use curses')
+ opts, rest = parser.parse_args(sys.argv[1:])
+ if opts.verbose >= 2 in sys.argv:
+ logging.basicConfig(filename='dumpspice.log', level=logging.DEBUG)
+ print "saving debug log to dumpspice.log"
+ if opts.proxy:
+ local_port = int(opts.local_port)
+ remote_addr = opts.remote_addr.split(':')
+ remote_addr = (remote_addr[0], int(remote_addr[1]))
+ main = (lambda stdscr, local_port=local_port, remote_addr=remote_addr:
+ fromproxy(stdscr, local_port, remote_addr))
else:
- main(None)
+ main = frompcap
+ verbose = opts.verbose
+ try:
+ if opts.curses in sys.argv:
+ import curses
+ curses.wrapper(main)
+ else:
+ main(None)
+ except KeyboardInterrupt, e:
+ closeallsockets()
diff --git a/mypcap.py b/mypcap.py
new file mode 100755
index 0000000..9a3b3c4
--- /dev/null
+++ b/mypcap.py
@@ -0,0 +1,111 @@
+#!/usr/bin/python
+
+from time import time
+import logging
+from ctypes import (cdll, c_char_p, create_string_buffer, Structure, byref
+ , c_int32, c_long, pointer, POINTER, c_char)
+
+logger = logging.getLogger('mypcap')
+pcap = cdll.LoadLibrary('libpcap.so')
+PCAP_ERRBUF_SIZE = 1000 # bogus
+DATA_SIZE = 2000000
+for f in "pcap_next pcap_create pcap_set_snaplen pcap_set_buffer_size pcap_activate pcap_next_ex pcap_perror".split():
+ if len(f) > 0:
+ globals()[f] = getattr(pcap, f)
+
+time_t = c_long
+class timeval(Structure):
+ _fields_ = [('tv_sec', time_t),
+ ('tv_usec', c_long)]
+
+class pcap_pkthdr(Structure):
+ _fields_ = [('ts', timeval), # time stamp
+ ('caplen', c_int32), # length of portion present
+ ('len', c_int32)] # length this packet (off wire)
+
+PCAP_ERROR = -1 # generic error code
+PCAP_ERROR_BREAK = -2 # loop terminated by pcap_breakloop
+PCAP_ERROR_NOT_ACTIVATED = -3 # the capture needs to be activated
+PCAP_ERROR_ACTIVATED = -4 # the operation can't be performed on already activated captures
+PCAP_ERROR_NO_SUCH_DEVICE = -5 # no such device exists
+PCAP_ERROR_RFMON_NOTSUP = -6 # this device doesn't support rfmon (monitor) mode
+PCAP_ERROR_NOT_RFMON = -7 # operation supported only in monitor mode
+PCAP_ERROR_PERM_DENIED = -8 # no permission to open the device
+PCAP_ERROR_IFACE_NOT_UP = -9 # interface isn't up
+
+PCAP_WARNING = 1 # generic warning code
+PCAP_WARNING_PROMISC_NOTSUP = 2 # this device doesn't support promiscuous mode
+
+pcap_errors = [PCAP_ERROR_ACTIVATED
+ ,PCAP_ERROR_NO_SUCH_DEVICE
+ ,PCAP_ERROR_PERM_DENIED
+ ,PCAP_ERROR_RFMON_NOTSUP
+ ,PCAP_ERROR_IFACE_NOT_UP
+ ,PCAP_ERROR]
+
+pcap_next.restype = POINTER(c_char)
+
+class PCAPError(Exception):
+ pass
+
+class PCAP(object):
+ def __init__(self, dev):
+ self.errbuf = create_string_buffer(PCAP_ERRBUF_SIZE)
+ self.hdr = pcap_pkthdr()
+ self.hdrp = pointer(self.hdr)
+ self.pcap = pcap_create(dev, self.errbuf)
+ if self.pcap == 0:
+ raise PCAPError("error: %s" % self.errbuf.value)
+ if pcap_set_snaplen(self.pcap, 65536) != 0:
+ raise PCAPError("error: set snaplen failed")
+ if pcap_set_buffer_size(self.pcap, DATA_SIZE) != 0:
+ raise PCAPError("error: set buffer size failed")
+ ret = pcap_activate(self.pcap)
+ if ret in [PCAP_WARNING_PROMISC_NOTSUP, PCAP_WARNING]:
+ logger.warning("pcap_activate returned a warning - %d" % ret)
+ elif ret in pcap_errors:
+ pcap_perror(self.pcap, "")
+ raise PCAPError('pcap_activate failed')
+
+ def __iter__(self):
+ i_pkt = 0
+ hdr = self.hdr
+ total = 0
+ while True:
+ data = pcap_next(self.pcap, self.hdrp)
+ #ret = pcap_next_ex(self.pcap, byref(hdrp), byref(data));
+ ret = 1
+ if ret == 1:
+ total += hdr.len
+ i_pkt += 1
+ secs = float(hdr.ts.tv_sec) + float(hdr.ts.tv_usec)/1000000
+ logger.debug("%3d: %5.3f, %d, %d| %d" % (
+ i_pkt, secs, hdr.caplen, hdr.len, total))
+ yield secs, data[:hdr.len] # if data is NULL?
+ elif ret == 0:
+ logger.debug("%3d: timeout" % i_pkt)
+ elif ret == -1:
+ pcap_perror(self.pcap, "")
+ raise PCAPError("pcap_next errored")
+ else:
+ raise PCAPError("error: unexpected return %d" % ret)
+
+pcap = PCAP
+
+def main():
+ cap_2 = 10
+ cap = 2*cap_2
+ total = 0
+ start = time()
+ for secs, data in pcap('lo'):
+ if len(data) > cap:
+ r = data[:cap_2] + ('...%s...' % (len(data) - cap)) + data[-cap_2:]
+ else:
+ r = data
+ total += len(data)
+ print '%7.3f %12d %10d %s' % (secs, total, len(data), repr(r))
+
+if __name__ == '__main__':
+ import sys
+ sys.exit(main())
+
diff --git a/pcapspice.py b/pcapspice.py
index b1c1e96..05a5bef 100644
--- a/pcapspice.py
+++ b/pcapspice.py
@@ -1,24 +1,27 @@
from itertools import chain, repeat
-from pcaputil import (tcp_parse, is_tcp, is_tcp_data, is_tcp_syn,
- conversations_iter, header_conversation_iter)
-import client
+import logging
+
+from pcaputil import header_conversation_iter
+import client_proto
from client import SpiceLinkHeader, SpiceLinkMess, SpiceLinkReply, SpiceDataHeader
+logger = logging.getLogger('pcapspice')
+
def is_single_packet_data(payload):
return len(payload) == SpiceDataHeader(payload).e.size + SpiceDataHeader.size
num_spice_messages = sum(len(ch.client) + len(ch.server) for ch in
- client.client_proto.channels.values())
+ client_proto.channels.values())
valid_message_ids = set(sum([ch.client.keys() + ch.server.keys() for ch in
- client.client_proto.channels.values()], []))
+ client_proto.channels.values()], []))
-all_channels = client.client_proto.channels.keys()
+all_channels = client_proto.channels.keys()
def possible_channels(server_message, header):
return set(c for c in all_channels if header.e.type in
- (client.client_proto.channels[c].server.keys() if server_message
- else client.client_proto.channels[c].client.keys()))
+ (client_proto.channels[c].server.keys() if server_message
+ else client_proto.channels[c].client.keys()))
guesses = {}
@@ -34,140 +37,49 @@ def guess_channel_iter():
seen_headers.append((src, dst, data_header))
optional_channels = optional_channels & possible_channels(
src == min(src, dst), data_header)
- print optional_channels
+ logger.debug(str(optional_channels))
if len(optional_channels) == 1:
channel = list(optional_channels)[0]
guesses[key] = channel
if len(optional_channels) == 0:
import pdb; pdb.set_trace()
-spice_size_from_header = lambda header: header.e.size
-
-def channel_spice_start_iter(lower_port, higher_port):
- server_port, client_port = lower_port, higher_port
- ports = [server_port, client_port]
- messages = dict([(port,None) for port in ports])
- payloads = dict([(port,[]) for port in ports])
- headers = dict([(port,None) for port in ports])
- discarded_len = {client_port:128, server_port:4}
- message_ctors = {client_port:SpiceLinkMess, server_port:SpiceLinkReply}
- if len(payload) > 0:
- payloads[src].append(payload)
- while not all(messages.values()):
- src, dst, payload = yield None
- if len(payload) == 0: continue
- payloads[src].append(payload)
- len_payload = sum(map(len,payloads[src]))
- if headers[src] is None and len_payload >= SpiceLinkHeader.size:
- headers[src] = SpiceLinkHeader(''.join(payloads[src]))
- import pdb; pdb.set_trace()
- expected_len = headers[src].e.size + SpiceLinkHeader.size + discarded_len[src]
- if headers[src] is not None and expected_len <= len_payload:
- if expected_len != len_payload:
- print "extra bytes dumped!!! %d" % (len_payload - expected_len)
- messages[src] = message_ctors[src](''.join(payloads[src])[
- SpiceLinkHeader.size:])
- channel = messages[client_port].e.channel_type
- print "channel type %s created" % channel
- result = None
- src, dst, payload = yield None
- data_iter = channel_spice_iter(src, dst, payload, channel=channel)
- result = data_iter.next()
- while True:
- src, dst, payload = yield result
- result = data_iter.send((src, dst, payload))
-
-def channel_spice_iter(lower_port, higher_port, channel=None):
- server_port, client_port = lower_port, higher_port
- serial = {}
- guesser = guess_channel_iter()
- guesser.next()
-
- result = None
- # yield spice data parsed results
- while True:
- src, dst, payload = yield result
- data_header = SpiceDataHeader(payload)
- # basic sanity check on the header
- bad_header = False
- if serial.has_key(src):
- if data_header.e.serial != serial[src] + 1:
- print "bad serial, replacing for %s->%s (%s!=%s+1)" % (
- src, dst, data_header.e.serial, serial[src])
- bad_header = True
- serial[src] = data_header.e.serial
- if data_header.e.type not in valid_message_ids:
- print "bad message type %s in %s->%s" % (data_header.e.type,
- src, dst)
- if bad_header: continue
- if channel is None:
- channel = guesser.send((src, dst, data_header))
- if channel is not None:
- # read as many messages as it takes to get
- (msg_proto, (result_name, result_value),
- left_over) = client.client_proto.parse(
- channel_type=channel,
- is_client=src == client_port,
- header=data_header,
- data=payload[SpiceDataHeader.size:])
- result = (data_header, msg_proto, result_name, result_value,
- left_over)
-
-# for every tcp packet
-# if it is a syn, read the link header, decide which channel we are
-# if it is data
-# if we know which channel we are, parse it
-# else guess which channel.
-# if we just figured out which channel we are, releaes all pending packets.
-
-def spice_iter1(packet_iter):
- return conversations_iter(
- packet_iter,
- channel_spice_start_iter,
- channel_spice_iter)
-
-def packet_header_iter_gen(hdr_ctor, pkt_ctor):
- return repeat(
- (hdr_ctor.size,
- hdr_ctor,
- lambda h: h.e.size,
- pkt_ctor))
-
def ident(x, **kw):
return x
def channel_spice_message_iter(src, dst, guesser):
+ """ coroutine, accepts (src,dst,header,data) packets
+ and yields maybe parsed spice results.
+
+ maybe X - returns None or X
+ """
server_port, client_port = min(src, dst), max(src, dst)
serial = {}
channel = None
result = None
- # yield spice data parsed results
while True:
src, dst, (data_header, message) = yield result
- # basic sanity check on the header
+ result = None
bad_header = False
if serial.has_key(src):
if data_header.e.serial != serial[src] + 1:
- print "bad serial, replacing for %s->%s (%s!=%s+1)" % (
- src, dst, data_header.e.serial, serial[src])
+ logger.debug("bad serial, replacing for %s->%s (%s!=%s+1)" % (
+ src, dst, data_header.e.serial, serial[src]))
bad_header = True
serial[src] = data_header.e.serial
if data_header.e.type not in valid_message_ids:
- print "bad message type %s in %s->%s" % (data_header.e.type,
- src, dst)
+ logger.error("bad message type %s in %s->%s" % (data_header.e.type,
+ src, dst))
bad_header = True
if bad_header: continue
- if channel is None:
- channel = guesser.send((src, dst, data_header))
+ channel = guesser.send((src, dst, data_header))
if channel is not None:
# read as many messages as it takes to get
- (msg_proto, (result_name, result_value)
- ) = client.client_proto.parse(
- channel_type=channel,
- is_client=src == client_port,
- header=data_header,
- data=message)
- result = (msg_proto, result_name, result_value)
+ result = client_proto.parse(
+ channel_type=channel,
+ is_client=src == client_port,
+ header=data_header,
+ data=message)
class ChannelGuesser(object):
@@ -177,13 +89,15 @@ class ChannelGuesser(object):
self.channel = None
def send(self, (src, dst, payload)):
- if self.channel: return self.channel
- return self.iter.send((src, dst, payload))
+ if self.channel:
+ return self.channel
+ self.channel = self.iter.send((src, dst, payload))
+ return self.channel
def on_client_link_message(self, p, **kw):
self._client_message = SpiceLinkMess(p)
self.channel = self._client_message.e.channel_type
- print "STARTED channel %s" % self.channel
+ logger.info("STARTED channel %s" % self.channel)
def on_server_link_message(self, p, **kw):
self._server_message = SpiceLinkReply(p)
@@ -194,37 +108,35 @@ def spice_iter(packet_iter):
is_server = src < dst
key = (min(src, dst), max(src, dst))
guesser = guessers[key] = guessers.get(key, ChannelGuesser())
- print guesser, is_server, "**** make_start_iter %s -> %s" % (src, dst)
+ logger.debug('%s %s **** make_start_iter %s -> %s' %( guesser, is_server, src, dst))
link_messages = []
yield (SpiceLinkHeader.size,
SpiceLinkHeader.verify,
lambda h: h.e.size,
(guesser.on_server_link_message if is_server else
guesser.on_client_link_message))
- print guesser, is_server, "2"
- message_iter = channel_spice_message_iter(src, dst, guesser)
- message_iter.next()
- # TODO: way to say "src->dst 128, dst->src 4"
+ logger.debug('%s %s 2' % (guesser, is_server))
yield 128 if src > dst else 4, ident, lambda h: 0, ident
- for i, data in enumerate(packet_header_iter_gen(
- SpiceDataHeader,
- lambda pkt, src, dst, header:
- message_iter.send((src, dst, (header, pkt))))):
- print guesser, is_server, "2+%s" % i
+ for i, data in enumerate(make_iter(src, dst)):
+ logger.debug('%s %s 2+%s' % (guesser, is_server, i))
yield data
def make_iter(src, dst):
key = (min(src, dst), max(src, dst))
guesser = guessers[key] = guessers.get(key, ChannelGuesser())
message_iter = channel_spice_message_iter(src, dst, guesser)
message_iter.next()
- return packet_header_iter_gen(
- SpiceDataHeader,
+ return repeat(
+ (SpiceDataHeader.size,
+ SpiceDataHeader.verify,
+ lambda h: h.e.size,
lambda pkt, src, dst, header:
- message_iter.send((src, dst, (header, pkt))))
+ message_iter.send((src, dst, (header, pkt)))))
return header_conversation_iter(
packet_iter,
server_start=make_start_iter,
client_start=make_start_iter,
server=make_iter,
- client=make_iter)
+ client=make_iter,
+ filter_result=lambda r: (r.msg is not None and r.msg.data != None
+ and hasattr(r.msg.data, 'result_name')))
diff --git a/pcaputil.py b/pcaputil.py
index 007c6c2..c6432fe 100644
--- a/pcaputil.py
+++ b/pcaputil.py
@@ -1,7 +1,11 @@
import os
import struct
from itertools import imap, ifilter
-#import pcap
+import logging
+from collections import namedtuple, defaultdict
+import hashlib
+
+logger = logging.getLogger('pcaputil')
TCP_PROTOCOL = 6
TCP_SYN = 2
@@ -20,16 +24,17 @@ def is_tcp_syn(pkt):
return (is_tcp(pkt)
and (ord(pkt[47]) & TCP_SYN))
+TCP = namedtuple('TCP', ['src', 'dst', 'flags', 'seq', 'ack', 'window', 'data'])
def tcp_parse(pkt):
tcp_offset = 14 + (ord(pkt[14]) & 0xf) * 4
src, dst, seq, ack, data_offset16, flags, window = struct.unpack(
'>HHIIBBH', pkt[tcp_offset:tcp_offset+16])
data_offset = tcp_offset + (data_offset16 >> 2)
if data_offset != 66:
- print "WHOOHOO"
- return src, dst, flags, pkt[data_offset:]
+ logger.debug("WHOOHOO")
+ return TCP(src, dst, flags, seq, ack, window, pkt[data_offset:])
-def mypcap(filename):
+def myparsepcap(filename):
with open(filename, 'r') as fd:
file_size = os.stat(filename).st_size
i = 0
@@ -40,7 +45,7 @@ def mypcap(filename):
while i + hdr_len < file_size:
ts, l1, l2 = struct.unpack('QII', fd.read(hdr_len))
if l1 < l2:
- print "short recorded packet: #%s,%s: %s < %s" % (i_pkt, i, l1, l2)
+ logger.debug("short recorded packet: #%s,%s: %s < %s" % (i_pkt, i, l1, l2))
pkt = fd.read(l1)
if is_tcp_data(pkt):
src, dst, tcp_payload = tcp_parse(pkt)
@@ -51,7 +56,7 @@ def mypcap(filename):
def get_conversations(file):
convs = {}
port_counts = {}
- for ts, src, dst, data in mypcap(file):
+ for ts, src, dst, data in myparsepcap(file):
key = (src,dst)
if key not in convs:
convs[key] = [ts]
@@ -68,7 +73,80 @@ def get_conversations(file):
times[client[0]] = (client, server)
return [map(lambda x: ''.join(x[1:]), times[t]) for t in sorted(times.keys())]
+class Conversation(object):
+ def __init__(self, npkt, iter):
+ self.npkt = npkt
+ self.iter = iter
+
def conversations_iter(packet_iter, **keyed_iters):
+ convs = {}
+ keys = {(False, True): 'server', (False, False): 'client',
+ (True, True): 'server_start', (True, False): 'client_start'}
+ filter_result = ((lambda x: x is not None and keyed_iters['filter_result'](x))
+ if 'filter_result' in keyed_iters
+ else (lambda x: x is not None))
+ for i_pkt, (src, dst, payload) in enumerate(packet_iter):
+ logger.debug("%3d: conversation_iter %s->%s #(%s)" % (i_pkt,
+ src, dst, len(payload)))
+ key = (src, dst)
+ if key not in convs:
+ iter_gen = keyed_iters[keys[(True, src < dst)]]
+ convs[key] = Conversation(npkt=1, iter=iter_gen(*key))
+ convs[key].iter.next()
+ else:
+ conv = convs[key]
+ conv.npkt += 1
+ result = convs[key].iter.send((src, dst, payload))
+ if filter_result(result):
+ yield result
+
+class TCPConversation(object):
+ def __init__(self, npkt, min_next_seq, iter):
+ self.npkt = npkt
+ self.min_next_seq = min_next_seq
+ self.iter = iter
+
+def tcp_conversations_iter(packet_iter, **keyed_iters):
+ """ conversation_start_iter is used when we catch the SYN packet,
+ otherwise we use conversation_iter. the later is expected to be
+ used by the former.
+ """
+ convs = {}
+ prints = set()
+ keys = {(False, True): 'server', (False, False): 'client',
+ (True, True): 'server_start', (True, False): 'client_start'}
+ filter_result = ((lambda x: x is not None and keyed_iters['filter_result'](x))
+ if 'filter_result' in keyed_iters
+ else (lambda x: x is not None))
+ for i_pkt, tcp in enumerate(non_reordering_tcp_iter(packet_iter)):
+ src, dst, flags, payload = tcp.src, tcp.dst, tcp.flags, tcp.data
+ logger.debug("%3d: conversation_iter %s->%s #(%s)" % (i_pkt,
+ src, dst, len(payload)))
+ key = (src, dst)
+ fingerprint = hashlib.md5(payload).digest()
+ if key not in convs:
+ iter_gen = keyed_iters[keys[(flags & TCP_SYN > 0, src < dst)]]
+ convs[key] = TCPConversation(npkt=1, iter=iter_gen(*key),
+ min_next_seq = tcp.seq + len(payload) + (1 if flags & TCP_SYN else 0))
+ convs[key].iter.next()
+ else:
+ conv = convs[key]
+ conv.npkt += 1
+ seq_change = tcp.seq - conv.min_next_seq
+ next_min_seq = max(conv.min_next_seq, tcp.seq + len(payload))
+ logger.debug("%4d: seq changed by %s, expect +%s" % (
+ conv.npkt, seq_change, next_min_seq - conv.min_next_seq))
+ if seq_change < 0 and fingerprint not in prints:
+ print "busted!"
+ conv.min_next_seq = next_min_seq
+ if seq_change < 0:
+ continue
+ prints.add(fingerprint)
+ result = convs[key].iter.send((src, dst, payload))
+ if filter_result(result):
+ yield result
+
+def tcp_conversations_iter_with_sequence(packet_iter, **keyed_iters):
""" conversation_start_iter is used when we catch the SYN packet,
otherwise we use conversation_iter. the later is expected to be
used by the former.
@@ -76,69 +154,124 @@ def conversations_iter(packet_iter, **keyed_iters):
conversations = {}
keys = {(False, True): 'server', (False, False): 'client',
(True, True): 'server_start', (True, False): 'client_start'}
- for src, dst, flags, payload in tcp_iter(packet_iter):
- print "conversation_iter %s->%s #(%s)" % (src, dst, len(payload))
+ filter_result = ((lambda x: x is not None and keyed_iters['filter_result'](x))
+ if 'filter_result' in keyed_iters
+ else (lambda x: x is not None))
+ outstanding = defaultdict(lambda: [])
+ seqs = {}
+ seqs_start = {}
+ acks = {}
+ def show_seq(key):
+ return '%s (R%s)' % (seqs[key], sum(seqs[key]) - seqs_start[key])
+ for i_pkt, tcp in enumerate(non_reordering_tcp_iter(packet_iter)):
+ src, dst, flags, seq, payload = tcp.src, tcp.dst, tcp.flags, tcp.seq, tcp.data
+ logger.debug("%3d: conversation_iter %s->%s #(%s)" % (i_pkt,
+ src, dst, len(payload)))
key = (src, dst)
+ rev_key = (dst, src)
if key not in conversations:
iter_gen = keyed_iters[keys[(flags & TCP_SYN > 0, src < dst)]]
conversations[key] = iter_gen(*key)
conversations[key].next()
- maybe_result = conversations[key].send((src, dst, payload))
- if maybe_result is not None:
- yield maybe_result
+ seqs[key] = (seq , 0) # sequence# of first byte in payload, payload length - next packet should have seq of a+b
+ seqs_start[key] = seq
+ acks[key] = tcp.ack
+ logger.debug("_ %s setting seq to %s" % (
+ key, show_seq(key)))
+ outstanding[key].append(tcp)
+ removed = []
+ for (the_seq, i, ot) in sorted(
+ (ot.seq, i, ot) for i, ot in enumerate(outstanding[key])):
+ expected = sum(seqs[key])
+ if expected == the_seq:
+ logger.debug("A %s (%s,%s) %s" % (str(key), len(payload),
+ flags & TCP_SYN > 0, show_seq(key)))
+ removed.append(i)
+ result = conversations[key].send((ot.src, ot.dst, ot.data))
+ if filter_result(result):
+ yield result
+ seqs[key] = (the_seq, max(len(ot.data),
+ (ot.flags & TCP_SYN) / 2))
+ else:
+ logger.debug("C %s +%s == %s != %s, ack = %s" % (
+ key, seqs[key], show_seq(key),
+ the_seq, acks.get(rev_key, None)))
+ for i in reversed(removed):
+ del outstanding[key][i]
+ if len(removed) == 0:
+ logger.debug("B %s: seq %s" % (key, show_seq(key)))
+ #print "B %s: %s" % (key, [(x.seq-seqs[key], len(x.data)) for x in outstanding[key]])
+ #import pdb; pdb.set_trace()
+
def consume_packets(packets, needed_len):
+ """
+ >>> pcaputil.consume_packets(['01234','5678','90ab','cdef'], 7)
+ ('0123456', ['78', '90ab', 'cdef'])
+ """
pkt = None
+ ret_packets = packets
total_len = sum(map(len, packets))
if total_len >= needed_len:
pkt = ''.join(packets)
if total_len > needed_len:
- del packets[:-1]
- packets[0] = packets[0][-(total_len - needed_len):]
+ part_len = 0
+ for i, p in enumerate(packets):
+ if part_len + len(p) > needed_len:
+ ret_packets = [p[needed_len - part_len:]] + packets[i+1:]
+ break
+ part_len += len(p)
pkt = pkt[:needed_len]
- assert(len(packets[0]) + len(pkt) == total_len)
+ assert(sum(map(len, ret_packets)) + len(pkt) == total_len)
else:
- del packets[:]
- return pkt
+ ret_packets = []
+ assert(pkt is None or len(pkt) == needed_len)
+ return pkt, ret_packets
def ident(x, **kw):
return x
+CollectorResult = namedtuple('CollectorResult', ('src', 'dst', 'msg'))
+CollectorPacket = namedtuple('CollectorPacket', ('header', 'data'))
+
def collect_packets(header_iter_gen):
""" Example:
>>> list(pcaputil.send_multiple(pcaputil.collect_packets(lambda src, dst: iter([(10, ident, lambda h: 20, ident)]*2))(17, 42), [(42, 17, ' '*30)]))
[(' ', ' ')]
"""
- def collector(src, dst):
+ def collector(start_src, start_dst):
packets = []
history = []
+ src, dst = start_src, start_dst
header_iter = header_iter_gen(src, dst)
hdr_size, hdr_ctor, size_from_hdr, pkt_ctor = header_iter.next()
msg = None
header = pkt = None
searched_size = hdr_size
while True:
- src, dst, payload = yield (src, dst, msg)
+ src, dst, payload = yield CollectorResult(src, dst, msg)
packets.append(payload)
history.append(payload)
- print "collect_packets: searching for %s (%s)" % (searched_size, sum(map(len, packets)))
- data = consume_packets(packets, searched_size)
+ logger.debug("collect_packets: (%s->%s) searching for %s (%s)" % (src, dst,
+ searched_size, sum(map(len, packets))))
+ data, packets = consume_packets(packets, searched_size)
while data != None:
- print "collect_packets: %s" % len(data)
+ logger.debug("collect_packets: (%s->%s) %s" % (src, dst, len(data)))
if header is None:
header = hdr_ctor(data)
- print "collect_packets: header of length %s, expect %s" % (len(data), size_from_hdr(header))
+ if header is None: # invalid header
+ break
+ logger.debug("collect_packets: (%s->%s) header of length %s, expect %s" % (
+ src, dst, len(data), size_from_hdr(header)))
searched_size = size_from_hdr(header)
- if searched_size > 1000000:
- import pdb; pdb.set_trace()
else:
- print "collect_packets: packet of length %s" % len(data)
+ logger.debug("collect_packets: (%s->%s) packet of length %s" % (src, dst, len(data)))
pkt = pkt_ctor(data, src=src, dst=dst, header=header)
- msg = (header, pkt)
+ msg = CollectorPacket(header, pkt)
hdr_size, hdr_ctor, size_from_hdr, pkt_ctor = header_iter.next()
header = None
searched_size = hdr_size
- data = consume_packets(packets, searched_size)
+ data, packets = consume_packets(packets, searched_size)
return collector
def send_multiple(gen, sent_iter):
@@ -152,20 +285,57 @@ def send_multiple(gen, sent_iter):
return
def header_conversation_iter(packet_iter, **kw):
- return conversations_iter(packet_iter,
- **dict([(k, collect_packets(v)) for k,v in kw.items()]))
+ opts = dict([(k, collect_packets(v)) for k,v in kw.items() if k != 'filter_result'])
+ if 'filter_result' in kw:
+ opts['filter_result'] = kw['filter_result']
+ return conversations_iter(packet_iter, **opts)
def tcp_data_iter(packet_iter):
return ifilter(lambda src, dst, flags, data:
len(data) > 0, imap(tcp_parse, ifilter(is_tcp_data, imap(lambda x: ''.join(x[1]), packet_iter))))
-def tcp_iter(packet_iter):
- return ifilter(lambda (src, dst, flags, data):
- flags & (TCP_SYN | TCP_FIN) or len(data) > 0,
+def non_reordering_tcp_iter(packet_iter):
+ return ifilter(lambda t:
+ t.flags & (TCP_SYN | TCP_FIN) or len(t.data) > 0,
imap(tcp_parse, ifilter(is_tcp, imap(lambda x: ''.join(x[1]), packet_iter))))
+def iter_conv(n_tcp, src, dst):
+ for t in n_tcp:
+ if t.src == src and t.dst == dst:
+ yield t
+
def packet_iter(dev):
- import pcap
- return pcap.pcap(dev)
+ import mypcap as pcap
+ def filter_none():
+ for t in pcap.pcap(dev):
+ if t is None:
+ print "another None from pcap??"
+ else:
+ yield t
+ return filter_none()
+
+def test():
+ logging.basicConfig(level=logging.DEBUG,filename="testout.log")
+ p = packet_iter('lo')
+ class Handler(object):
+ def __init__(self, src, dst):
+ self.src = src
+ self.dst = dst
+ self.data = []
+ def __iter__(self):
+ data = ''
+ src, dst = self.src, self.dst
+ while True:
+ src, dst, data = yield None
+ print ('%s->%s %s' % (src, dst, len(data)))
+ self.data.append(data)
+ def handler(src, dst):
+ return iter(Handler(src, dst))
+ server = client = server_start = client_start = handler
+ for b in conversations_iter(p, server=server,client=client,
+ server_start=server_start,client_start=client_start):
+ print b
+if __name__ == '__main__':
+ test()
diff --git a/proxy.py b/proxy.py
new file mode 100755
index 0000000..a8478d6
--- /dev/null
+++ b/proxy.py
@@ -0,0 +1,122 @@
+#!/usr/bin/python
+from socket import socket, AF_INET, SOCK_STREAM
+from select import select
+
+MAX_PACKET_SIZE=65536
+
+class Socket(socket):
+ sockets = []
+ def __init__(self, *args, **kw):
+ super(Socket, self).__init__(*args, **kw)
+ self.sockets.append(self)
+
+class twowaydict(dict):
+ def __init__(self):
+ self.other = dict()
+ def __delitem__(self, k):
+ if super(twowaydict,self).__contains__(k):
+ other_k = self[k]
+ else:
+ other_k = k
+ k = self.other[k]
+ del self.other[other_k]
+ super(twowaydict,self).__delitem__(k)
+ def __setitem__(self, k, v):
+ super(twowaydict, self).__setitem__(k, v)
+ self.other[v] = k
+ def __contains__(self, k):
+ return super(twowaydict,self).__contains__(k) or k in self.other
+ def __getitem__(self, k):
+ if super(twowaydict,self).__contains__(k):
+ return super(twowaydict,self).__getitem__(k)
+ return self.other[k]
+ def getpair(self, k):
+ if k in self:
+ return k, self[k]
+ elif k in self.other:
+ return k, self.other[k]
+ def allkeys(self):
+ return self.keys() + self.other.keys()
+
+BROKEN_PIPE_ERRNO, TRANSPORT_NOT_CONNECTED_ERRNO = 32, 107
+
+def make_accepter(port):
+ accepter = Socket(AF_INET, SOCK_STREAM)
+ accepter.bind(('0.0.0.0', port))
+ accepter.listen(1)
+ return accepter
+
+def connect(addr):
+ s = Socket(AF_INET, SOCK_STREAM)
+ s.connect(addr)
+ s.setblocking(False)
+ return s
+
+def proxy(local_port, remote_addr):
+ print "proxying from %s to %s" % (local_port, remote_addr)
+ accepter = make_accepter(local_port)
+ open_socks = twowaydict()
+ close_errnos = set([BROKEN_PIPE_ERRNO, TRANSPORT_NOT_CONNECTED_ERRNO])
+ while True:
+ #print "open: %s" % len(open_socks)
+ rds, _wrs, _ex = select(
+ [accepter]+open_socks.allkeys(), [], [])
+ for s in rds:
+ if s is accepter:
+ s, _addr = accepter.accept()
+ open_socks[s] = connect(remote_addr)
+ else:
+ other = open_socks[s]
+ src_dst = [s, other]
+ dont_recv = False
+ for i in xrange(len(src_dst)):
+ try:
+ src_dst[i] = src_dst[i].getpeername()[1]
+ except Exception, e:
+ if e.errno in close_errnos:
+ src_dst[1-i].close()
+ if s in open_socks:
+ del open_socks[s]
+ dont_recv = True
+ if dont_recv: continue
+ src, dst = src_dst
+ data = s.recv(MAX_PACKET_SIZE)
+ if len(data) == 0:
+ other.close()
+ del open_socks[s]
+ continue
+ yield src, dst, data
+ try:
+ other.send(data)
+ except Exception, e:
+ if e.errno in close_errnos:
+ n = len(open_socks)
+ s.close()
+ open_socks[s].close()
+ del open_socks[s]
+ assert(len(open_socks) == n - 1)
+ else:
+ import pdb; pdb.set_trace()
+
+def closeallsockets():
+ for s in Socket.sockets:
+ try:
+ s.close()
+ except:
+ pass
+
+def example_main():
+ import sys
+ target_host, target_port = sys.argv[-1].split(':')
+ local_port = int(sys.argv[-2])
+ for src, dst, data in proxy(local_port , (target_host, int(target_port))):
+ print "%s->%s %s" % (src, dst, len(data))
+
+def tests():
+ port_num = 8000
+ from_port, to_port = port_num, port_num+1000
+ proxy(port_num, ('localhost', port_num+1000))
+
+if __name__ == '__main__':
+ tests()
+
diff --git a/test_send.py b/test_send.py
new file mode 100755
index 0000000..ef9a99c
--- /dev/null
+++ b/test_send.py
@@ -0,0 +1,12 @@
+#!/usr/bin/python
+
+import socket
+
+s=socket.socket(socket.AF_INET,socket.SOCK_STREAM)
+s.connect(('localhost',9999))
+def test(n):
+ s.send(''.join(map(str,xrange(n)))[:n])
+
+while True:
+ n = int(raw_input('n?>'))
+ test(n+34)