diff options
Diffstat (limited to 'client_proto.py')
-rw-r--r-- | client_proto.py | 320 |
1 files changed, 320 insertions, 0 deletions
diff --git a/client_proto.py b/client_proto.py new file mode 100644 index 0000000..c29c182 --- /dev/null +++ b/client_proto.py @@ -0,0 +1,320 @@ +""" +Decode spice protocol using spice demarshaller. +""" + +ANNOTATE = False + +import struct +import sys +if not 'proto' in locals(): + sys.path.append('../spice/python_modules') + import spice_parser + import ptypes + proto = spice_parser.parse('../spice/spice.proto') + reloads = 1 +else: + reloads += 1 + + +channels = {} +for channel in proto.channels: + channels[channel.value] = channel + channel.client = dict(zip([x.value for x in + channel.channel_type.client_messages], + channel.channel_type.client_messages)) + channel.server = dict(zip([x.value for x in + channel.channel_type.server_messages], + channel.channel_type.server_messages)) + # todo parsing of messages / building of messages + +def mapdict(f, d): + return dict((k, f(v)) for k,v in d.items()) + +primitives=mapdict(struct.Struct, dict( + uint64='<Q', + int64='<q', + uint32='<I', + int32='<i', + uint16='<H', + int16='<h', + uint8='<B', + int8='<b', + )) + +primitive_arrays=dict([(k, lambda n, e=e: struct.Struct('<'+e*n)) + for k, e in + [('uint64','Q'), + ('int64','q'), + ('uint32','I'), + ('int32','i'), + ('uint16','H'), + ('int16','h'), + ]]) + +class NullUnpacker(object): + def __init__(self, n): + self.size = n + def unpack(self, s): + return s + unpack_from=unpack + +class Flag(object): + def __init__(self, names, value): + self.name = None + for k, v in names.items(): + if value == 1<<k: + self.name = v + break + if self.name is None: + self.name = set(v for k,v in names.items() if (1<<k)&value) + self.value = value + def __str__(self): + return 'F(%s,%s)' % (self.name, self.value) + __repr__ = __str__ + +class Enum(object): + def __init__(self, the_type, name, value): + self.name = name + self.value = value + #self.the_type = the_type + # TODO - cache this + #self.name_to_value = dict(zip(self.the_type.values(), self.the_type.keys())) + def __str__(self): + return 'E(%s,%s)' % (self.name, self.value) + __repr__ = __str__ + + +primitive_arrays['int8'] = primitive_arrays['uint8'] = lambda n: NullUnpacker(n) + +def ensure_dict(maybe_d): + if isinstance(maybe_d, dict): + return maybe_d + return dict(maybe_d) + +def get_sub_member(member_d, var_name): + if var_name in member_d: + return member_d[var_name] + ret = member_d + for part in var_name.split('.'): + ret = ensure_dict(ret)[part] + return ret + +def annotate_data(f): + def wrapped(member, data, i_s, parsed): + i = len(data.an) + i_s_start = i_s + data.an.append((i_s, member)) # annotation (before value is known) + i_s, result_name, result_value = f(member, data, i_s, parsed) + data.an[i] = (((i_s_start, i_s), member, result_value)) + return i_s, result_name, result_value + wrapped.func_name = '%s@annotate_data' % f.func_name + return wrapped + +def simple_annotate_data(f): + def wrapped(member, the_type, data, i_s): + data.an.append((i_s, member)) + return f(member, the_type, data, i_s) + wrapped.func_name = '%s@simple_annotate_data' % f.func_name + return wrapped + +def member_size(m): + pass + +def makelen(n, c, s): + """pad s with c up to length n""" + return s + c * (n - len(s)) + +class AnnotatedString(str): + def __init__(self, x): + super(AnnotatedString, self).__init__(x) + self.an = [] + self.max_pointer_i_s = 0 + def san(self): + def nicify(x): + if len(x) == 2 and not isinstance(x[0], tuple): + start = '='*8 if x[0] == 0 else ' '+'='*8 + return makelen(80, '=', start + str(x)) + return str(x) + print '\n'.join(map(nicify, self.an)) + #print '\n'.join('%s-%s,%s,%s,%s' % (i_s_start, i_s, member.member_type. for (i_s_start, i_s),member,val in self.an])) + +print_annotation = lambda data: data.san() + +if not ANNOTATE: + #AnnotatedString = lambda x: x + print_annotation = lambda x: None + #annotate_data = lambda x: x + #simple_annotate_data = lambda x: x + +@annotate_data +def parse_member(member, data, i_s, parsed): + result_name = member.name + result_value = None + if len(set(member.attributes.keys()) - set( + ['end', 'ctype', 'as_ptr', 'nomarshal', 'anon', 'chunk'])) > 0: + print "has attributes %s" % member.attributes + #import pdb; pdb.set_trace() + if hasattr(member, 'is_switch') and member.is_switch(): + var_name = member.variable + parsed_d = dict(parsed) + var = get_sub_member(parsed_d, var_name) + for case in member.cases: + for value in case.values: + if case.values[0] is None: + # default + return parse_member(case.member, data, i_s, parsed) + cond_wrap = (lambda x: not x) if value[0] == '!' else (lambda x: x) + if cond_wrap(var.name == value[1]): + return parse_member(case.member, data, i_s, parsed) + return i_s, result_name, 'empty switch' + member_type = member.member_type if hasattr(member, 'member_type') else member + member_type_name = member_type.name + results = None + if member_type_name is not None: + if member_type_name in primitives: + primitive = primitives[member_type_name] + result_value = primitive.unpack_from(data[i_s:i_s+primitive.size])[0] + i_s_out = i_s + primitive.size + elif member_type.is_struct(): + i_s_out, result_name, result_value = parse_complex_member(member, member_type, data, i_s) + elif member_type.is_enum(): + primitive = primitives[member_type.primitive_type()] + value = primitive.unpack(data[i_s:i_s+primitive.size])[0] + result_value = Enum(member_type.names, member_type.names.get(value, value), value) + # TODO - use enum for nice display + i_s_out = i_s + primitive.size + elif member_type.is_primitive() and hasattr(member_type, 'names'): + # assume flag + if not isinstance(member_type, ptypes.FlagsType): + print "not really a flag.." + import pdb; pdb.set_trace() + primitive = primitives[member_type.primitive_type()] + value = primitive.unpack(data[i_s:i_s+primitive.size])[0] + result_value = Flag(member_type.names, value) + i_s_out = i_s + primitive.size + else: + import pdb; pdb.set_trace() + elif member_type.is_array(): + if member_type.is_remaining_length(): + data_len = len(data) - i_s + elif member_type.is_identifier_length(): + num_elements = dict(parsed)[member_type.size] + data_len = None + elif member_type.is_image_size_length(): + bpp = member_type.size[1] + width_name = member_type.size[2] + rows_name = member_type.size[3] + assert(isinstance(width_name, str)) + assert(isinstance(rows_name, str)) + parsed_d = dict(parsed) + width = get_sub_member(parsed_d, width_name) + rows = get_sub_member(parsed, rows_name) + if bpp == 8: + data_len = rows * width + elif bpp == 1: + data_len = ((width + 7) // 8 ) * rows + else: + data_len =((bpp * width + 7) // 8 ) * rows + else: + print "unhandled array length type" + import pdb; pdb.set_trace() + element_type = member_type.element_type + if element_type.name in primitive_arrays: + element_size = primitives[element_type.name].size + if data_len is None: + data_len = num_elements * element_size + primitive = primitive_arrays[element_type.name] + contents = primitive(data_len).unpack_from(data[i_s:i_s+data_len]) + i_s_out = i_s + data_len + else: + contents = [] + i_s_out = i_s + for _ in xrange(num_elements): + i_s_out, sub_name, sub_value = parse_member(element_type, data, i_s_out, []) + if sub_value is None: + print "list parsing failed" + contents = [] + break + contents.append((sub_name, sub_value)) + result_value = contents + elif member_type.is_pointer(): + pointer_primitive_name = member_type.primitive_type() + if pointer_primitive_name in primitives: + primitive = primitives[pointer_primitive_name] + offset = primitive.unpack_from(data[i_s:i_s+primitive.size])[0] + if offset == 0: + result_value = [] + elif offset > len(data): + print "ooops.. bad offset %s > %s" % (offset, len(data)) + import pdb; pdb.set_trace() + result_value = None + else: + discard_i_s, discard_result_name, result_value = parse_member( + member_type.target_type, data, offset, []) + data.max_pointer_i_s = max(data.max_pointer_i_s, discard_i_s) + i_s_out = i_s + primitive.size + else: + print "pointer with unknown size??" + import pdb; pdb.set_trace() + if result_name in ['data', 'ents', 'glyphs', 'String', 'str']: + result_value = NoPrint(name='', s=result_value) + elif len(str(result_value)) > 1000: + print "noprint?" + import pdb; pdb.set_trace() + return i_s_out, result_name, result_value + +@simple_annotate_data +def parse_complex_member(member, the_type, data, i_s): + parsed = [] + i_s_out = i_s + for sub_member in the_type.members: + i_s_out, result_name, result_value = parse_member(sub_member, data, i_s_out, parsed) + if result_value is not None: + parsed.append((result_name, result_value)) + else: + print "don't know how to parse %s (%s) in %s" % ( + sub_member.name, the_type, member) + break + return i_s_out, member.name, parsed + +def parse(channel_type, is_client, header, data): + if channel_type not in channels: + return NoPrint(name='unknown channel (%s %s)' % ( + is_client, header), s=data) + # annotation support + data = AnnotatedString(data) + channel = channels[channel_type] + collection = channel.client if is_client else channel.server + if header.e.type not in collection: + print "bad data - no such message in protocol" + i_s, result_name, result_value = 0, None, None + else: + msg_proto = collection[header.e.type] + i_s, result_name, result_value = parse_complex_member( + msg_proto, msg_proto.message_type, data, 0) + i_s = max(data.max_pointer_i_s, i_s) + left_over = NoPrint(name='%s:%s' % (channel_type, i_s), s=data[i_s:]) if i_s < len(data) else '' + print_annotation(data) + if len(left_over) > 0: + print "WARNING: in message %s %s out %s unaccounted for (%s%%)" % ( + msg_proto.name, len(left_over), len(data), float(len(left_over))/len(data)) + #import pdb; pdb.set_trace() + return (msg_proto, (result_name, result_value)) + +class NoPrint(object): + objs = {} + def __init__(self, name, s): + self.s = s + self.id = id(self.s) + self.objs[self.id] = self.s + self.name = name + def orig(self): + return self.s + def __str__(self): + return 'NP' + (str(self.s) if len(self.s) < 20 else '%r...[%s,%s] #%s' % ( + self.s[:20], len(self.s), self.id, + self.name)) + __repr__ = __str__ + def __len__(self): + return len(self.s) + |