summaryrefslogtreecommitdiff
path: root/pcapspice.py
blob: 775f8ef29ff938c72510c24fa196ff667486694d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from itertools import repeat
import logging

from pcaputil import header_conversation_iter
import client_proto
from client import (SpiceLinkHeader, SpiceLinkMess,
                    SpiceLinkReply, SpiceDataHeader)

logger = logging.getLogger('pcapspice')

def is_single_packet_data(payload):
    return len(payload) == SpiceDataHeader(payload).e.size + SpiceDataHeader.size
#
guesses = {}

def guess_channel_iter():
    channel = None
    optional_channels = set(client_proto.all_channels)
    seen_headers = []
    while True:
        src, dst, data_header = yield channel
        key = (min(src, dst), max(src, dst))
        if guesses.has_key(key):
            optional_channels = set([guesses[key]])
        seen_headers.append((src, dst, data_header))
        optional_channels = optional_channels & client_proto.possible_channels(
            src == min(src, dst), data_header)
        logger.debug(str(optional_channels))
        if len(optional_channels) == 1:
            channel = list(optional_channels)[0]
            guesses[key] = channel
        if len(optional_channels) == 0:
            import pdb; pdb.set_trace()

def ident(x, **kw):
    return x

def channel_spice_message_iter(src, dst, guesser):
    """ coroutine, accepts (src,dst,header,data) packets
    and yields maybe parsed spice results.

     maybe X - returns None or X
    """
    server_port, client_port = min(src, dst), max(src, dst)
    serial = {}
    channel = None
    result = None
    while True:
        src, dst, (data_header, message) = yield result
        result = None
        bad_header = False
        if serial.has_key(src):
            if data_header.e.serial != serial[src] + 1:
                logger.debug("bad serial, replacing for %s->%s (%s!=%s+1)" % (
                    src, dst, data_header.e.serial, serial[src]))
                bad_header = True
        serial[src] = data_header.e.serial
        if data_header.e.type not in client_proto.valid_message_ids:
            logger.error("bad message type %s in %s->%s" % (data_header.e.type,
                src, dst))
            bad_header = True
        if bad_header: continue
        channel = guesser.send((src, dst, data_header))
        if channel is not None:
            # read as many messages as it takes to get
            result = client_proto.parse(
                        channel_type=channel,
                        is_client=src == client_port,
                        header=data_header,
                        data=message)

class ChannelGuesser(object):

    def __init__(self):
        self.iter = None
        self.channel = None

    def send(self, (src, dst, payload)):
        if self.channel:
            return self.channel
        if self.iter is None:
            self.iter = guess_channel_iter()
            self.iter.next()
        self.channel = self.iter.send((src, dst, payload))
        return self.channel

    def on_client_link_message(self, p, **kw):
        self._client_message = SpiceLinkMess(p)
        self.channel = self._client_message.e.channel_type
        logger.info("STARTED channel %s" % self.channel)

    def on_server_link_message(self, p, **kw):
        self._server_message = SpiceLinkReply(p)
 
def spice_iter(packet_iter):
    guessers = {}
    def make_start_iter(src, dst):
        is_server = src < dst
        key = (min(src, dst), max(src, dst))
        guesser = guessers[key] = guessers.get(key, ChannelGuesser())
        logger.debug('%s %s **** make_start_iter %s -> %s' %( guesser, is_server, src, dst))
        link_messages = []
        yield (SpiceLinkHeader.size,
            SpiceLinkHeader.set_proto,
            lambda h: h.e.size,
            (guesser.on_server_link_message if is_server else 
            guesser.on_client_link_message))
        logger.debug('%s %s 2' % (guesser, is_server))
        yield 128 if src > dst else 4, ident, lambda h: 0, ident
        for i, data in enumerate(make_iter(src, dst)):
            logger.debug('%s %s 2+%s' % (guesser, is_server, i))
            yield data
    def make_iter(src, dst):
        key = (min(src, dst), max(src, dst))
        guesser = guessers[key] = guessers.get(key, ChannelGuesser())
        message_iter = channel_spice_message_iter(src, dst, guesser)
        message_iter.next()
        return repeat(
            (SpiceDataHeader.size,
            SpiceDataHeader.verify,
            lambda h: h.e.size,
            lambda pkt, src, dst, header:
                    message_iter.send((src, dst, (header, pkt)))))
    def filter_result(r):
        return (r.msg is not None and r.msg.data != None
                    and hasattr(r.msg.data, 'result_name'))
    return header_conversation_iter(
        packet_iter,
        server_start=make_start_iter,
        client_start=make_start_iter,
        server=make_iter,
        client=make_iter,
        filter_result=filter_result)