summaryrefslogtreecommitdiff
path: root/pcapspice.py
diff options
context:
space:
mode:
Diffstat (limited to 'pcapspice.py')
-rw-r--r--pcapspice.py230
1 files changed, 230 insertions, 0 deletions
diff --git a/pcapspice.py b/pcapspice.py
new file mode 100644
index 0000000..b1c1e96
--- /dev/null
+++ b/pcapspice.py
@@ -0,0 +1,230 @@
+from itertools import chain, repeat
+from pcaputil import (tcp_parse, is_tcp, is_tcp_data, is_tcp_syn,
+ conversations_iter, header_conversation_iter)
+import client
+from client import SpiceLinkHeader, SpiceLinkMess, SpiceLinkReply, SpiceDataHeader
+
+def is_single_packet_data(payload):
+ return len(payload) == SpiceDataHeader(payload).e.size + SpiceDataHeader.size
+
+num_spice_messages = sum(len(ch.client) + len(ch.server) for ch in
+ client.client_proto.channels.values())
+
+valid_message_ids = set(sum([ch.client.keys() + ch.server.keys() for ch in
+ client.client_proto.channels.values()], []))
+
+all_channels = client.client_proto.channels.keys()
+
+def possible_channels(server_message, header):
+ return set(c for c in all_channels if header.e.type in
+ (client.client_proto.channels[c].server.keys() if server_message
+ else client.client_proto.channels[c].client.keys()))
+
+guesses = {}
+
+def guess_channel_iter():
+ channel = None
+ optional_channels = set(all_channels)
+ seen_headers = []
+ while True:
+ src, dst, data_header = yield channel
+ key = (min(src, dst), max(src, dst))
+ if guesses.has_key(key):
+ optional_channels = set([guesses[key]])
+ seen_headers.append((src, dst, data_header))
+ optional_channels = optional_channels & possible_channels(
+ src == min(src, dst), data_header)
+ print optional_channels
+ if len(optional_channels) == 1:
+ channel = list(optional_channels)[0]
+ guesses[key] = channel
+ if len(optional_channels) == 0:
+ import pdb; pdb.set_trace()
+
+spice_size_from_header = lambda header: header.e.size
+
+def channel_spice_start_iter(lower_port, higher_port):
+ server_port, client_port = lower_port, higher_port
+ ports = [server_port, client_port]
+ messages = dict([(port,None) for port in ports])
+ payloads = dict([(port,[]) for port in ports])
+ headers = dict([(port,None) for port in ports])
+ discarded_len = {client_port:128, server_port:4}
+ message_ctors = {client_port:SpiceLinkMess, server_port:SpiceLinkReply}
+ if len(payload) > 0:
+ payloads[src].append(payload)
+ while not all(messages.values()):
+ src, dst, payload = yield None
+ if len(payload) == 0: continue
+ payloads[src].append(payload)
+ len_payload = sum(map(len,payloads[src]))
+ if headers[src] is None and len_payload >= SpiceLinkHeader.size:
+ headers[src] = SpiceLinkHeader(''.join(payloads[src]))
+ import pdb; pdb.set_trace()
+ expected_len = headers[src].e.size + SpiceLinkHeader.size + discarded_len[src]
+ if headers[src] is not None and expected_len <= len_payload:
+ if expected_len != len_payload:
+ print "extra bytes dumped!!! %d" % (len_payload - expected_len)
+ messages[src] = message_ctors[src](''.join(payloads[src])[
+ SpiceLinkHeader.size:])
+ channel = messages[client_port].e.channel_type
+ print "channel type %s created" % channel
+ result = None
+ src, dst, payload = yield None
+ data_iter = channel_spice_iter(src, dst, payload, channel=channel)
+ result = data_iter.next()
+ while True:
+ src, dst, payload = yield result
+ result = data_iter.send((src, dst, payload))
+
+def channel_spice_iter(lower_port, higher_port, channel=None):
+ server_port, client_port = lower_port, higher_port
+ serial = {}
+ guesser = guess_channel_iter()
+ guesser.next()
+
+ result = None
+ # yield spice data parsed results
+ while True:
+ src, dst, payload = yield result
+ data_header = SpiceDataHeader(payload)
+ # basic sanity check on the header
+ bad_header = False
+ if serial.has_key(src):
+ if data_header.e.serial != serial[src] + 1:
+ print "bad serial, replacing for %s->%s (%s!=%s+1)" % (
+ src, dst, data_header.e.serial, serial[src])
+ bad_header = True
+ serial[src] = data_header.e.serial
+ if data_header.e.type not in valid_message_ids:
+ print "bad message type %s in %s->%s" % (data_header.e.type,
+ src, dst)
+ if bad_header: continue
+ if channel is None:
+ channel = guesser.send((src, dst, data_header))
+ if channel is not None:
+ # read as many messages as it takes to get
+ (msg_proto, (result_name, result_value),
+ left_over) = client.client_proto.parse(
+ channel_type=channel,
+ is_client=src == client_port,
+ header=data_header,
+ data=payload[SpiceDataHeader.size:])
+ result = (data_header, msg_proto, result_name, result_value,
+ left_over)
+
+# for every tcp packet
+# if it is a syn, read the link header, decide which channel we are
+# if it is data
+# if we know which channel we are, parse it
+# else guess which channel.
+# if we just figured out which channel we are, releaes all pending packets.
+
+def spice_iter1(packet_iter):
+ return conversations_iter(
+ packet_iter,
+ channel_spice_start_iter,
+ channel_spice_iter)
+
+def packet_header_iter_gen(hdr_ctor, pkt_ctor):
+ return repeat(
+ (hdr_ctor.size,
+ hdr_ctor,
+ lambda h: h.e.size,
+ pkt_ctor))
+
+def ident(x, **kw):
+ return x
+
+def channel_spice_message_iter(src, dst, guesser):
+ server_port, client_port = min(src, dst), max(src, dst)
+ serial = {}
+ channel = None
+ result = None
+ # yield spice data parsed results
+ while True:
+ src, dst, (data_header, message) = yield result
+ # basic sanity check on the header
+ bad_header = False
+ if serial.has_key(src):
+ if data_header.e.serial != serial[src] + 1:
+ print "bad serial, replacing for %s->%s (%s!=%s+1)" % (
+ src, dst, data_header.e.serial, serial[src])
+ bad_header = True
+ serial[src] = data_header.e.serial
+ if data_header.e.type not in valid_message_ids:
+ print "bad message type %s in %s->%s" % (data_header.e.type,
+ src, dst)
+ bad_header = True
+ if bad_header: continue
+ if channel is None:
+ channel = guesser.send((src, dst, data_header))
+ if channel is not None:
+ # read as many messages as it takes to get
+ (msg_proto, (result_name, result_value)
+ ) = client.client_proto.parse(
+ channel_type=channel,
+ is_client=src == client_port,
+ header=data_header,
+ data=message)
+ result = (msg_proto, result_name, result_value)
+
+class ChannelGuesser(object):
+
+ def __init__(self):
+ self.iter = guess_channel_iter()
+ self.iter.next()
+ self.channel = None
+
+ def send(self, (src, dst, payload)):
+ if self.channel: return self.channel
+ return self.iter.send((src, dst, payload))
+
+ def on_client_link_message(self, p, **kw):
+ self._client_message = SpiceLinkMess(p)
+ self.channel = self._client_message.e.channel_type
+ print "STARTED channel %s" % self.channel
+
+ def on_server_link_message(self, p, **kw):
+ self._server_message = SpiceLinkReply(p)
+
+def spice_iter(packet_iter):
+ guessers = {}
+ def make_start_iter(src, dst):
+ is_server = src < dst
+ key = (min(src, dst), max(src, dst))
+ guesser = guessers[key] = guessers.get(key, ChannelGuesser())
+ print guesser, is_server, "**** make_start_iter %s -> %s" % (src, dst)
+ link_messages = []
+ yield (SpiceLinkHeader.size,
+ SpiceLinkHeader.verify,
+ lambda h: h.e.size,
+ (guesser.on_server_link_message if is_server else
+ guesser.on_client_link_message))
+ print guesser, is_server, "2"
+ message_iter = channel_spice_message_iter(src, dst, guesser)
+ message_iter.next()
+ # TODO: way to say "src->dst 128, dst->src 4"
+ yield 128 if src > dst else 4, ident, lambda h: 0, ident
+ for i, data in enumerate(packet_header_iter_gen(
+ SpiceDataHeader,
+ lambda pkt, src, dst, header:
+ message_iter.send((src, dst, (header, pkt))))):
+ print guesser, is_server, "2+%s" % i
+ yield data
+ def make_iter(src, dst):
+ key = (min(src, dst), max(src, dst))
+ guesser = guessers[key] = guessers.get(key, ChannelGuesser())
+ message_iter = channel_spice_message_iter(src, dst, guesser)
+ message_iter.next()
+ return packet_header_iter_gen(
+ SpiceDataHeader,
+ lambda pkt, src, dst, header:
+ message_iter.send((src, dst, (header, pkt))))
+ return header_conversation_iter(
+ packet_iter,
+ server_start=make_start_iter,
+ client_start=make_start_iter,
+ server=make_iter,
+ client=make_iter)
+