""" Decode spice protocol using spice demarshaller. """ ANNOTATE = True NO_PRINT_PROTECTION = False import logging import os from collections import namedtuple import struct import sys # we assume we are side by side with spice. hopefully, that is so. spice_dir = os.path.join(os.path.dirname(sys.modules[__name__].__file__), '../spice') logger = logging.getLogger('client_proto') if not 'proto' in locals(): sys.path.append(os.path.join(spice_dir, 'python_modules')) import spice_parser import ptypes proto = None major_version, minor_version = None, None #channels, num_spice_messages, valid_spice_messages #all_channels reloads = 1 else: reloads += 1 def make_channels_dict(proto): channels = {} for channel in proto.channels: channels[channel.value] = channel channel.client = dict(zip([x.value for x in channel.channel_type.client_messages], channel.channel_type.client_messages)) channel.server = dict(zip([x.value for x in channel.channel_type.server_messages], channel.channel_type.server_messages)) # todo parsing of messages / building of messages return channels channels = None def mapdict(f, d): return dict((k, f(v)) for k,v in d.items()) primitives=mapdict(struct.Struct, dict( uint64=' 0: logging.debug("has attributes %s" % member.attributes) #import pdb; pdb.set_trace() if hasattr(member, 'is_switch') and member.is_switch(): var_name = member.variable parsed_d = dict(parsed) var = get_sub_member(parsed_d, var_name) for case in member.cases: for value in case.values: if case.values[0] is None: # default return parse_member(case.member, data, i_s, parsed) cond_wrap = (lambda x: not x) if value[0] == '!' else (lambda x: x) if cond_wrap(var.name == value[1]): return parse_member(case.member, data, i_s, parsed) return i_s, result_name, 'empty switch' member_type = member.member_type if hasattr(member, 'member_type') else member member_type_name = member_type.name results = None if member_type_name is not None: if member_type_name in primitives: primitive = primitives[member_type_name] result_value = primitive.unpack_from(data[i_s:i_s+primitive.size])[0] i_s_out = i_s + primitive.size elif member_type.is_struct(): i_s_out, result_name, result_value = parse_complex_member(member, member_type, data, i_s) elif hasattr(member_type, 'is_enum') and member_type.is_enum(): primitive = primitives[member_type.primitive_type()] value = primitive.unpack(data[i_s:i_s+primitive.size])[0] result_value = Enum(member_type.names, member_type.names.get(value, value), value) # TODO - use enum for nice display i_s_out = i_s + primitive.size elif member_type.is_primitive(): # assume flag primitive = primitives[member_type.primitive_type()] value = primitive.unpack(data[i_s:i_s+primitive.size])[0] if hasattr(member_type, 'names'): if not isinstance(member_type, ptypes.FlagsType): print "not really a flag.." import pdb; pdb.set_trace() result_value = Flag(member_type.names, value) else: result_value = value i_s_out = i_s + primitive.size else: import pdb; pdb.set_trace() elif member_type.is_array(): num_elements = None if member_type.is_remaining_length(): data_len = len(data) - i_s elif member_type.is_identifier_length(): num_elements = dict(parsed)[member_type.size] data_len = None elif member_type.is_image_size_length(): bpp = member_type.size[1] width_name = member_type.size[2] rows_name = member_type.size[3] assert(isinstance(width_name, str)) assert(isinstance(rows_name, str)) parsed_d = dict(parsed) width = get_sub_member(parsed_d, width_name) rows = get_sub_member(parsed, rows_name) if bpp == 8: data_len = rows * width elif bpp == 1: data_len = ((width + 7) // 8 ) * rows else: data_len =((bpp * width + 7) // 8 ) * rows else: print "unhandled array length type" import pdb; pdb.set_trace() element_type = member_type.element_type if element_type.name in primitive_arrays: element_size = primitives[element_type.name].size if data_len is None: data_len = num_elements * element_size if num_elements is None: num_elements = data_len / element_size primitive = primitive_arrays[element_type.name] contents = primitive(num_elements).unpack_from(data[i_s:i_s+data_len]) i_s_out = i_s + data_len else: contents = [] i_s_out = i_s for _ in xrange(num_elements): i_s_out, sub_name, sub_value = parse_member(element_type, data, i_s_out, []) if sub_value is None: logging.error("list parsing failed") contents = [] break contents.append((sub_name, sub_value)) result_value = contents elif member_type.is_pointer(): pointer_primitive_name = member_type.primitive_type() if pointer_primitive_name in primitives: primitive = primitives[pointer_primitive_name] offset = primitive.unpack_from(data[i_s:i_s+primitive.size])[0] if offset == 0: result_value = [] elif offset > len(data): print "ooops.. bad offset %s > %s" % (offset, len(data)) import pdb; pdb.set_trace() result_value = None else: discard_i_s, discard_result_name, result_value = parse_member( member_type.target_type, data, offset, []) data.max_pointer_i_s = max(data.max_pointer_i_s, discard_i_s) i_s_out = i_s + primitive.size else: print "pointer with unknown size??" import pdb; pdb.set_trace() if NO_PRINT_PROTECTION: if result_name in ['data', 'ents', 'glyphs', 'String', 'str', 'rects']: result_value = NoPrint(name='', s=result_value) elif len(str(result_value)) > 1000: print "noprint?" import pdb; pdb.set_trace() return i_s_out, result_name, result_value @simple_annotate_data def parse_complex_member(member, the_type, data, i_s): parsed = [] i_s_out = i_s for sub_member in the_type.members: i_s_out, result_name, result_value = parse_member(sub_member, data, i_s_out, parsed) if result_value is not None: parsed.append((result_name, result_value)) else: logger.error("don't know how to parse %s (%s) in %s" % ( sub_member.name, the_type, member)) break return i_s_out, member.name, parsed ParseResult=namedtuple('ParseResult', ['msg_proto', 'result_name', 'result_value', 'an']) def parse(channel_type, is_client, header, data): if channel_type not in channels: return NoPrint(name='unknown channel (%s %s)' % ( is_client, header), s=data) # annotation support an_data = AnnotatedString(data) channel = channels[channel_type] collection = channel.client if is_client else channel.server if header.e.type not in collection: logger.error("bad data - no such message in protocol") i_s, result_name, result_value = 0, None, None else: msg_proto = collection[header.e.type] i_s, result_name, result_value = parse_complex_member( msg_proto, msg_proto.message_type, an_data, 0) i_s = max(an_data.max_pointer_i_s, i_s) left_over = NoPrint(name='%s:%s' % (channel_type, i_s), s=an_data[i_s:]) if i_s < len(an_data) else '' #result_value.an_data = an_data # let the reference escape, so we can print annotation an_data if len(left_over) > 0: logger.warning("in message %s %s out %s unaccounted for (%2.1d%%)" % ( msg_proto.name, len(left_over), len(an_data), 100.0*len(left_over)/len(an_data))) #import pdb; pdb.set_trace() return ParseResult(msg_proto, result_name, result_value, an=an_data) class NoPrint(object): objs = {} def __init__(self, name, s): self.s = s self.id = id(self.s) self.objs[self.id] = self.s self.name = name def orig(self): return self.s def __str__(self): return 'NP' + (str(self.s) if len(self.s) < 20 else '%r...[%s,%s] #%s' % ( self.s[:20], len(self.s), self.id, self.name)) __repr__ = __str__ def __len__(self): return len(self.s) def set_proto(major_version, minor_version): global proto, channels, num_spice_messages, valid_spice_messages global all_channels, valid_message_ids if globals()['major_version'] == major_version and globals()['minor_version'] == minor_version: return globals().update(dict(major_version=major_version, minor_version=minor_version)) if major_version == 1 : proto = spice_parser.parse(os.path.join(spice_dir, 'spice1.proto')) else: proto = spice_parser.parse(os.path.join(spice_dir, 'spice.proto')) channels = make_channels_dict(proto) num_spice_messages = sum(len(ch.client) + len(ch.server) for ch in channels.values()) valid_message_ids = set(sum([ch.client.keys() + ch.server.keys() for ch in channels.values()], [])) all_channels = channels.keys() def possible_channels(server_message, header): return set(c for c in all_channels if header.e.type in (client_proto.channels[c].server.keys() if server_message else client_proto.channels[c].client.keys()))