summaryrefslogtreecommitdiff
path: root/client_proto.py
diff options
context:
space:
mode:
Diffstat (limited to 'client_proto.py')
-rw-r--r--client_proto.py320
1 files changed, 320 insertions, 0 deletions
diff --git a/client_proto.py b/client_proto.py
new file mode 100644
index 0000000..c29c182
--- /dev/null
+++ b/client_proto.py
@@ -0,0 +1,320 @@
+"""
+Decode spice protocol using spice demarshaller.
+"""
+
+ANNOTATE = False
+
+import struct
+import sys
+if not 'proto' in locals():
+ sys.path.append('../spice/python_modules')
+ import spice_parser
+ import ptypes
+ proto = spice_parser.parse('../spice/spice.proto')
+ reloads = 1
+else:
+ reloads += 1
+
+
+channels = {}
+for channel in proto.channels:
+ channels[channel.value] = channel
+ channel.client = dict(zip([x.value for x in
+ channel.channel_type.client_messages],
+ channel.channel_type.client_messages))
+ channel.server = dict(zip([x.value for x in
+ channel.channel_type.server_messages],
+ channel.channel_type.server_messages))
+ # todo parsing of messages / building of messages
+
+def mapdict(f, d):
+ return dict((k, f(v)) for k,v in d.items())
+
+primitives=mapdict(struct.Struct, dict(
+ uint64='<Q',
+ int64='<q',
+ uint32='<I',
+ int32='<i',
+ uint16='<H',
+ int16='<h',
+ uint8='<B',
+ int8='<b',
+ ))
+
+primitive_arrays=dict([(k, lambda n, e=e: struct.Struct('<'+e*n))
+ for k, e in
+ [('uint64','Q'),
+ ('int64','q'),
+ ('uint32','I'),
+ ('int32','i'),
+ ('uint16','H'),
+ ('int16','h'),
+ ]])
+
+class NullUnpacker(object):
+ def __init__(self, n):
+ self.size = n
+ def unpack(self, s):
+ return s
+ unpack_from=unpack
+
+class Flag(object):
+ def __init__(self, names, value):
+ self.name = None
+ for k, v in names.items():
+ if value == 1<<k:
+ self.name = v
+ break
+ if self.name is None:
+ self.name = set(v for k,v in names.items() if (1<<k)&value)
+ self.value = value
+ def __str__(self):
+ return 'F(%s,%s)' % (self.name, self.value)
+ __repr__ = __str__
+
+class Enum(object):
+ def __init__(self, the_type, name, value):
+ self.name = name
+ self.value = value
+ #self.the_type = the_type
+ # TODO - cache this
+ #self.name_to_value = dict(zip(self.the_type.values(), self.the_type.keys()))
+ def __str__(self):
+ return 'E(%s,%s)' % (self.name, self.value)
+ __repr__ = __str__
+
+
+primitive_arrays['int8'] = primitive_arrays['uint8'] = lambda n: NullUnpacker(n)
+
+def ensure_dict(maybe_d):
+ if isinstance(maybe_d, dict):
+ return maybe_d
+ return dict(maybe_d)
+
+def get_sub_member(member_d, var_name):
+ if var_name in member_d:
+ return member_d[var_name]
+ ret = member_d
+ for part in var_name.split('.'):
+ ret = ensure_dict(ret)[part]
+ return ret
+
+def annotate_data(f):
+ def wrapped(member, data, i_s, parsed):
+ i = len(data.an)
+ i_s_start = i_s
+ data.an.append((i_s, member)) # annotation (before value is known)
+ i_s, result_name, result_value = f(member, data, i_s, parsed)
+ data.an[i] = (((i_s_start, i_s), member, result_value))
+ return i_s, result_name, result_value
+ wrapped.func_name = '%s@annotate_data' % f.func_name
+ return wrapped
+
+def simple_annotate_data(f):
+ def wrapped(member, the_type, data, i_s):
+ data.an.append((i_s, member))
+ return f(member, the_type, data, i_s)
+ wrapped.func_name = '%s@simple_annotate_data' % f.func_name
+ return wrapped
+
+def member_size(m):
+ pass
+
+def makelen(n, c, s):
+ """pad s with c up to length n"""
+ return s + c * (n - len(s))
+
+class AnnotatedString(str):
+ def __init__(self, x):
+ super(AnnotatedString, self).__init__(x)
+ self.an = []
+ self.max_pointer_i_s = 0
+ def san(self):
+ def nicify(x):
+ if len(x) == 2 and not isinstance(x[0], tuple):
+ start = '='*8 if x[0] == 0 else ' '+'='*8
+ return makelen(80, '=', start + str(x))
+ return str(x)
+ print '\n'.join(map(nicify, self.an))
+ #print '\n'.join('%s-%s,%s,%s,%s' % (i_s_start, i_s, member.member_type. for (i_s_start, i_s),member,val in self.an]))
+
+print_annotation = lambda data: data.san()
+
+if not ANNOTATE:
+ #AnnotatedString = lambda x: x
+ print_annotation = lambda x: None
+ #annotate_data = lambda x: x
+ #simple_annotate_data = lambda x: x
+
+@annotate_data
+def parse_member(member, data, i_s, parsed):
+ result_name = member.name
+ result_value = None
+ if len(set(member.attributes.keys()) - set(
+ ['end', 'ctype', 'as_ptr', 'nomarshal', 'anon', 'chunk'])) > 0:
+ print "has attributes %s" % member.attributes
+ #import pdb; pdb.set_trace()
+ if hasattr(member, 'is_switch') and member.is_switch():
+ var_name = member.variable
+ parsed_d = dict(parsed)
+ var = get_sub_member(parsed_d, var_name)
+ for case in member.cases:
+ for value in case.values:
+ if case.values[0] is None:
+ # default
+ return parse_member(case.member, data, i_s, parsed)
+ cond_wrap = (lambda x: not x) if value[0] == '!' else (lambda x: x)
+ if cond_wrap(var.name == value[1]):
+ return parse_member(case.member, data, i_s, parsed)
+ return i_s, result_name, 'empty switch'
+ member_type = member.member_type if hasattr(member, 'member_type') else member
+ member_type_name = member_type.name
+ results = None
+ if member_type_name is not None:
+ if member_type_name in primitives:
+ primitive = primitives[member_type_name]
+ result_value = primitive.unpack_from(data[i_s:i_s+primitive.size])[0]
+ i_s_out = i_s + primitive.size
+ elif member_type.is_struct():
+ i_s_out, result_name, result_value = parse_complex_member(member, member_type, data, i_s)
+ elif member_type.is_enum():
+ primitive = primitives[member_type.primitive_type()]
+ value = primitive.unpack(data[i_s:i_s+primitive.size])[0]
+ result_value = Enum(member_type.names, member_type.names.get(value, value), value)
+ # TODO - use enum for nice display
+ i_s_out = i_s + primitive.size
+ elif member_type.is_primitive() and hasattr(member_type, 'names'):
+ # assume flag
+ if not isinstance(member_type, ptypes.FlagsType):
+ print "not really a flag.."
+ import pdb; pdb.set_trace()
+ primitive = primitives[member_type.primitive_type()]
+ value = primitive.unpack(data[i_s:i_s+primitive.size])[0]
+ result_value = Flag(member_type.names, value)
+ i_s_out = i_s + primitive.size
+ else:
+ import pdb; pdb.set_trace()
+ elif member_type.is_array():
+ if member_type.is_remaining_length():
+ data_len = len(data) - i_s
+ elif member_type.is_identifier_length():
+ num_elements = dict(parsed)[member_type.size]
+ data_len = None
+ elif member_type.is_image_size_length():
+ bpp = member_type.size[1]
+ width_name = member_type.size[2]
+ rows_name = member_type.size[3]
+ assert(isinstance(width_name, str))
+ assert(isinstance(rows_name, str))
+ parsed_d = dict(parsed)
+ width = get_sub_member(parsed_d, width_name)
+ rows = get_sub_member(parsed, rows_name)
+ if bpp == 8:
+ data_len = rows * width
+ elif bpp == 1:
+ data_len = ((width + 7) // 8 ) * rows
+ else:
+ data_len =((bpp * width + 7) // 8 ) * rows
+ else:
+ print "unhandled array length type"
+ import pdb; pdb.set_trace()
+ element_type = member_type.element_type
+ if element_type.name in primitive_arrays:
+ element_size = primitives[element_type.name].size
+ if data_len is None:
+ data_len = num_elements * element_size
+ primitive = primitive_arrays[element_type.name]
+ contents = primitive(data_len).unpack_from(data[i_s:i_s+data_len])
+ i_s_out = i_s + data_len
+ else:
+ contents = []
+ i_s_out = i_s
+ for _ in xrange(num_elements):
+ i_s_out, sub_name, sub_value = parse_member(element_type, data, i_s_out, [])
+ if sub_value is None:
+ print "list parsing failed"
+ contents = []
+ break
+ contents.append((sub_name, sub_value))
+ result_value = contents
+ elif member_type.is_pointer():
+ pointer_primitive_name = member_type.primitive_type()
+ if pointer_primitive_name in primitives:
+ primitive = primitives[pointer_primitive_name]
+ offset = primitive.unpack_from(data[i_s:i_s+primitive.size])[0]
+ if offset == 0:
+ result_value = []
+ elif offset > len(data):
+ print "ooops.. bad offset %s > %s" % (offset, len(data))
+ import pdb; pdb.set_trace()
+ result_value = None
+ else:
+ discard_i_s, discard_result_name, result_value = parse_member(
+ member_type.target_type, data, offset, [])
+ data.max_pointer_i_s = max(data.max_pointer_i_s, discard_i_s)
+ i_s_out = i_s + primitive.size
+ else:
+ print "pointer with unknown size??"
+ import pdb; pdb.set_trace()
+ if result_name in ['data', 'ents', 'glyphs', 'String', 'str']:
+ result_value = NoPrint(name='', s=result_value)
+ elif len(str(result_value)) > 1000:
+ print "noprint?"
+ import pdb; pdb.set_trace()
+ return i_s_out, result_name, result_value
+
+@simple_annotate_data
+def parse_complex_member(member, the_type, data, i_s):
+ parsed = []
+ i_s_out = i_s
+ for sub_member in the_type.members:
+ i_s_out, result_name, result_value = parse_member(sub_member, data, i_s_out, parsed)
+ if result_value is not None:
+ parsed.append((result_name, result_value))
+ else:
+ print "don't know how to parse %s (%s) in %s" % (
+ sub_member.name, the_type, member)
+ break
+ return i_s_out, member.name, parsed
+
+def parse(channel_type, is_client, header, data):
+ if channel_type not in channels:
+ return NoPrint(name='unknown channel (%s %s)' % (
+ is_client, header), s=data)
+ # annotation support
+ data = AnnotatedString(data)
+ channel = channels[channel_type]
+ collection = channel.client if is_client else channel.server
+ if header.e.type not in collection:
+ print "bad data - no such message in protocol"
+ i_s, result_name, result_value = 0, None, None
+ else:
+ msg_proto = collection[header.e.type]
+ i_s, result_name, result_value = parse_complex_member(
+ msg_proto, msg_proto.message_type, data, 0)
+ i_s = max(data.max_pointer_i_s, i_s)
+ left_over = NoPrint(name='%s:%s' % (channel_type, i_s), s=data[i_s:]) if i_s < len(data) else ''
+ print_annotation(data)
+ if len(left_over) > 0:
+ print "WARNING: in message %s %s out %s unaccounted for (%s%%)" % (
+ msg_proto.name, len(left_over), len(data), float(len(left_over))/len(data))
+ #import pdb; pdb.set_trace()
+ return (msg_proto, (result_name, result_value))
+
+class NoPrint(object):
+ objs = {}
+ def __init__(self, name, s):
+ self.s = s
+ self.id = id(self.s)
+ self.objs[self.id] = self.s
+ self.name = name
+ def orig(self):
+ return self.s
+ def __str__(self):
+ return 'NP' + (str(self.s) if len(self.s) < 20 else '%r...[%s,%s] #%s' % (
+ self.s[:20], len(self.s), self.id,
+ self.name))
+ __repr__ = __str__
+ def __len__(self):
+ return len(self.s)
+