diff options
author | Alon Levy <alevy@redhat.com> | 2010-08-18 13:43:29 +0300 |
---|---|---|
committer | Alon Levy <alevy@redhat.com> | 2010-08-18 13:43:29 +0300 |
commit | 1ecec7be766f59d2f19e74f03f01a61589c32313 (patch) | |
tree | a7a0819acdd576f49b39c41cc2df0821243ab5a6 | |
parent | 08951c2d430aa1216d186f7b52fb9b2298350291 (diff) |
proxy version - working; pcap version - misses some packets due to missing reordering logic; doesn't handle bad headers very nicely
-rw-r--r-- | client.py | 26 | ||||
-rw-r--r-- | client_proto.py | 65 | ||||
-rwxr-xr-x | dumpspice.py | 83 | ||||
-rwxr-xr-x | mypcap.py | 111 | ||||
-rw-r--r-- | pcapspice.py | 178 | ||||
-rw-r--r-- | pcaputil.py | 238 | ||||
-rwxr-xr-x | proxy.py | 122 | ||||
-rwxr-xr-x | test_send.py | 12 |
8 files changed, 631 insertions, 204 deletions
@@ -1,8 +1,12 @@ import struct import socket +import logging + import client_proto +logger = logging.getLogger('client') + # TODO - parse spice/protocol.h directly (to keep in sync automatically) SPICE_MAGIC = "REDQ" SPICE_VERSION_MAJOR = 2 @@ -143,6 +147,24 @@ class SpiceDataHeader(Struct): self.e.type, self.e.sub_list, self.e.size) __repr__ = __str__ + @classmethod + def verify(cls, *args, **kw): + inst = cls(*args, **kw) + if inst.e.sub_list != 0 and inst.e.sub_list >= inst.e.size: + logger.error('too large sub_list packet - %s' % inst) + import pdb; pdb.set_trace() + return None + if inst.e.type < 0 or inst.e.type >= 400: + logger.error('bad type of packet - %s' % inst) + import pdb; pdb.set_trace() + return None + if inst.e.size > 1000000: + logger.error('too large packet - %s' % inst) + import pdb; pdb.set_trace() + return None + return inst + + class SpiceSubMessage(Struct): __metaclass__ = StructMeta fields = [(uint16, 'type'), (uint32, 'size')] @@ -167,7 +189,9 @@ class SpiceLinkHeader(Struct): @classmethod def verify(cls, *args, **kw): inst = cls(*args, **kw) - assert(inst.e.magic == SPICE_MAGIC) + if inst.e.magic != SPICE_MAGIC: + logger.error('bad magic in packet: %s' % inst) + return None return inst link_header_size = SpiceLinkHeader.size diff --git a/client_proto.py b/client_proto.py index c29c182..3095243 100644 --- a/client_proto.py +++ b/client_proto.py @@ -3,7 +3,13 @@ Decode spice protocol using spice demarshaller. """ ANNOTATE = False +NO_PRINT_PROTECTION = False +import logging + +logger = logging.getLogger('client_proto') + +from collections import namedtuple import struct import sys if not 'proto' in locals(): @@ -152,7 +158,7 @@ def parse_member(member, data, i_s, parsed): result_value = None if len(set(member.attributes.keys()) - set( ['end', 'ctype', 'as_ptr', 'nomarshal', 'anon', 'chunk'])) > 0: - print "has attributes %s" % member.attributes + logging.debug("has attributes %s" % member.attributes) #import pdb; pdb.set_trace() if hasattr(member, 'is_switch') and member.is_switch(): var_name = member.variable @@ -177,24 +183,28 @@ def parse_member(member, data, i_s, parsed): i_s_out = i_s + primitive.size elif member_type.is_struct(): i_s_out, result_name, result_value = parse_complex_member(member, member_type, data, i_s) - elif member_type.is_enum(): + elif hasattr(member_type, 'is_enum') and member_type.is_enum(): primitive = primitives[member_type.primitive_type()] value = primitive.unpack(data[i_s:i_s+primitive.size])[0] result_value = Enum(member_type.names, member_type.names.get(value, value), value) # TODO - use enum for nice display i_s_out = i_s + primitive.size - elif member_type.is_primitive() and hasattr(member_type, 'names'): + elif member_type.is_primitive(): # assume flag - if not isinstance(member_type, ptypes.FlagsType): - print "not really a flag.." - import pdb; pdb.set_trace() primitive = primitives[member_type.primitive_type()] value = primitive.unpack(data[i_s:i_s+primitive.size])[0] - result_value = Flag(member_type.names, value) + if hasattr(member_type, 'names'): + if not isinstance(member_type, ptypes.FlagsType): + print "not really a flag.." + import pdb; pdb.set_trace() + result_value = Flag(member_type.names, value) + else: + result_value = value i_s_out = i_s + primitive.size else: import pdb; pdb.set_trace() elif member_type.is_array(): + num_elements = None if member_type.is_remaining_length(): data_len = len(data) - i_s elif member_type.is_identifier_length(): @@ -223,8 +233,10 @@ def parse_member(member, data, i_s, parsed): element_size = primitives[element_type.name].size if data_len is None: data_len = num_elements * element_size + if num_elements is None: + num_elements = data_len / element_size primitive = primitive_arrays[element_type.name] - contents = primitive(data_len).unpack_from(data[i_s:i_s+data_len]) + contents = primitive(num_elements).unpack_from(data[i_s:i_s+data_len]) i_s_out = i_s + data_len else: contents = [] @@ -232,7 +244,7 @@ def parse_member(member, data, i_s, parsed): for _ in xrange(num_elements): i_s_out, sub_name, sub_value = parse_member(element_type, data, i_s_out, []) if sub_value is None: - print "list parsing failed" + logging.error("list parsing failed") contents = [] break contents.append((sub_name, sub_value)) @@ -256,11 +268,12 @@ def parse_member(member, data, i_s, parsed): else: print "pointer with unknown size??" import pdb; pdb.set_trace() - if result_name in ['data', 'ents', 'glyphs', 'String', 'str']: - result_value = NoPrint(name='', s=result_value) - elif len(str(result_value)) > 1000: - print "noprint?" - import pdb; pdb.set_trace() + if NO_PRINT_PROTECTION: + if result_name in ['data', 'ents', 'glyphs', 'String', 'str', 'rects']: + result_value = NoPrint(name='', s=result_value) + elif len(str(result_value)) > 1000: + print "noprint?" + import pdb; pdb.set_trace() return i_s_out, result_name, result_value @simple_annotate_data @@ -272,11 +285,13 @@ def parse_complex_member(member, the_type, data, i_s): if result_value is not None: parsed.append((result_name, result_value)) else: - print "don't know how to parse %s (%s) in %s" % ( - sub_member.name, the_type, member) + logger.error("don't know how to parse %s (%s) in %s" % ( + sub_member.name, the_type, member)) break return i_s_out, member.name, parsed +ParseResult=namedtuple('ParseResult', ['msg_proto', 'result_name', 'result_value']) + def parse(channel_type, is_client, header, data): if channel_type not in channels: return NoPrint(name='unknown channel (%s %s)' % ( @@ -286,20 +301,20 @@ def parse(channel_type, is_client, header, data): channel = channels[channel_type] collection = channel.client if is_client else channel.server if header.e.type not in collection: - print "bad data - no such message in protocol" + logger.error("bad data - no such message in protocol") i_s, result_name, result_value = 0, None, None else: msg_proto = collection[header.e.type] i_s, result_name, result_value = parse_complex_member( msg_proto, msg_proto.message_type, data, 0) - i_s = max(data.max_pointer_i_s, i_s) - left_over = NoPrint(name='%s:%s' % (channel_type, i_s), s=data[i_s:]) if i_s < len(data) else '' - print_annotation(data) - if len(left_over) > 0: - print "WARNING: in message %s %s out %s unaccounted for (%s%%)" % ( - msg_proto.name, len(left_over), len(data), float(len(left_over))/len(data)) - #import pdb; pdb.set_trace() - return (msg_proto, (result_name, result_value)) + i_s = max(data.max_pointer_i_s, i_s) + left_over = NoPrint(name='%s:%s' % (channel_type, i_s), s=data[i_s:]) if i_s < len(data) else '' + #result_value.data = data # let the reference escape, so we can print annotation data + if len(left_over) > 0: + logger.warning("in message %s %s out %s unaccounted for (%2.1d%%)" % ( + msg_proto.name, len(left_over), len(data), 100.0*len(left_over)/len(data))) + #import pdb; pdb.set_trace() + return ParseResult(msg_proto, result_name, result_value) class NoPrint(object): objs = {} diff --git a/dumpspice.py b/dumpspice.py index 4c69db4..9c643c9 100755 --- a/dumpspice.py +++ b/dumpspice.py @@ -1,22 +1,83 @@ #!/usr/bin/env python import sys import pcaputil -import pcapspice +from proxy import proxy, closeallsockets +from collections import defaultdict +from time import time +from select import select +from optparse import OptionParser +import logging -def main(stdscr=None): - p = pcaputil.packet_iter('lo') +dt = 1.0 + +class Histogram(defaultdict): + def __init__(self): + super(Histogram, self).__init__(lambda: (0,0)) + self.last = defaultdict(lambda: (0,0)) + def show(self): + print "----------------------------" + print '\n'.join(['%20s: %6d %4d' % (k, self[k][1], self[k][1] - self.last[k][1]) for t,k in sorted((t,k) for k,(t,c) in self.items())]) + self.last.update(self) + +verbose = 0 + +def dumpspice(p, stdscr=None): + import pcapspice spice = pcapspice.spice_iter(p) + hist = Histogram() + last_print = start_time = time() if stdscr: stdscr.erase() - for d in spice: - if verbose: - print d + # replace the "for d in spice:" loop with a select + while True: + #rds, _ws, _xs = select([p.fileno()],[],[],dt) + cur_time = time() + do_read = True # = len(rds) > 0: + if do_read: + d = spice.next() + if verbose: + print d + old_time, old_count = hist[d.msg.data.result_name] + hist[d.msg.data.result_name] = (cur_time, old_count + 1) + if cur_time - last_print > dt: + hist.show() + last_print = cur_time + +def frompcap(stdscr=None): + p = pcaputil.packet_iter('lo') + return dumpspice(p, stdscr) + +def fromproxy(stdscr, local_port, remote_addr): + p = proxy(local_port=local_port, remote_addr=remote_addr) + return dumpspice(p, stdscr=stdscr) if __name__ == '__main__': - verbose = '-v' in sys.argv - if '-c' in sys.argv: - import curses - curses.wrapper(main) + parser = OptionParser() + parser.add_option('-p', '--proxy', dest='proxy', help='use proxy', + action='store_true') + parser.add_option('-l', '--localport', dest='local_port', help='set proxy local port') + parser.add_option('-r', '--remoteaddr', dest='remote_addr', help='set proxy remote address') + parser.add_option('-v', '--verbose', dest='verbose', action='count', help='verbosity', default=0) + parser.add_option('-c', '--curses', dest='curses', action='store_true', help='use curses') + opts, rest = parser.parse_args(sys.argv[1:]) + if opts.verbose >= 2 in sys.argv: + logging.basicConfig(filename='dumpspice.log', level=logging.DEBUG) + print "saving debug log to dumpspice.log" + if opts.proxy: + local_port = int(opts.local_port) + remote_addr = opts.remote_addr.split(':') + remote_addr = (remote_addr[0], int(remote_addr[1])) + main = (lambda stdscr, local_port=local_port, remote_addr=remote_addr: + fromproxy(stdscr, local_port, remote_addr)) else: - main(None) + main = frompcap + verbose = opts.verbose + try: + if opts.curses in sys.argv: + import curses + curses.wrapper(main) + else: + main(None) + except KeyboardInterrupt, e: + closeallsockets() diff --git a/mypcap.py b/mypcap.py new file mode 100755 index 0000000..9a3b3c4 --- /dev/null +++ b/mypcap.py @@ -0,0 +1,111 @@ +#!/usr/bin/python + +from time import time +import logging +from ctypes import (cdll, c_char_p, create_string_buffer, Structure, byref + , c_int32, c_long, pointer, POINTER, c_char) + +logger = logging.getLogger('mypcap') +pcap = cdll.LoadLibrary('libpcap.so') +PCAP_ERRBUF_SIZE = 1000 # bogus +DATA_SIZE = 2000000 +for f in "pcap_next pcap_create pcap_set_snaplen pcap_set_buffer_size pcap_activate pcap_next_ex pcap_perror".split(): + if len(f) > 0: + globals()[f] = getattr(pcap, f) + +time_t = c_long +class timeval(Structure): + _fields_ = [('tv_sec', time_t), + ('tv_usec', c_long)] + +class pcap_pkthdr(Structure): + _fields_ = [('ts', timeval), # time stamp + ('caplen', c_int32), # length of portion present + ('len', c_int32)] # length this packet (off wire) + +PCAP_ERROR = -1 # generic error code +PCAP_ERROR_BREAK = -2 # loop terminated by pcap_breakloop +PCAP_ERROR_NOT_ACTIVATED = -3 # the capture needs to be activated +PCAP_ERROR_ACTIVATED = -4 # the operation can't be performed on already activated captures +PCAP_ERROR_NO_SUCH_DEVICE = -5 # no such device exists +PCAP_ERROR_RFMON_NOTSUP = -6 # this device doesn't support rfmon (monitor) mode +PCAP_ERROR_NOT_RFMON = -7 # operation supported only in monitor mode +PCAP_ERROR_PERM_DENIED = -8 # no permission to open the device +PCAP_ERROR_IFACE_NOT_UP = -9 # interface isn't up + +PCAP_WARNING = 1 # generic warning code +PCAP_WARNING_PROMISC_NOTSUP = 2 # this device doesn't support promiscuous mode + +pcap_errors = [PCAP_ERROR_ACTIVATED + ,PCAP_ERROR_NO_SUCH_DEVICE + ,PCAP_ERROR_PERM_DENIED + ,PCAP_ERROR_RFMON_NOTSUP + ,PCAP_ERROR_IFACE_NOT_UP + ,PCAP_ERROR] + +pcap_next.restype = POINTER(c_char) + +class PCAPError(Exception): + pass + +class PCAP(object): + def __init__(self, dev): + self.errbuf = create_string_buffer(PCAP_ERRBUF_SIZE) + self.hdr = pcap_pkthdr() + self.hdrp = pointer(self.hdr) + self.pcap = pcap_create(dev, self.errbuf) + if self.pcap == 0: + raise PCAPError("error: %s" % self.errbuf.value) + if pcap_set_snaplen(self.pcap, 65536) != 0: + raise PCAPError("error: set snaplen failed") + if pcap_set_buffer_size(self.pcap, DATA_SIZE) != 0: + raise PCAPError("error: set buffer size failed") + ret = pcap_activate(self.pcap) + if ret in [PCAP_WARNING_PROMISC_NOTSUP, PCAP_WARNING]: + logger.warning("pcap_activate returned a warning - %d" % ret) + elif ret in pcap_errors: + pcap_perror(self.pcap, "") + raise PCAPError('pcap_activate failed') + + def __iter__(self): + i_pkt = 0 + hdr = self.hdr + total = 0 + while True: + data = pcap_next(self.pcap, self.hdrp) + #ret = pcap_next_ex(self.pcap, byref(hdrp), byref(data)); + ret = 1 + if ret == 1: + total += hdr.len + i_pkt += 1 + secs = float(hdr.ts.tv_sec) + float(hdr.ts.tv_usec)/1000000 + logger.debug("%3d: %5.3f, %d, %d| %d" % ( + i_pkt, secs, hdr.caplen, hdr.len, total)) + yield secs, data[:hdr.len] # if data is NULL? + elif ret == 0: + logger.debug("%3d: timeout" % i_pkt) + elif ret == -1: + pcap_perror(self.pcap, "") + raise PCAPError("pcap_next errored") + else: + raise PCAPError("error: unexpected return %d" % ret) + +pcap = PCAP + +def main(): + cap_2 = 10 + cap = 2*cap_2 + total = 0 + start = time() + for secs, data in pcap('lo'): + if len(data) > cap: + r = data[:cap_2] + ('...%s...' % (len(data) - cap)) + data[-cap_2:] + else: + r = data + total += len(data) + print '%7.3f %12d %10d %s' % (secs, total, len(data), repr(r)) + +if __name__ == '__main__': + import sys + sys.exit(main()) + diff --git a/pcapspice.py b/pcapspice.py index b1c1e96..05a5bef 100644 --- a/pcapspice.py +++ b/pcapspice.py @@ -1,24 +1,27 @@ from itertools import chain, repeat -from pcaputil import (tcp_parse, is_tcp, is_tcp_data, is_tcp_syn, - conversations_iter, header_conversation_iter) -import client +import logging + +from pcaputil import header_conversation_iter +import client_proto from client import SpiceLinkHeader, SpiceLinkMess, SpiceLinkReply, SpiceDataHeader +logger = logging.getLogger('pcapspice') + def is_single_packet_data(payload): return len(payload) == SpiceDataHeader(payload).e.size + SpiceDataHeader.size num_spice_messages = sum(len(ch.client) + len(ch.server) for ch in - client.client_proto.channels.values()) + client_proto.channels.values()) valid_message_ids = set(sum([ch.client.keys() + ch.server.keys() for ch in - client.client_proto.channels.values()], [])) + client_proto.channels.values()], [])) -all_channels = client.client_proto.channels.keys() +all_channels = client_proto.channels.keys() def possible_channels(server_message, header): return set(c for c in all_channels if header.e.type in - (client.client_proto.channels[c].server.keys() if server_message - else client.client_proto.channels[c].client.keys())) + (client_proto.channels[c].server.keys() if server_message + else client_proto.channels[c].client.keys())) guesses = {} @@ -34,140 +37,49 @@ def guess_channel_iter(): seen_headers.append((src, dst, data_header)) optional_channels = optional_channels & possible_channels( src == min(src, dst), data_header) - print optional_channels + logger.debug(str(optional_channels)) if len(optional_channels) == 1: channel = list(optional_channels)[0] guesses[key] = channel if len(optional_channels) == 0: import pdb; pdb.set_trace() -spice_size_from_header = lambda header: header.e.size - -def channel_spice_start_iter(lower_port, higher_port): - server_port, client_port = lower_port, higher_port - ports = [server_port, client_port] - messages = dict([(port,None) for port in ports]) - payloads = dict([(port,[]) for port in ports]) - headers = dict([(port,None) for port in ports]) - discarded_len = {client_port:128, server_port:4} - message_ctors = {client_port:SpiceLinkMess, server_port:SpiceLinkReply} - if len(payload) > 0: - payloads[src].append(payload) - while not all(messages.values()): - src, dst, payload = yield None - if len(payload) == 0: continue - payloads[src].append(payload) - len_payload = sum(map(len,payloads[src])) - if headers[src] is None and len_payload >= SpiceLinkHeader.size: - headers[src] = SpiceLinkHeader(''.join(payloads[src])) - import pdb; pdb.set_trace() - expected_len = headers[src].e.size + SpiceLinkHeader.size + discarded_len[src] - if headers[src] is not None and expected_len <= len_payload: - if expected_len != len_payload: - print "extra bytes dumped!!! %d" % (len_payload - expected_len) - messages[src] = message_ctors[src](''.join(payloads[src])[ - SpiceLinkHeader.size:]) - channel = messages[client_port].e.channel_type - print "channel type %s created" % channel - result = None - src, dst, payload = yield None - data_iter = channel_spice_iter(src, dst, payload, channel=channel) - result = data_iter.next() - while True: - src, dst, payload = yield result - result = data_iter.send((src, dst, payload)) - -def channel_spice_iter(lower_port, higher_port, channel=None): - server_port, client_port = lower_port, higher_port - serial = {} - guesser = guess_channel_iter() - guesser.next() - - result = None - # yield spice data parsed results - while True: - src, dst, payload = yield result - data_header = SpiceDataHeader(payload) - # basic sanity check on the header - bad_header = False - if serial.has_key(src): - if data_header.e.serial != serial[src] + 1: - print "bad serial, replacing for %s->%s (%s!=%s+1)" % ( - src, dst, data_header.e.serial, serial[src]) - bad_header = True - serial[src] = data_header.e.serial - if data_header.e.type not in valid_message_ids: - print "bad message type %s in %s->%s" % (data_header.e.type, - src, dst) - if bad_header: continue - if channel is None: - channel = guesser.send((src, dst, data_header)) - if channel is not None: - # read as many messages as it takes to get - (msg_proto, (result_name, result_value), - left_over) = client.client_proto.parse( - channel_type=channel, - is_client=src == client_port, - header=data_header, - data=payload[SpiceDataHeader.size:]) - result = (data_header, msg_proto, result_name, result_value, - left_over) - -# for every tcp packet -# if it is a syn, read the link header, decide which channel we are -# if it is data -# if we know which channel we are, parse it -# else guess which channel. -# if we just figured out which channel we are, releaes all pending packets. - -def spice_iter1(packet_iter): - return conversations_iter( - packet_iter, - channel_spice_start_iter, - channel_spice_iter) - -def packet_header_iter_gen(hdr_ctor, pkt_ctor): - return repeat( - (hdr_ctor.size, - hdr_ctor, - lambda h: h.e.size, - pkt_ctor)) - def ident(x, **kw): return x def channel_spice_message_iter(src, dst, guesser): + """ coroutine, accepts (src,dst,header,data) packets + and yields maybe parsed spice results. + + maybe X - returns None or X + """ server_port, client_port = min(src, dst), max(src, dst) serial = {} channel = None result = None - # yield spice data parsed results while True: src, dst, (data_header, message) = yield result - # basic sanity check on the header + result = None bad_header = False if serial.has_key(src): if data_header.e.serial != serial[src] + 1: - print "bad serial, replacing for %s->%s (%s!=%s+1)" % ( - src, dst, data_header.e.serial, serial[src]) + logger.debug("bad serial, replacing for %s->%s (%s!=%s+1)" % ( + src, dst, data_header.e.serial, serial[src])) bad_header = True serial[src] = data_header.e.serial if data_header.e.type not in valid_message_ids: - print "bad message type %s in %s->%s" % (data_header.e.type, - src, dst) + logger.error("bad message type %s in %s->%s" % (data_header.e.type, + src, dst)) bad_header = True if bad_header: continue - if channel is None: - channel = guesser.send((src, dst, data_header)) + channel = guesser.send((src, dst, data_header)) if channel is not None: # read as many messages as it takes to get - (msg_proto, (result_name, result_value) - ) = client.client_proto.parse( - channel_type=channel, - is_client=src == client_port, - header=data_header, - data=message) - result = (msg_proto, result_name, result_value) + result = client_proto.parse( + channel_type=channel, + is_client=src == client_port, + header=data_header, + data=message) class ChannelGuesser(object): @@ -177,13 +89,15 @@ class ChannelGuesser(object): self.channel = None def send(self, (src, dst, payload)): - if self.channel: return self.channel - return self.iter.send((src, dst, payload)) + if self.channel: + return self.channel + self.channel = self.iter.send((src, dst, payload)) + return self.channel def on_client_link_message(self, p, **kw): self._client_message = SpiceLinkMess(p) self.channel = self._client_message.e.channel_type - print "STARTED channel %s" % self.channel + logger.info("STARTED channel %s" % self.channel) def on_server_link_message(self, p, **kw): self._server_message = SpiceLinkReply(p) @@ -194,37 +108,35 @@ def spice_iter(packet_iter): is_server = src < dst key = (min(src, dst), max(src, dst)) guesser = guessers[key] = guessers.get(key, ChannelGuesser()) - print guesser, is_server, "**** make_start_iter %s -> %s" % (src, dst) + logger.debug('%s %s **** make_start_iter %s -> %s' %( guesser, is_server, src, dst)) link_messages = [] yield (SpiceLinkHeader.size, SpiceLinkHeader.verify, lambda h: h.e.size, (guesser.on_server_link_message if is_server else guesser.on_client_link_message)) - print guesser, is_server, "2" - message_iter = channel_spice_message_iter(src, dst, guesser) - message_iter.next() - # TODO: way to say "src->dst 128, dst->src 4" + logger.debug('%s %s 2' % (guesser, is_server)) yield 128 if src > dst else 4, ident, lambda h: 0, ident - for i, data in enumerate(packet_header_iter_gen( - SpiceDataHeader, - lambda pkt, src, dst, header: - message_iter.send((src, dst, (header, pkt))))): - print guesser, is_server, "2+%s" % i + for i, data in enumerate(make_iter(src, dst)): + logger.debug('%s %s 2+%s' % (guesser, is_server, i)) yield data def make_iter(src, dst): key = (min(src, dst), max(src, dst)) guesser = guessers[key] = guessers.get(key, ChannelGuesser()) message_iter = channel_spice_message_iter(src, dst, guesser) message_iter.next() - return packet_header_iter_gen( - SpiceDataHeader, + return repeat( + (SpiceDataHeader.size, + SpiceDataHeader.verify, + lambda h: h.e.size, lambda pkt, src, dst, header: - message_iter.send((src, dst, (header, pkt)))) + message_iter.send((src, dst, (header, pkt))))) return header_conversation_iter( packet_iter, server_start=make_start_iter, client_start=make_start_iter, server=make_iter, - client=make_iter) + client=make_iter, + filter_result=lambda r: (r.msg is not None and r.msg.data != None + and hasattr(r.msg.data, 'result_name'))) diff --git a/pcaputil.py b/pcaputil.py index 007c6c2..c6432fe 100644 --- a/pcaputil.py +++ b/pcaputil.py @@ -1,7 +1,11 @@ import os import struct from itertools import imap, ifilter -#import pcap +import logging +from collections import namedtuple, defaultdict +import hashlib + +logger = logging.getLogger('pcaputil') TCP_PROTOCOL = 6 TCP_SYN = 2 @@ -20,16 +24,17 @@ def is_tcp_syn(pkt): return (is_tcp(pkt) and (ord(pkt[47]) & TCP_SYN)) +TCP = namedtuple('TCP', ['src', 'dst', 'flags', 'seq', 'ack', 'window', 'data']) def tcp_parse(pkt): tcp_offset = 14 + (ord(pkt[14]) & 0xf) * 4 src, dst, seq, ack, data_offset16, flags, window = struct.unpack( '>HHIIBBH', pkt[tcp_offset:tcp_offset+16]) data_offset = tcp_offset + (data_offset16 >> 2) if data_offset != 66: - print "WHOOHOO" - return src, dst, flags, pkt[data_offset:] + logger.debug("WHOOHOO") + return TCP(src, dst, flags, seq, ack, window, pkt[data_offset:]) -def mypcap(filename): +def myparsepcap(filename): with open(filename, 'r') as fd: file_size = os.stat(filename).st_size i = 0 @@ -40,7 +45,7 @@ def mypcap(filename): while i + hdr_len < file_size: ts, l1, l2 = struct.unpack('QII', fd.read(hdr_len)) if l1 < l2: - print "short recorded packet: #%s,%s: %s < %s" % (i_pkt, i, l1, l2) + logger.debug("short recorded packet: #%s,%s: %s < %s" % (i_pkt, i, l1, l2)) pkt = fd.read(l1) if is_tcp_data(pkt): src, dst, tcp_payload = tcp_parse(pkt) @@ -51,7 +56,7 @@ def mypcap(filename): def get_conversations(file): convs = {} port_counts = {} - for ts, src, dst, data in mypcap(file): + for ts, src, dst, data in myparsepcap(file): key = (src,dst) if key not in convs: convs[key] = [ts] @@ -68,7 +73,80 @@ def get_conversations(file): times[client[0]] = (client, server) return [map(lambda x: ''.join(x[1:]), times[t]) for t in sorted(times.keys())] +class Conversation(object): + def __init__(self, npkt, iter): + self.npkt = npkt + self.iter = iter + def conversations_iter(packet_iter, **keyed_iters): + convs = {} + keys = {(False, True): 'server', (False, False): 'client', + (True, True): 'server_start', (True, False): 'client_start'} + filter_result = ((lambda x: x is not None and keyed_iters['filter_result'](x)) + if 'filter_result' in keyed_iters + else (lambda x: x is not None)) + for i_pkt, (src, dst, payload) in enumerate(packet_iter): + logger.debug("%3d: conversation_iter %s->%s #(%s)" % (i_pkt, + src, dst, len(payload))) + key = (src, dst) + if key not in convs: + iter_gen = keyed_iters[keys[(True, src < dst)]] + convs[key] = Conversation(npkt=1, iter=iter_gen(*key)) + convs[key].iter.next() + else: + conv = convs[key] + conv.npkt += 1 + result = convs[key].iter.send((src, dst, payload)) + if filter_result(result): + yield result + +class TCPConversation(object): + def __init__(self, npkt, min_next_seq, iter): + self.npkt = npkt + self.min_next_seq = min_next_seq + self.iter = iter + +def tcp_conversations_iter(packet_iter, **keyed_iters): + """ conversation_start_iter is used when we catch the SYN packet, + otherwise we use conversation_iter. the later is expected to be + used by the former. + """ + convs = {} + prints = set() + keys = {(False, True): 'server', (False, False): 'client', + (True, True): 'server_start', (True, False): 'client_start'} + filter_result = ((lambda x: x is not None and keyed_iters['filter_result'](x)) + if 'filter_result' in keyed_iters + else (lambda x: x is not None)) + for i_pkt, tcp in enumerate(non_reordering_tcp_iter(packet_iter)): + src, dst, flags, payload = tcp.src, tcp.dst, tcp.flags, tcp.data + logger.debug("%3d: conversation_iter %s->%s #(%s)" % (i_pkt, + src, dst, len(payload))) + key = (src, dst) + fingerprint = hashlib.md5(payload).digest() + if key not in convs: + iter_gen = keyed_iters[keys[(flags & TCP_SYN > 0, src < dst)]] + convs[key] = TCPConversation(npkt=1, iter=iter_gen(*key), + min_next_seq = tcp.seq + len(payload) + (1 if flags & TCP_SYN else 0)) + convs[key].iter.next() + else: + conv = convs[key] + conv.npkt += 1 + seq_change = tcp.seq - conv.min_next_seq + next_min_seq = max(conv.min_next_seq, tcp.seq + len(payload)) + logger.debug("%4d: seq changed by %s, expect +%s" % ( + conv.npkt, seq_change, next_min_seq - conv.min_next_seq)) + if seq_change < 0 and fingerprint not in prints: + print "busted!" + conv.min_next_seq = next_min_seq + if seq_change < 0: + continue + prints.add(fingerprint) + result = convs[key].iter.send((src, dst, payload)) + if filter_result(result): + yield result + +def tcp_conversations_iter_with_sequence(packet_iter, **keyed_iters): """ conversation_start_iter is used when we catch the SYN packet, otherwise we use conversation_iter. the later is expected to be used by the former. @@ -76,69 +154,124 @@ def conversations_iter(packet_iter, **keyed_iters): conversations = {} keys = {(False, True): 'server', (False, False): 'client', (True, True): 'server_start', (True, False): 'client_start'} - for src, dst, flags, payload in tcp_iter(packet_iter): - print "conversation_iter %s->%s #(%s)" % (src, dst, len(payload)) + filter_result = ((lambda x: x is not None and keyed_iters['filter_result'](x)) + if 'filter_result' in keyed_iters + else (lambda x: x is not None)) + outstanding = defaultdict(lambda: []) + seqs = {} + seqs_start = {} + acks = {} + def show_seq(key): + return '%s (R%s)' % (seqs[key], sum(seqs[key]) - seqs_start[key]) + for i_pkt, tcp in enumerate(non_reordering_tcp_iter(packet_iter)): + src, dst, flags, seq, payload = tcp.src, tcp.dst, tcp.flags, tcp.seq, tcp.data + logger.debug("%3d: conversation_iter %s->%s #(%s)" % (i_pkt, + src, dst, len(payload))) key = (src, dst) + rev_key = (dst, src) if key not in conversations: iter_gen = keyed_iters[keys[(flags & TCP_SYN > 0, src < dst)]] conversations[key] = iter_gen(*key) conversations[key].next() - maybe_result = conversations[key].send((src, dst, payload)) - if maybe_result is not None: - yield maybe_result + seqs[key] = (seq , 0) # sequence# of first byte in payload, payload length - next packet should have seq of a+b + seqs_start[key] = seq + acks[key] = tcp.ack + logger.debug("_ %s setting seq to %s" % ( + key, show_seq(key))) + outstanding[key].append(tcp) + removed = [] + for (the_seq, i, ot) in sorted( + (ot.seq, i, ot) for i, ot in enumerate(outstanding[key])): + expected = sum(seqs[key]) + if expected == the_seq: + logger.debug("A %s (%s,%s) %s" % (str(key), len(payload), + flags & TCP_SYN > 0, show_seq(key))) + removed.append(i) + result = conversations[key].send((ot.src, ot.dst, ot.data)) + if filter_result(result): + yield result + seqs[key] = (the_seq, max(len(ot.data), + (ot.flags & TCP_SYN) / 2)) + else: + logger.debug("C %s +%s == %s != %s, ack = %s" % ( + key, seqs[key], show_seq(key), + the_seq, acks.get(rev_key, None))) + for i in reversed(removed): + del outstanding[key][i] + if len(removed) == 0: + logger.debug("B %s: seq %s" % (key, show_seq(key))) + #print "B %s: %s" % (key, [(x.seq-seqs[key], len(x.data)) for x in outstanding[key]]) + #import pdb; pdb.set_trace() + def consume_packets(packets, needed_len): + """ + >>> pcaputil.consume_packets(['01234','5678','90ab','cdef'], 7) + ('0123456', ['78', '90ab', 'cdef']) + """ pkt = None + ret_packets = packets total_len = sum(map(len, packets)) if total_len >= needed_len: pkt = ''.join(packets) if total_len > needed_len: - del packets[:-1] - packets[0] = packets[0][-(total_len - needed_len):] + part_len = 0 + for i, p in enumerate(packets): + if part_len + len(p) > needed_len: + ret_packets = [p[needed_len - part_len:]] + packets[i+1:] + break + part_len += len(p) pkt = pkt[:needed_len] - assert(len(packets[0]) + len(pkt) == total_len) + assert(sum(map(len, ret_packets)) + len(pkt) == total_len) else: - del packets[:] - return pkt + ret_packets = [] + assert(pkt is None or len(pkt) == needed_len) + return pkt, ret_packets def ident(x, **kw): return x +CollectorResult = namedtuple('CollectorResult', ('src', 'dst', 'msg')) +CollectorPacket = namedtuple('CollectorPacket', ('header', 'data')) + def collect_packets(header_iter_gen): """ Example: >>> list(pcaputil.send_multiple(pcaputil.collect_packets(lambda src, dst: iter([(10, ident, lambda h: 20, ident)]*2))(17, 42), [(42, 17, ' '*30)])) [(' ', ' ')] """ - def collector(src, dst): + def collector(start_src, start_dst): packets = [] history = [] + src, dst = start_src, start_dst header_iter = header_iter_gen(src, dst) hdr_size, hdr_ctor, size_from_hdr, pkt_ctor = header_iter.next() msg = None header = pkt = None searched_size = hdr_size while True: - src, dst, payload = yield (src, dst, msg) + src, dst, payload = yield CollectorResult(src, dst, msg) packets.append(payload) history.append(payload) - print "collect_packets: searching for %s (%s)" % (searched_size, sum(map(len, packets))) - data = consume_packets(packets, searched_size) + logger.debug("collect_packets: (%s->%s) searching for %s (%s)" % (src, dst, + searched_size, sum(map(len, packets)))) + data, packets = consume_packets(packets, searched_size) while data != None: - print "collect_packets: %s" % len(data) + logger.debug("collect_packets: (%s->%s) %s" % (src, dst, len(data))) if header is None: header = hdr_ctor(data) - print "collect_packets: header of length %s, expect %s" % (len(data), size_from_hdr(header)) + if header is None: # invalid header + break + logger.debug("collect_packets: (%s->%s) header of length %s, expect %s" % ( + src, dst, len(data), size_from_hdr(header))) searched_size = size_from_hdr(header) - if searched_size > 1000000: - import pdb; pdb.set_trace() else: - print "collect_packets: packet of length %s" % len(data) + logger.debug("collect_packets: (%s->%s) packet of length %s" % (src, dst, len(data))) pkt = pkt_ctor(data, src=src, dst=dst, header=header) - msg = (header, pkt) + msg = CollectorPacket(header, pkt) hdr_size, hdr_ctor, size_from_hdr, pkt_ctor = header_iter.next() header = None searched_size = hdr_size - data = consume_packets(packets, searched_size) + data, packets = consume_packets(packets, searched_size) return collector def send_multiple(gen, sent_iter): @@ -152,20 +285,57 @@ def send_multiple(gen, sent_iter): return def header_conversation_iter(packet_iter, **kw): - return conversations_iter(packet_iter, - **dict([(k, collect_packets(v)) for k,v in kw.items()])) + opts = dict([(k, collect_packets(v)) for k,v in kw.items() if k != 'filter_result']) + if 'filter_result' in kw: + opts['filter_result'] = kw['filter_result'] + return conversations_iter(packet_iter, **opts) def tcp_data_iter(packet_iter): return ifilter(lambda src, dst, flags, data: len(data) > 0, imap(tcp_parse, ifilter(is_tcp_data, imap(lambda x: ''.join(x[1]), packet_iter)))) -def tcp_iter(packet_iter): - return ifilter(lambda (src, dst, flags, data): - flags & (TCP_SYN | TCP_FIN) or len(data) > 0, +def non_reordering_tcp_iter(packet_iter): + return ifilter(lambda t: + t.flags & (TCP_SYN | TCP_FIN) or len(t.data) > 0, imap(tcp_parse, ifilter(is_tcp, imap(lambda x: ''.join(x[1]), packet_iter)))) +def iter_conv(n_tcp, src, dst): + for t in n_tcp: + if t.src == src and t.dst == dst: + yield t + def packet_iter(dev): - import pcap - return pcap.pcap(dev) + import mypcap as pcap + def filter_none(): + for t in pcap.pcap(dev): + if t is None: + print "another None from pcap??" + else: + yield t + return filter_none() + +def test(): + logging.basicConfig(level=logging.DEBUG,filename="testout.log") + p = packet_iter('lo') + class Handler(object): + def __init__(self, src, dst): + self.src = src + self.dst = dst + self.data = [] + def __iter__(self): + data = '' + src, dst = self.src, self.dst + while True: + src, dst, data = yield None + print ('%s->%s %s' % (src, dst, len(data))) + self.data.append(data) + def handler(src, dst): + return iter(Handler(src, dst)) + server = client = server_start = client_start = handler + for b in conversations_iter(p, server=server,client=client, + server_start=server_start,client_start=client_start): + print b +if __name__ == '__main__': + test() diff --git a/proxy.py b/proxy.py new file mode 100755 index 0000000..a8478d6 --- /dev/null +++ b/proxy.py @@ -0,0 +1,122 @@ +#!/usr/bin/python +from socket import socket, AF_INET, SOCK_STREAM +from select import select + +MAX_PACKET_SIZE=65536 + +class Socket(socket): + sockets = [] + def __init__(self, *args, **kw): + super(Socket, self).__init__(*args, **kw) + self.sockets.append(self) + +class twowaydict(dict): + def __init__(self): + self.other = dict() + def __delitem__(self, k): + if super(twowaydict,self).__contains__(k): + other_k = self[k] + else: + other_k = k + k = self.other[k] + del self.other[other_k] + super(twowaydict,self).__delitem__(k) + def __setitem__(self, k, v): + super(twowaydict, self).__setitem__(k, v) + self.other[v] = k + def __contains__(self, k): + return super(twowaydict,self).__contains__(k) or k in self.other + def __getitem__(self, k): + if super(twowaydict,self).__contains__(k): + return super(twowaydict,self).__getitem__(k) + return self.other[k] + def getpair(self, k): + if k in self: + return k, self[k] + elif k in self.other: + return k, self.other[k] + def allkeys(self): + return self.keys() + self.other.keys() + +BROKEN_PIPE_ERRNO, TRANSPORT_NOT_CONNECTED_ERRNO = 32, 107 + +def make_accepter(port): + accepter = Socket(AF_INET, SOCK_STREAM) + accepter.bind(('0.0.0.0', port)) + accepter.listen(1) + return accepter + +def connect(addr): + s = Socket(AF_INET, SOCK_STREAM) + s.connect(addr) + s.setblocking(False) + return s + +def proxy(local_port, remote_addr): + print "proxying from %s to %s" % (local_port, remote_addr) + accepter = make_accepter(local_port) + open_socks = twowaydict() + close_errnos = set([BROKEN_PIPE_ERRNO, TRANSPORT_NOT_CONNECTED_ERRNO]) + while True: + #print "open: %s" % len(open_socks) + rds, _wrs, _ex = select( + [accepter]+open_socks.allkeys(), [], []) + for s in rds: + if s is accepter: + s, _addr = accepter.accept() + open_socks[s] = connect(remote_addr) + else: + other = open_socks[s] + src_dst = [s, other] + dont_recv = False + for i in xrange(len(src_dst)): + try: + src_dst[i] = src_dst[i].getpeername()[1] + except Exception, e: + if e.errno in close_errnos: + src_dst[1-i].close() + if s in open_socks: + del open_socks[s] + dont_recv = True + if dont_recv: continue + src, dst = src_dst + data = s.recv(MAX_PACKET_SIZE) + if len(data) == 0: + other.close() + del open_socks[s] + continue + yield src, dst, data + try: + other.send(data) + except Exception, e: + if e.errno in close_errnos: + n = len(open_socks) + s.close() + open_socks[s].close() + del open_socks[s] + assert(len(open_socks) == n - 1) + else: + import pdb; pdb.set_trace() + +def closeallsockets(): + for s in Socket.sockets: + try: + s.close() + except: + pass + +def example_main(): + import sys + target_host, target_port = sys.argv[-1].split(':') + local_port = int(sys.argv[-2]) + for src, dst, data in proxy(local_port , (target_host, int(target_port))): + print "%s->%s %s" % (src, dst, len(data)) + +def tests(): + port_num = 8000 + from_port, to_port = port_num, port_num+1000 + proxy(port_num, ('localhost', port_num+1000)) + +if __name__ == '__main__': + tests() + diff --git a/test_send.py b/test_send.py new file mode 100755 index 0000000..ef9a99c --- /dev/null +++ b/test_send.py @@ -0,0 +1,12 @@ +#!/usr/bin/python + +import socket + +s=socket.socket(socket.AF_INET,socket.SOCK_STREAM) +s.connect(('localhost',9999)) +def test(n): + s.send(''.join(map(str,xrange(n)))[:n]) + +while True: + n = int(raw_input('n?>')) + test(n+34) |