diff options
Diffstat (limited to 'pcapspice.py')
-rw-r--r-- | pcapspice.py | 230 |
1 files changed, 230 insertions, 0 deletions
diff --git a/pcapspice.py b/pcapspice.py new file mode 100644 index 0000000..b1c1e96 --- /dev/null +++ b/pcapspice.py @@ -0,0 +1,230 @@ +from itertools import chain, repeat +from pcaputil import (tcp_parse, is_tcp, is_tcp_data, is_tcp_syn, + conversations_iter, header_conversation_iter) +import client +from client import SpiceLinkHeader, SpiceLinkMess, SpiceLinkReply, SpiceDataHeader + +def is_single_packet_data(payload): + return len(payload) == SpiceDataHeader(payload).e.size + SpiceDataHeader.size + +num_spice_messages = sum(len(ch.client) + len(ch.server) for ch in + client.client_proto.channels.values()) + +valid_message_ids = set(sum([ch.client.keys() + ch.server.keys() for ch in + client.client_proto.channels.values()], [])) + +all_channels = client.client_proto.channels.keys() + +def possible_channels(server_message, header): + return set(c for c in all_channels if header.e.type in + (client.client_proto.channels[c].server.keys() if server_message + else client.client_proto.channels[c].client.keys())) + +guesses = {} + +def guess_channel_iter(): + channel = None + optional_channels = set(all_channels) + seen_headers = [] + while True: + src, dst, data_header = yield channel + key = (min(src, dst), max(src, dst)) + if guesses.has_key(key): + optional_channels = set([guesses[key]]) + seen_headers.append((src, dst, data_header)) + optional_channels = optional_channels & possible_channels( + src == min(src, dst), data_header) + print optional_channels + if len(optional_channels) == 1: + channel = list(optional_channels)[0] + guesses[key] = channel + if len(optional_channels) == 0: + import pdb; pdb.set_trace() + +spice_size_from_header = lambda header: header.e.size + +def channel_spice_start_iter(lower_port, higher_port): + server_port, client_port = lower_port, higher_port + ports = [server_port, client_port] + messages = dict([(port,None) for port in ports]) + payloads = dict([(port,[]) for port in ports]) + headers = dict([(port,None) for port in ports]) + discarded_len = {client_port:128, server_port:4} + message_ctors = {client_port:SpiceLinkMess, server_port:SpiceLinkReply} + if len(payload) > 0: + payloads[src].append(payload) + while not all(messages.values()): + src, dst, payload = yield None + if len(payload) == 0: continue + payloads[src].append(payload) + len_payload = sum(map(len,payloads[src])) + if headers[src] is None and len_payload >= SpiceLinkHeader.size: + headers[src] = SpiceLinkHeader(''.join(payloads[src])) + import pdb; pdb.set_trace() + expected_len = headers[src].e.size + SpiceLinkHeader.size + discarded_len[src] + if headers[src] is not None and expected_len <= len_payload: + if expected_len != len_payload: + print "extra bytes dumped!!! %d" % (len_payload - expected_len) + messages[src] = message_ctors[src](''.join(payloads[src])[ + SpiceLinkHeader.size:]) + channel = messages[client_port].e.channel_type + print "channel type %s created" % channel + result = None + src, dst, payload = yield None + data_iter = channel_spice_iter(src, dst, payload, channel=channel) + result = data_iter.next() + while True: + src, dst, payload = yield result + result = data_iter.send((src, dst, payload)) + +def channel_spice_iter(lower_port, higher_port, channel=None): + server_port, client_port = lower_port, higher_port + serial = {} + guesser = guess_channel_iter() + guesser.next() + + result = None + # yield spice data parsed results + while True: + src, dst, payload = yield result + data_header = SpiceDataHeader(payload) + # basic sanity check on the header + bad_header = False + if serial.has_key(src): + if data_header.e.serial != serial[src] + 1: + print "bad serial, replacing for %s->%s (%s!=%s+1)" % ( + src, dst, data_header.e.serial, serial[src]) + bad_header = True + serial[src] = data_header.e.serial + if data_header.e.type not in valid_message_ids: + print "bad message type %s in %s->%s" % (data_header.e.type, + src, dst) + if bad_header: continue + if channel is None: + channel = guesser.send((src, dst, data_header)) + if channel is not None: + # read as many messages as it takes to get + (msg_proto, (result_name, result_value), + left_over) = client.client_proto.parse( + channel_type=channel, + is_client=src == client_port, + header=data_header, + data=payload[SpiceDataHeader.size:]) + result = (data_header, msg_proto, result_name, result_value, + left_over) + +# for every tcp packet +# if it is a syn, read the link header, decide which channel we are +# if it is data +# if we know which channel we are, parse it +# else guess which channel. +# if we just figured out which channel we are, releaes all pending packets. + +def spice_iter1(packet_iter): + return conversations_iter( + packet_iter, + channel_spice_start_iter, + channel_spice_iter) + +def packet_header_iter_gen(hdr_ctor, pkt_ctor): + return repeat( + (hdr_ctor.size, + hdr_ctor, + lambda h: h.e.size, + pkt_ctor)) + +def ident(x, **kw): + return x + +def channel_spice_message_iter(src, dst, guesser): + server_port, client_port = min(src, dst), max(src, dst) + serial = {} + channel = None + result = None + # yield spice data parsed results + while True: + src, dst, (data_header, message) = yield result + # basic sanity check on the header + bad_header = False + if serial.has_key(src): + if data_header.e.serial != serial[src] + 1: + print "bad serial, replacing for %s->%s (%s!=%s+1)" % ( + src, dst, data_header.e.serial, serial[src]) + bad_header = True + serial[src] = data_header.e.serial + if data_header.e.type not in valid_message_ids: + print "bad message type %s in %s->%s" % (data_header.e.type, + src, dst) + bad_header = True + if bad_header: continue + if channel is None: + channel = guesser.send((src, dst, data_header)) + if channel is not None: + # read as many messages as it takes to get + (msg_proto, (result_name, result_value) + ) = client.client_proto.parse( + channel_type=channel, + is_client=src == client_port, + header=data_header, + data=message) + result = (msg_proto, result_name, result_value) + +class ChannelGuesser(object): + + def __init__(self): + self.iter = guess_channel_iter() + self.iter.next() + self.channel = None + + def send(self, (src, dst, payload)): + if self.channel: return self.channel + return self.iter.send((src, dst, payload)) + + def on_client_link_message(self, p, **kw): + self._client_message = SpiceLinkMess(p) + self.channel = self._client_message.e.channel_type + print "STARTED channel %s" % self.channel + + def on_server_link_message(self, p, **kw): + self._server_message = SpiceLinkReply(p) + +def spice_iter(packet_iter): + guessers = {} + def make_start_iter(src, dst): + is_server = src < dst + key = (min(src, dst), max(src, dst)) + guesser = guessers[key] = guessers.get(key, ChannelGuesser()) + print guesser, is_server, "**** make_start_iter %s -> %s" % (src, dst) + link_messages = [] + yield (SpiceLinkHeader.size, + SpiceLinkHeader.verify, + lambda h: h.e.size, + (guesser.on_server_link_message if is_server else + guesser.on_client_link_message)) + print guesser, is_server, "2" + message_iter = channel_spice_message_iter(src, dst, guesser) + message_iter.next() + # TODO: way to say "src->dst 128, dst->src 4" + yield 128 if src > dst else 4, ident, lambda h: 0, ident + for i, data in enumerate(packet_header_iter_gen( + SpiceDataHeader, + lambda pkt, src, dst, header: + message_iter.send((src, dst, (header, pkt))))): + print guesser, is_server, "2+%s" % i + yield data + def make_iter(src, dst): + key = (min(src, dst), max(src, dst)) + guesser = guessers[key] = guessers.get(key, ChannelGuesser()) + message_iter = channel_spice_message_iter(src, dst, guesser) + message_iter.next() + return packet_header_iter_gen( + SpiceDataHeader, + lambda pkt, src, dst, header: + message_iter.send((src, dst, (header, pkt)))) + return header_conversation_iter( + packet_iter, + server_start=make_start_iter, + client_start=make_start_iter, + server=make_iter, + client=make_iter) + |