summaryrefslogtreecommitdiff
path: root/client.py
blob: 192c8ae3babae50feef780fe258a6574f45c4ef3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
import socket
import logging
import struct

from structutil import (Struct, StructMeta, uint16, uint8, uint64,
    uint32, uint32_arr, list_to_str)
import client_proto

logger = logging.getLogger('client')

DEBUG_ENABLE_PDB=False

if DEBUG_ENABLE_PDB:
    def debug_break():
        import pdb; pdb.set_trace()
else:
    def debug_break():
        pass

# TODO - parse spice/protocol.h directly (to keep in sync automatically)
SPICE_MAGIC = "REDQ"
SPICE_VERSION_MAJOR = 2
SPICE_VERSION_MINOR = 0

################################################################################
# spice-protocol/spice/enums.h

(SPICE_CHANNEL_MAIN,
SPICE_CHANNEL_DISPLAY,
SPICE_CHANNEL_INPUTS,
SPICE_CHANNEL_CURSOR,
SPICE_CHANNEL_PLAYBACK,
SPICE_CHANNEL_RECORD,
SPICE_CHANNEL_TUNNEL,
SPICE_END_CHANNEL) = xrange(1,1+8)

################################################################################

# Encryption & Ticketing Parameters
SPICE_MAX_PASSWORD_LENGTH=60
SPICE_TICKET_KEY_PAIR_LENGTH=1024
SPICE_TICKET_PUBKEY_BYTES=(SPICE_TICKET_KEY_PAIR_LENGTH / 8 + 34)

class SpiceDataHeader(Struct):
    __metaclass__ = StructMeta
    fields = [(uint64, 'serial'), (uint16, 'type'), (uint32, 'size'),
        (uint32, 'sub_list')]

    def __str__(self):
        return 'Spice #%s, T%s, sub %s, size %s' % (self.e.serial,
            self.e.type, self.e.sub_list, self.e.size)
    __repr__ = __str__

    @classmethod
    def verify(cls, *args, **kw):
        inst = cls(*args, **kw)
        if inst.e.sub_list != 0 and inst.e.sub_list >= inst.e.size:
            logger.error('too large sub_list packet - %s' % inst)
            debug_break()
            return None
        if inst.e.type < 0 or inst.e.type >= 400:
            logger.error('bad type of packet - %s' % inst)
            debug_break()
            return None
        if inst.e.size > 2000000:
            logger.error('too large packet - %s' % inst)
            debug_break()
            return None
        return inst
            

class SpiceSubMessage(Struct):
    __metaclass__ = StructMeta
    fields = [(uint16, 'type'), (uint32, 'size')]

class SpiceSubMessageList(Struct):
    __metaclass__ = StructMeta
    fields = [(uint16, 'size'), (uint32_arr, 'sub_messages')]

data_header_size = SpiceDataHeader.size

class SpiceLinkHeader(Struct):
    __metaclass__ = StructMeta
    fields = [(uint32, 'magic'), (uint32, 'major_version'),
        (uint32, 'minor_version'), (uint32, 'size')]

    def __str__(self):
        return 'SpiceLink magic=%s, (%s,%s), size %s' % (
            struct.pack('<I', self.e.magic),
            self.e.major_version, self.e.minor_version, self.e.size)
    __repr__ = __str__

    @classmethod
    def verify(cls, *args, **kw):
        inst = cls(*args, **kw)
        if inst.e.magic != SPICE_MAGIC:
            logger.error('bad magic in packet: %s' % inst)
            return None
        return inst

    @classmethod
    def set_proto(cls, *args, **kw):
        inst = cls(*args, **kw)
        if inst.e.magic != SPICE_MAGIC:
            logger.error('bad magic in packet: %s' % inst)
            return None
        import client_proto
        client_proto.set_proto(major_version=inst.e.major_version,
            minor_version=inst.e.minor_version)
        return inst

link_header_size = SpiceLinkHeader.size

class SpiceLinkMess(Struct):
    __metaclass__ = StructMeta
    fields = [(uint32, 'connection_id'), (uint8, 'channel_type'), (uint8, 'channel_id'),
              (uint32, 'num_common_caps'), (uint32, 'num_channel_caps'), (uint32, 'caps_offset')]

    def __str__(self):
        return 'SpiceLinkMess conn id %s, ch type %s, ch id %s caps (%s,%s,%s)' % (
            self.e.connection_id, self.e.channel_type, self.e.channel_id,
            self.e.num_common_caps, self.e.num_channel_caps,
            self.e.caps_offset)
    __repr__ = __str__

link_mess_size = SpiceLinkMess.size

class SpiceLinkReply(Struct):
    __metaclass__ = StructMeta
    fields = [(uint32, 'error'),
              (uint8*SPICE_TICKET_PUBKEY_BYTES, 'pub_key'),
              (uint32, 'num_common_caps'),
              (uint32, 'num_channel_caps'),
              (uint32, 'caps_offset')]

    def __str__(self):
        return '<SpiceLinkReply error=%s, common=%s, channel=%s, pub=%s:%s>' %(
            self.e.error, self.e.num_common_caps, self.e.num_channel_caps,
            self.e.pub_key[:10], len(self.e.pub_key))
    __repr__ = __str__


def parse_link_msg(s):
    link_header = SpiceLinkHeader(s)
    assert(len(s) == SpiceLinkHeader.size + link_header.size)
    link_msg = SpiceLinkMess(s[SpiceLinkHeader.size:])
    return link_header, link_msg

def make_data_msg(serial, type, payload, sub_list):
    data_header = SpiceDataHeader(serial=serial,
        type=type, size=len(payload), sub_list=sub_list)

def connect(host, port):
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.connect((host, port))
    return s

class Channel(object):
    def __init__(self, s, connection_id, channel_type, channel_id,
            common_caps, channel_caps):
        self._s = s
        self._channel_type = channel_type
        self._channel_id = channel_id
        self._connection_id = connection_id
        self._common_caps = common_caps
        self._channel_caps = channel_caps
        self._proto = (client_proto.channels[channel_type] if
                       client_proto.channels else None)
        self.send_link_message()

    def send_link_message(self):
        link_message = SpiceLinkMess(
            channel_type=self._channel_type,
            channel_id=self._channel_id,
            caps_offset=SpiceLinkMess.size,
            num_channel_caps=len(self._channel_caps),
            num_common_caps=len(self._common_caps),
            connection_id=self._connection_id)
        link_header = SpiceLinkHeader(magic=SPICE_MAGIC,
            major_version=SPICE_VERSION_MAJOR,
            minor_version=SPICE_VERSION_MINOR, size=link_message.size)
        self._s.send(
              link_header.tostr()
            + link_message.tostr()
            + list_to_str(self._common_caps)
            + list_to_str(self._channel_caps))
        self.parse_one(SpiceLinkHeader, self._on_link_reply)

    def _on_link_reply(self, header, data):
        link_reply = SpiceLinkReply(data)
        self._link_reply = link_reply
        self._link_reply_header = header

    def _on_data(self, header, data):
        self._last_data = data
        self._last_data_header = header
        print "_on_data: %s" % len(data)

    def parse_data(self):
        self.parse_one(SpiceDataHeader, self._on_data)

    def parse_one(self, ctor, on_message):
        header = ctor(self._s.recv(ctor.size))
        on_message(header, self._s.recv(header.e.size))

class MainChannel(Channel):
    def send_attach_channels(self):
        import pdb; pdb.set_trace()
        self._s.send(s)

class Client(object):
    def __init__(self, host, port, strace=False):
        self._host = host
        self._port = port
        self._strace = strace

    def connect_main(self):
        if self._strace:
            import stracer
            s = StracerNetwork(
                '/store/upstream/bin/spicec -h %s -p %s' % (
                        self._host, self._port)
                , client=self)
            self._main_sock = s.sock_iface()
        else:
            self._main_sock = connect(self._host, self._port)
        self.main = self.connect_to_channel(self._main_sock,
            connection_id=0, channel_type=SPICE_CHANNEL_MAIN,
            channel_id=0)

    # for usage from ipython interactively
    def connect_to_channel(self, s, connection_id, channel_type, channel_id,
        common_caps=[], channel_caps=[]):
        return Channel(s,
            connection_id=connection_id,
            channel_type=channel_type,
            channel_id=channel_id,
            common_caps=common_caps,
            channel_caps=channel_caps)

    def run(self):
        print "TODO"

    def on_message(self, msg):
        import pdb; pdb.set_trace()

SPICE_MAGIC = struct.unpack('<I', 'REDQ')[0]

################################################################################

def find_first_data(s):
    rest = []
    found = False
    if (len(s) > 182+SpiceDataHeader.size) and (
        SpiceDataHeader(s[182:]).e.serial == 1):
        print "found at 182 (good guess!)"
        found = True
        first_data = 182
    else:
        #if len(s) > 100000:
        #    import pdb; pdb.set_trace()
        b = buffer(s)
        candidates = [(i, h) for i,h in
            ((i, SpiceDataHeader(b[i:])) for i in
                xrange(0,min(500,len(s))-SpiceDataHeader.size,2)) if h.e.serial == 1]
        if len(candidates) == 1:
            found = True
            first_data = candidates[0][0]
            print "found at %s" % first_data
    if found:
        rest = s[first_data:]
    else:
        print "not found"
    return rest

def unspice_channels(keys, collecteds):
    return dict([(ch, (key, conv)) for key, (ch, conv) in zip(keys, 
        parse_conversations(collecteds))])

def parse_conversations(conversations):
    channels = []
    for client_stream, server_stream in conversations:
        print "parsing (client %s, server %s)" % (
                        len(client_stream), len(server_stream))
        channel_type, client_messages = unspice_client(client_stream)
        channel_conv = (client_messages,
                    unspice_server(channel_type, server_stream))
        channels.append((channel_type, channel_conv))
    return channels

def parse_link_start(s, msg_ctor):
    if len(s) < SpiceLinkHeader.size:
        return 0, []
    link_header = SpiceLinkHeader(s)
    assert(link_header.e.magic == SPICE_MAGIC)
    msg = msg_ctor(s[
        link_header_size:link_header_size+link_header.e.size])
    return link_header_size+link_header.e.size, [link_header, msg]
 
def parse_link_server_start(s):
    return parse_link_start(s, SpiceLinkReply)

def parse_link_client_start(s):
    return parse_link_start(s, SpiceLinkMess)

def unspice_server(channel_type, s):
    if len(s) < SpiceLinkHeader.size:
        return []
    link_header = SpiceLinkHeader(s)
    if link_header.e.magic != SPICE_MAGIC:
        print "ch %s: bad server magic, trying to look for SpiceDataHeader (%s)" % (
            channel_type, len(s))
        link_reply = None
        rest = find_first_data(s)
    else:
        link_reply = SpiceLinkReply(s[
            link_header_size:link_header_size+link_header.e.size])
        rest = s[link_header_size+link_header.e.size:]
    ret = [link_header, link_reply]
    return ret + unspice_rest(channel_type, False, rest)

def unspice_client(s):
    if len(s) == 0:
        return -1, []
    # First message starts with a link header, then link message
    link_header = SpiceLinkHeader(s)
    if link_header.e.magic != SPICE_MAGIC:
        print "bad client magic, trying to look for SpiceDataHeader (%s)" % len(s)
        link_message = None
        channel_type = -1 # XXX - can guess based on messages
        rest = find_first_data(s)
    else:
        link_message = SpiceLinkMess(s[
            link_header_size:link_header_size+link_header.e.size])
        channel_type = link_message.e.channel_type
        rest = s[link_header_size+link_header.e.size:]
    ret = [link_header, link_message]
    return channel_type, ret + unspice_rest(channel_type, True, rest)

def unspice_rest(channel_type, is_client, s):
    datas = []
    if len(s) < 8:
        return datas
    i_s = 0
    if struct.unpack(ENDIANESS + uint64, s[:8])[0] > 1000:
        if is_client:
            print "client hack - ignoring 128 bytes"
            i_s = 128
        else:
            print "server hack - ignoring 4 bytes"
            i_s = 4
    while len(s) > i_s:
        header = SpiceDataHeader(s[i_s:])
        data_start = i_s + SpiceDataHeader.size
        data_end = data_start + header.e.size
        if data_end > len(s):
            break
        datas.append((header, client_proto.parse(
                    channel_type=channel_type,
                    is_client=is_client,
                    header=header,
                    data=s[data_start:data_end])))
        i_s += SpiceDataHeader.size + header.e.size
    return datas

def parse_ctors(s, ctors):
    ret = []
    i = 0
    for c in ctors:
        c_size = c.size
        if len(s) <= i:
            break
        ret.append(c(s[i:]))
        i += c_size
    return ret