From 08951c2d430aa1216d186f7b52fb9b2298350291 Mon Sep 17 00:00:00 2001 From: Alon Levy Date: Mon, 16 Aug 2010 13:31:28 +0300 Subject: initial --- client.py | 442 ++++++++++++++++++++++++++++++++++++++++++++++++++ client_proto.py | 320 ++++++++++++++++++++++++++++++++++++ client_proto_tests.py | 25 +++ compress.py | 36 ++++ dumpspice.py | 22 +++ getem.py | 98 +++++++++++ main.py | 14 ++ pcapspice.py | 230 ++++++++++++++++++++++++++ pcaputil.py | 171 +++++++++++++++++++ stracer.py | 163 +++++++++++++++++++ util.py | 3 + 11 files changed, 1524 insertions(+) create mode 100644 client.py create mode 100644 client_proto.py create mode 100644 client_proto_tests.py create mode 100644 compress.py create mode 100755 dumpspice.py create mode 100644 getem.py create mode 100644 main.py create mode 100644 pcapspice.py create mode 100644 pcaputil.py create mode 100644 stracer.py create mode 100644 util.py diff --git a/client.py b/client.py new file mode 100644 index 0000000..646c2f9 --- /dev/null +++ b/client.py @@ -0,0 +1,442 @@ +import struct +import socket + +import client_proto + +# TODO - parse spice/protocol.h directly (to keep in sync automatically) +SPICE_MAGIC = "REDQ" +SPICE_VERSION_MAJOR = 2 +SPICE_VERSION_MINOR = 0 + +################################################################################ +# spice-protocol/spice/enums.h + +(SPICE_CHANNEL_MAIN, +SPICE_CHANNEL_DISPLAY, +SPICE_CHANNEL_INPUTS, +SPICE_CHANNEL_CURSOR, +SPICE_CHANNEL_PLAYBACK, +SPICE_CHANNEL_RECORD, +SPICE_CHANNEL_TUNNEL, +SPICE_END_CHANNEL) = xrange(1,1+8) + +################################################################################ + +# Encryption & Ticketing Parameters +SPICE_MAX_PASSWORD_LENGTH=60 +SPICE_TICKET_KEY_PAIR_LENGTH=1024 +SPICE_TICKET_PUBKEY_BYTES=(SPICE_TICKET_KEY_PAIR_LENGTH / 8 + 34) + +def unpack_list(structs, s): + """ struct has a bug - + sizeof('IBIII') == 20 + sizeof('IIIIB') == 16 + """ + ret = [] + start = 0 + for the_struct in structs: + ret.extend(list(the_struct.unpack(s[start:start + s.size]))) + t += s.size + return ret + +def group(format): + return reduce(lambda cs, c: cs[:-1]+[cs[-1]+c] if len(cs) > 0 and c == cs[-1][-1] else cs+[c], format, []) + +class Elements(object): + pass + +ENDIANESS = '<' # small endian + +class StructList(object): + + def __init__(self, formats): + self._s = map(struct.Struct, (ENDIANESS+f for f in formats)) + self.size = sum([s.size for s in self._s]) + + def pack(self, *args): + i_s, i_e = 0, 0 + r = [] + for s in self._s: + i_e += len(s.format) - (1 if s.format[0] in '<>' else 0) + r.append(s.pack(*args[i_s:i_e])) + i_s = i_e + return ''.join(r) + + def unpack(self, st): + r = [] + i_s, i_e = 0, 0 + for s in self._s: + i_e += s.size + r.append(list(s.unpack(st[i_s:i_e]))) + i_s = i_e + return sum(r, []) + +class StructMeta(type): + def __new__(meta, classname, bases, classDict): + fields = classDict['fields'] + is_complex = classDict['_is_complex'] = callable(fields[-1][0]) + if is_complex: + classDict['complex_field'] = complex_field = fields[-1] + fields = fields[:-1] + assert(not any(map(callable, fields))) + classDict['_s'] = StructList(group(''.join(t for t,n in fields))) + classDict['_names'] = [n for t,n in fields] + classDict['size'] = classDict['_s'].size + classDict['field_elements'] = [len(t) for t, n in fields] + return type.__new__(meta, classname, bases, classDict) + +def indice_pairs(sizes): + s = 0 + for size in sizes: + yield s, s+size + s += size + +def cut(elements, sizes): + for s, e in indice_pairs(sizes): + if e - s == 1: + yield elements[s] + else: + yield elements[s:e] + +class Struct(object): + @classmethod + def parse(cls, s): + base = list(cut(cls._s.unpack(s[:cls._s.size]), cls.field_elements)) + if cls._is_complex: + import pdb; pdb.set_trace() + return base + return base + + @classmethod + def make(cls, **kw): + args = [] + args = [kw[n] for n in cls._names] + assert(len(args) == len(cls._names) == len(kw)) + return cls._s.pack(*args) + + def __init__(self, *args, **kw): + self.e = Elements() + if (len(kw) == 0 and len(args) == 1) or (len(kw) == 1 and kw.has_key('s')): + s = args[0] if len(args) == 1 else kw['s'] + self.elements = elements = self.parse(s) + else: + self.elements = elements = [kw[n] for n in self._names] + self.e.__dict__.update(dict(zip(self._names, elements))) + + def tostr(self): + return self.make(**self.e.__dict__) + +uint16 = 'H' +uint32 = 'I' +uint64 = 'Q' +uint8 = 'B' + +uint32_arr = lambda s: ENDIANESS + uint32*s.e.size + +class SpiceDataHeader(Struct): + __metaclass__ = StructMeta + fields = [(uint64, 'serial'), (uint16, 'type'), (uint32, 'size'), + (uint32, 'sub_list')] + + def __str__(self): + return 'Spice #%s, T%s, sub %s, size %s' % (self.e.serial, + self.e.type, self.e.sub_list, self.e.size) + __repr__ = __str__ + +class SpiceSubMessage(Struct): + __metaclass__ = StructMeta + fields = [(uint16, 'type'), (uint32, 'size')] + +class SpiceSubMessageList(Struct): + __metaclass__ = StructMeta + fields = [(uint16, 'size'), (uint32_arr, 'sub_messages')] + +data_header_size = SpiceDataHeader.size + +class SpiceLinkHeader(Struct): + __metaclass__ = StructMeta + fields = [(uint32, 'magic'), (uint32, 'major_version'), + (uint32, 'minor_version'), (uint32, 'size')] + + def __str__(self): + return 'SpiceLink magic=%s, (%s,%s), size %s' % ( + struct.pack('' %( + self.e.error, self.e.num_common_caps, self.e.num_channel_caps, + self.e.pub_key[:10], len(self.e.pub_key)) + __repr__ = __str__ + + +def parse_link_msg(s): + link_header = SpiceLinkHeader(s) + assert(len(s) == SpiceLinkHeader.size + link_header.size) + link_msg = SpiceLinkMess(s[SpiceLinkHeader.size:]) + return link_header, link_msg + +def make_data_msg(serial, type, payload, sub_list): + data_header = SpiceDataHeader(serial=serial, + type=type, size=len(payload), sub_list=sub_list) + +def connect(host, port): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect((host, port)) + return s + +def list_to_str(l, type=uint32): + if len(l) == 0: + return '' + return struct.pack(ENDIANESS+len(l)*type, l) + +class Channel(object): + def __init__(self, s, connection_id, channel_type, channel_id, + common_caps, channel_caps): + self._s = s + self._channel_type = channel_type + self._channel_id = channel_id + self._connection_id = connection_id + self._common_caps = common_caps + self._channel_caps = channel_caps + self._proto = client_proto.channels[channel_type] + self.send_link_message() + + def send_link_message(self): + link_message = SpiceLinkMess( + channel_type=self._channel_type, + channel_id=self._channel_id, + caps_offset=SpiceLinkMess.size, + num_channel_caps=len(self._channel_caps), + num_common_caps=len(self._common_caps), + connection_id=self._connection_id) + link_header = SpiceLinkHeader(magic=SPICE_MAGIC, + major_version=SPICE_VERSION_MAJOR, + minor_version=SPICE_VERSION_MINOR, size=link_message.size) + self._s.send( + link_header.tostr() + + link_message.tostr() + + list_to_str(self._common_caps) + + list_to_str(self._channel_caps)) + self.parse_one(SpiceLinkHeader, self._on_link_reply) + + def _on_link_reply(self, header, data): + link_reply = SpiceLinkReply(data) + self._link_reply = link_reply + self._link_reply_header = header + + def _on_data(self, header, data): + self._last_data = data + self._last_data_header = header + print "_on_data: %s" % len(data) + + def parse_data(self): + self.parse_one(SpiceDataHeader, self._on_data) + + def parse_one(self, ctor, on_message): + header = ctor(self._s.recv(ctor.size)) + on_message(header, self._s.recv(header.e.size)) + +class MainChannel(Channel): + def send_attach_channels(self): + import pdb; pdb.set_trace() + self._s.send(s) + +class Client(object): + def __init__(self, host, port, strace=False): + self._host = host + self._port = port + self._strace = strace + + def connect_main(self): + if self._strace: + import stracer + s = StracerNetwork( + '/store/upstream/bin/spicec -h %s -p %s' % ( + self._host, self._port) + , client=self) + self._main_sock = s.sock_iface() + else: + self._main_sock = connect(self._host, self._port) + self.main = self.connect_to_channel(self._main_sock, + connection_id=0, channel_type=SPICE_CHANNEL_MAIN, + channel_id=0) + + # for usage from ipython interactively + def connect_to_channel(self, s, connection_id, channel_type, channel_id, + common_caps=[], channel_caps=[]): + return Channel(s, + connection_id=connection_id, + channel_type=channel_type, + channel_id=channel_id, + common_caps=common_caps, + channel_caps=channel_caps) + + def run(self): + print "TODO" + + def on_message(self, msg): + import pdb; pdb.set_trace() + +SPICE_MAGIC = struct.unpack(' 182+SpiceDataHeader.size) and ( + SpiceDataHeader(s[182:]).e.serial == 1): + print "found at 182 (good guess!)" + found = True + first_data = 182 + else: + #if len(s) > 100000: + # import pdb; pdb.set_trace() + b = buffer(s) + candidates = [(i, h) for i,h in + ((i, SpiceDataHeader(b[i:])) for i in + xrange(0,min(500,len(s))-SpiceDataHeader.size,2)) if h.e.serial == 1] + if len(candidates) == 1: + found = True + first_data = candidates[0][0] + print "found at %s" % first_data + if found: + rest = s[first_data:] + else: + print "not found" + return rest + +def unspice_channels(keys, collecteds): + return dict([(ch, (key, conv)) for key, (ch, conv) in zip(keys, + parse_conversations(collecteds))]) + +def parse_conversations(conversations): + channels = [] + for client_stream, server_stream in conversations: + print "parsing (client %s, server %s)" % ( + len(client_stream), len(server_stream)) + channel_type, client_messages = unspice_client(client_stream) + channel_conv = (client_messages, + unspice_server(channel_type, server_stream)) + channels.append((channel_type, channel_conv)) + return channels + +def parse_link_start(s, msg_ctor): + if len(s) < SpiceLinkHeader.size: + return 0, [] + link_header = SpiceLinkHeader(s) + assert(link_header.e.magic == SPICE_MAGIC) + msg = msg_ctor(s[ + link_header_size:link_header_size+link_header.e.size]) + return link_header_size+link_header.e.size, [link_header, msg] + +def parse_link_server_start(s): + return parse_link_start(s, SpiceLinkReply) + +def parse_link_client_start(s): + return parse_link_start(s, SpiceLinkMess) + +def unspice_server(channel_type, s): + if len(s) < SpiceLinkHeader.size: + return [] + link_header = SpiceLinkHeader(s) + if link_header.e.magic != SPICE_MAGIC: + print "ch %s: bad server magic, trying to look for SpiceDataHeader (%s)" % ( + channel_type, len(s)) + link_reply = None + rest = find_first_data(s) + else: + link_reply = SpiceLinkReply(s[ + link_header_size:link_header_size+link_header.e.size]) + rest = s[link_header_size+link_header.e.size:] + ret = [link_header, link_reply] + return ret + unspice_rest(channel_type, False, rest) + +def unspice_client(s): + if len(s) == 0: + return -1, [] + # First message starts with a link header, then link message + link_header = SpiceLinkHeader(s) + if link_header.e.magic != SPICE_MAGIC: + print "bad client magic, trying to look for SpiceDataHeader (%s)" % len(s) + link_message = None + channel_type = -1 # XXX - can guess based on messages + rest = find_first_data(s) + else: + link_message = SpiceLinkMess(s[ + link_header_size:link_header_size+link_header.e.size]) + channel_type = link_message.e.channel_type + rest = s[link_header_size+link_header.e.size:] + ret = [link_header, link_message] + return channel_type, ret + unspice_rest(channel_type, True, rest) + +def unspice_rest(channel_type, is_client, s): + datas = [] + if len(s) < 8: + return datas + i_s = 0 + if struct.unpack(ENDIANESS + uint64, s[:8])[0] > 1000: + if is_client: + print "client hack - ignoring 128 bytes" + i_s = 128 + else: + print "server hack - ignoring 4 bytes" + i_s = 4 + while len(s) > i_s: + header = SpiceDataHeader(s[i_s:]) + data_start = i_s + SpiceDataHeader.size + data_end = data_start + header.e.size + if data_end > len(s): + break + datas.append((header, client_proto.parse( + channel_type=channel_type, + is_client=is_client, + header=header, + data=s[data_start:data_end]))) + i_s += SpiceDataHeader.size + header.e.size + return datas + +def parse_ctors(s, ctors): + ret = [] + i = 0 + for c in ctors: + c_size = c.size + if len(s) <= i: + break + ret.append(c(s[i:])) + i += c_size + return ret + + diff --git a/client_proto.py b/client_proto.py new file mode 100644 index 0000000..c29c182 --- /dev/null +++ b/client_proto.py @@ -0,0 +1,320 @@ +""" +Decode spice protocol using spice demarshaller. +""" + +ANNOTATE = False + +import struct +import sys +if not 'proto' in locals(): + sys.path.append('../spice/python_modules') + import spice_parser + import ptypes + proto = spice_parser.parse('../spice/spice.proto') + reloads = 1 +else: + reloads += 1 + + +channels = {} +for channel in proto.channels: + channels[channel.value] = channel + channel.client = dict(zip([x.value for x in + channel.channel_type.client_messages], + channel.channel_type.client_messages)) + channel.server = dict(zip([x.value for x in + channel.channel_type.server_messages], + channel.channel_type.server_messages)) + # todo parsing of messages / building of messages + +def mapdict(f, d): + return dict((k, f(v)) for k,v in d.items()) + +primitives=mapdict(struct.Struct, dict( + uint64=' 0: + print "has attributes %s" % member.attributes + #import pdb; pdb.set_trace() + if hasattr(member, 'is_switch') and member.is_switch(): + var_name = member.variable + parsed_d = dict(parsed) + var = get_sub_member(parsed_d, var_name) + for case in member.cases: + for value in case.values: + if case.values[0] is None: + # default + return parse_member(case.member, data, i_s, parsed) + cond_wrap = (lambda x: not x) if value[0] == '!' else (lambda x: x) + if cond_wrap(var.name == value[1]): + return parse_member(case.member, data, i_s, parsed) + return i_s, result_name, 'empty switch' + member_type = member.member_type if hasattr(member, 'member_type') else member + member_type_name = member_type.name + results = None + if member_type_name is not None: + if member_type_name in primitives: + primitive = primitives[member_type_name] + result_value = primitive.unpack_from(data[i_s:i_s+primitive.size])[0] + i_s_out = i_s + primitive.size + elif member_type.is_struct(): + i_s_out, result_name, result_value = parse_complex_member(member, member_type, data, i_s) + elif member_type.is_enum(): + primitive = primitives[member_type.primitive_type()] + value = primitive.unpack(data[i_s:i_s+primitive.size])[0] + result_value = Enum(member_type.names, member_type.names.get(value, value), value) + # TODO - use enum for nice display + i_s_out = i_s + primitive.size + elif member_type.is_primitive() and hasattr(member_type, 'names'): + # assume flag + if not isinstance(member_type, ptypes.FlagsType): + print "not really a flag.." + import pdb; pdb.set_trace() + primitive = primitives[member_type.primitive_type()] + value = primitive.unpack(data[i_s:i_s+primitive.size])[0] + result_value = Flag(member_type.names, value) + i_s_out = i_s + primitive.size + else: + import pdb; pdb.set_trace() + elif member_type.is_array(): + if member_type.is_remaining_length(): + data_len = len(data) - i_s + elif member_type.is_identifier_length(): + num_elements = dict(parsed)[member_type.size] + data_len = None + elif member_type.is_image_size_length(): + bpp = member_type.size[1] + width_name = member_type.size[2] + rows_name = member_type.size[3] + assert(isinstance(width_name, str)) + assert(isinstance(rows_name, str)) + parsed_d = dict(parsed) + width = get_sub_member(parsed_d, width_name) + rows = get_sub_member(parsed, rows_name) + if bpp == 8: + data_len = rows * width + elif bpp == 1: + data_len = ((width + 7) // 8 ) * rows + else: + data_len =((bpp * width + 7) // 8 ) * rows + else: + print "unhandled array length type" + import pdb; pdb.set_trace() + element_type = member_type.element_type + if element_type.name in primitive_arrays: + element_size = primitives[element_type.name].size + if data_len is None: + data_len = num_elements * element_size + primitive = primitive_arrays[element_type.name] + contents = primitive(data_len).unpack_from(data[i_s:i_s+data_len]) + i_s_out = i_s + data_len + else: + contents = [] + i_s_out = i_s + for _ in xrange(num_elements): + i_s_out, sub_name, sub_value = parse_member(element_type, data, i_s_out, []) + if sub_value is None: + print "list parsing failed" + contents = [] + break + contents.append((sub_name, sub_value)) + result_value = contents + elif member_type.is_pointer(): + pointer_primitive_name = member_type.primitive_type() + if pointer_primitive_name in primitives: + primitive = primitives[pointer_primitive_name] + offset = primitive.unpack_from(data[i_s:i_s+primitive.size])[0] + if offset == 0: + result_value = [] + elif offset > len(data): + print "ooops.. bad offset %s > %s" % (offset, len(data)) + import pdb; pdb.set_trace() + result_value = None + else: + discard_i_s, discard_result_name, result_value = parse_member( + member_type.target_type, data, offset, []) + data.max_pointer_i_s = max(data.max_pointer_i_s, discard_i_s) + i_s_out = i_s + primitive.size + else: + print "pointer with unknown size??" + import pdb; pdb.set_trace() + if result_name in ['data', 'ents', 'glyphs', 'String', 'str']: + result_value = NoPrint(name='', s=result_value) + elif len(str(result_value)) > 1000: + print "noprint?" + import pdb; pdb.set_trace() + return i_s_out, result_name, result_value + +@simple_annotate_data +def parse_complex_member(member, the_type, data, i_s): + parsed = [] + i_s_out = i_s + for sub_member in the_type.members: + i_s_out, result_name, result_value = parse_member(sub_member, data, i_s_out, parsed) + if result_value is not None: + parsed.append((result_name, result_value)) + else: + print "don't know how to parse %s (%s) in %s" % ( + sub_member.name, the_type, member) + break + return i_s_out, member.name, parsed + +def parse(channel_type, is_client, header, data): + if channel_type not in channels: + return NoPrint(name='unknown channel (%s %s)' % ( + is_client, header), s=data) + # annotation support + data = AnnotatedString(data) + channel = channels[channel_type] + collection = channel.client if is_client else channel.server + if header.e.type not in collection: + print "bad data - no such message in protocol" + i_s, result_name, result_value = 0, None, None + else: + msg_proto = collection[header.e.type] + i_s, result_name, result_value = parse_complex_member( + msg_proto, msg_proto.message_type, data, 0) + i_s = max(data.max_pointer_i_s, i_s) + left_over = NoPrint(name='%s:%s' % (channel_type, i_s), s=data[i_s:]) if i_s < len(data) else '' + print_annotation(data) + if len(left_over) > 0: + print "WARNING: in message %s %s out %s unaccounted for (%s%%)" % ( + msg_proto.name, len(left_over), len(data), float(len(left_over))/len(data)) + #import pdb; pdb.set_trace() + return (msg_proto, (result_name, result_value)) + +class NoPrint(object): + objs = {} + def __init__(self, name, s): + self.s = s + self.id = id(self.s) + self.objs[self.id] = self.s + self.name = name + def orig(self): + return self.s + def __str__(self): + return 'NP' + (str(self.s) if len(self.s) < 20 else '%r...[%s,%s] #%s' % ( + self.s[:20], len(self.s), self.id, + self.name)) + __repr__ = __str__ + def __len__(self): + return len(self.s) + diff --git a/client_proto_tests.py b/client_proto_tests.py new file mode 100644 index 0000000..7bab3f0 --- /dev/null +++ b/client_proto_tests.py @@ -0,0 +1,25 @@ +import client_proto as cp + +def test_complete_parse(data, the_type): + i_s, name, result = cp.parse_member(the_type, cp.AnnotatedString(data), 0, []) + assert(i_s == len(data)) + +def test_parse_complex_member(data, member, the_type): + i_s, name, result = cp.parse_complex_member(member, the_type, cp.AnnotatedString(data), 0) + assert(i_s == len(data)) + + +def tests(): + six_rects = '\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00U\x01\x00\x00\x80\x02\x00\x00U\x01\x00\x00\x00\x00\x00\x00\xd1\x01\x00\x00R\x01\x00\x00U\x01\x00\x00K\x02\x00\x00\xd1\x01\x00\x00\x80\x02\x00\x00\xd1\x01\x00\x00\x00\x00\x00\x00\xd6\x01\x00\x00R\x01\x00\x00\xd1\x01\x00\x00K\x02\x00\x00\xd6\x01\x00\x00\x80\x02\x00\x00\xd6\x01\x00\x00\x00\x00\x00\x00\xe0\x01\x00\x00\x80\x02\x00\x00' + draw_text = '\x00\x00\x00\x00\xc7\x01\x00\x00#\x00\x00\x00\xdd\x01\x00\x00I\x00\x00\x00\x01\x01\x00\x00\x00\xc7\x01\x00\x00#\x00\x00\x00\xdd\x01\x00\x00I\x00\x00\x00G\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\xff\xff\x00\x00\x08\x00\x00\x00\x05\x00\x02#\x00\x00\x00\xd7\x01\x00\x00\xff\xff\xff\xff\xf6\xff\xff\xff\t\x00\n\x00\x03\xaf\xfdp\x00?\xff\xff\xf7\x00\xcf\x90\x03\xfe\x00\x15\x00\x06\xff\x00\x009\xef\xf8\x00\x07\xff\xe90\x00\x0e\xf6\x00\x06\x10\x0e\xf3\x01\xcf\xb0\x07\xff\xff\xfe \x00m\xff\xa2\x00,\x00\x00\x00\xd7\x01\x00\x00\x00\x00\x00\x00\xf4\xff\xff\xff\x05\x00\x0c\x00-\xfb\x00\xaf\xff\x00\xbf`\x00\xaf`\x00\x7f\x80\x00?\xb0\x00\x1f\xe0\x00\x0e\xf1\x00\xff\xff\xb0\xbf\xff\xf0\x05\xfa\x00\x02\xbc\x002\x00\x00\x00\xd7\x01\x00\x00\x00\x00\x00\x00\xf6\xff\xff\xff\x08\x00\n\x00-\xfb/\xf0\xbf\xff\xef\xf0\xff \x8f\xf3\xdfr\t\xf6_\xff\xff\xf9\x03\xae\xff\xfb\x00\x00\x00\xff\x0c\xf5\x05\xff\x03\xff\xff\xfb\x00L\xff\xa2<\x00\x00\x00\xd7\x01\x00\x00\xff\xff\xff\xff\xf6\xff\xff\xff\t\x00\n\x00/\xc0\x00\x00\x00\x0f\xf0\x00\x00\x00\x0b\xf3\x00\x00\x00\t\xf5\x00\x00\x00\x06\xf8\x00\x00\x00\x03\xfd\x00\x00\x00\x00\xff`\x00\x00\x00\xee\xf7\x10\x00\x00\xbf_\xfa\x00\x00\x7f#\xaf\x00C\x00\x00\x00\xd7\x01\x00\x00\x00\x00\x00\x00\xf4\xff\xff\xff\x05\x00\x0c\x00-\xfb\x00\xaf\xff\x00\xbf`\x00\xaf`\x00\x7f\x80\x00?\xb0\x00\x1f\xe0\x00\x0e\xf1\x00\xff\xff\xb0\xbf\xff\xf0\x05\xfa\x00\x02\xbc\x00' + clip_rects = cp.proto.channels[1].server[304].message_type.members[0].member_type.members[2].member_type.members[1].cases[0].member + assert(clip_rects.name == 'rects') + test_complete_parse(six_rects, clip_rects) + draw_text_message = cp.proto.channels[1].server[311] + assert(draw_text_message.name == 'draw_text') + test_parse_complex_member(draw_text, draw_text_message, draw_text_message.message_type) + print '\n'.join(map(str,result[1][1])) + +if __name__ == '__main__': + tests() + diff --git a/compress.py b/compress.py new file mode 100644 index 0000000..736d967 --- /dev/null +++ b/compress.py @@ -0,0 +1,36 @@ +""" +SPICE compression and decompression implementation. + +we have our own lz implementation, see spice/common/lz_common.h + +""" +import struct +from util import reverse_dict + +LZ_MAGIC = struct.unpack('IHHIIIII', s) + assert(magic == LZ_MAGIC) + assert(major == LZ_MAJOR) + assert(minor == LZ_MINOR) + del s + return locals() diff --git a/dumpspice.py b/dumpspice.py new file mode 100755 index 0000000..4c69db4 --- /dev/null +++ b/dumpspice.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +import sys +import pcaputil +import pcapspice + +def main(stdscr=None): + p = pcaputil.packet_iter('lo') + spice = pcapspice.spice_iter(p) + if stdscr: + stdscr.erase() + for d in spice: + if verbose: + print d + +if __name__ == '__main__': + verbose = '-v' in sys.argv + if '-c' in sys.argv: + import curses + curses.wrapper(main) + else: + main(None) + diff --git a/getem.py b/getem.py new file mode 100644 index 0000000..b41bf94 --- /dev/null +++ b/getem.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python2 + +import pcap +import sys +import string +import time +import socket +import struct + +protocols={socket.IPPROTO_TCP:'tcp', + socket.IPPROTO_UDP:'udp', + socket.IPPROTO_ICMP:'icmp'} + +def decode_ip_packet(s): + d={} + d['version']=(ord(s[0]) & 0xf0) >> 4 + d['header_len']=ord(s[0]) & 0x0f + d['tos']=ord(s[1]) + d['total_len']=socket.ntohs(struct.unpack('H',s[2:4])[0]) + d['id']=socket.ntohs(struct.unpack('H',s[4:6])[0]) + d['flags']=(ord(s[6]) & 0xe0) >> 5 + d['fragment_offset']=socket.ntohs(struct.unpack('H',s[6:8])[0] & 0x1f) + d['ttl']=ord(s[8]) + d['protocol']=ord(s[9]) + d['checksum']=socket.ntohs(struct.unpack('H',s[10:12])[0]) + d['source_address']=pcap.ntoa(struct.unpack('i',s[12:16])[0]) + d['destination_address']=pcap.ntoa(struct.unpack('i',s[16:20])[0]) + if d['header_len']>5: + d['options']=s[20:4*(d['header_len']-5)] + else: + d['options']=None + d['data']=s[4*d['header_len']:] + return d + + +def dumphex(s): + bytes = map(lambda x: '%.2x' % x, map(ord, s)) + for i in xrange(0,len(bytes)/16): + print ' %s' % string.join(bytes[i*16:(i+1)*16],' ') + print ' %s' % string.join(bytes[(i+1)*16:],' ') + + +def print_packet(pktlen, data, timestamp): + if not data: + return + + if data[12:14]=='\x08\x00': + decoded=decode_ip_packet(data[14:]) + print '\n%s.%f %s > %s' % (time.strftime('%H:%M', + time.localtime(timestamp)), + timestamp % 60, + decoded['source_address'], + decoded['destination_address']) + for key in ['version', 'header_len', 'tos', 'total_len', 'id', + 'flags', 'fragment_offset', 'ttl']: + print ' %s: %d' % (key, decoded[key]) + print ' protocol: %s' % protocols[decoded['protocol']] + print ' header checksum: %d' % decoded['checksum'] + print ' data:' + dumphex(decoded['data']) + + +if __name__=='__main__': + + if len(sys.argv) < 3: + print 'usage: sniff.py ' + sys.exit(0) + #dev = pcap.lookupdev() + dev = sys.argv[1] + #net, mask = pcap.lookupnet(dev) + # note: to_ms does nothing on linux + p = pcap.pcap(dev) + #p.open_live(dev, 1600, 0, 100) + #p.dump_open('dumpfile') + #p.setfilter(string.join(sys.argv[2:],' '), 0, 0) + + # try-except block to catch keyboard interrupt. Failure to shut + # down cleanly can result in the interface not being taken out of promisc. + # mode + #p.setnonblock(1) + try: + while 1: + p.dispatch(1, print_packet) + + # specify 'None' to dump to dumpfile, assuming you have called + # the dump_open method + # p.dispatch(0, None) + + # the loop method is another way of doing things + # p.loop(1, print_packet) + + # as is the next() method + # p.next() returns a (pktlen, data, timestamp) tuple + # apply(print_packet,p.next()) + except KeyboardInterrupt: + print '%s' % sys.exc_type + print 'shutting down' + print '%d packets received, %d packets dropped, %d packets dropped by interface' % p.stats() diff --git a/main.py b/main.py new file mode 100644 index 0000000..a8c54c7 --- /dev/null +++ b/main.py @@ -0,0 +1,14 @@ +import sys +import client +import optparse + +o = optparse.OptionParser() +o.conflict_handler = 'resolve' +o.add_option('-p') +o.add_option('-h') +d, rest = o.parse_args(sys.argv[1:]) +client = client(host=d.h, port=d.p) + +if __name __ == '__main__': + client.run() + diff --git a/pcapspice.py b/pcapspice.py new file mode 100644 index 0000000..b1c1e96 --- /dev/null +++ b/pcapspice.py @@ -0,0 +1,230 @@ +from itertools import chain, repeat +from pcaputil import (tcp_parse, is_tcp, is_tcp_data, is_tcp_syn, + conversations_iter, header_conversation_iter) +import client +from client import SpiceLinkHeader, SpiceLinkMess, SpiceLinkReply, SpiceDataHeader + +def is_single_packet_data(payload): + return len(payload) == SpiceDataHeader(payload).e.size + SpiceDataHeader.size + +num_spice_messages = sum(len(ch.client) + len(ch.server) for ch in + client.client_proto.channels.values()) + +valid_message_ids = set(sum([ch.client.keys() + ch.server.keys() for ch in + client.client_proto.channels.values()], [])) + +all_channels = client.client_proto.channels.keys() + +def possible_channels(server_message, header): + return set(c for c in all_channels if header.e.type in + (client.client_proto.channels[c].server.keys() if server_message + else client.client_proto.channels[c].client.keys())) + +guesses = {} + +def guess_channel_iter(): + channel = None + optional_channels = set(all_channels) + seen_headers = [] + while True: + src, dst, data_header = yield channel + key = (min(src, dst), max(src, dst)) + if guesses.has_key(key): + optional_channels = set([guesses[key]]) + seen_headers.append((src, dst, data_header)) + optional_channels = optional_channels & possible_channels( + src == min(src, dst), data_header) + print optional_channels + if len(optional_channels) == 1: + channel = list(optional_channels)[0] + guesses[key] = channel + if len(optional_channels) == 0: + import pdb; pdb.set_trace() + +spice_size_from_header = lambda header: header.e.size + +def channel_spice_start_iter(lower_port, higher_port): + server_port, client_port = lower_port, higher_port + ports = [server_port, client_port] + messages = dict([(port,None) for port in ports]) + payloads = dict([(port,[]) for port in ports]) + headers = dict([(port,None) for port in ports]) + discarded_len = {client_port:128, server_port:4} + message_ctors = {client_port:SpiceLinkMess, server_port:SpiceLinkReply} + if len(payload) > 0: + payloads[src].append(payload) + while not all(messages.values()): + src, dst, payload = yield None + if len(payload) == 0: continue + payloads[src].append(payload) + len_payload = sum(map(len,payloads[src])) + if headers[src] is None and len_payload >= SpiceLinkHeader.size: + headers[src] = SpiceLinkHeader(''.join(payloads[src])) + import pdb; pdb.set_trace() + expected_len = headers[src].e.size + SpiceLinkHeader.size + discarded_len[src] + if headers[src] is not None and expected_len <= len_payload: + if expected_len != len_payload: + print "extra bytes dumped!!! %d" % (len_payload - expected_len) + messages[src] = message_ctors[src](''.join(payloads[src])[ + SpiceLinkHeader.size:]) + channel = messages[client_port].e.channel_type + print "channel type %s created" % channel + result = None + src, dst, payload = yield None + data_iter = channel_spice_iter(src, dst, payload, channel=channel) + result = data_iter.next() + while True: + src, dst, payload = yield result + result = data_iter.send((src, dst, payload)) + +def channel_spice_iter(lower_port, higher_port, channel=None): + server_port, client_port = lower_port, higher_port + serial = {} + guesser = guess_channel_iter() + guesser.next() + + result = None + # yield spice data parsed results + while True: + src, dst, payload = yield result + data_header = SpiceDataHeader(payload) + # basic sanity check on the header + bad_header = False + if serial.has_key(src): + if data_header.e.serial != serial[src] + 1: + print "bad serial, replacing for %s->%s (%s!=%s+1)" % ( + src, dst, data_header.e.serial, serial[src]) + bad_header = True + serial[src] = data_header.e.serial + if data_header.e.type not in valid_message_ids: + print "bad message type %s in %s->%s" % (data_header.e.type, + src, dst) + if bad_header: continue + if channel is None: + channel = guesser.send((src, dst, data_header)) + if channel is not None: + # read as many messages as it takes to get + (msg_proto, (result_name, result_value), + left_over) = client.client_proto.parse( + channel_type=channel, + is_client=src == client_port, + header=data_header, + data=payload[SpiceDataHeader.size:]) + result = (data_header, msg_proto, result_name, result_value, + left_over) + +# for every tcp packet +# if it is a syn, read the link header, decide which channel we are +# if it is data +# if we know which channel we are, parse it +# else guess which channel. +# if we just figured out which channel we are, releaes all pending packets. + +def spice_iter1(packet_iter): + return conversations_iter( + packet_iter, + channel_spice_start_iter, + channel_spice_iter) + +def packet_header_iter_gen(hdr_ctor, pkt_ctor): + return repeat( + (hdr_ctor.size, + hdr_ctor, + lambda h: h.e.size, + pkt_ctor)) + +def ident(x, **kw): + return x + +def channel_spice_message_iter(src, dst, guesser): + server_port, client_port = min(src, dst), max(src, dst) + serial = {} + channel = None + result = None + # yield spice data parsed results + while True: + src, dst, (data_header, message) = yield result + # basic sanity check on the header + bad_header = False + if serial.has_key(src): + if data_header.e.serial != serial[src] + 1: + print "bad serial, replacing for %s->%s (%s!=%s+1)" % ( + src, dst, data_header.e.serial, serial[src]) + bad_header = True + serial[src] = data_header.e.serial + if data_header.e.type not in valid_message_ids: + print "bad message type %s in %s->%s" % (data_header.e.type, + src, dst) + bad_header = True + if bad_header: continue + if channel is None: + channel = guesser.send((src, dst, data_header)) + if channel is not None: + # read as many messages as it takes to get + (msg_proto, (result_name, result_value) + ) = client.client_proto.parse( + channel_type=channel, + is_client=src == client_port, + header=data_header, + data=message) + result = (msg_proto, result_name, result_value) + +class ChannelGuesser(object): + + def __init__(self): + self.iter = guess_channel_iter() + self.iter.next() + self.channel = None + + def send(self, (src, dst, payload)): + if self.channel: return self.channel + return self.iter.send((src, dst, payload)) + + def on_client_link_message(self, p, **kw): + self._client_message = SpiceLinkMess(p) + self.channel = self._client_message.e.channel_type + print "STARTED channel %s" % self.channel + + def on_server_link_message(self, p, **kw): + self._server_message = SpiceLinkReply(p) + +def spice_iter(packet_iter): + guessers = {} + def make_start_iter(src, dst): + is_server = src < dst + key = (min(src, dst), max(src, dst)) + guesser = guessers[key] = guessers.get(key, ChannelGuesser()) + print guesser, is_server, "**** make_start_iter %s -> %s" % (src, dst) + link_messages = [] + yield (SpiceLinkHeader.size, + SpiceLinkHeader.verify, + lambda h: h.e.size, + (guesser.on_server_link_message if is_server else + guesser.on_client_link_message)) + print guesser, is_server, "2" + message_iter = channel_spice_message_iter(src, dst, guesser) + message_iter.next() + # TODO: way to say "src->dst 128, dst->src 4" + yield 128 if src > dst else 4, ident, lambda h: 0, ident + for i, data in enumerate(packet_header_iter_gen( + SpiceDataHeader, + lambda pkt, src, dst, header: + message_iter.send((src, dst, (header, pkt))))): + print guesser, is_server, "2+%s" % i + yield data + def make_iter(src, dst): + key = (min(src, dst), max(src, dst)) + guesser = guessers[key] = guessers.get(key, ChannelGuesser()) + message_iter = channel_spice_message_iter(src, dst, guesser) + message_iter.next() + return packet_header_iter_gen( + SpiceDataHeader, + lambda pkt, src, dst, header: + message_iter.send((src, dst, (header, pkt)))) + return header_conversation_iter( + packet_iter, + server_start=make_start_iter, + client_start=make_start_iter, + server=make_iter, + client=make_iter) + diff --git a/pcaputil.py b/pcaputil.py new file mode 100644 index 0000000..007c6c2 --- /dev/null +++ b/pcaputil.py @@ -0,0 +1,171 @@ +import os +import struct +from itertools import imap, ifilter +#import pcap + +TCP_PROTOCOL = 6 +TCP_SYN = 2 +TCP_FIN = 1 + +def is_tcp(pkt): + return (len(pkt) > 47 + and pkt[12:14] == '\x08\x00' # Ethernet.Type == IP + and ord(pkt[23]) == TCP_PROTOCOL) + +def is_tcp_data(pkt): + return (is_tcp(pkt) + and (not ord(pkt[47]) & (TCP_SYN | TCP_FIN))) + +def is_tcp_syn(pkt): + return (is_tcp(pkt) + and (ord(pkt[47]) & TCP_SYN)) + +def tcp_parse(pkt): + tcp_offset = 14 + (ord(pkt[14]) & 0xf) * 4 + src, dst, seq, ack, data_offset16, flags, window = struct.unpack( + '>HHIIBBH', pkt[tcp_offset:tcp_offset+16]) + data_offset = tcp_offset + (data_offset16 >> 2) + if data_offset != 66: + print "WHOOHOO" + return src, dst, flags, pkt[data_offset:] + +def mypcap(filename): + with open(filename, 'r') as fd: + file_size = os.stat(filename).st_size + i = 0 + i = 24 + hdr_len = 8 + 4 + 4 + i_pkt = 0 + preamble = fd.read(24) + while i + hdr_len < file_size: + ts, l1, l2 = struct.unpack('QII', fd.read(hdr_len)) + if l1 < l2: + print "short recorded packet: #%s,%s: %s < %s" % (i_pkt, i, l1, l2) + pkt = fd.read(l1) + if is_tcp_data(pkt): + src, dst, tcp_payload = tcp_parse(pkt) + yield ts, src, dst, tcp_payload + i_pkt += 1 + i += hdr_len + l1 + +def get_conversations(file): + convs = {} + port_counts = {} + for ts, src, dst, data in mypcap(file): + key = (src,dst) + if key not in convs: + convs[key] = [ts] + convs[key].append(data) + ports = [src for src,dst in convs.keys()] + server_port = sorted(ports)[0] + times = {} + for src, dst in convs.keys(): + if src != server_port: + src, dst = dst, src + server = convs[(src, dst)] + client = convs[(dst, src)] + assert(client[0] < server[0]) # client sends first msg + times[client[0]] = (client, server) + return [map(lambda x: ''.join(x[1:]), times[t]) for t in sorted(times.keys())] + +def conversations_iter(packet_iter, **keyed_iters): + """ conversation_start_iter is used when we catch the SYN packet, + otherwise we use conversation_iter. the later is expected to be + used by the former. + """ + conversations = {} + keys = {(False, True): 'server', (False, False): 'client', + (True, True): 'server_start', (True, False): 'client_start'} + for src, dst, flags, payload in tcp_iter(packet_iter): + print "conversation_iter %s->%s #(%s)" % (src, dst, len(payload)) + key = (src, dst) + if key not in conversations: + iter_gen = keyed_iters[keys[(flags & TCP_SYN > 0, src < dst)]] + conversations[key] = iter_gen(*key) + conversations[key].next() + maybe_result = conversations[key].send((src, dst, payload)) + if maybe_result is not None: + yield maybe_result + +def consume_packets(packets, needed_len): + pkt = None + total_len = sum(map(len, packets)) + if total_len >= needed_len: + pkt = ''.join(packets) + if total_len > needed_len: + del packets[:-1] + packets[0] = packets[0][-(total_len - needed_len):] + pkt = pkt[:needed_len] + assert(len(packets[0]) + len(pkt) == total_len) + else: + del packets[:] + return pkt + +def ident(x, **kw): + return x + +def collect_packets(header_iter_gen): + """ Example: + >>> list(pcaputil.send_multiple(pcaputil.collect_packets(lambda src, dst: iter([(10, ident, lambda h: 20, ident)]*2))(17, 42), [(42, 17, ' '*30)])) + [(' ', ' ')] + """ + def collector(src, dst): + packets = [] + history = [] + header_iter = header_iter_gen(src, dst) + hdr_size, hdr_ctor, size_from_hdr, pkt_ctor = header_iter.next() + msg = None + header = pkt = None + searched_size = hdr_size + while True: + src, dst, payload = yield (src, dst, msg) + packets.append(payload) + history.append(payload) + print "collect_packets: searching for %s (%s)" % (searched_size, sum(map(len, packets))) + data = consume_packets(packets, searched_size) + while data != None: + print "collect_packets: %s" % len(data) + if header is None: + header = hdr_ctor(data) + print "collect_packets: header of length %s, expect %s" % (len(data), size_from_hdr(header)) + searched_size = size_from_hdr(header) + if searched_size > 1000000: + import pdb; pdb.set_trace() + else: + print "collect_packets: packet of length %s" % len(data) + pkt = pkt_ctor(data, src=src, dst=dst, header=header) + msg = (header, pkt) + hdr_size, hdr_ctor, size_from_hdr, pkt_ctor = header_iter.next() + header = None + searched_size = hdr_size + data = consume_packets(packets, searched_size) + return collector + +def send_multiple(gen, sent_iter): + sent_iter = iter(sent_iter) + gen.next() + try: + while True: + next = sent_iter.next() + yield gen.send(next) + except StopIteration: + return + +def header_conversation_iter(packet_iter, **kw): + return conversations_iter(packet_iter, + **dict([(k, collect_packets(v)) for k,v in kw.items()])) + +def tcp_data_iter(packet_iter): + return ifilter(lambda src, dst, flags, data: + len(data) > 0, imap(tcp_parse, ifilter(is_tcp_data, imap(lambda x: ''.join(x[1]), packet_iter)))) + +def tcp_iter(packet_iter): + return ifilter(lambda (src, dst, flags, data): + flags & (TCP_SYN | TCP_FIN) or len(data) > 0, + imap(tcp_parse, ifilter(is_tcp, imap(lambda x: ''.join(x[1]), packet_iter)))) + +def packet_iter(dev): + import pcap + return pcap.pcap(dev) + + diff --git a/stracer.py b/stracer.py new file mode 100644 index 0000000..60b735f --- /dev/null +++ b/stracer.py @@ -0,0 +1,163 @@ +import subprocess +import re +from client import unspice_channels + +def parse_connect(line): + """'{sa_family=AF_INET, sin_port=htons(5926), sin_addr=inet_addr("127.0.0.1")}, 16) = 0\n'""" + m = re.match('\s+\{sa_family=(?P[A-Z_]+),\s+sin_port=htons\((?P[0-9]+)\),\s+sin_addr=inet_addr\("(?P[0-9.]+)"\)\},\s+([0-9]+)\)\s+=\s+(?P[0-9])+\n', line) + if not m: return m + return m.groupdict() + +def from_quoted(s): + return ''.join(chr(int(s[i+2:i+4], 16)) for i in + xrange(0, len(s), 4)) + +def match_message_container(line, rexp): + m = re.match(rexp, line) + if not m: return m + d = m.groupdict() + d['msg'] = from_quoted(d['msg']) + d['size'] = int(d['size']) + if 'ret' in d: + d['ret'] = int(d['ret']) + assert(len(d['msg']) == d['ret']) + d['type'] = 'sendto' + return d + +def parse_sendto(line): + """"\x14\x00\x00\x00\x16\x00\x01\x03\x06\x83\x59\x4c\x00\x00\x00\x00\x00\x00\x00\x00", 20, 0, {sa_family=AF_NETLINK, pid=0, groups=00000000}, 12) = 20""" + if 'unfinished' in line: + return match_message_container(line, '\s+"(?P[x0-9a-f\\\\]+)"\.?\.?\.?,\s+(?P[0-9]+),\s+([0-9]+),\s+.+,\s+([0-9]+) \n') + return match_message_container(line, '\s+"(?P[x0-9a-f\\\\]+)"\.?\.?\.?,\s+(?P[0-9]+),\s+([0-9]+),\s+.+,\s+([0-9]+)\)\s+=\s+(?P[0-9]+)\n') + +def parse_recvfrom(line): + """"\x00\x00\x00\x00\xd2\x01\x00\x00\x1d\x01\x00\x00\x80\x02\x00\x00\x30\x02\x00\x00\x00\xba\x02\x00\x00\xb6\x01\x00\x00", 29, 0, NULL, NULL) = 29""" + return match_message_container(line, '\s+"(?P[x0-9a-f\\\\]+)"\.?\.?\.?,\s+(?P[0-9]+),\s+([0-9]+),.+,.+\)\s+=\s+(?P[0-9]+)\n') + +def parse_recvmsg(line): + print "TODO - parse_recvmsg" + return + +def parse_sendmsg(line): + print "TODO - parse_sendmsg" + return + +def parse_start(line): + # really lame parsing, won't deal with nested first arguments - but + # works for all the system calls we are interested in. Actually just + # adding nesting level to each char would solve this. + m = re.search('\[pid\s+([0-9]+)\]\s+([a-zA-Z0-9]+)\(([^,]+),', line) + return m + +class StracerNetwork(object): + + _unhandled_syscalls = set(['socket', 'bind', 'getsockname', 'setsockopt', + # later + 'getpeername', 'getsockopt', 'shutdown']) + def __init__(self, cmd, on_message = lambda x: None): + self.max_bytes_per_message = max_bytes = 2**20 + self._p = subprocess.Popen( + ('strace -qvxx -e trace=network -s %s -Cf %s' % (max_bytes, cmd)).split(), + stderr=subprocess.PIPE) + self.connections = [] + self.sockets = {} + self._message_count = 0 + self._lines = [] + self.on_message = on_message + + def total_length(self, n): + if n not in self.sockets: + return 0 + return sum([len(d['msg']) for m, d in self.sockets[n] if d is not None and 'msg' in d]) + + def terse(self, count=10): + while sum([len(v) for v in self.sockets.values()]) < count: + print sum(map(len, self.sockets.items())) + self.handle_line() + sorted_keys = list(sorted(self.sockets.keys())) + return sorted_keys, [self.terse_key(k) for k in sorted_keys] + + def terse_key(self, key): + return [(dict(recvfrom='R',sendto='S').get(tag), d['msg']) + for ((tag, line), d) in self.sockets[key] if d is not None] + + + def conversations(self, count=10): + keys, terses = self.terse(count=count) + return keys, [[''.join([msg for t,msg in terse if t==k]) for k in 'SR'] + for terse in terses] + + def spiced(self, count=10): + keys, collecteds = self.conversations(count) + return unspice_channels(keys, collecteds) + + def lines(self): + for line in self._p.stderr: + print line if len(line) < 100 else line[:100] + '...' + self._lines.append(line) + yield line + + def wait_for_connect(self, host): + while not self.host_connected(host): + self.handle_line() + + def host_connected(self, host): + return any([d['host'] == host for d in self.connections]) + + def add_to_socket(self, sock, tag, line, parser): + if sock not in self.sockets: + print "did I miss a connect?? at %s" % len(self._lines) + self.add_socket(sock) + parsed = parser(line) + self.sockets[sock].append(((tag, line), parsed)) + self._message_count += 1 + if parsed: + self.on_message(parsed['msg']) + + def add_socket(self, sock, d={}): + d['sock'] = sock + if 'host' not in d: + d['host'] = 'unknown' + self.connections.append(d) + self.sockets[sock] = [] + + def handle_line(self): + line = self.lines().next() + m = parse_start(line) + if not m: return + line = line[m.end():] + pid, syscall, firstarg = m.groups() + if 'connect' in syscall: + d = parse_connect(line) + if not d: return + self.add_socket(int(firstarg), d) + elif 'sendto' in syscall: + self.add_to_socket(int(firstarg), 'sendto', line, parse_sendto) + elif 'recvmsg' in syscall: + self.add_to_socket(int(firstarg), 'recvmsg', line, parse_recvmsg) + elif 'recvfrom' in syscall: + self.add_to_socket(int(firstarg), 'recvfrom', line, parse_recvfrom) + elif 'sendmsg' in syscall: + self.add_to_socket(int(firstarg), 'sendmsg', line, parse_sendmsg) + elif syscall in self._unhandled_syscalls: + pass + else: + print "unhandled %s" % syscall + import pdb; pdb.set_trace() + + def wait_for(self, num): + ret = [] + start_count = self._message_count + while True: + self.handle_line() + if self._message_count - start_count >= num: + return + +def trace_spicec(host, port): + s = StracerNetwork('/store/upstream/bin/spicec -h %s -p %s' % (host, port)) + sock = s.wait_for_connect(host) + return s.filter(sock, 10) + +if __name__ == '__main__': + trace_spicec('127.0.0.1', 5926) + diff --git a/util.py b/util.py new file mode 100644 index 0000000..2170eaa --- /dev/null +++ b/util.py @@ -0,0 +1,3 @@ +def reverse_dict(d): + return dict([(v,k) for k,v in d.items()]) + -- cgit v1.2.3