diff options
author | Alon Levy <alevy@redhat.com> | 2010-08-16 13:31:28 +0300 |
---|---|---|
committer | Alon Levy <alevy@redhat.com> | 2010-08-16 13:31:28 +0300 |
commit | 08951c2d430aa1216d186f7b52fb9b2298350291 (patch) | |
tree | c2ea618c5ddaf6702f08ed33ef04f27fe70c985d |
initial
-rw-r--r-- | client.py | 442 | ||||
-rw-r--r-- | client_proto.py | 320 | ||||
-rw-r--r-- | client_proto_tests.py | 25 | ||||
-rw-r--r-- | compress.py | 36 | ||||
-rwxr-xr-x | dumpspice.py | 22 | ||||
-rw-r--r-- | getem.py | 98 | ||||
-rw-r--r-- | main.py | 14 | ||||
-rw-r--r-- | pcapspice.py | 230 | ||||
-rw-r--r-- | pcaputil.py | 171 | ||||
-rw-r--r-- | stracer.py | 163 | ||||
-rw-r--r-- | util.py | 3 |
11 files changed, 1524 insertions, 0 deletions
diff --git a/client.py b/client.py new file mode 100644 index 0000000..646c2f9 --- /dev/null +++ b/client.py @@ -0,0 +1,442 @@ +import struct +import socket + +import client_proto + +# TODO - parse spice/protocol.h directly (to keep in sync automatically) +SPICE_MAGIC = "REDQ" +SPICE_VERSION_MAJOR = 2 +SPICE_VERSION_MINOR = 0 + +################################################################################ +# spice-protocol/spice/enums.h + +(SPICE_CHANNEL_MAIN, +SPICE_CHANNEL_DISPLAY, +SPICE_CHANNEL_INPUTS, +SPICE_CHANNEL_CURSOR, +SPICE_CHANNEL_PLAYBACK, +SPICE_CHANNEL_RECORD, +SPICE_CHANNEL_TUNNEL, +SPICE_END_CHANNEL) = xrange(1,1+8) + +################################################################################ + +# Encryption & Ticketing Parameters +SPICE_MAX_PASSWORD_LENGTH=60 +SPICE_TICKET_KEY_PAIR_LENGTH=1024 +SPICE_TICKET_PUBKEY_BYTES=(SPICE_TICKET_KEY_PAIR_LENGTH / 8 + 34) + +def unpack_list(structs, s): + """ struct has a bug - + sizeof('IBIII') == 20 + sizeof('IIIIB') == 16 + """ + ret = [] + start = 0 + for the_struct in structs: + ret.extend(list(the_struct.unpack(s[start:start + s.size]))) + t += s.size + return ret + +def group(format): + return reduce(lambda cs, c: cs[:-1]+[cs[-1]+c] if len(cs) > 0 and c == cs[-1][-1] else cs+[c], format, []) + +class Elements(object): + pass + +ENDIANESS = '<' # small endian + +class StructList(object): + + def __init__(self, formats): + self._s = map(struct.Struct, (ENDIANESS+f for f in formats)) + self.size = sum([s.size for s in self._s]) + + def pack(self, *args): + i_s, i_e = 0, 0 + r = [] + for s in self._s: + i_e += len(s.format) - (1 if s.format[0] in '<>' else 0) + r.append(s.pack(*args[i_s:i_e])) + i_s = i_e + return ''.join(r) + + def unpack(self, st): + r = [] + i_s, i_e = 0, 0 + for s in self._s: + i_e += s.size + r.append(list(s.unpack(st[i_s:i_e]))) + i_s = i_e + return sum(r, []) + +class StructMeta(type): + def __new__(meta, classname, bases, classDict): + fields = classDict['fields'] + is_complex = classDict['_is_complex'] = callable(fields[-1][0]) + if is_complex: + classDict['complex_field'] = complex_field = fields[-1] + fields = fields[:-1] + assert(not any(map(callable, fields))) + classDict['_s'] = StructList(group(''.join(t for t,n in fields))) + classDict['_names'] = [n for t,n in fields] + classDict['size'] = classDict['_s'].size + classDict['field_elements'] = [len(t) for t, n in fields] + return type.__new__(meta, classname, bases, classDict) + +def indice_pairs(sizes): + s = 0 + for size in sizes: + yield s, s+size + s += size + +def cut(elements, sizes): + for s, e in indice_pairs(sizes): + if e - s == 1: + yield elements[s] + else: + yield elements[s:e] + +class Struct(object): + @classmethod + def parse(cls, s): + base = list(cut(cls._s.unpack(s[:cls._s.size]), cls.field_elements)) + if cls._is_complex: + import pdb; pdb.set_trace() + return base + return base + + @classmethod + def make(cls, **kw): + args = [] + args = [kw[n] for n in cls._names] + assert(len(args) == len(cls._names) == len(kw)) + return cls._s.pack(*args) + + def __init__(self, *args, **kw): + self.e = Elements() + if (len(kw) == 0 and len(args) == 1) or (len(kw) == 1 and kw.has_key('s')): + s = args[0] if len(args) == 1 else kw['s'] + self.elements = elements = self.parse(s) + else: + self.elements = elements = [kw[n] for n in self._names] + self.e.__dict__.update(dict(zip(self._names, elements))) + + def tostr(self): + return self.make(**self.e.__dict__) + +uint16 = 'H' +uint32 = 'I' +uint64 = 'Q' +uint8 = 'B' + +uint32_arr = lambda s: ENDIANESS + uint32*s.e.size + +class SpiceDataHeader(Struct): + __metaclass__ = StructMeta + fields = [(uint64, 'serial'), (uint16, 'type'), (uint32, 'size'), + (uint32, 'sub_list')] + + def __str__(self): + return 'Spice #%s, T%s, sub %s, size %s' % (self.e.serial, + self.e.type, self.e.sub_list, self.e.size) + __repr__ = __str__ + +class SpiceSubMessage(Struct): + __metaclass__ = StructMeta + fields = [(uint16, 'type'), (uint32, 'size')] + +class SpiceSubMessageList(Struct): + __metaclass__ = StructMeta + fields = [(uint16, 'size'), (uint32_arr, 'sub_messages')] + +data_header_size = SpiceDataHeader.size + +class SpiceLinkHeader(Struct): + __metaclass__ = StructMeta + fields = [(uint32, 'magic'), (uint32, 'major_version'), + (uint32, 'minor_version'), (uint32, 'size')] + + def __str__(self): + return 'SpiceLink magic=%s, (%s,%s), size %s' % ( + struct.pack('<I', self.e.magic), + self.e.major_version, self.e.minor_version, self.e.size) + __repr__ = __str__ + + @classmethod + def verify(cls, *args, **kw): + inst = cls(*args, **kw) + assert(inst.e.magic == SPICE_MAGIC) + return inst + +link_header_size = SpiceLinkHeader.size + +class SpiceLinkMess(Struct): + __metaclass__ = StructMeta + fields = [(uint32, 'connection_id'), (uint8, 'channel_type'), (uint8, 'channel_id'), + (uint32, 'num_common_caps'), (uint32, 'num_channel_caps'), (uint32, 'caps_offset')] + + def __str__(self): + return 'SpiceLinkMess conn id %s, ch type %s, ch id %s caps (%s,%s,%s)' % ( + self.e.connection_id, self.e.channel_type, self.e.channel_id, + self.e.num_common_caps, self.e.num_channel_caps, + self.e.caps_offset) + __repr__ = __str__ + +link_mess_size = SpiceLinkMess.size + +class SpiceLinkReply(Struct): + __metaclass__ = StructMeta + fields = [(uint32, 'error'), + (uint8*SPICE_TICKET_PUBKEY_BYTES, 'pub_key'), + (uint32, 'num_common_caps'), + (uint32, 'num_channel_caps'), + (uint32, 'caps_offset')] + + def __str__(self): + return '<SpiceLinkReply error=%s, common=%s, channel=%s, pub=%s:%s>' %( + self.e.error, self.e.num_common_caps, self.e.num_channel_caps, + self.e.pub_key[:10], len(self.e.pub_key)) + __repr__ = __str__ + + +def parse_link_msg(s): + link_header = SpiceLinkHeader(s) + assert(len(s) == SpiceLinkHeader.size + link_header.size) + link_msg = SpiceLinkMess(s[SpiceLinkHeader.size:]) + return link_header, link_msg + +def make_data_msg(serial, type, payload, sub_list): + data_header = SpiceDataHeader(serial=serial, + type=type, size=len(payload), sub_list=sub_list) + +def connect(host, port): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect((host, port)) + return s + +def list_to_str(l, type=uint32): + if len(l) == 0: + return '' + return struct.pack(ENDIANESS+len(l)*type, l) + +class Channel(object): + def __init__(self, s, connection_id, channel_type, channel_id, + common_caps, channel_caps): + self._s = s + self._channel_type = channel_type + self._channel_id = channel_id + self._connection_id = connection_id + self._common_caps = common_caps + self._channel_caps = channel_caps + self._proto = client_proto.channels[channel_type] + self.send_link_message() + + def send_link_message(self): + link_message = SpiceLinkMess( + channel_type=self._channel_type, + channel_id=self._channel_id, + caps_offset=SpiceLinkMess.size, + num_channel_caps=len(self._channel_caps), + num_common_caps=len(self._common_caps), + connection_id=self._connection_id) + link_header = SpiceLinkHeader(magic=SPICE_MAGIC, + major_version=SPICE_VERSION_MAJOR, + minor_version=SPICE_VERSION_MINOR, size=link_message.size) + self._s.send( + link_header.tostr() + + link_message.tostr() + + list_to_str(self._common_caps) + + list_to_str(self._channel_caps)) + self.parse_one(SpiceLinkHeader, self._on_link_reply) + + def _on_link_reply(self, header, data): + link_reply = SpiceLinkReply(data) + self._link_reply = link_reply + self._link_reply_header = header + + def _on_data(self, header, data): + self._last_data = data + self._last_data_header = header + print "_on_data: %s" % len(data) + + def parse_data(self): + self.parse_one(SpiceDataHeader, self._on_data) + + def parse_one(self, ctor, on_message): + header = ctor(self._s.recv(ctor.size)) + on_message(header, self._s.recv(header.e.size)) + +class MainChannel(Channel): + def send_attach_channels(self): + import pdb; pdb.set_trace() + self._s.send(s) + +class Client(object): + def __init__(self, host, port, strace=False): + self._host = host + self._port = port + self._strace = strace + + def connect_main(self): + if self._strace: + import stracer + s = StracerNetwork( + '/store/upstream/bin/spicec -h %s -p %s' % ( + self._host, self._port) + , client=self) + self._main_sock = s.sock_iface() + else: + self._main_sock = connect(self._host, self._port) + self.main = self.connect_to_channel(self._main_sock, + connection_id=0, channel_type=SPICE_CHANNEL_MAIN, + channel_id=0) + + # for usage from ipython interactively + def connect_to_channel(self, s, connection_id, channel_type, channel_id, + common_caps=[], channel_caps=[]): + return Channel(s, + connection_id=connection_id, + channel_type=channel_type, + channel_id=channel_id, + common_caps=common_caps, + channel_caps=channel_caps) + + def run(self): + print "TODO" + + def on_message(self, msg): + import pdb; pdb.set_trace() + +SPICE_MAGIC = struct.unpack('<I', 'REDQ')[0] + +################################################################################ + +def find_first_data(s): + rest = [] + found = False + if (len(s) > 182+SpiceDataHeader.size) and ( + SpiceDataHeader(s[182:]).e.serial == 1): + print "found at 182 (good guess!)" + found = True + first_data = 182 + else: + #if len(s) > 100000: + # import pdb; pdb.set_trace() + b = buffer(s) + candidates = [(i, h) for i,h in + ((i, SpiceDataHeader(b[i:])) for i in + xrange(0,min(500,len(s))-SpiceDataHeader.size,2)) if h.e.serial == 1] + if len(candidates) == 1: + found = True + first_data = candidates[0][0] + print "found at %s" % first_data + if found: + rest = s[first_data:] + else: + print "not found" + return rest + +def unspice_channels(keys, collecteds): + return dict([(ch, (key, conv)) for key, (ch, conv) in zip(keys, + parse_conversations(collecteds))]) + +def parse_conversations(conversations): + channels = [] + for client_stream, server_stream in conversations: + print "parsing (client %s, server %s)" % ( + len(client_stream), len(server_stream)) + channel_type, client_messages = unspice_client(client_stream) + channel_conv = (client_messages, + unspice_server(channel_type, server_stream)) + channels.append((channel_type, channel_conv)) + return channels + +def parse_link_start(s, msg_ctor): + if len(s) < SpiceLinkHeader.size: + return 0, [] + link_header = SpiceLinkHeader(s) + assert(link_header.e.magic == SPICE_MAGIC) + msg = msg_ctor(s[ + link_header_size:link_header_size+link_header.e.size]) + return link_header_size+link_header.e.size, [link_header, msg] + +def parse_link_server_start(s): + return parse_link_start(s, SpiceLinkReply) + +def parse_link_client_start(s): + return parse_link_start(s, SpiceLinkMess) + +def unspice_server(channel_type, s): + if len(s) < SpiceLinkHeader.size: + return [] + link_header = SpiceLinkHeader(s) + if link_header.e.magic != SPICE_MAGIC: + print "ch %s: bad server magic, trying to look for SpiceDataHeader (%s)" % ( + channel_type, len(s)) + link_reply = None + rest = find_first_data(s) + else: + link_reply = SpiceLinkReply(s[ + link_header_size:link_header_size+link_header.e.size]) + rest = s[link_header_size+link_header.e.size:] + ret = [link_header, link_reply] + return ret + unspice_rest(channel_type, False, rest) + +def unspice_client(s): + if len(s) == 0: + return -1, [] + # First message starts with a link header, then link message + link_header = SpiceLinkHeader(s) + if link_header.e.magic != SPICE_MAGIC: + print "bad client magic, trying to look for SpiceDataHeader (%s)" % len(s) + link_message = None + channel_type = -1 # XXX - can guess based on messages + rest = find_first_data(s) + else: + link_message = SpiceLinkMess(s[ + link_header_size:link_header_size+link_header.e.size]) + channel_type = link_message.e.channel_type + rest = s[link_header_size+link_header.e.size:] + ret = [link_header, link_message] + return channel_type, ret + unspice_rest(channel_type, True, rest) + +def unspice_rest(channel_type, is_client, s): + datas = [] + if len(s) < 8: + return datas + i_s = 0 + if struct.unpack(ENDIANESS + uint64, s[:8])[0] > 1000: + if is_client: + print "client hack - ignoring 128 bytes" + i_s = 128 + else: + print "server hack - ignoring 4 bytes" + i_s = 4 + while len(s) > i_s: + header = SpiceDataHeader(s[i_s:]) + data_start = i_s + SpiceDataHeader.size + data_end = data_start + header.e.size + if data_end > len(s): + break + datas.append((header, client_proto.parse( + channel_type=channel_type, + is_client=is_client, + header=header, + data=s[data_start:data_end]))) + i_s += SpiceDataHeader.size + header.e.size + return datas + +def parse_ctors(s, ctors): + ret = [] + i = 0 + for c in ctors: + c_size = c.size + if len(s) <= i: + break + ret.append(c(s[i:])) + i += c_size + return ret + + diff --git a/client_proto.py b/client_proto.py new file mode 100644 index 0000000..c29c182 --- /dev/null +++ b/client_proto.py @@ -0,0 +1,320 @@ +""" +Decode spice protocol using spice demarshaller. +""" + +ANNOTATE = False + +import struct +import sys +if not 'proto' in locals(): + sys.path.append('../spice/python_modules') + import spice_parser + import ptypes + proto = spice_parser.parse('../spice/spice.proto') + reloads = 1 +else: + reloads += 1 + + +channels = {} +for channel in proto.channels: + channels[channel.value] = channel + channel.client = dict(zip([x.value for x in + channel.channel_type.client_messages], + channel.channel_type.client_messages)) + channel.server = dict(zip([x.value for x in + channel.channel_type.server_messages], + channel.channel_type.server_messages)) + # todo parsing of messages / building of messages + +def mapdict(f, d): + return dict((k, f(v)) for k,v in d.items()) + +primitives=mapdict(struct.Struct, dict( + uint64='<Q', + int64='<q', + uint32='<I', + int32='<i', + uint16='<H', + int16='<h', + uint8='<B', + int8='<b', + )) + +primitive_arrays=dict([(k, lambda n, e=e: struct.Struct('<'+e*n)) + for k, e in + [('uint64','Q'), + ('int64','q'), + ('uint32','I'), + ('int32','i'), + ('uint16','H'), + ('int16','h'), + ]]) + +class NullUnpacker(object): + def __init__(self, n): + self.size = n + def unpack(self, s): + return s + unpack_from=unpack + +class Flag(object): + def __init__(self, names, value): + self.name = None + for k, v in names.items(): + if value == 1<<k: + self.name = v + break + if self.name is None: + self.name = set(v for k,v in names.items() if (1<<k)&value) + self.value = value + def __str__(self): + return 'F(%s,%s)' % (self.name, self.value) + __repr__ = __str__ + +class Enum(object): + def __init__(self, the_type, name, value): + self.name = name + self.value = value + #self.the_type = the_type + # TODO - cache this + #self.name_to_value = dict(zip(self.the_type.values(), self.the_type.keys())) + def __str__(self): + return 'E(%s,%s)' % (self.name, self.value) + __repr__ = __str__ + + +primitive_arrays['int8'] = primitive_arrays['uint8'] = lambda n: NullUnpacker(n) + +def ensure_dict(maybe_d): + if isinstance(maybe_d, dict): + return maybe_d + return dict(maybe_d) + +def get_sub_member(member_d, var_name): + if var_name in member_d: + return member_d[var_name] + ret = member_d + for part in var_name.split('.'): + ret = ensure_dict(ret)[part] + return ret + +def annotate_data(f): + def wrapped(member, data, i_s, parsed): + i = len(data.an) + i_s_start = i_s + data.an.append((i_s, member)) # annotation (before value is known) + i_s, result_name, result_value = f(member, data, i_s, parsed) + data.an[i] = (((i_s_start, i_s), member, result_value)) + return i_s, result_name, result_value + wrapped.func_name = '%s@annotate_data' % f.func_name + return wrapped + +def simple_annotate_data(f): + def wrapped(member, the_type, data, i_s): + data.an.append((i_s, member)) + return f(member, the_type, data, i_s) + wrapped.func_name = '%s@simple_annotate_data' % f.func_name + return wrapped + +def member_size(m): + pass + +def makelen(n, c, s): + """pad s with c up to length n""" + return s + c * (n - len(s)) + +class AnnotatedString(str): + def __init__(self, x): + super(AnnotatedString, self).__init__(x) + self.an = [] + self.max_pointer_i_s = 0 + def san(self): + def nicify(x): + if len(x) == 2 and not isinstance(x[0], tuple): + start = '='*8 if x[0] == 0 else ' '+'='*8 + return makelen(80, '=', start + str(x)) + return str(x) + print '\n'.join(map(nicify, self.an)) + #print '\n'.join('%s-%s,%s,%s,%s' % (i_s_start, i_s, member.member_type. for (i_s_start, i_s),member,val in self.an])) + +print_annotation = lambda data: data.san() + +if not ANNOTATE: + #AnnotatedString = lambda x: x + print_annotation = lambda x: None + #annotate_data = lambda x: x + #simple_annotate_data = lambda x: x + +@annotate_data +def parse_member(member, data, i_s, parsed): + result_name = member.name + result_value = None + if len(set(member.attributes.keys()) - set( + ['end', 'ctype', 'as_ptr', 'nomarshal', 'anon', 'chunk'])) > 0: + print "has attributes %s" % member.attributes + #import pdb; pdb.set_trace() + if hasattr(member, 'is_switch') and member.is_switch(): + var_name = member.variable + parsed_d = dict(parsed) + var = get_sub_member(parsed_d, var_name) + for case in member.cases: + for value in case.values: + if case.values[0] is None: + # default + return parse_member(case.member, data, i_s, parsed) + cond_wrap = (lambda x: not x) if value[0] == '!' else (lambda x: x) + if cond_wrap(var.name == value[1]): + return parse_member(case.member, data, i_s, parsed) + return i_s, result_name, 'empty switch' + member_type = member.member_type if hasattr(member, 'member_type') else member + member_type_name = member_type.name + results = None + if member_type_name is not None: + if member_type_name in primitives: + primitive = primitives[member_type_name] + result_value = primitive.unpack_from(data[i_s:i_s+primitive.size])[0] + i_s_out = i_s + primitive.size + elif member_type.is_struct(): + i_s_out, result_name, result_value = parse_complex_member(member, member_type, data, i_s) + elif member_type.is_enum(): + primitive = primitives[member_type.primitive_type()] + value = primitive.unpack(data[i_s:i_s+primitive.size])[0] + result_value = Enum(member_type.names, member_type.names.get(value, value), value) + # TODO - use enum for nice display + i_s_out = i_s + primitive.size + elif member_type.is_primitive() and hasattr(member_type, 'names'): + # assume flag + if not isinstance(member_type, ptypes.FlagsType): + print "not really a flag.." + import pdb; pdb.set_trace() + primitive = primitives[member_type.primitive_type()] + value = primitive.unpack(data[i_s:i_s+primitive.size])[0] + result_value = Flag(member_type.names, value) + i_s_out = i_s + primitive.size + else: + import pdb; pdb.set_trace() + elif member_type.is_array(): + if member_type.is_remaining_length(): + data_len = len(data) - i_s + elif member_type.is_identifier_length(): + num_elements = dict(parsed)[member_type.size] + data_len = None + elif member_type.is_image_size_length(): + bpp = member_type.size[1] + width_name = member_type.size[2] + rows_name = member_type.size[3] + assert(isinstance(width_name, str)) + assert(isinstance(rows_name, str)) + parsed_d = dict(parsed) + width = get_sub_member(parsed_d, width_name) + rows = get_sub_member(parsed, rows_name) + if bpp == 8: + data_len = rows * width + elif bpp == 1: + data_len = ((width + 7) // 8 ) * rows + else: + data_len =((bpp * width + 7) // 8 ) * rows + else: + print "unhandled array length type" + import pdb; pdb.set_trace() + element_type = member_type.element_type + if element_type.name in primitive_arrays: + element_size = primitives[element_type.name].size + if data_len is None: + data_len = num_elements * element_size + primitive = primitive_arrays[element_type.name] + contents = primitive(data_len).unpack_from(data[i_s:i_s+data_len]) + i_s_out = i_s + data_len + else: + contents = [] + i_s_out = i_s + for _ in xrange(num_elements): + i_s_out, sub_name, sub_value = parse_member(element_type, data, i_s_out, []) + if sub_value is None: + print "list parsing failed" + contents = [] + break + contents.append((sub_name, sub_value)) + result_value = contents + elif member_type.is_pointer(): + pointer_primitive_name = member_type.primitive_type() + if pointer_primitive_name in primitives: + primitive = primitives[pointer_primitive_name] + offset = primitive.unpack_from(data[i_s:i_s+primitive.size])[0] + if offset == 0: + result_value = [] + elif offset > len(data): + print "ooops.. bad offset %s > %s" % (offset, len(data)) + import pdb; pdb.set_trace() + result_value = None + else: + discard_i_s, discard_result_name, result_value = parse_member( + member_type.target_type, data, offset, []) + data.max_pointer_i_s = max(data.max_pointer_i_s, discard_i_s) + i_s_out = i_s + primitive.size + else: + print "pointer with unknown size??" + import pdb; pdb.set_trace() + if result_name in ['data', 'ents', 'glyphs', 'String', 'str']: + result_value = NoPrint(name='', s=result_value) + elif len(str(result_value)) > 1000: + print "noprint?" + import pdb; pdb.set_trace() + return i_s_out, result_name, result_value + +@simple_annotate_data +def parse_complex_member(member, the_type, data, i_s): + parsed = [] + i_s_out = i_s + for sub_member in the_type.members: + i_s_out, result_name, result_value = parse_member(sub_member, data, i_s_out, parsed) + if result_value is not None: + parsed.append((result_name, result_value)) + else: + print "don't know how to parse %s (%s) in %s" % ( + sub_member.name, the_type, member) + break + return i_s_out, member.name, parsed + +def parse(channel_type, is_client, header, data): + if channel_type not in channels: + return NoPrint(name='unknown channel (%s %s)' % ( + is_client, header), s=data) + # annotation support + data = AnnotatedString(data) + channel = channels[channel_type] + collection = channel.client if is_client else channel.server + if header.e.type not in collection: + print "bad data - no such message in protocol" + i_s, result_name, result_value = 0, None, None + else: + msg_proto = collection[header.e.type] + i_s, result_name, result_value = parse_complex_member( + msg_proto, msg_proto.message_type, data, 0) + i_s = max(data.max_pointer_i_s, i_s) + left_over = NoPrint(name='%s:%s' % (channel_type, i_s), s=data[i_s:]) if i_s < len(data) else '' + print_annotation(data) + if len(left_over) > 0: + print "WARNING: in message %s %s out %s unaccounted for (%s%%)" % ( + msg_proto.name, len(left_over), len(data), float(len(left_over))/len(data)) + #import pdb; pdb.set_trace() + return (msg_proto, (result_name, result_value)) + +class NoPrint(object): + objs = {} + def __init__(self, name, s): + self.s = s + self.id = id(self.s) + self.objs[self.id] = self.s + self.name = name + def orig(self): + return self.s + def __str__(self): + return 'NP' + (str(self.s) if len(self.s) < 20 else '%r...[%s,%s] #%s' % ( + self.s[:20], len(self.s), self.id, + self.name)) + __repr__ = __str__ + def __len__(self): + return len(self.s) + diff --git a/client_proto_tests.py b/client_proto_tests.py new file mode 100644 index 0000000..7bab3f0 --- /dev/null +++ b/client_proto_tests.py @@ -0,0 +1,25 @@ +import client_proto as cp + +def test_complete_parse(data, the_type): + i_s, name, result = cp.parse_member(the_type, cp.AnnotatedString(data), 0, []) + assert(i_s == len(data)) + +def test_parse_complex_member(data, member, the_type): + i_s, name, result = cp.parse_complex_member(member, the_type, cp.AnnotatedString(data), 0) + assert(i_s == len(data)) + + +def tests(): + six_rects = '\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00U\x01\x00\x00\x80\x02\x00\x00U\x01\x00\x00\x00\x00\x00\x00\xd1\x01\x00\x00R\x01\x00\x00U\x01\x00\x00K\x02\x00\x00\xd1\x01\x00\x00\x80\x02\x00\x00\xd1\x01\x00\x00\x00\x00\x00\x00\xd6\x01\x00\x00R\x01\x00\x00\xd1\x01\x00\x00K\x02\x00\x00\xd6\x01\x00\x00\x80\x02\x00\x00\xd6\x01\x00\x00\x00\x00\x00\x00\xe0\x01\x00\x00\x80\x02\x00\x00' + draw_text = '\x00\x00\x00\x00\xc7\x01\x00\x00#\x00\x00\x00\xdd\x01\x00\x00I\x00\x00\x00\x01\x01\x00\x00\x00\xc7\x01\x00\x00#\x00\x00\x00\xdd\x01\x00\x00I\x00\x00\x00G\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\xff\xff\x00\x00\x08\x00\x00\x00\x05\x00\x02#\x00\x00\x00\xd7\x01\x00\x00\xff\xff\xff\xff\xf6\xff\xff\xff\t\x00\n\x00\x03\xaf\xfdp\x00?\xff\xff\xf7\x00\xcf\x90\x03\xfe\x00\x15\x00\x06\xff\x00\x009\xef\xf8\x00\x07\xff\xe90\x00\x0e\xf6\x00\x06\x10\x0e\xf3\x01\xcf\xb0\x07\xff\xff\xfe \x00m\xff\xa2\x00,\x00\x00\x00\xd7\x01\x00\x00\x00\x00\x00\x00\xf4\xff\xff\xff\x05\x00\x0c\x00-\xfb\x00\xaf\xff\x00\xbf`\x00\xaf`\x00\x7f\x80\x00?\xb0\x00\x1f\xe0\x00\x0e\xf1\x00\xff\xff\xb0\xbf\xff\xf0\x05\xfa\x00\x02\xbc\x002\x00\x00\x00\xd7\x01\x00\x00\x00\x00\x00\x00\xf6\xff\xff\xff\x08\x00\n\x00-\xfb/\xf0\xbf\xff\xef\xf0\xff \x8f\xf3\xdfr\t\xf6_\xff\xff\xf9\x03\xae\xff\xfb\x00\x00\x00\xff\x0c\xf5\x05\xff\x03\xff\xff\xfb\x00L\xff\xa2<\x00\x00\x00\xd7\x01\x00\x00\xff\xff\xff\xff\xf6\xff\xff\xff\t\x00\n\x00/\xc0\x00\x00\x00\x0f\xf0\x00\x00\x00\x0b\xf3\x00\x00\x00\t\xf5\x00\x00\x00\x06\xf8\x00\x00\x00\x03\xfd\x00\x00\x00\x00\xff`\x00\x00\x00\xee\xf7\x10\x00\x00\xbf_\xfa\x00\x00\x7f#\xaf\x00C\x00\x00\x00\xd7\x01\x00\x00\x00\x00\x00\x00\xf4\xff\xff\xff\x05\x00\x0c\x00-\xfb\x00\xaf\xff\x00\xbf`\x00\xaf`\x00\x7f\x80\x00?\xb0\x00\x1f\xe0\x00\x0e\xf1\x00\xff\xff\xb0\xbf\xff\xf0\x05\xfa\x00\x02\xbc\x00' + clip_rects = cp.proto.channels[1].server[304].message_type.members[0].member_type.members[2].member_type.members[1].cases[0].member + assert(clip_rects.name == 'rects') + test_complete_parse(six_rects, clip_rects) + draw_text_message = cp.proto.channels[1].server[311] + assert(draw_text_message.name == 'draw_text') + test_parse_complex_member(draw_text, draw_text_message, draw_text_message.message_type) + print '\n'.join(map(str,result[1][1])) + +if __name__ == '__main__': + tests() + diff --git a/compress.py b/compress.py new file mode 100644 index 0000000..736d967 --- /dev/null +++ b/compress.py @@ -0,0 +1,36 @@ +""" +SPICE compression and decompression implementation. + +we have our own lz implementation, see spice/common/lz_common.h + +""" +import struct +from util import reverse_dict + +LZ_MAGIC = struct.unpack('<I', "LZ ")[0] +LZ_MAJOR = 1 +LZ_MINOR = 1 + +LzImageType = (LZ_IMAGE_TYPE_INVALID, + LZ_IMAGE_TYPE_PLT1_LE, + LZ_IMAGE_TYPE_PLT1_BE, # PLT stands for palette + LZ_IMAGE_TYPE_PLT4_LE, + LZ_IMAGE_TYPE_PLT4_BE, + LZ_IMAGE_TYPE_PLT8, + LZ_IMAGE_TYPE_RGB16, + LZ_IMAGE_TYPE_RGB24, + LZ_IMAGE_TYPE_RGB32, + LZ_IMAGE_TYPE_RGBA, + LZ_IMAGE_TYPE_XXXA) = xrange(11) +LzImageType = dict([(k, locals()[k]) for k in locals().keys() if k.startswith('LZ_IMAGE_TYPE')]) +LzImageTypeRev = reverse_dict(LzImageType) + +def lz_decompress(s): + (magic, major, minor, + type, width, height, stride, out_top_down + ) = struct.unpack_from('>IHHIIIII', s) + assert(magic == LZ_MAGIC) + assert(major == LZ_MAJOR) + assert(minor == LZ_MINOR) + del s + return locals() diff --git a/dumpspice.py b/dumpspice.py new file mode 100755 index 0000000..4c69db4 --- /dev/null +++ b/dumpspice.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +import sys +import pcaputil +import pcapspice + +def main(stdscr=None): + p = pcaputil.packet_iter('lo') + spice = pcapspice.spice_iter(p) + if stdscr: + stdscr.erase() + for d in spice: + if verbose: + print d + +if __name__ == '__main__': + verbose = '-v' in sys.argv + if '-c' in sys.argv: + import curses + curses.wrapper(main) + else: + main(None) + diff --git a/getem.py b/getem.py new file mode 100644 index 0000000..b41bf94 --- /dev/null +++ b/getem.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python2 + +import pcap +import sys +import string +import time +import socket +import struct + +protocols={socket.IPPROTO_TCP:'tcp', + socket.IPPROTO_UDP:'udp', + socket.IPPROTO_ICMP:'icmp'} + +def decode_ip_packet(s): + d={} + d['version']=(ord(s[0]) & 0xf0) >> 4 + d['header_len']=ord(s[0]) & 0x0f + d['tos']=ord(s[1]) + d['total_len']=socket.ntohs(struct.unpack('H',s[2:4])[0]) + d['id']=socket.ntohs(struct.unpack('H',s[4:6])[0]) + d['flags']=(ord(s[6]) & 0xe0) >> 5 + d['fragment_offset']=socket.ntohs(struct.unpack('H',s[6:8])[0] & 0x1f) + d['ttl']=ord(s[8]) + d['protocol']=ord(s[9]) + d['checksum']=socket.ntohs(struct.unpack('H',s[10:12])[0]) + d['source_address']=pcap.ntoa(struct.unpack('i',s[12:16])[0]) + d['destination_address']=pcap.ntoa(struct.unpack('i',s[16:20])[0]) + if d['header_len']>5: + d['options']=s[20:4*(d['header_len']-5)] + else: + d['options']=None + d['data']=s[4*d['header_len']:] + return d + + +def dumphex(s): + bytes = map(lambda x: '%.2x' % x, map(ord, s)) + for i in xrange(0,len(bytes)/16): + print ' %s' % string.join(bytes[i*16:(i+1)*16],' ') + print ' %s' % string.join(bytes[(i+1)*16:],' ') + + +def print_packet(pktlen, data, timestamp): + if not data: + return + + if data[12:14]=='\x08\x00': + decoded=decode_ip_packet(data[14:]) + print '\n%s.%f %s > %s' % (time.strftime('%H:%M', + time.localtime(timestamp)), + timestamp % 60, + decoded['source_address'], + decoded['destination_address']) + for key in ['version', 'header_len', 'tos', 'total_len', 'id', + 'flags', 'fragment_offset', 'ttl']: + print ' %s: %d' % (key, decoded[key]) + print ' protocol: %s' % protocols[decoded['protocol']] + print ' header checksum: %d' % decoded['checksum'] + print ' data:' + dumphex(decoded['data']) + + +if __name__=='__main__': + + if len(sys.argv) < 3: + print 'usage: sniff.py <interface> <expr>' + sys.exit(0) + #dev = pcap.lookupdev() + dev = sys.argv[1] + #net, mask = pcap.lookupnet(dev) + # note: to_ms does nothing on linux + p = pcap.pcap(dev) + #p.open_live(dev, 1600, 0, 100) + #p.dump_open('dumpfile') + #p.setfilter(string.join(sys.argv[2:],' '), 0, 0) + + # try-except block to catch keyboard interrupt. Failure to shut + # down cleanly can result in the interface not being taken out of promisc. + # mode + #p.setnonblock(1) + try: + while 1: + p.dispatch(1, print_packet) + + # specify 'None' to dump to dumpfile, assuming you have called + # the dump_open method + # p.dispatch(0, None) + + # the loop method is another way of doing things + # p.loop(1, print_packet) + + # as is the next() method + # p.next() returns a (pktlen, data, timestamp) tuple + # apply(print_packet,p.next()) + except KeyboardInterrupt: + print '%s' % sys.exc_type + print 'shutting down' + print '%d packets received, %d packets dropped, %d packets dropped by interface' % p.stats() @@ -0,0 +1,14 @@ +import sys +import client +import optparse + +o = optparse.OptionParser() +o.conflict_handler = 'resolve' +o.add_option('-p') +o.add_option('-h') +d, rest = o.parse_args(sys.argv[1:]) +client = client(host=d.h, port=d.p) + +if __name __ == '__main__': + client.run() + diff --git a/pcapspice.py b/pcapspice.py new file mode 100644 index 0000000..b1c1e96 --- /dev/null +++ b/pcapspice.py @@ -0,0 +1,230 @@ +from itertools import chain, repeat +from pcaputil import (tcp_parse, is_tcp, is_tcp_data, is_tcp_syn, + conversations_iter, header_conversation_iter) +import client +from client import SpiceLinkHeader, SpiceLinkMess, SpiceLinkReply, SpiceDataHeader + +def is_single_packet_data(payload): + return len(payload) == SpiceDataHeader(payload).e.size + SpiceDataHeader.size + +num_spice_messages = sum(len(ch.client) + len(ch.server) for ch in + client.client_proto.channels.values()) + +valid_message_ids = set(sum([ch.client.keys() + ch.server.keys() for ch in + client.client_proto.channels.values()], [])) + +all_channels = client.client_proto.channels.keys() + +def possible_channels(server_message, header): + return set(c for c in all_channels if header.e.type in + (client.client_proto.channels[c].server.keys() if server_message + else client.client_proto.channels[c].client.keys())) + +guesses = {} + +def guess_channel_iter(): + channel = None + optional_channels = set(all_channels) + seen_headers = [] + while True: + src, dst, data_header = yield channel + key = (min(src, dst), max(src, dst)) + if guesses.has_key(key): + optional_channels = set([guesses[key]]) + seen_headers.append((src, dst, data_header)) + optional_channels = optional_channels & possible_channels( + src == min(src, dst), data_header) + print optional_channels + if len(optional_channels) == 1: + channel = list(optional_channels)[0] + guesses[key] = channel + if len(optional_channels) == 0: + import pdb; pdb.set_trace() + +spice_size_from_header = lambda header: header.e.size + +def channel_spice_start_iter(lower_port, higher_port): + server_port, client_port = lower_port, higher_port + ports = [server_port, client_port] + messages = dict([(port,None) for port in ports]) + payloads = dict([(port,[]) for port in ports]) + headers = dict([(port,None) for port in ports]) + discarded_len = {client_port:128, server_port:4} + message_ctors = {client_port:SpiceLinkMess, server_port:SpiceLinkReply} + if len(payload) > 0: + payloads[src].append(payload) + while not all(messages.values()): + src, dst, payload = yield None + if len(payload) == 0: continue + payloads[src].append(payload) + len_payload = sum(map(len,payloads[src])) + if headers[src] is None and len_payload >= SpiceLinkHeader.size: + headers[src] = SpiceLinkHeader(''.join(payloads[src])) + import pdb; pdb.set_trace() + expected_len = headers[src].e.size + SpiceLinkHeader.size + discarded_len[src] + if headers[src] is not None and expected_len <= len_payload: + if expected_len != len_payload: + print "extra bytes dumped!!! %d" % (len_payload - expected_len) + messages[src] = message_ctors[src](''.join(payloads[src])[ + SpiceLinkHeader.size:]) + channel = messages[client_port].e.channel_type + print "channel type %s created" % channel + result = None + src, dst, payload = yield None + data_iter = channel_spice_iter(src, dst, payload, channel=channel) + result = data_iter.next() + while True: + src, dst, payload = yield result + result = data_iter.send((src, dst, payload)) + +def channel_spice_iter(lower_port, higher_port, channel=None): + server_port, client_port = lower_port, higher_port + serial = {} + guesser = guess_channel_iter() + guesser.next() + + result = None + # yield spice data parsed results + while True: + src, dst, payload = yield result + data_header = SpiceDataHeader(payload) + # basic sanity check on the header + bad_header = False + if serial.has_key(src): + if data_header.e.serial != serial[src] + 1: + print "bad serial, replacing for %s->%s (%s!=%s+1)" % ( + src, dst, data_header.e.serial, serial[src]) + bad_header = True + serial[src] = data_header.e.serial + if data_header.e.type not in valid_message_ids: + print "bad message type %s in %s->%s" % (data_header.e.type, + src, dst) + if bad_header: continue + if channel is None: + channel = guesser.send((src, dst, data_header)) + if channel is not None: + # read as many messages as it takes to get + (msg_proto, (result_name, result_value), + left_over) = client.client_proto.parse( + channel_type=channel, + is_client=src == client_port, + header=data_header, + data=payload[SpiceDataHeader.size:]) + result = (data_header, msg_proto, result_name, result_value, + left_over) + +# for every tcp packet +# if it is a syn, read the link header, decide which channel we are +# if it is data +# if we know which channel we are, parse it +# else guess which channel. +# if we just figured out which channel we are, releaes all pending packets. + +def spice_iter1(packet_iter): + return conversations_iter( + packet_iter, + channel_spice_start_iter, + channel_spice_iter) + +def packet_header_iter_gen(hdr_ctor, pkt_ctor): + return repeat( + (hdr_ctor.size, + hdr_ctor, + lambda h: h.e.size, + pkt_ctor)) + +def ident(x, **kw): + return x + +def channel_spice_message_iter(src, dst, guesser): + server_port, client_port = min(src, dst), max(src, dst) + serial = {} + channel = None + result = None + # yield spice data parsed results + while True: + src, dst, (data_header, message) = yield result + # basic sanity check on the header + bad_header = False + if serial.has_key(src): + if data_header.e.serial != serial[src] + 1: + print "bad serial, replacing for %s->%s (%s!=%s+1)" % ( + src, dst, data_header.e.serial, serial[src]) + bad_header = True + serial[src] = data_header.e.serial + if data_header.e.type not in valid_message_ids: + print "bad message type %s in %s->%s" % (data_header.e.type, + src, dst) + bad_header = True + if bad_header: continue + if channel is None: + channel = guesser.send((src, dst, data_header)) + if channel is not None: + # read as many messages as it takes to get + (msg_proto, (result_name, result_value) + ) = client.client_proto.parse( + channel_type=channel, + is_client=src == client_port, + header=data_header, + data=message) + result = (msg_proto, result_name, result_value) + +class ChannelGuesser(object): + + def __init__(self): + self.iter = guess_channel_iter() + self.iter.next() + self.channel = None + + def send(self, (src, dst, payload)): + if self.channel: return self.channel + return self.iter.send((src, dst, payload)) + + def on_client_link_message(self, p, **kw): + self._client_message = SpiceLinkMess(p) + self.channel = self._client_message.e.channel_type + print "STARTED channel %s" % self.channel + + def on_server_link_message(self, p, **kw): + self._server_message = SpiceLinkReply(p) + +def spice_iter(packet_iter): + guessers = {} + def make_start_iter(src, dst): + is_server = src < dst + key = (min(src, dst), max(src, dst)) + guesser = guessers[key] = guessers.get(key, ChannelGuesser()) + print guesser, is_server, "**** make_start_iter %s -> %s" % (src, dst) + link_messages = [] + yield (SpiceLinkHeader.size, + SpiceLinkHeader.verify, + lambda h: h.e.size, + (guesser.on_server_link_message if is_server else + guesser.on_client_link_message)) + print guesser, is_server, "2" + message_iter = channel_spice_message_iter(src, dst, guesser) + message_iter.next() + # TODO: way to say "src->dst 128, dst->src 4" + yield 128 if src > dst else 4, ident, lambda h: 0, ident + for i, data in enumerate(packet_header_iter_gen( + SpiceDataHeader, + lambda pkt, src, dst, header: + message_iter.send((src, dst, (header, pkt))))): + print guesser, is_server, "2+%s" % i + yield data + def make_iter(src, dst): + key = (min(src, dst), max(src, dst)) + guesser = guessers[key] = guessers.get(key, ChannelGuesser()) + message_iter = channel_spice_message_iter(src, dst, guesser) + message_iter.next() + return packet_header_iter_gen( + SpiceDataHeader, + lambda pkt, src, dst, header: + message_iter.send((src, dst, (header, pkt)))) + return header_conversation_iter( + packet_iter, + server_start=make_start_iter, + client_start=make_start_iter, + server=make_iter, + client=make_iter) + diff --git a/pcaputil.py b/pcaputil.py new file mode 100644 index 0000000..007c6c2 --- /dev/null +++ b/pcaputil.py @@ -0,0 +1,171 @@ +import os +import struct +from itertools import imap, ifilter +#import pcap + +TCP_PROTOCOL = 6 +TCP_SYN = 2 +TCP_FIN = 1 + +def is_tcp(pkt): + return (len(pkt) > 47 + and pkt[12:14] == '\x08\x00' # Ethernet.Type == IP + and ord(pkt[23]) == TCP_PROTOCOL) + +def is_tcp_data(pkt): + return (is_tcp(pkt) + and (not ord(pkt[47]) & (TCP_SYN | TCP_FIN))) + +def is_tcp_syn(pkt): + return (is_tcp(pkt) + and (ord(pkt[47]) & TCP_SYN)) + +def tcp_parse(pkt): + tcp_offset = 14 + (ord(pkt[14]) & 0xf) * 4 + src, dst, seq, ack, data_offset16, flags, window = struct.unpack( + '>HHIIBBH', pkt[tcp_offset:tcp_offset+16]) + data_offset = tcp_offset + (data_offset16 >> 2) + if data_offset != 66: + print "WHOOHOO" + return src, dst, flags, pkt[data_offset:] + +def mypcap(filename): + with open(filename, 'r') as fd: + file_size = os.stat(filename).st_size + i = 0 + i = 24 + hdr_len = 8 + 4 + 4 + i_pkt = 0 + preamble = fd.read(24) + while i + hdr_len < file_size: + ts, l1, l2 = struct.unpack('QII', fd.read(hdr_len)) + if l1 < l2: + print "short recorded packet: #%s,%s: %s < %s" % (i_pkt, i, l1, l2) + pkt = fd.read(l1) + if is_tcp_data(pkt): + src, dst, tcp_payload = tcp_parse(pkt) + yield ts, src, dst, tcp_payload + i_pkt += 1 + i += hdr_len + l1 + +def get_conversations(file): + convs = {} + port_counts = {} + for ts, src, dst, data in mypcap(file): + key = (src,dst) + if key not in convs: + convs[key] = [ts] + convs[key].append(data) + ports = [src for src,dst in convs.keys()] + server_port = sorted(ports)[0] + times = {} + for src, dst in convs.keys(): + if src != server_port: + src, dst = dst, src + server = convs[(src, dst)] + client = convs[(dst, src)] + assert(client[0] < server[0]) # client sends first msg + times[client[0]] = (client, server) + return [map(lambda x: ''.join(x[1:]), times[t]) for t in sorted(times.keys())] + +def conversations_iter(packet_iter, **keyed_iters): + """ conversation_start_iter is used when we catch the SYN packet, + otherwise we use conversation_iter. the later is expected to be + used by the former. + """ + conversations = {} + keys = {(False, True): 'server', (False, False): 'client', + (True, True): 'server_start', (True, False): 'client_start'} + for src, dst, flags, payload in tcp_iter(packet_iter): + print "conversation_iter %s->%s #(%s)" % (src, dst, len(payload)) + key = (src, dst) + if key not in conversations: + iter_gen = keyed_iters[keys[(flags & TCP_SYN > 0, src < dst)]] + conversations[key] = iter_gen(*key) + conversations[key].next() + maybe_result = conversations[key].send((src, dst, payload)) + if maybe_result is not None: + yield maybe_result + +def consume_packets(packets, needed_len): + pkt = None + total_len = sum(map(len, packets)) + if total_len >= needed_len: + pkt = ''.join(packets) + if total_len > needed_len: + del packets[:-1] + packets[0] = packets[0][-(total_len - needed_len):] + pkt = pkt[:needed_len] + assert(len(packets[0]) + len(pkt) == total_len) + else: + del packets[:] + return pkt + +def ident(x, **kw): + return x + +def collect_packets(header_iter_gen): + """ Example: + >>> list(pcaputil.send_multiple(pcaputil.collect_packets(lambda src, dst: iter([(10, ident, lambda h: 20, ident)]*2))(17, 42), [(42, 17, ' '*30)])) + [(' ', ' ')] + """ + def collector(src, dst): + packets = [] + history = [] + header_iter = header_iter_gen(src, dst) + hdr_size, hdr_ctor, size_from_hdr, pkt_ctor = header_iter.next() + msg = None + header = pkt = None + searched_size = hdr_size + while True: + src, dst, payload = yield (src, dst, msg) + packets.append(payload) + history.append(payload) + print "collect_packets: searching for %s (%s)" % (searched_size, sum(map(len, packets))) + data = consume_packets(packets, searched_size) + while data != None: + print "collect_packets: %s" % len(data) + if header is None: + header = hdr_ctor(data) + print "collect_packets: header of length %s, expect %s" % (len(data), size_from_hdr(header)) + searched_size = size_from_hdr(header) + if searched_size > 1000000: + import pdb; pdb.set_trace() + else: + print "collect_packets: packet of length %s" % len(data) + pkt = pkt_ctor(data, src=src, dst=dst, header=header) + msg = (header, pkt) + hdr_size, hdr_ctor, size_from_hdr, pkt_ctor = header_iter.next() + header = None + searched_size = hdr_size + data = consume_packets(packets, searched_size) + return collector + +def send_multiple(gen, sent_iter): + sent_iter = iter(sent_iter) + gen.next() + try: + while True: + next = sent_iter.next() + yield gen.send(next) + except StopIteration: + return + +def header_conversation_iter(packet_iter, **kw): + return conversations_iter(packet_iter, + **dict([(k, collect_packets(v)) for k,v in kw.items()])) + +def tcp_data_iter(packet_iter): + return ifilter(lambda src, dst, flags, data: + len(data) > 0, imap(tcp_parse, ifilter(is_tcp_data, imap(lambda x: ''.join(x[1]), packet_iter)))) + +def tcp_iter(packet_iter): + return ifilter(lambda (src, dst, flags, data): + flags & (TCP_SYN | TCP_FIN) or len(data) > 0, + imap(tcp_parse, ifilter(is_tcp, imap(lambda x: ''.join(x[1]), packet_iter)))) + +def packet_iter(dev): + import pcap + return pcap.pcap(dev) + + diff --git a/stracer.py b/stracer.py new file mode 100644 index 0000000..60b735f --- /dev/null +++ b/stracer.py @@ -0,0 +1,163 @@ +import subprocess +import re +from client import unspice_channels + +def parse_connect(line): + """'{sa_family=AF_INET, sin_port=htons(5926), sin_addr=inet_addr("127.0.0.1")}, 16) = 0\n'""" + m = re.match('\s+\{sa_family=(?P<family>[A-Z_]+),\s+sin_port=htons\((?P<port>[0-9]+)\),\s+sin_addr=inet_addr\("(?P<host>[0-9.]+)"\)\},\s+([0-9]+)\)\s+=\s+(?P<ret>[0-9])+\n', line) + if not m: return m + return m.groupdict() + +def from_quoted(s): + return ''.join(chr(int(s[i+2:i+4], 16)) for i in + xrange(0, len(s), 4)) + +def match_message_container(line, rexp): + m = re.match(rexp, line) + if not m: return m + d = m.groupdict() + d['msg'] = from_quoted(d['msg']) + d['size'] = int(d['size']) + if 'ret' in d: + d['ret'] = int(d['ret']) + assert(len(d['msg']) == d['ret']) + d['type'] = 'sendto' + return d + +def parse_sendto(line): + """"\x14\x00\x00\x00\x16\x00\x01\x03\x06\x83\x59\x4c\x00\x00\x00\x00\x00\x00\x00\x00", 20, 0, {sa_family=AF_NETLINK, pid=0, groups=00000000}, 12) = 20""" + if 'unfinished' in line: + return match_message_container(line, '\s+"(?P<msg>[x0-9a-f\\\\]+)"\.?\.?\.?,\s+(?P<size>[0-9]+),\s+([0-9]+),\s+.+,\s+([0-9]+) <unfinished ...>\n') + return match_message_container(line, '\s+"(?P<msg>[x0-9a-f\\\\]+)"\.?\.?\.?,\s+(?P<size>[0-9]+),\s+([0-9]+),\s+.+,\s+([0-9]+)\)\s+=\s+(?P<ret>[0-9]+)\n') + +def parse_recvfrom(line): + """"\x00\x00\x00\x00\xd2\x01\x00\x00\x1d\x01\x00\x00\x80\x02\x00\x00\x30\x02\x00\x00\x00\xba\x02\x00\x00\xb6\x01\x00\x00", 29, 0, NULL, NULL) = 29""" + return match_message_container(line, '\s+"(?P<msg>[x0-9a-f\\\\]+)"\.?\.?\.?,\s+(?P<size>[0-9]+),\s+([0-9]+),.+,.+\)\s+=\s+(?P<ret>[0-9]+)\n') + +def parse_recvmsg(line): + print "TODO - parse_recvmsg" + return + +def parse_sendmsg(line): + print "TODO - parse_sendmsg" + return + +def parse_start(line): + # really lame parsing, won't deal with nested first arguments - but + # works for all the system calls we are interested in. Actually just + # adding nesting level to each char would solve this. + m = re.search('\[pid\s+([0-9]+)\]\s+([a-zA-Z0-9]+)\(([^,]+),', line) + return m + +class StracerNetwork(object): + + _unhandled_syscalls = set(['socket', 'bind', 'getsockname', 'setsockopt', + # later + 'getpeername', 'getsockopt', 'shutdown']) + def __init__(self, cmd, on_message = lambda x: None): + self.max_bytes_per_message = max_bytes = 2**20 + self._p = subprocess.Popen( + ('strace -qvxx -e trace=network -s %s -Cf %s' % (max_bytes, cmd)).split(), + stderr=subprocess.PIPE) + self.connections = [] + self.sockets = {} + self._message_count = 0 + self._lines = [] + self.on_message = on_message + + def total_length(self, n): + if n not in self.sockets: + return 0 + return sum([len(d['msg']) for m, d in self.sockets[n] if d is not None and 'msg' in d]) + + def terse(self, count=10): + while sum([len(v) for v in self.sockets.values()]) < count: + print sum(map(len, self.sockets.items())) + self.handle_line() + sorted_keys = list(sorted(self.sockets.keys())) + return sorted_keys, [self.terse_key(k) for k in sorted_keys] + + def terse_key(self, key): + return [(dict(recvfrom='R',sendto='S').get(tag), d['msg']) + for ((tag, line), d) in self.sockets[key] if d is not None] + + + def conversations(self, count=10): + keys, terses = self.terse(count=count) + return keys, [[''.join([msg for t,msg in terse if t==k]) for k in 'SR'] + for terse in terses] + + def spiced(self, count=10): + keys, collecteds = self.conversations(count) + return unspice_channels(keys, collecteds) + + def lines(self): + for line in self._p.stderr: + print line if len(line) < 100 else line[:100] + '...' + self._lines.append(line) + yield line + + def wait_for_connect(self, host): + while not self.host_connected(host): + self.handle_line() + + def host_connected(self, host): + return any([d['host'] == host for d in self.connections]) + + def add_to_socket(self, sock, tag, line, parser): + if sock not in self.sockets: + print "did I miss a connect?? at %s" % len(self._lines) + self.add_socket(sock) + parsed = parser(line) + self.sockets[sock].append(((tag, line), parsed)) + self._message_count += 1 + if parsed: + self.on_message(parsed['msg']) + + def add_socket(self, sock, d={}): + d['sock'] = sock + if 'host' not in d: + d['host'] = 'unknown' + self.connections.append(d) + self.sockets[sock] = [] + + def handle_line(self): + line = self.lines().next() + m = parse_start(line) + if not m: return + line = line[m.end():] + pid, syscall, firstarg = m.groups() + if 'connect' in syscall: + d = parse_connect(line) + if not d: return + self.add_socket(int(firstarg), d) + elif 'sendto' in syscall: + self.add_to_socket(int(firstarg), 'sendto', line, parse_sendto) + elif 'recvmsg' in syscall: + self.add_to_socket(int(firstarg), 'recvmsg', line, parse_recvmsg) + elif 'recvfrom' in syscall: + self.add_to_socket(int(firstarg), 'recvfrom', line, parse_recvfrom) + elif 'sendmsg' in syscall: + self.add_to_socket(int(firstarg), 'sendmsg', line, parse_sendmsg) + elif syscall in self._unhandled_syscalls: + pass + else: + print "unhandled %s" % syscall + import pdb; pdb.set_trace() + + def wait_for(self, num): + ret = [] + start_count = self._message_count + while True: + self.handle_line() + if self._message_count - start_count >= num: + return + +def trace_spicec(host, port): + s = StracerNetwork('/store/upstream/bin/spicec -h %s -p %s' % (host, port)) + sock = s.wait_for_connect(host) + return s.filter(sock, 10) + +if __name__ == '__main__': + trace_spicec('127.0.0.1', 5926) + @@ -0,0 +1,3 @@ +def reverse_dict(d): + return dict([(v,k) for k,v in d.items()]) + |