import socket import logging import struct from structutil import (Struct, StructMeta, uint16, uint8, uint64, uint32, uint32_arr, list_to_str) import client_proto logger = logging.getLogger('client') DEBUG_ENABLE_PDB=False if DEBUG_ENABLE_PDB: def debug_break(): import pdb; pdb.set_trace() else: def debug_break(): pass # TODO - parse spice/protocol.h directly (to keep in sync automatically) SPICE_MAGIC = "REDQ" SPICE_VERSION_MAJOR = 2 SPICE_VERSION_MINOR = 0 ################################################################################ # spice-protocol/spice/enums.h (SPICE_CHANNEL_MAIN, SPICE_CHANNEL_DISPLAY, SPICE_CHANNEL_INPUTS, SPICE_CHANNEL_CURSOR, SPICE_CHANNEL_PLAYBACK, SPICE_CHANNEL_RECORD, SPICE_CHANNEL_TUNNEL, SPICE_END_CHANNEL) = xrange(1,1+8) ################################################################################ # Encryption & Ticketing Parameters SPICE_MAX_PASSWORD_LENGTH=60 SPICE_TICKET_KEY_PAIR_LENGTH=1024 SPICE_TICKET_PUBKEY_BYTES=(SPICE_TICKET_KEY_PAIR_LENGTH / 8 + 34) class SpiceDataHeader(Struct): __metaclass__ = StructMeta fields = [(uint64, 'serial'), (uint16, 'type'), (uint32, 'size'), (uint32, 'sub_list')] def __str__(self): return 'Spice #%s, T%s, sub %s, size %s' % (self.e.serial, self.e.type, self.e.sub_list, self.e.size) __repr__ = __str__ @classmethod def verify(cls, *args, **kw): inst = cls(*args, **kw) if inst.e.sub_list != 0 and inst.e.sub_list >= inst.e.size: logger.error('too large sub_list packet - %s' % inst) debug_break() return None if inst.e.type < 0 or inst.e.type >= 400: logger.error('bad type of packet - %s' % inst) debug_break() return None if inst.e.size > 2000000: logger.error('too large packet - %s' % inst) debug_break() return None return inst class SpiceSubMessage(Struct): __metaclass__ = StructMeta fields = [(uint16, 'type'), (uint32, 'size')] class SpiceSubMessageList(Struct): __metaclass__ = StructMeta fields = [(uint16, 'size'), (uint32_arr, 'sub_messages')] data_header_size = SpiceDataHeader.size class SpiceLinkHeader(Struct): __metaclass__ = StructMeta fields = [(uint32, 'magic'), (uint32, 'major_version'), (uint32, 'minor_version'), (uint32, 'size')] def __str__(self): return 'SpiceLink magic=%s, (%s,%s), size %s' % ( struct.pack('' %( self.e.error, self.e.num_common_caps, self.e.num_channel_caps, self.e.pub_key[:10], len(self.e.pub_key)) __repr__ = __str__ def parse_link_msg(s): link_header = SpiceLinkHeader(s) assert(len(s) == SpiceLinkHeader.size + link_header.size) link_msg = SpiceLinkMess(s[SpiceLinkHeader.size:]) return link_header, link_msg def make_data_msg(serial, type, payload, sub_list): data_header = SpiceDataHeader(serial=serial, type=type, size=len(payload), sub_list=sub_list) def connect(host, port): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.connect((host, port)) return s class Channel(object): def __init__(self, s, connection_id, channel_type, channel_id, common_caps, channel_caps): self._s = s self._channel_type = channel_type self._channel_id = channel_id self._connection_id = connection_id self._common_caps = common_caps self._channel_caps = channel_caps self._proto = (client_proto.channels[channel_type] if client_proto.channels else None) self.send_link_message() def send_link_message(self): link_message = SpiceLinkMess( channel_type=self._channel_type, channel_id=self._channel_id, caps_offset=SpiceLinkMess.size, num_channel_caps=len(self._channel_caps), num_common_caps=len(self._common_caps), connection_id=self._connection_id) link_header = SpiceLinkHeader(magic=SPICE_MAGIC, major_version=SPICE_VERSION_MAJOR, minor_version=SPICE_VERSION_MINOR, size=link_message.size) self._s.send( link_header.tostr() + link_message.tostr() + list_to_str(self._common_caps) + list_to_str(self._channel_caps)) self.parse_one(SpiceLinkHeader, self._on_link_reply) def _on_link_reply(self, header, data): link_reply = SpiceLinkReply(data) self._link_reply = link_reply self._link_reply_header = header def _on_data(self, header, data): self._last_data = data self._last_data_header = header print "_on_data: %s" % len(data) def parse_data(self): self.parse_one(SpiceDataHeader, self._on_data) def parse_one(self, ctor, on_message): header = ctor(self._s.recv(ctor.size)) on_message(header, self._s.recv(header.e.size)) class MainChannel(Channel): def send_attach_channels(self): import pdb; pdb.set_trace() self._s.send(s) class Client(object): def __init__(self, host, port, strace=False): self._host = host self._port = port self._strace = strace def connect_main(self): if self._strace: import stracer s = StracerNetwork( '/store/upstream/bin/spicec -h %s -p %s' % ( self._host, self._port) , client=self) self._main_sock = s.sock_iface() else: self._main_sock = connect(self._host, self._port) self.main = self.connect_to_channel(self._main_sock, connection_id=0, channel_type=SPICE_CHANNEL_MAIN, channel_id=0) # for usage from ipython interactively def connect_to_channel(self, s, connection_id, channel_type, channel_id, common_caps=[], channel_caps=[]): return Channel(s, connection_id=connection_id, channel_type=channel_type, channel_id=channel_id, common_caps=common_caps, channel_caps=channel_caps) def run(self): print "TODO" def on_message(self, msg): import pdb; pdb.set_trace() SPICE_MAGIC = struct.unpack(' 182+SpiceDataHeader.size) and ( SpiceDataHeader(s[182:]).e.serial == 1): print "found at 182 (good guess!)" found = True first_data = 182 else: #if len(s) > 100000: # import pdb; pdb.set_trace() b = buffer(s) candidates = [(i, h) for i,h in ((i, SpiceDataHeader(b[i:])) for i in xrange(0,min(500,len(s))-SpiceDataHeader.size,2)) if h.e.serial == 1] if len(candidates) == 1: found = True first_data = candidates[0][0] print "found at %s" % first_data if found: rest = s[first_data:] else: print "not found" return rest def unspice_channels(keys, collecteds): return dict([(ch, (key, conv)) for key, (ch, conv) in zip(keys, parse_conversations(collecteds))]) def parse_conversations(conversations): channels = [] for client_stream, server_stream in conversations: print "parsing (client %s, server %s)" % ( len(client_stream), len(server_stream)) channel_type, client_messages = unspice_client(client_stream) channel_conv = (client_messages, unspice_server(channel_type, server_stream)) channels.append((channel_type, channel_conv)) return channels def parse_link_start(s, msg_ctor): if len(s) < SpiceLinkHeader.size: return 0, [] link_header = SpiceLinkHeader(s) assert(link_header.e.magic == SPICE_MAGIC) msg = msg_ctor(s[ link_header_size:link_header_size+link_header.e.size]) return link_header_size+link_header.e.size, [link_header, msg] def parse_link_server_start(s): return parse_link_start(s, SpiceLinkReply) def parse_link_client_start(s): return parse_link_start(s, SpiceLinkMess) def unspice_server(channel_type, s): if len(s) < SpiceLinkHeader.size: return [] link_header = SpiceLinkHeader(s) if link_header.e.magic != SPICE_MAGIC: print "ch %s: bad server magic, trying to look for SpiceDataHeader (%s)" % ( channel_type, len(s)) link_reply = None rest = find_first_data(s) else: link_reply = SpiceLinkReply(s[ link_header_size:link_header_size+link_header.e.size]) rest = s[link_header_size+link_header.e.size:] ret = [link_header, link_reply] return ret + unspice_rest(channel_type, False, rest) def unspice_client(s): if len(s) == 0: return -1, [] # First message starts with a link header, then link message link_header = SpiceLinkHeader(s) if link_header.e.magic != SPICE_MAGIC: print "bad client magic, trying to look for SpiceDataHeader (%s)" % len(s) link_message = None channel_type = -1 # XXX - can guess based on messages rest = find_first_data(s) else: link_message = SpiceLinkMess(s[ link_header_size:link_header_size+link_header.e.size]) channel_type = link_message.e.channel_type rest = s[link_header_size+link_header.e.size:] ret = [link_header, link_message] return channel_type, ret + unspice_rest(channel_type, True, rest) def unspice_rest(channel_type, is_client, s): datas = [] if len(s) < 8: return datas i_s = 0 if struct.unpack(ENDIANESS + uint64, s[:8])[0] > 1000: if is_client: print "client hack - ignoring 128 bytes" i_s = 128 else: print "server hack - ignoring 4 bytes" i_s = 4 while len(s) > i_s: header = SpiceDataHeader(s[i_s:]) data_start = i_s + SpiceDataHeader.size data_end = data_start + header.e.size if data_end > len(s): break datas.append((header, client_proto.parse( channel_type=channel_type, is_client=is_client, header=header, data=s[data_start:data_end]))) i_s += SpiceDataHeader.size + header.e.size return datas def parse_ctors(s, ctors): ret = [] i = 0 for c in ctors: c_size = c.size if len(s) <= i: break ret.append(c(s[i:])) i += c_size return ret