summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlon Levy <alevy@redhat.com>2011-08-12 11:39:04 +0300
committerAlon Levy <alevy@redhat.com>2011-08-12 11:39:04 +0300
commit470de5de12a7474feffdd6e6650e6d797a844605 (patch)
treed09d406376714f29ffaada4dd43208b8e4cdb579
refactored proxy part from spicedump
-rwxr-xr-xproxy.py191
1 files changed, 191 insertions, 0 deletions
diff --git a/proxy.py b/proxy.py
new file mode 100755
index 0000000..8c43a25
--- /dev/null
+++ b/proxy.py
@@ -0,0 +1,191 @@
+#!/usr/bin/python
+from socket import socket, AF_INET, SOCK_STREAM
+from select import select
+
+MAX_PACKET_SIZE=65536
+
+class Socket(socket):
+ sockets = []
+ def __init__(self, *args, **kw):
+ super(Socket, self).__init__(*args, **kw)
+ self.sockets.append(self)
+
+class twowaydict(dict):
+ def __init__(self):
+ self.other = dict()
+ def __delitem__(self, k):
+ if super(twowaydict,self).__contains__(k):
+ other_k = self[k]
+ else:
+ other_k = k
+ k = self.other[k]
+ del self.other[other_k]
+ super(twowaydict,self).__delitem__(k)
+ def __setitem__(self, k, v):
+ super(twowaydict, self).__setitem__(k, v)
+ self.other[v] = k
+ def __contains__(self, k):
+ return super(twowaydict,self).__contains__(k) or k in self.other
+ def __getitem__(self, k):
+ if super(twowaydict,self).__contains__(k):
+ return super(twowaydict,self).__getitem__(k)
+ return self.other[k]
+ def getpair(self, k):
+ if k in self:
+ return k, self[k]
+ elif k in self.other:
+ return k, self.other[k]
+ def allkeys(self):
+ return self.keys() + self.other.keys()
+
+BROKEN_PIPE_ERRNO, TRANSPORT_NOT_CONNECTED_ERRNO = 32, 107
+
+def make_accepter(port, host='127.0.0.1'):
+ accepter = Socket(AF_INET, SOCK_STREAM)
+ accepter.bind((host, port))
+ accepter.listen(1)
+ return accepter
+
+def connect(addr):
+ s = Socket(AF_INET, SOCK_STREAM)
+ s.connect(addr)
+ s.setblocking(False)
+ return s
+
+
+class Proxy(object):
+
+ def __init__(self, local_port, remote_addr, host='127.0.0.1'):
+ self._drop_next = False
+ handle_input, select_based_iterator, get_fds = make_proxy(
+ self, local_port, remote_addr, host)
+ self.select_based_iterator = select_based_iterator()
+
+ def drop_next(self):
+ self._drop_next = True
+
+ def check_drop_next(self):
+ dn = self._drop_next
+ self._drop_next = False
+ return dn
+
+ def __iter__(self):
+ for x in self.select_based_iterator:
+ yield x
+
+
+def make_proxy(local_port, remote_addr, host = '127.0.0.1',
+ check_drop_next=lambda: True):
+ print "proxying from %s to %s" % (local_port, remote_addr)
+ accepter = make_accepter(local_port, host)
+ open_socks = twowaydict()
+ close_errnos = set([BROKEN_PIPE_ERRNO, TRANSPORT_NOT_CONNECTED_ERRNO])
+ def get_fds():
+ return [accepter] + open_socks.allkeys()
+ def iterate_packets():
+ """
+ An iterator with inputs, designed to work with some external
+ event loop. Supply readable sockets list via send, see below example.
+ """
+ readable = yield []
+ results = []
+ while True:
+ for s in readable:
+ result = handle_input(s)
+ if result:
+ src, dst, data, completer = result
+ results.append((src, dst, data))
+ # TODO - do the condition here, have it provided as
+ # a callback. The mixed oop/fp style is bad.
+ completer()
+ readable = yield results
+ results = []
+ def handle_input(s):
+ """
+ returns None if a socket was closed (nothing to do),
+ or a quadtuple of (src, dst, data, completer)
+ The completer will send the packet to the destination. Conversely
+ if you don't call it the packet is simply dropped.
+ """
+ if s is accepter:
+ s, _addr = accepter.accept()
+ open_socks[s] = connect(remote_addr)
+ else:
+ other = open_socks[s]
+ src_dst_socks = [s, other]
+ src_dst = [None, None]
+ dont_recv = False
+ for i in xrange(len(src_dst)):
+ try:
+ src_dst[i] = src_dst_socks[i].getpeername()[1]
+ except Exception, e:
+ if e.errno in close_errnos:
+ src_dst_socks[1-i].close()
+ if s in open_socks:
+ del open_socks[s]
+ dont_recv = True
+ if dont_recv:
+ return None
+ src, dst = src_dst
+ data = s.recv(MAX_PACKET_SIZE)
+ if len(data) == 0:
+ other.close()
+ del open_socks[s]
+ return
+ def completer():
+ if check_drop_next():
+ return
+ try:
+ other.send(data)
+ except Exception, e:
+ if e.errno in close_errnos:
+ n = len(open_socks)
+ s.close()
+ open_socks[s].close()
+ del open_socks[s]
+ assert(len(open_socks) == n - 1)
+ else:
+ import pdb; pdb.set_trace()
+ return src, dst, data, completer
+ def select_based_iterator():
+ packet_iter = iterate_packets()
+ packet_iter.next()
+ while True:
+ #print "open: %s" % len(open_socks)
+ rds, _wrs, _ex = select(get_fds(), [], [])
+ results = packet_iter.send(rds)
+ for r in results:
+ yield r
+ return handle_input, select_based_iterator, get_fds
+
+def proxy(local_port, remote_addr):
+ return Proxy(local_port, remote_addr)
+
+def closeallsockets():
+ for s in Socket.sockets:
+ try:
+ s.close()
+ except:
+ pass
+
+def example_main():
+ import sys
+ target_host, target_port = sys.argv[-1].split(':')
+ local_port = int(sys.argv[-2])
+ for src, dst, data in proxy(local_port , (target_host, int(target_port))):
+ print "%s->%s %s" % (src, dst, len(data))
+
+def tests():
+ import argparse
+ import sys
+ p = argparse.ArgumentParser(description="proxy multiple socket connections")
+ p.add_argument('--listen-port', default=11000)
+ p.add_argument('--remote-port', default=12000)
+ p.add_argument('--remote-host', default='127.0.0.1')
+ args = p.parse_args(sys.argv[1:])
+ for x in proxy(args.listen_port, (args.remote_host, args.remote_port)):
+ print x
+
+if __name__ == '__main__':
+ tests()
+