summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlon Levy <alevy@redhat.com>2010-08-16 13:31:28 +0300
committerAlon Levy <alevy@redhat.com>2010-08-16 13:31:28 +0300
commit08951c2d430aa1216d186f7b52fb9b2298350291 (patch)
treec2ea618c5ddaf6702f08ed33ef04f27fe70c985d
initial
-rw-r--r--client.py442
-rw-r--r--client_proto.py320
-rw-r--r--client_proto_tests.py25
-rw-r--r--compress.py36
-rwxr-xr-xdumpspice.py22
-rw-r--r--getem.py98
-rw-r--r--main.py14
-rw-r--r--pcapspice.py230
-rw-r--r--pcaputil.py171
-rw-r--r--stracer.py163
-rw-r--r--util.py3
11 files changed, 1524 insertions, 0 deletions
diff --git a/client.py b/client.py
new file mode 100644
index 0000000..646c2f9
--- /dev/null
+++ b/client.py
@@ -0,0 +1,442 @@
+import struct
+import socket
+
+import client_proto
+
+# TODO - parse spice/protocol.h directly (to keep in sync automatically)
+SPICE_MAGIC = "REDQ"
+SPICE_VERSION_MAJOR = 2
+SPICE_VERSION_MINOR = 0
+
+################################################################################
+# spice-protocol/spice/enums.h
+
+(SPICE_CHANNEL_MAIN,
+SPICE_CHANNEL_DISPLAY,
+SPICE_CHANNEL_INPUTS,
+SPICE_CHANNEL_CURSOR,
+SPICE_CHANNEL_PLAYBACK,
+SPICE_CHANNEL_RECORD,
+SPICE_CHANNEL_TUNNEL,
+SPICE_END_CHANNEL) = xrange(1,1+8)
+
+################################################################################
+
+# Encryption & Ticketing Parameters
+SPICE_MAX_PASSWORD_LENGTH=60
+SPICE_TICKET_KEY_PAIR_LENGTH=1024
+SPICE_TICKET_PUBKEY_BYTES=(SPICE_TICKET_KEY_PAIR_LENGTH / 8 + 34)
+
+def unpack_list(structs, s):
+ """ struct has a bug -
+ sizeof('IBIII') == 20
+ sizeof('IIIIB') == 16
+ """
+ ret = []
+ start = 0
+ for the_struct in structs:
+ ret.extend(list(the_struct.unpack(s[start:start + s.size])))
+ t += s.size
+ return ret
+
+def group(format):
+ return reduce(lambda cs, c: cs[:-1]+[cs[-1]+c] if len(cs) > 0 and c == cs[-1][-1] else cs+[c], format, [])
+
+class Elements(object):
+ pass
+
+ENDIANESS = '<' # small endian
+
+class StructList(object):
+
+ def __init__(self, formats):
+ self._s = map(struct.Struct, (ENDIANESS+f for f in formats))
+ self.size = sum([s.size for s in self._s])
+
+ def pack(self, *args):
+ i_s, i_e = 0, 0
+ r = []
+ for s in self._s:
+ i_e += len(s.format) - (1 if s.format[0] in '<>' else 0)
+ r.append(s.pack(*args[i_s:i_e]))
+ i_s = i_e
+ return ''.join(r)
+
+ def unpack(self, st):
+ r = []
+ i_s, i_e = 0, 0
+ for s in self._s:
+ i_e += s.size
+ r.append(list(s.unpack(st[i_s:i_e])))
+ i_s = i_e
+ return sum(r, [])
+
+class StructMeta(type):
+ def __new__(meta, classname, bases, classDict):
+ fields = classDict['fields']
+ is_complex = classDict['_is_complex'] = callable(fields[-1][0])
+ if is_complex:
+ classDict['complex_field'] = complex_field = fields[-1]
+ fields = fields[:-1]
+ assert(not any(map(callable, fields)))
+ classDict['_s'] = StructList(group(''.join(t for t,n in fields)))
+ classDict['_names'] = [n for t,n in fields]
+ classDict['size'] = classDict['_s'].size
+ classDict['field_elements'] = [len(t) for t, n in fields]
+ return type.__new__(meta, classname, bases, classDict)
+
+def indice_pairs(sizes):
+ s = 0
+ for size in sizes:
+ yield s, s+size
+ s += size
+
+def cut(elements, sizes):
+ for s, e in indice_pairs(sizes):
+ if e - s == 1:
+ yield elements[s]
+ else:
+ yield elements[s:e]
+
+class Struct(object):
+ @classmethod
+ def parse(cls, s):
+ base = list(cut(cls._s.unpack(s[:cls._s.size]), cls.field_elements))
+ if cls._is_complex:
+ import pdb; pdb.set_trace()
+ return base
+ return base
+
+ @classmethod
+ def make(cls, **kw):
+ args = []
+ args = [kw[n] for n in cls._names]
+ assert(len(args) == len(cls._names) == len(kw))
+ return cls._s.pack(*args)
+
+ def __init__(self, *args, **kw):
+ self.e = Elements()
+ if (len(kw) == 0 and len(args) == 1) or (len(kw) == 1 and kw.has_key('s')):
+ s = args[0] if len(args) == 1 else kw['s']
+ self.elements = elements = self.parse(s)
+ else:
+ self.elements = elements = [kw[n] for n in self._names]
+ self.e.__dict__.update(dict(zip(self._names, elements)))
+
+ def tostr(self):
+ return self.make(**self.e.__dict__)
+
+uint16 = 'H'
+uint32 = 'I'
+uint64 = 'Q'
+uint8 = 'B'
+
+uint32_arr = lambda s: ENDIANESS + uint32*s.e.size
+
+class SpiceDataHeader(Struct):
+ __metaclass__ = StructMeta
+ fields = [(uint64, 'serial'), (uint16, 'type'), (uint32, 'size'),
+ (uint32, 'sub_list')]
+
+ def __str__(self):
+ return 'Spice #%s, T%s, sub %s, size %s' % (self.e.serial,
+ self.e.type, self.e.sub_list, self.e.size)
+ __repr__ = __str__
+
+class SpiceSubMessage(Struct):
+ __metaclass__ = StructMeta
+ fields = [(uint16, 'type'), (uint32, 'size')]
+
+class SpiceSubMessageList(Struct):
+ __metaclass__ = StructMeta
+ fields = [(uint16, 'size'), (uint32_arr, 'sub_messages')]
+
+data_header_size = SpiceDataHeader.size
+
+class SpiceLinkHeader(Struct):
+ __metaclass__ = StructMeta
+ fields = [(uint32, 'magic'), (uint32, 'major_version'),
+ (uint32, 'minor_version'), (uint32, 'size')]
+
+ def __str__(self):
+ return 'SpiceLink magic=%s, (%s,%s), size %s' % (
+ struct.pack('<I', self.e.magic),
+ self.e.major_version, self.e.minor_version, self.e.size)
+ __repr__ = __str__
+
+ @classmethod
+ def verify(cls, *args, **kw):
+ inst = cls(*args, **kw)
+ assert(inst.e.magic == SPICE_MAGIC)
+ return inst
+
+link_header_size = SpiceLinkHeader.size
+
+class SpiceLinkMess(Struct):
+ __metaclass__ = StructMeta
+ fields = [(uint32, 'connection_id'), (uint8, 'channel_type'), (uint8, 'channel_id'),
+ (uint32, 'num_common_caps'), (uint32, 'num_channel_caps'), (uint32, 'caps_offset')]
+
+ def __str__(self):
+ return 'SpiceLinkMess conn id %s, ch type %s, ch id %s caps (%s,%s,%s)' % (
+ self.e.connection_id, self.e.channel_type, self.e.channel_id,
+ self.e.num_common_caps, self.e.num_channel_caps,
+ self.e.caps_offset)
+ __repr__ = __str__
+
+link_mess_size = SpiceLinkMess.size
+
+class SpiceLinkReply(Struct):
+ __metaclass__ = StructMeta
+ fields = [(uint32, 'error'),
+ (uint8*SPICE_TICKET_PUBKEY_BYTES, 'pub_key'),
+ (uint32, 'num_common_caps'),
+ (uint32, 'num_channel_caps'),
+ (uint32, 'caps_offset')]
+
+ def __str__(self):
+ return '<SpiceLinkReply error=%s, common=%s, channel=%s, pub=%s:%s>' %(
+ self.e.error, self.e.num_common_caps, self.e.num_channel_caps,
+ self.e.pub_key[:10], len(self.e.pub_key))
+ __repr__ = __str__
+
+
+def parse_link_msg(s):
+ link_header = SpiceLinkHeader(s)
+ assert(len(s) == SpiceLinkHeader.size + link_header.size)
+ link_msg = SpiceLinkMess(s[SpiceLinkHeader.size:])
+ return link_header, link_msg
+
+def make_data_msg(serial, type, payload, sub_list):
+ data_header = SpiceDataHeader(serial=serial,
+ type=type, size=len(payload), sub_list=sub_list)
+
+def connect(host, port):
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.connect((host, port))
+ return s
+
+def list_to_str(l, type=uint32):
+ if len(l) == 0:
+ return ''
+ return struct.pack(ENDIANESS+len(l)*type, l)
+
+class Channel(object):
+ def __init__(self, s, connection_id, channel_type, channel_id,
+ common_caps, channel_caps):
+ self._s = s
+ self._channel_type = channel_type
+ self._channel_id = channel_id
+ self._connection_id = connection_id
+ self._common_caps = common_caps
+ self._channel_caps = channel_caps
+ self._proto = client_proto.channels[channel_type]
+ self.send_link_message()
+
+ def send_link_message(self):
+ link_message = SpiceLinkMess(
+ channel_type=self._channel_type,
+ channel_id=self._channel_id,
+ caps_offset=SpiceLinkMess.size,
+ num_channel_caps=len(self._channel_caps),
+ num_common_caps=len(self._common_caps),
+ connection_id=self._connection_id)
+ link_header = SpiceLinkHeader(magic=SPICE_MAGIC,
+ major_version=SPICE_VERSION_MAJOR,
+ minor_version=SPICE_VERSION_MINOR, size=link_message.size)
+ self._s.send(
+ link_header.tostr()
+ + link_message.tostr()
+ + list_to_str(self._common_caps)
+ + list_to_str(self._channel_caps))
+ self.parse_one(SpiceLinkHeader, self._on_link_reply)
+
+ def _on_link_reply(self, header, data):
+ link_reply = SpiceLinkReply(data)
+ self._link_reply = link_reply
+ self._link_reply_header = header
+
+ def _on_data(self, header, data):
+ self._last_data = data
+ self._last_data_header = header
+ print "_on_data: %s" % len(data)
+
+ def parse_data(self):
+ self.parse_one(SpiceDataHeader, self._on_data)
+
+ def parse_one(self, ctor, on_message):
+ header = ctor(self._s.recv(ctor.size))
+ on_message(header, self._s.recv(header.e.size))
+
+class MainChannel(Channel):
+ def send_attach_channels(self):
+ import pdb; pdb.set_trace()
+ self._s.send(s)
+
+class Client(object):
+ def __init__(self, host, port, strace=False):
+ self._host = host
+ self._port = port
+ self._strace = strace
+
+ def connect_main(self):
+ if self._strace:
+ import stracer
+ s = StracerNetwork(
+ '/store/upstream/bin/spicec -h %s -p %s' % (
+ self._host, self._port)
+ , client=self)
+ self._main_sock = s.sock_iface()
+ else:
+ self._main_sock = connect(self._host, self._port)
+ self.main = self.connect_to_channel(self._main_sock,
+ connection_id=0, channel_type=SPICE_CHANNEL_MAIN,
+ channel_id=0)
+
+ # for usage from ipython interactively
+ def connect_to_channel(self, s, connection_id, channel_type, channel_id,
+ common_caps=[], channel_caps=[]):
+ return Channel(s,
+ connection_id=connection_id,
+ channel_type=channel_type,
+ channel_id=channel_id,
+ common_caps=common_caps,
+ channel_caps=channel_caps)
+
+ def run(self):
+ print "TODO"
+
+ def on_message(self, msg):
+ import pdb; pdb.set_trace()
+
+SPICE_MAGIC = struct.unpack('<I', 'REDQ')[0]
+
+################################################################################
+
+def find_first_data(s):
+ rest = []
+ found = False
+ if (len(s) > 182+SpiceDataHeader.size) and (
+ SpiceDataHeader(s[182:]).e.serial == 1):
+ print "found at 182 (good guess!)"
+ found = True
+ first_data = 182
+ else:
+ #if len(s) > 100000:
+ # import pdb; pdb.set_trace()
+ b = buffer(s)
+ candidates = [(i, h) for i,h in
+ ((i, SpiceDataHeader(b[i:])) for i in
+ xrange(0,min(500,len(s))-SpiceDataHeader.size,2)) if h.e.serial == 1]
+ if len(candidates) == 1:
+ found = True
+ first_data = candidates[0][0]
+ print "found at %s" % first_data
+ if found:
+ rest = s[first_data:]
+ else:
+ print "not found"
+ return rest
+
+def unspice_channels(keys, collecteds):
+ return dict([(ch, (key, conv)) for key, (ch, conv) in zip(keys,
+ parse_conversations(collecteds))])
+
+def parse_conversations(conversations):
+ channels = []
+ for client_stream, server_stream in conversations:
+ print "parsing (client %s, server %s)" % (
+ len(client_stream), len(server_stream))
+ channel_type, client_messages = unspice_client(client_stream)
+ channel_conv = (client_messages,
+ unspice_server(channel_type, server_stream))
+ channels.append((channel_type, channel_conv))
+ return channels
+
+def parse_link_start(s, msg_ctor):
+ if len(s) < SpiceLinkHeader.size:
+ return 0, []
+ link_header = SpiceLinkHeader(s)
+ assert(link_header.e.magic == SPICE_MAGIC)
+ msg = msg_ctor(s[
+ link_header_size:link_header_size+link_header.e.size])
+ return link_header_size+link_header.e.size, [link_header, msg]
+
+def parse_link_server_start(s):
+ return parse_link_start(s, SpiceLinkReply)
+
+def parse_link_client_start(s):
+ return parse_link_start(s, SpiceLinkMess)
+
+def unspice_server(channel_type, s):
+ if len(s) < SpiceLinkHeader.size:
+ return []
+ link_header = SpiceLinkHeader(s)
+ if link_header.e.magic != SPICE_MAGIC:
+ print "ch %s: bad server magic, trying to look for SpiceDataHeader (%s)" % (
+ channel_type, len(s))
+ link_reply = None
+ rest = find_first_data(s)
+ else:
+ link_reply = SpiceLinkReply(s[
+ link_header_size:link_header_size+link_header.e.size])
+ rest = s[link_header_size+link_header.e.size:]
+ ret = [link_header, link_reply]
+ return ret + unspice_rest(channel_type, False, rest)
+
+def unspice_client(s):
+ if len(s) == 0:
+ return -1, []
+ # First message starts with a link header, then link message
+ link_header = SpiceLinkHeader(s)
+ if link_header.e.magic != SPICE_MAGIC:
+ print "bad client magic, trying to look for SpiceDataHeader (%s)" % len(s)
+ link_message = None
+ channel_type = -1 # XXX - can guess based on messages
+ rest = find_first_data(s)
+ else:
+ link_message = SpiceLinkMess(s[
+ link_header_size:link_header_size+link_header.e.size])
+ channel_type = link_message.e.channel_type
+ rest = s[link_header_size+link_header.e.size:]
+ ret = [link_header, link_message]
+ return channel_type, ret + unspice_rest(channel_type, True, rest)
+
+def unspice_rest(channel_type, is_client, s):
+ datas = []
+ if len(s) < 8:
+ return datas
+ i_s = 0
+ if struct.unpack(ENDIANESS + uint64, s[:8])[0] > 1000:
+ if is_client:
+ print "client hack - ignoring 128 bytes"
+ i_s = 128
+ else:
+ print "server hack - ignoring 4 bytes"
+ i_s = 4
+ while len(s) > i_s:
+ header = SpiceDataHeader(s[i_s:])
+ data_start = i_s + SpiceDataHeader.size
+ data_end = data_start + header.e.size
+ if data_end > len(s):
+ break
+ datas.append((header, client_proto.parse(
+ channel_type=channel_type,
+ is_client=is_client,
+ header=header,
+ data=s[data_start:data_end])))
+ i_s += SpiceDataHeader.size + header.e.size
+ return datas
+
+def parse_ctors(s, ctors):
+ ret = []
+ i = 0
+ for c in ctors:
+ c_size = c.size
+ if len(s) <= i:
+ break
+ ret.append(c(s[i:]))
+ i += c_size
+ return ret
+
+
diff --git a/client_proto.py b/client_proto.py
new file mode 100644
index 0000000..c29c182
--- /dev/null
+++ b/client_proto.py
@@ -0,0 +1,320 @@
+"""
+Decode spice protocol using spice demarshaller.
+"""
+
+ANNOTATE = False
+
+import struct
+import sys
+if not 'proto' in locals():
+ sys.path.append('../spice/python_modules')
+ import spice_parser
+ import ptypes
+ proto = spice_parser.parse('../spice/spice.proto')
+ reloads = 1
+else:
+ reloads += 1
+
+
+channels = {}
+for channel in proto.channels:
+ channels[channel.value] = channel
+ channel.client = dict(zip([x.value for x in
+ channel.channel_type.client_messages],
+ channel.channel_type.client_messages))
+ channel.server = dict(zip([x.value for x in
+ channel.channel_type.server_messages],
+ channel.channel_type.server_messages))
+ # todo parsing of messages / building of messages
+
+def mapdict(f, d):
+ return dict((k, f(v)) for k,v in d.items())
+
+primitives=mapdict(struct.Struct, dict(
+ uint64='<Q',
+ int64='<q',
+ uint32='<I',
+ int32='<i',
+ uint16='<H',
+ int16='<h',
+ uint8='<B',
+ int8='<b',
+ ))
+
+primitive_arrays=dict([(k, lambda n, e=e: struct.Struct('<'+e*n))
+ for k, e in
+ [('uint64','Q'),
+ ('int64','q'),
+ ('uint32','I'),
+ ('int32','i'),
+ ('uint16','H'),
+ ('int16','h'),
+ ]])
+
+class NullUnpacker(object):
+ def __init__(self, n):
+ self.size = n
+ def unpack(self, s):
+ return s
+ unpack_from=unpack
+
+class Flag(object):
+ def __init__(self, names, value):
+ self.name = None
+ for k, v in names.items():
+ if value == 1<<k:
+ self.name = v
+ break
+ if self.name is None:
+ self.name = set(v for k,v in names.items() if (1<<k)&value)
+ self.value = value
+ def __str__(self):
+ return 'F(%s,%s)' % (self.name, self.value)
+ __repr__ = __str__
+
+class Enum(object):
+ def __init__(self, the_type, name, value):
+ self.name = name
+ self.value = value
+ #self.the_type = the_type
+ # TODO - cache this
+ #self.name_to_value = dict(zip(self.the_type.values(), self.the_type.keys()))
+ def __str__(self):
+ return 'E(%s,%s)' % (self.name, self.value)
+ __repr__ = __str__
+
+
+primitive_arrays['int8'] = primitive_arrays['uint8'] = lambda n: NullUnpacker(n)
+
+def ensure_dict(maybe_d):
+ if isinstance(maybe_d, dict):
+ return maybe_d
+ return dict(maybe_d)
+
+def get_sub_member(member_d, var_name):
+ if var_name in member_d:
+ return member_d[var_name]
+ ret = member_d
+ for part in var_name.split('.'):
+ ret = ensure_dict(ret)[part]
+ return ret
+
+def annotate_data(f):
+ def wrapped(member, data, i_s, parsed):
+ i = len(data.an)
+ i_s_start = i_s
+ data.an.append((i_s, member)) # annotation (before value is known)
+ i_s, result_name, result_value = f(member, data, i_s, parsed)
+ data.an[i] = (((i_s_start, i_s), member, result_value))
+ return i_s, result_name, result_value
+ wrapped.func_name = '%s@annotate_data' % f.func_name
+ return wrapped
+
+def simple_annotate_data(f):
+ def wrapped(member, the_type, data, i_s):
+ data.an.append((i_s, member))
+ return f(member, the_type, data, i_s)
+ wrapped.func_name = '%s@simple_annotate_data' % f.func_name
+ return wrapped
+
+def member_size(m):
+ pass
+
+def makelen(n, c, s):
+ """pad s with c up to length n"""
+ return s + c * (n - len(s))
+
+class AnnotatedString(str):
+ def __init__(self, x):
+ super(AnnotatedString, self).__init__(x)
+ self.an = []
+ self.max_pointer_i_s = 0
+ def san(self):
+ def nicify(x):
+ if len(x) == 2 and not isinstance(x[0], tuple):
+ start = '='*8 if x[0] == 0 else ' '+'='*8
+ return makelen(80, '=', start + str(x))
+ return str(x)
+ print '\n'.join(map(nicify, self.an))
+ #print '\n'.join('%s-%s,%s,%s,%s' % (i_s_start, i_s, member.member_type. for (i_s_start, i_s),member,val in self.an]))
+
+print_annotation = lambda data: data.san()
+
+if not ANNOTATE:
+ #AnnotatedString = lambda x: x
+ print_annotation = lambda x: None
+ #annotate_data = lambda x: x
+ #simple_annotate_data = lambda x: x
+
+@annotate_data
+def parse_member(member, data, i_s, parsed):
+ result_name = member.name
+ result_value = None
+ if len(set(member.attributes.keys()) - set(
+ ['end', 'ctype', 'as_ptr', 'nomarshal', 'anon', 'chunk'])) > 0:
+ print "has attributes %s" % member.attributes
+ #import pdb; pdb.set_trace()
+ if hasattr(member, 'is_switch') and member.is_switch():
+ var_name = member.variable
+ parsed_d = dict(parsed)
+ var = get_sub_member(parsed_d, var_name)
+ for case in member.cases:
+ for value in case.values:
+ if case.values[0] is None:
+ # default
+ return parse_member(case.member, data, i_s, parsed)
+ cond_wrap = (lambda x: not x) if value[0] == '!' else (lambda x: x)
+ if cond_wrap(var.name == value[1]):
+ return parse_member(case.member, data, i_s, parsed)
+ return i_s, result_name, 'empty switch'
+ member_type = member.member_type if hasattr(member, 'member_type') else member
+ member_type_name = member_type.name
+ results = None
+ if member_type_name is not None:
+ if member_type_name in primitives:
+ primitive = primitives[member_type_name]
+ result_value = primitive.unpack_from(data[i_s:i_s+primitive.size])[0]
+ i_s_out = i_s + primitive.size
+ elif member_type.is_struct():
+ i_s_out, result_name, result_value = parse_complex_member(member, member_type, data, i_s)
+ elif member_type.is_enum():
+ primitive = primitives[member_type.primitive_type()]
+ value = primitive.unpack(data[i_s:i_s+primitive.size])[0]
+ result_value = Enum(member_type.names, member_type.names.get(value, value), value)
+ # TODO - use enum for nice display
+ i_s_out = i_s + primitive.size
+ elif member_type.is_primitive() and hasattr(member_type, 'names'):
+ # assume flag
+ if not isinstance(member_type, ptypes.FlagsType):
+ print "not really a flag.."
+ import pdb; pdb.set_trace()
+ primitive = primitives[member_type.primitive_type()]
+ value = primitive.unpack(data[i_s:i_s+primitive.size])[0]
+ result_value = Flag(member_type.names, value)
+ i_s_out = i_s + primitive.size
+ else:
+ import pdb; pdb.set_trace()
+ elif member_type.is_array():
+ if member_type.is_remaining_length():
+ data_len = len(data) - i_s
+ elif member_type.is_identifier_length():
+ num_elements = dict(parsed)[member_type.size]
+ data_len = None
+ elif member_type.is_image_size_length():
+ bpp = member_type.size[1]
+ width_name = member_type.size[2]
+ rows_name = member_type.size[3]
+ assert(isinstance(width_name, str))
+ assert(isinstance(rows_name, str))
+ parsed_d = dict(parsed)
+ width = get_sub_member(parsed_d, width_name)
+ rows = get_sub_member(parsed, rows_name)
+ if bpp == 8:
+ data_len = rows * width
+ elif bpp == 1:
+ data_len = ((width + 7) // 8 ) * rows
+ else:
+ data_len =((bpp * width + 7) // 8 ) * rows
+ else:
+ print "unhandled array length type"
+ import pdb; pdb.set_trace()
+ element_type = member_type.element_type
+ if element_type.name in primitive_arrays:
+ element_size = primitives[element_type.name].size
+ if data_len is None:
+ data_len = num_elements * element_size
+ primitive = primitive_arrays[element_type.name]
+ contents = primitive(data_len).unpack_from(data[i_s:i_s+data_len])
+ i_s_out = i_s + data_len
+ else:
+ contents = []
+ i_s_out = i_s
+ for _ in xrange(num_elements):
+ i_s_out, sub_name, sub_value = parse_member(element_type, data, i_s_out, [])
+ if sub_value is None:
+ print "list parsing failed"
+ contents = []
+ break
+ contents.append((sub_name, sub_value))
+ result_value = contents
+ elif member_type.is_pointer():
+ pointer_primitive_name = member_type.primitive_type()
+ if pointer_primitive_name in primitives:
+ primitive = primitives[pointer_primitive_name]
+ offset = primitive.unpack_from(data[i_s:i_s+primitive.size])[0]
+ if offset == 0:
+ result_value = []
+ elif offset > len(data):
+ print "ooops.. bad offset %s > %s" % (offset, len(data))
+ import pdb; pdb.set_trace()
+ result_value = None
+ else:
+ discard_i_s, discard_result_name, result_value = parse_member(
+ member_type.target_type, data, offset, [])
+ data.max_pointer_i_s = max(data.max_pointer_i_s, discard_i_s)
+ i_s_out = i_s + primitive.size
+ else:
+ print "pointer with unknown size??"
+ import pdb; pdb.set_trace()
+ if result_name in ['data', 'ents', 'glyphs', 'String', 'str']:
+ result_value = NoPrint(name='', s=result_value)
+ elif len(str(result_value)) > 1000:
+ print "noprint?"
+ import pdb; pdb.set_trace()
+ return i_s_out, result_name, result_value
+
+@simple_annotate_data
+def parse_complex_member(member, the_type, data, i_s):
+ parsed = []
+ i_s_out = i_s
+ for sub_member in the_type.members:
+ i_s_out, result_name, result_value = parse_member(sub_member, data, i_s_out, parsed)
+ if result_value is not None:
+ parsed.append((result_name, result_value))
+ else:
+ print "don't know how to parse %s (%s) in %s" % (
+ sub_member.name, the_type, member)
+ break
+ return i_s_out, member.name, parsed
+
+def parse(channel_type, is_client, header, data):
+ if channel_type not in channels:
+ return NoPrint(name='unknown channel (%s %s)' % (
+ is_client, header), s=data)
+ # annotation support
+ data = AnnotatedString(data)
+ channel = channels[channel_type]
+ collection = channel.client if is_client else channel.server
+ if header.e.type not in collection:
+ print "bad data - no such message in protocol"
+ i_s, result_name, result_value = 0, None, None
+ else:
+ msg_proto = collection[header.e.type]
+ i_s, result_name, result_value = parse_complex_member(
+ msg_proto, msg_proto.message_type, data, 0)
+ i_s = max(data.max_pointer_i_s, i_s)
+ left_over = NoPrint(name='%s:%s' % (channel_type, i_s), s=data[i_s:]) if i_s < len(data) else ''
+ print_annotation(data)
+ if len(left_over) > 0:
+ print "WARNING: in message %s %s out %s unaccounted for (%s%%)" % (
+ msg_proto.name, len(left_over), len(data), float(len(left_over))/len(data))
+ #import pdb; pdb.set_trace()
+ return (msg_proto, (result_name, result_value))
+
+class NoPrint(object):
+ objs = {}
+ def __init__(self, name, s):
+ self.s = s
+ self.id = id(self.s)
+ self.objs[self.id] = self.s
+ self.name = name
+ def orig(self):
+ return self.s
+ def __str__(self):
+ return 'NP' + (str(self.s) if len(self.s) < 20 else '%r...[%s,%s] #%s' % (
+ self.s[:20], len(self.s), self.id,
+ self.name))
+ __repr__ = __str__
+ def __len__(self):
+ return len(self.s)
+
diff --git a/client_proto_tests.py b/client_proto_tests.py
new file mode 100644
index 0000000..7bab3f0
--- /dev/null
+++ b/client_proto_tests.py
@@ -0,0 +1,25 @@
+import client_proto as cp
+
+def test_complete_parse(data, the_type):
+ i_s, name, result = cp.parse_member(the_type, cp.AnnotatedString(data), 0, [])
+ assert(i_s == len(data))
+
+def test_parse_complex_member(data, member, the_type):
+ i_s, name, result = cp.parse_complex_member(member, the_type, cp.AnnotatedString(data), 0)
+ assert(i_s == len(data))
+
+
+def tests():
+ six_rects = '\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00U\x01\x00\x00\x80\x02\x00\x00U\x01\x00\x00\x00\x00\x00\x00\xd1\x01\x00\x00R\x01\x00\x00U\x01\x00\x00K\x02\x00\x00\xd1\x01\x00\x00\x80\x02\x00\x00\xd1\x01\x00\x00\x00\x00\x00\x00\xd6\x01\x00\x00R\x01\x00\x00\xd1\x01\x00\x00K\x02\x00\x00\xd6\x01\x00\x00\x80\x02\x00\x00\xd6\x01\x00\x00\x00\x00\x00\x00\xe0\x01\x00\x00\x80\x02\x00\x00'
+ draw_text = '\x00\x00\x00\x00\xc7\x01\x00\x00#\x00\x00\x00\xdd\x01\x00\x00I\x00\x00\x00\x01\x01\x00\x00\x00\xc7\x01\x00\x00#\x00\x00\x00\xdd\x01\x00\x00I\x00\x00\x00G\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\xff\xff\x00\x00\x08\x00\x00\x00\x05\x00\x02#\x00\x00\x00\xd7\x01\x00\x00\xff\xff\xff\xff\xf6\xff\xff\xff\t\x00\n\x00\x03\xaf\xfdp\x00?\xff\xff\xf7\x00\xcf\x90\x03\xfe\x00\x15\x00\x06\xff\x00\x009\xef\xf8\x00\x07\xff\xe90\x00\x0e\xf6\x00\x06\x10\x0e\xf3\x01\xcf\xb0\x07\xff\xff\xfe \x00m\xff\xa2\x00,\x00\x00\x00\xd7\x01\x00\x00\x00\x00\x00\x00\xf4\xff\xff\xff\x05\x00\x0c\x00-\xfb\x00\xaf\xff\x00\xbf`\x00\xaf`\x00\x7f\x80\x00?\xb0\x00\x1f\xe0\x00\x0e\xf1\x00\xff\xff\xb0\xbf\xff\xf0\x05\xfa\x00\x02\xbc\x002\x00\x00\x00\xd7\x01\x00\x00\x00\x00\x00\x00\xf6\xff\xff\xff\x08\x00\n\x00-\xfb/\xf0\xbf\xff\xef\xf0\xff \x8f\xf3\xdfr\t\xf6_\xff\xff\xf9\x03\xae\xff\xfb\x00\x00\x00\xff\x0c\xf5\x05\xff\x03\xff\xff\xfb\x00L\xff\xa2<\x00\x00\x00\xd7\x01\x00\x00\xff\xff\xff\xff\xf6\xff\xff\xff\t\x00\n\x00/\xc0\x00\x00\x00\x0f\xf0\x00\x00\x00\x0b\xf3\x00\x00\x00\t\xf5\x00\x00\x00\x06\xf8\x00\x00\x00\x03\xfd\x00\x00\x00\x00\xff`\x00\x00\x00\xee\xf7\x10\x00\x00\xbf_\xfa\x00\x00\x7f#\xaf\x00C\x00\x00\x00\xd7\x01\x00\x00\x00\x00\x00\x00\xf4\xff\xff\xff\x05\x00\x0c\x00-\xfb\x00\xaf\xff\x00\xbf`\x00\xaf`\x00\x7f\x80\x00?\xb0\x00\x1f\xe0\x00\x0e\xf1\x00\xff\xff\xb0\xbf\xff\xf0\x05\xfa\x00\x02\xbc\x00'
+ clip_rects = cp.proto.channels[1].server[304].message_type.members[0].member_type.members[2].member_type.members[1].cases[0].member
+ assert(clip_rects.name == 'rects')
+ test_complete_parse(six_rects, clip_rects)
+ draw_text_message = cp.proto.channels[1].server[311]
+ assert(draw_text_message.name == 'draw_text')
+ test_parse_complex_member(draw_text, draw_text_message, draw_text_message.message_type)
+ print '\n'.join(map(str,result[1][1]))
+
+if __name__ == '__main__':
+ tests()
+
diff --git a/compress.py b/compress.py
new file mode 100644
index 0000000..736d967
--- /dev/null
+++ b/compress.py
@@ -0,0 +1,36 @@
+"""
+SPICE compression and decompression implementation.
+
+we have our own lz implementation, see spice/common/lz_common.h
+
+"""
+import struct
+from util import reverse_dict
+
+LZ_MAGIC = struct.unpack('<I', "LZ ")[0]
+LZ_MAJOR = 1
+LZ_MINOR = 1
+
+LzImageType = (LZ_IMAGE_TYPE_INVALID,
+ LZ_IMAGE_TYPE_PLT1_LE,
+ LZ_IMAGE_TYPE_PLT1_BE, # PLT stands for palette
+ LZ_IMAGE_TYPE_PLT4_LE,
+ LZ_IMAGE_TYPE_PLT4_BE,
+ LZ_IMAGE_TYPE_PLT8,
+ LZ_IMAGE_TYPE_RGB16,
+ LZ_IMAGE_TYPE_RGB24,
+ LZ_IMAGE_TYPE_RGB32,
+ LZ_IMAGE_TYPE_RGBA,
+ LZ_IMAGE_TYPE_XXXA) = xrange(11)
+LzImageType = dict([(k, locals()[k]) for k in locals().keys() if k.startswith('LZ_IMAGE_TYPE')])
+LzImageTypeRev = reverse_dict(LzImageType)
+
+def lz_decompress(s):
+ (magic, major, minor,
+ type, width, height, stride, out_top_down
+ ) = struct.unpack_from('>IHHIIIII', s)
+ assert(magic == LZ_MAGIC)
+ assert(major == LZ_MAJOR)
+ assert(minor == LZ_MINOR)
+ del s
+ return locals()
diff --git a/dumpspice.py b/dumpspice.py
new file mode 100755
index 0000000..4c69db4
--- /dev/null
+++ b/dumpspice.py
@@ -0,0 +1,22 @@
+#!/usr/bin/env python
+import sys
+import pcaputil
+import pcapspice
+
+def main(stdscr=None):
+ p = pcaputil.packet_iter('lo')
+ spice = pcapspice.spice_iter(p)
+ if stdscr:
+ stdscr.erase()
+ for d in spice:
+ if verbose:
+ print d
+
+if __name__ == '__main__':
+ verbose = '-v' in sys.argv
+ if '-c' in sys.argv:
+ import curses
+ curses.wrapper(main)
+ else:
+ main(None)
+
diff --git a/getem.py b/getem.py
new file mode 100644
index 0000000..b41bf94
--- /dev/null
+++ b/getem.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python2
+
+import pcap
+import sys
+import string
+import time
+import socket
+import struct
+
+protocols={socket.IPPROTO_TCP:'tcp',
+ socket.IPPROTO_UDP:'udp',
+ socket.IPPROTO_ICMP:'icmp'}
+
+def decode_ip_packet(s):
+ d={}
+ d['version']=(ord(s[0]) & 0xf0) >> 4
+ d['header_len']=ord(s[0]) & 0x0f
+ d['tos']=ord(s[1])
+ d['total_len']=socket.ntohs(struct.unpack('H',s[2:4])[0])
+ d['id']=socket.ntohs(struct.unpack('H',s[4:6])[0])
+ d['flags']=(ord(s[6]) & 0xe0) >> 5
+ d['fragment_offset']=socket.ntohs(struct.unpack('H',s[6:8])[0] & 0x1f)
+ d['ttl']=ord(s[8])
+ d['protocol']=ord(s[9])
+ d['checksum']=socket.ntohs(struct.unpack('H',s[10:12])[0])
+ d['source_address']=pcap.ntoa(struct.unpack('i',s[12:16])[0])
+ d['destination_address']=pcap.ntoa(struct.unpack('i',s[16:20])[0])
+ if d['header_len']>5:
+ d['options']=s[20:4*(d['header_len']-5)]
+ else:
+ d['options']=None
+ d['data']=s[4*d['header_len']:]
+ return d
+
+
+def dumphex(s):
+ bytes = map(lambda x: '%.2x' % x, map(ord, s))
+ for i in xrange(0,len(bytes)/16):
+ print ' %s' % string.join(bytes[i*16:(i+1)*16],' ')
+ print ' %s' % string.join(bytes[(i+1)*16:],' ')
+
+
+def print_packet(pktlen, data, timestamp):
+ if not data:
+ return
+
+ if data[12:14]=='\x08\x00':
+ decoded=decode_ip_packet(data[14:])
+ print '\n%s.%f %s > %s' % (time.strftime('%H:%M',
+ time.localtime(timestamp)),
+ timestamp % 60,
+ decoded['source_address'],
+ decoded['destination_address'])
+ for key in ['version', 'header_len', 'tos', 'total_len', 'id',
+ 'flags', 'fragment_offset', 'ttl']:
+ print ' %s: %d' % (key, decoded[key])
+ print ' protocol: %s' % protocols[decoded['protocol']]
+ print ' header checksum: %d' % decoded['checksum']
+ print ' data:'
+ dumphex(decoded['data'])
+
+
+if __name__=='__main__':
+
+ if len(sys.argv) < 3:
+ print 'usage: sniff.py <interface> <expr>'
+ sys.exit(0)
+ #dev = pcap.lookupdev()
+ dev = sys.argv[1]
+ #net, mask = pcap.lookupnet(dev)
+ # note: to_ms does nothing on linux
+ p = pcap.pcap(dev)
+ #p.open_live(dev, 1600, 0, 100)
+ #p.dump_open('dumpfile')
+ #p.setfilter(string.join(sys.argv[2:],' '), 0, 0)
+
+ # try-except block to catch keyboard interrupt. Failure to shut
+ # down cleanly can result in the interface not being taken out of promisc.
+ # mode
+ #p.setnonblock(1)
+ try:
+ while 1:
+ p.dispatch(1, print_packet)
+
+ # specify 'None' to dump to dumpfile, assuming you have called
+ # the dump_open method
+ # p.dispatch(0, None)
+
+ # the loop method is another way of doing things
+ # p.loop(1, print_packet)
+
+ # as is the next() method
+ # p.next() returns a (pktlen, data, timestamp) tuple
+ # apply(print_packet,p.next())
+ except KeyboardInterrupt:
+ print '%s' % sys.exc_type
+ print 'shutting down'
+ print '%d packets received, %d packets dropped, %d packets dropped by interface' % p.stats()
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..a8c54c7
--- /dev/null
+++ b/main.py
@@ -0,0 +1,14 @@
+import sys
+import client
+import optparse
+
+o = optparse.OptionParser()
+o.conflict_handler = 'resolve'
+o.add_option('-p')
+o.add_option('-h')
+d, rest = o.parse_args(sys.argv[1:])
+client = client(host=d.h, port=d.p)
+
+if __name __ == '__main__':
+ client.run()
+
diff --git a/pcapspice.py b/pcapspice.py
new file mode 100644
index 0000000..b1c1e96
--- /dev/null
+++ b/pcapspice.py
@@ -0,0 +1,230 @@
+from itertools import chain, repeat
+from pcaputil import (tcp_parse, is_tcp, is_tcp_data, is_tcp_syn,
+ conversations_iter, header_conversation_iter)
+import client
+from client import SpiceLinkHeader, SpiceLinkMess, SpiceLinkReply, SpiceDataHeader
+
+def is_single_packet_data(payload):
+ return len(payload) == SpiceDataHeader(payload).e.size + SpiceDataHeader.size
+
+num_spice_messages = sum(len(ch.client) + len(ch.server) for ch in
+ client.client_proto.channels.values())
+
+valid_message_ids = set(sum([ch.client.keys() + ch.server.keys() for ch in
+ client.client_proto.channels.values()], []))
+
+all_channels = client.client_proto.channels.keys()
+
+def possible_channels(server_message, header):
+ return set(c for c in all_channels if header.e.type in
+ (client.client_proto.channels[c].server.keys() if server_message
+ else client.client_proto.channels[c].client.keys()))
+
+guesses = {}
+
+def guess_channel_iter():
+ channel = None
+ optional_channels = set(all_channels)
+ seen_headers = []
+ while True:
+ src, dst, data_header = yield channel
+ key = (min(src, dst), max(src, dst))
+ if guesses.has_key(key):
+ optional_channels = set([guesses[key]])
+ seen_headers.append((src, dst, data_header))
+ optional_channels = optional_channels & possible_channels(
+ src == min(src, dst), data_header)
+ print optional_channels
+ if len(optional_channels) == 1:
+ channel = list(optional_channels)[0]
+ guesses[key] = channel
+ if len(optional_channels) == 0:
+ import pdb; pdb.set_trace()
+
+spice_size_from_header = lambda header: header.e.size
+
+def channel_spice_start_iter(lower_port, higher_port):
+ server_port, client_port = lower_port, higher_port
+ ports = [server_port, client_port]
+ messages = dict([(port,None) for port in ports])
+ payloads = dict([(port,[]) for port in ports])
+ headers = dict([(port,None) for port in ports])
+ discarded_len = {client_port:128, server_port:4}
+ message_ctors = {client_port:SpiceLinkMess, server_port:SpiceLinkReply}
+ if len(payload) > 0:
+ payloads[src].append(payload)
+ while not all(messages.values()):
+ src, dst, payload = yield None
+ if len(payload) == 0: continue
+ payloads[src].append(payload)
+ len_payload = sum(map(len,payloads[src]))
+ if headers[src] is None and len_payload >= SpiceLinkHeader.size:
+ headers[src] = SpiceLinkHeader(''.join(payloads[src]))
+ import pdb; pdb.set_trace()
+ expected_len = headers[src].e.size + SpiceLinkHeader.size + discarded_len[src]
+ if headers[src] is not None and expected_len <= len_payload:
+ if expected_len != len_payload:
+ print "extra bytes dumped!!! %d" % (len_payload - expected_len)
+ messages[src] = message_ctors[src](''.join(payloads[src])[
+ SpiceLinkHeader.size:])
+ channel = messages[client_port].e.channel_type
+ print "channel type %s created" % channel
+ result = None
+ src, dst, payload = yield None
+ data_iter = channel_spice_iter(src, dst, payload, channel=channel)
+ result = data_iter.next()
+ while True:
+ src, dst, payload = yield result
+ result = data_iter.send((src, dst, payload))
+
+def channel_spice_iter(lower_port, higher_port, channel=None):
+ server_port, client_port = lower_port, higher_port
+ serial = {}
+ guesser = guess_channel_iter()
+ guesser.next()
+
+ result = None
+ # yield spice data parsed results
+ while True:
+ src, dst, payload = yield result
+ data_header = SpiceDataHeader(payload)
+ # basic sanity check on the header
+ bad_header = False
+ if serial.has_key(src):
+ if data_header.e.serial != serial[src] + 1:
+ print "bad serial, replacing for %s->%s (%s!=%s+1)" % (
+ src, dst, data_header.e.serial, serial[src])
+ bad_header = True
+ serial[src] = data_header.e.serial
+ if data_header.e.type not in valid_message_ids:
+ print "bad message type %s in %s->%s" % (data_header.e.type,
+ src, dst)
+ if bad_header: continue
+ if channel is None:
+ channel = guesser.send((src, dst, data_header))
+ if channel is not None:
+ # read as many messages as it takes to get
+ (msg_proto, (result_name, result_value),
+ left_over) = client.client_proto.parse(
+ channel_type=channel,
+ is_client=src == client_port,
+ header=data_header,
+ data=payload[SpiceDataHeader.size:])
+ result = (data_header, msg_proto, result_name, result_value,
+ left_over)
+
+# for every tcp packet
+# if it is a syn, read the link header, decide which channel we are
+# if it is data
+# if we know which channel we are, parse it
+# else guess which channel.
+# if we just figured out which channel we are, releaes all pending packets.
+
+def spice_iter1(packet_iter):
+ return conversations_iter(
+ packet_iter,
+ channel_spice_start_iter,
+ channel_spice_iter)
+
+def packet_header_iter_gen(hdr_ctor, pkt_ctor):
+ return repeat(
+ (hdr_ctor.size,
+ hdr_ctor,
+ lambda h: h.e.size,
+ pkt_ctor))
+
+def ident(x, **kw):
+ return x
+
+def channel_spice_message_iter(src, dst, guesser):
+ server_port, client_port = min(src, dst), max(src, dst)
+ serial = {}
+ channel = None
+ result = None
+ # yield spice data parsed results
+ while True:
+ src, dst, (data_header, message) = yield result
+ # basic sanity check on the header
+ bad_header = False
+ if serial.has_key(src):
+ if data_header.e.serial != serial[src] + 1:
+ print "bad serial, replacing for %s->%s (%s!=%s+1)" % (
+ src, dst, data_header.e.serial, serial[src])
+ bad_header = True
+ serial[src] = data_header.e.serial
+ if data_header.e.type not in valid_message_ids:
+ print "bad message type %s in %s->%s" % (data_header.e.type,
+ src, dst)
+ bad_header = True
+ if bad_header: continue
+ if channel is None:
+ channel = guesser.send((src, dst, data_header))
+ if channel is not None:
+ # read as many messages as it takes to get
+ (msg_proto, (result_name, result_value)
+ ) = client.client_proto.parse(
+ channel_type=channel,
+ is_client=src == client_port,
+ header=data_header,
+ data=message)
+ result = (msg_proto, result_name, result_value)
+
+class ChannelGuesser(object):
+
+ def __init__(self):
+ self.iter = guess_channel_iter()
+ self.iter.next()
+ self.channel = None
+
+ def send(self, (src, dst, payload)):
+ if self.channel: return self.channel
+ return self.iter.send((src, dst, payload))
+
+ def on_client_link_message(self, p, **kw):
+ self._client_message = SpiceLinkMess(p)
+ self.channel = self._client_message.e.channel_type
+ print "STARTED channel %s" % self.channel
+
+ def on_server_link_message(self, p, **kw):
+ self._server_message = SpiceLinkReply(p)
+
+def spice_iter(packet_iter):
+ guessers = {}
+ def make_start_iter(src, dst):
+ is_server = src < dst
+ key = (min(src, dst), max(src, dst))
+ guesser = guessers[key] = guessers.get(key, ChannelGuesser())
+ print guesser, is_server, "**** make_start_iter %s -> %s" % (src, dst)
+ link_messages = []
+ yield (SpiceLinkHeader.size,
+ SpiceLinkHeader.verify,
+ lambda h: h.e.size,
+ (guesser.on_server_link_message if is_server else
+ guesser.on_client_link_message))
+ print guesser, is_server, "2"
+ message_iter = channel_spice_message_iter(src, dst, guesser)
+ message_iter.next()
+ # TODO: way to say "src->dst 128, dst->src 4"
+ yield 128 if src > dst else 4, ident, lambda h: 0, ident
+ for i, data in enumerate(packet_header_iter_gen(
+ SpiceDataHeader,
+ lambda pkt, src, dst, header:
+ message_iter.send((src, dst, (header, pkt))))):
+ print guesser, is_server, "2+%s" % i
+ yield data
+ def make_iter(src, dst):
+ key = (min(src, dst), max(src, dst))
+ guesser = guessers[key] = guessers.get(key, ChannelGuesser())
+ message_iter = channel_spice_message_iter(src, dst, guesser)
+ message_iter.next()
+ return packet_header_iter_gen(
+ SpiceDataHeader,
+ lambda pkt, src, dst, header:
+ message_iter.send((src, dst, (header, pkt))))
+ return header_conversation_iter(
+ packet_iter,
+ server_start=make_start_iter,
+ client_start=make_start_iter,
+ server=make_iter,
+ client=make_iter)
+
diff --git a/pcaputil.py b/pcaputil.py
new file mode 100644
index 0000000..007c6c2
--- /dev/null
+++ b/pcaputil.py
@@ -0,0 +1,171 @@
+import os
+import struct
+from itertools import imap, ifilter
+#import pcap
+
+TCP_PROTOCOL = 6
+TCP_SYN = 2
+TCP_FIN = 1
+
+def is_tcp(pkt):
+ return (len(pkt) > 47
+ and pkt[12:14] == '\x08\x00' # Ethernet.Type == IP
+ and ord(pkt[23]) == TCP_PROTOCOL)
+
+def is_tcp_data(pkt):
+ return (is_tcp(pkt)
+ and (not ord(pkt[47]) & (TCP_SYN | TCP_FIN)))
+
+def is_tcp_syn(pkt):
+ return (is_tcp(pkt)
+ and (ord(pkt[47]) & TCP_SYN))
+
+def tcp_parse(pkt):
+ tcp_offset = 14 + (ord(pkt[14]) & 0xf) * 4
+ src, dst, seq, ack, data_offset16, flags, window = struct.unpack(
+ '>HHIIBBH', pkt[tcp_offset:tcp_offset+16])
+ data_offset = tcp_offset + (data_offset16 >> 2)
+ if data_offset != 66:
+ print "WHOOHOO"
+ return src, dst, flags, pkt[data_offset:]
+
+def mypcap(filename):
+ with open(filename, 'r') as fd:
+ file_size = os.stat(filename).st_size
+ i = 0
+ i = 24
+ hdr_len = 8 + 4 + 4
+ i_pkt = 0
+ preamble = fd.read(24)
+ while i + hdr_len < file_size:
+ ts, l1, l2 = struct.unpack('QII', fd.read(hdr_len))
+ if l1 < l2:
+ print "short recorded packet: #%s,%s: %s < %s" % (i_pkt, i, l1, l2)
+ pkt = fd.read(l1)
+ if is_tcp_data(pkt):
+ src, dst, tcp_payload = tcp_parse(pkt)
+ yield ts, src, dst, tcp_payload
+ i_pkt += 1
+ i += hdr_len + l1
+
+def get_conversations(file):
+ convs = {}
+ port_counts = {}
+ for ts, src, dst, data in mypcap(file):
+ key = (src,dst)
+ if key not in convs:
+ convs[key] = [ts]
+ convs[key].append(data)
+ ports = [src for src,dst in convs.keys()]
+ server_port = sorted(ports)[0]
+ times = {}
+ for src, dst in convs.keys():
+ if src != server_port:
+ src, dst = dst, src
+ server = convs[(src, dst)]
+ client = convs[(dst, src)]
+ assert(client[0] < server[0]) # client sends first msg
+ times[client[0]] = (client, server)
+ return [map(lambda x: ''.join(x[1:]), times[t]) for t in sorted(times.keys())]
+
+def conversations_iter(packet_iter, **keyed_iters):
+ """ conversation_start_iter is used when we catch the SYN packet,
+ otherwise we use conversation_iter. the later is expected to be
+ used by the former.
+ """
+ conversations = {}
+ keys = {(False, True): 'server', (False, False): 'client',
+ (True, True): 'server_start', (True, False): 'client_start'}
+ for src, dst, flags, payload in tcp_iter(packet_iter):
+ print "conversation_iter %s->%s #(%s)" % (src, dst, len(payload))
+ key = (src, dst)
+ if key not in conversations:
+ iter_gen = keyed_iters[keys[(flags & TCP_SYN > 0, src < dst)]]
+ conversations[key] = iter_gen(*key)
+ conversations[key].next()
+ maybe_result = conversations[key].send((src, dst, payload))
+ if maybe_result is not None:
+ yield maybe_result
+
+def consume_packets(packets, needed_len):
+ pkt = None
+ total_len = sum(map(len, packets))
+ if total_len >= needed_len:
+ pkt = ''.join(packets)
+ if total_len > needed_len:
+ del packets[:-1]
+ packets[0] = packets[0][-(total_len - needed_len):]
+ pkt = pkt[:needed_len]
+ assert(len(packets[0]) + len(pkt) == total_len)
+ else:
+ del packets[:]
+ return pkt
+
+def ident(x, **kw):
+ return x
+
+def collect_packets(header_iter_gen):
+ """ Example:
+ >>> list(pcaputil.send_multiple(pcaputil.collect_packets(lambda src, dst: iter([(10, ident, lambda h: 20, ident)]*2))(17, 42), [(42, 17, ' '*30)]))
+ [(' ', ' ')]
+ """
+ def collector(src, dst):
+ packets = []
+ history = []
+ header_iter = header_iter_gen(src, dst)
+ hdr_size, hdr_ctor, size_from_hdr, pkt_ctor = header_iter.next()
+ msg = None
+ header = pkt = None
+ searched_size = hdr_size
+ while True:
+ src, dst, payload = yield (src, dst, msg)
+ packets.append(payload)
+ history.append(payload)
+ print "collect_packets: searching for %s (%s)" % (searched_size, sum(map(len, packets)))
+ data = consume_packets(packets, searched_size)
+ while data != None:
+ print "collect_packets: %s" % len(data)
+ if header is None:
+ header = hdr_ctor(data)
+ print "collect_packets: header of length %s, expect %s" % (len(data), size_from_hdr(header))
+ searched_size = size_from_hdr(header)
+ if searched_size > 1000000:
+ import pdb; pdb.set_trace()
+ else:
+ print "collect_packets: packet of length %s" % len(data)
+ pkt = pkt_ctor(data, src=src, dst=dst, header=header)
+ msg = (header, pkt)
+ hdr_size, hdr_ctor, size_from_hdr, pkt_ctor = header_iter.next()
+ header = None
+ searched_size = hdr_size
+ data = consume_packets(packets, searched_size)
+ return collector
+
+def send_multiple(gen, sent_iter):
+ sent_iter = iter(sent_iter)
+ gen.next()
+ try:
+ while True:
+ next = sent_iter.next()
+ yield gen.send(next)
+ except StopIteration:
+ return
+
+def header_conversation_iter(packet_iter, **kw):
+ return conversations_iter(packet_iter,
+ **dict([(k, collect_packets(v)) for k,v in kw.items()]))
+
+def tcp_data_iter(packet_iter):
+ return ifilter(lambda src, dst, flags, data:
+ len(data) > 0, imap(tcp_parse, ifilter(is_tcp_data, imap(lambda x: ''.join(x[1]), packet_iter))))
+
+def tcp_iter(packet_iter):
+ return ifilter(lambda (src, dst, flags, data):
+ flags & (TCP_SYN | TCP_FIN) or len(data) > 0,
+ imap(tcp_parse, ifilter(is_tcp, imap(lambda x: ''.join(x[1]), packet_iter))))
+
+def packet_iter(dev):
+ import pcap
+ return pcap.pcap(dev)
+
+
diff --git a/stracer.py b/stracer.py
new file mode 100644
index 0000000..60b735f
--- /dev/null
+++ b/stracer.py
@@ -0,0 +1,163 @@
+import subprocess
+import re
+from client import unspice_channels
+
+def parse_connect(line):
+ """'{sa_family=AF_INET, sin_port=htons(5926), sin_addr=inet_addr("127.0.0.1")}, 16) = 0\n'"""
+ m = re.match('\s+\{sa_family=(?P<family>[A-Z_]+),\s+sin_port=htons\((?P<port>[0-9]+)\),\s+sin_addr=inet_addr\("(?P<host>[0-9.]+)"\)\},\s+([0-9]+)\)\s+=\s+(?P<ret>[0-9])+\n', line)
+ if not m: return m
+ return m.groupdict()
+
+def from_quoted(s):
+ return ''.join(chr(int(s[i+2:i+4], 16)) for i in
+ xrange(0, len(s), 4))
+
+def match_message_container(line, rexp):
+ m = re.match(rexp, line)
+ if not m: return m
+ d = m.groupdict()
+ d['msg'] = from_quoted(d['msg'])
+ d['size'] = int(d['size'])
+ if 'ret' in d:
+ d['ret'] = int(d['ret'])
+ assert(len(d['msg']) == d['ret'])
+ d['type'] = 'sendto'
+ return d
+
+def parse_sendto(line):
+ """"\x14\x00\x00\x00\x16\x00\x01\x03\x06\x83\x59\x4c\x00\x00\x00\x00\x00\x00\x00\x00", 20, 0, {sa_family=AF_NETLINK, pid=0, groups=00000000}, 12) = 20"""
+ if 'unfinished' in line:
+ return match_message_container(line, '\s+"(?P<msg>[x0-9a-f\\\\]+)"\.?\.?\.?,\s+(?P<size>[0-9]+),\s+([0-9]+),\s+.+,\s+([0-9]+) <unfinished ...>\n')
+ return match_message_container(line, '\s+"(?P<msg>[x0-9a-f\\\\]+)"\.?\.?\.?,\s+(?P<size>[0-9]+),\s+([0-9]+),\s+.+,\s+([0-9]+)\)\s+=\s+(?P<ret>[0-9]+)\n')
+
+def parse_recvfrom(line):
+ """"\x00\x00\x00\x00\xd2\x01\x00\x00\x1d\x01\x00\x00\x80\x02\x00\x00\x30\x02\x00\x00\x00\xba\x02\x00\x00\xb6\x01\x00\x00", 29, 0, NULL, NULL) = 29"""
+ return match_message_container(line, '\s+"(?P<msg>[x0-9a-f\\\\]+)"\.?\.?\.?,\s+(?P<size>[0-9]+),\s+([0-9]+),.+,.+\)\s+=\s+(?P<ret>[0-9]+)\n')
+
+def parse_recvmsg(line):
+ print "TODO - parse_recvmsg"
+ return
+
+def parse_sendmsg(line):
+ print "TODO - parse_sendmsg"
+ return
+
+def parse_start(line):
+ # really lame parsing, won't deal with nested first arguments - but
+ # works for all the system calls we are interested in. Actually just
+ # adding nesting level to each char would solve this.
+ m = re.search('\[pid\s+([0-9]+)\]\s+([a-zA-Z0-9]+)\(([^,]+),', line)
+ return m
+
+class StracerNetwork(object):
+
+ _unhandled_syscalls = set(['socket', 'bind', 'getsockname', 'setsockopt',
+ # later
+ 'getpeername', 'getsockopt', 'shutdown'])
+ def __init__(self, cmd, on_message = lambda x: None):
+ self.max_bytes_per_message = max_bytes = 2**20
+ self._p = subprocess.Popen(
+ ('strace -qvxx -e trace=network -s %s -Cf %s' % (max_bytes, cmd)).split(),
+ stderr=subprocess.PIPE)
+ self.connections = []
+ self.sockets = {}
+ self._message_count = 0
+ self._lines = []
+ self.on_message = on_message
+
+ def total_length(self, n):
+ if n not in self.sockets:
+ return 0
+ return sum([len(d['msg']) for m, d in self.sockets[n] if d is not None and 'msg' in d])
+
+ def terse(self, count=10):
+ while sum([len(v) for v in self.sockets.values()]) < count:
+ print sum(map(len, self.sockets.items()))
+ self.handle_line()
+ sorted_keys = list(sorted(self.sockets.keys()))
+ return sorted_keys, [self.terse_key(k) for k in sorted_keys]
+
+ def terse_key(self, key):
+ return [(dict(recvfrom='R',sendto='S').get(tag), d['msg'])
+ for ((tag, line), d) in self.sockets[key] if d is not None]
+
+
+ def conversations(self, count=10):
+ keys, terses = self.terse(count=count)
+ return keys, [[''.join([msg for t,msg in terse if t==k]) for k in 'SR']
+ for terse in terses]
+
+ def spiced(self, count=10):
+ keys, collecteds = self.conversations(count)
+ return unspice_channels(keys, collecteds)
+
+ def lines(self):
+ for line in self._p.stderr:
+ print line if len(line) < 100 else line[:100] + '...'
+ self._lines.append(line)
+ yield line
+
+ def wait_for_connect(self, host):
+ while not self.host_connected(host):
+ self.handle_line()
+
+ def host_connected(self, host):
+ return any([d['host'] == host for d in self.connections])
+
+ def add_to_socket(self, sock, tag, line, parser):
+ if sock not in self.sockets:
+ print "did I miss a connect?? at %s" % len(self._lines)
+ self.add_socket(sock)
+ parsed = parser(line)
+ self.sockets[sock].append(((tag, line), parsed))
+ self._message_count += 1
+ if parsed:
+ self.on_message(parsed['msg'])
+
+ def add_socket(self, sock, d={}):
+ d['sock'] = sock
+ if 'host' not in d:
+ d['host'] = 'unknown'
+ self.connections.append(d)
+ self.sockets[sock] = []
+
+ def handle_line(self):
+ line = self.lines().next()
+ m = parse_start(line)
+ if not m: return
+ line = line[m.end():]
+ pid, syscall, firstarg = m.groups()
+ if 'connect' in syscall:
+ d = parse_connect(line)
+ if not d: return
+ self.add_socket(int(firstarg), d)
+ elif 'sendto' in syscall:
+ self.add_to_socket(int(firstarg), 'sendto', line, parse_sendto)
+ elif 'recvmsg' in syscall:
+ self.add_to_socket(int(firstarg), 'recvmsg', line, parse_recvmsg)
+ elif 'recvfrom' in syscall:
+ self.add_to_socket(int(firstarg), 'recvfrom', line, parse_recvfrom)
+ elif 'sendmsg' in syscall:
+ self.add_to_socket(int(firstarg), 'sendmsg', line, parse_sendmsg)
+ elif syscall in self._unhandled_syscalls:
+ pass
+ else:
+ print "unhandled %s" % syscall
+ import pdb; pdb.set_trace()
+
+ def wait_for(self, num):
+ ret = []
+ start_count = self._message_count
+ while True:
+ self.handle_line()
+ if self._message_count - start_count >= num:
+ return
+
+def trace_spicec(host, port):
+ s = StracerNetwork('/store/upstream/bin/spicec -h %s -p %s' % (host, port))
+ sock = s.wait_for_connect(host)
+ return s.filter(sock, 10)
+
+if __name__ == '__main__':
+ trace_spicec('127.0.0.1', 5926)
+
diff --git a/util.py b/util.py
new file mode 100644
index 0000000..2170eaa
--- /dev/null
+++ b/util.py
@@ -0,0 +1,3 @@
+def reverse_dict(d):
+ return dict([(v,k) for k,v in d.items()])
+