diff options
author | Alon Levy <alevy@redhat.com> | 2011-08-12 11:39:04 +0300 |
---|---|---|
committer | Alon Levy <alevy@redhat.com> | 2011-08-12 11:39:04 +0300 |
commit | 470de5de12a7474feffdd6e6650e6d797a844605 (patch) | |
tree | d09d406376714f29ffaada4dd43208b8e4cdb579 |
refactored proxy part from spicedump
-rwxr-xr-x | proxy.py | 191 |
1 files changed, 191 insertions, 0 deletions
diff --git a/proxy.py b/proxy.py new file mode 100755 index 0000000..8c43a25 --- /dev/null +++ b/proxy.py @@ -0,0 +1,191 @@ +#!/usr/bin/python +from socket import socket, AF_INET, SOCK_STREAM +from select import select + +MAX_PACKET_SIZE=65536 + +class Socket(socket): + sockets = [] + def __init__(self, *args, **kw): + super(Socket, self).__init__(*args, **kw) + self.sockets.append(self) + +class twowaydict(dict): + def __init__(self): + self.other = dict() + def __delitem__(self, k): + if super(twowaydict,self).__contains__(k): + other_k = self[k] + else: + other_k = k + k = self.other[k] + del self.other[other_k] + super(twowaydict,self).__delitem__(k) + def __setitem__(self, k, v): + super(twowaydict, self).__setitem__(k, v) + self.other[v] = k + def __contains__(self, k): + return super(twowaydict,self).__contains__(k) or k in self.other + def __getitem__(self, k): + if super(twowaydict,self).__contains__(k): + return super(twowaydict,self).__getitem__(k) + return self.other[k] + def getpair(self, k): + if k in self: + return k, self[k] + elif k in self.other: + return k, self.other[k] + def allkeys(self): + return self.keys() + self.other.keys() + +BROKEN_PIPE_ERRNO, TRANSPORT_NOT_CONNECTED_ERRNO = 32, 107 + +def make_accepter(port, host='127.0.0.1'): + accepter = Socket(AF_INET, SOCK_STREAM) + accepter.bind((host, port)) + accepter.listen(1) + return accepter + +def connect(addr): + s = Socket(AF_INET, SOCK_STREAM) + s.connect(addr) + s.setblocking(False) + return s + + +class Proxy(object): + + def __init__(self, local_port, remote_addr, host='127.0.0.1'): + self._drop_next = False + handle_input, select_based_iterator, get_fds = make_proxy( + self, local_port, remote_addr, host) + self.select_based_iterator = select_based_iterator() + + def drop_next(self): + self._drop_next = True + + def check_drop_next(self): + dn = self._drop_next + self._drop_next = False + return dn + + def __iter__(self): + for x in self.select_based_iterator: + yield x + + +def make_proxy(local_port, remote_addr, host = '127.0.0.1', + check_drop_next=lambda: True): + print "proxying from %s to %s" % (local_port, remote_addr) + accepter = make_accepter(local_port, host) + open_socks = twowaydict() + close_errnos = set([BROKEN_PIPE_ERRNO, TRANSPORT_NOT_CONNECTED_ERRNO]) + def get_fds(): + return [accepter] + open_socks.allkeys() + def iterate_packets(): + """ + An iterator with inputs, designed to work with some external + event loop. Supply readable sockets list via send, see below example. + """ + readable = yield [] + results = [] + while True: + for s in readable: + result = handle_input(s) + if result: + src, dst, data, completer = result + results.append((src, dst, data)) + # TODO - do the condition here, have it provided as + # a callback. The mixed oop/fp style is bad. + completer() + readable = yield results + results = [] + def handle_input(s): + """ + returns None if a socket was closed (nothing to do), + or a quadtuple of (src, dst, data, completer) + The completer will send the packet to the destination. Conversely + if you don't call it the packet is simply dropped. + """ + if s is accepter: + s, _addr = accepter.accept() + open_socks[s] = connect(remote_addr) + else: + other = open_socks[s] + src_dst_socks = [s, other] + src_dst = [None, None] + dont_recv = False + for i in xrange(len(src_dst)): + try: + src_dst[i] = src_dst_socks[i].getpeername()[1] + except Exception, e: + if e.errno in close_errnos: + src_dst_socks[1-i].close() + if s in open_socks: + del open_socks[s] + dont_recv = True + if dont_recv: + return None + src, dst = src_dst + data = s.recv(MAX_PACKET_SIZE) + if len(data) == 0: + other.close() + del open_socks[s] + return + def completer(): + if check_drop_next(): + return + try: + other.send(data) + except Exception, e: + if e.errno in close_errnos: + n = len(open_socks) + s.close() + open_socks[s].close() + del open_socks[s] + assert(len(open_socks) == n - 1) + else: + import pdb; pdb.set_trace() + return src, dst, data, completer + def select_based_iterator(): + packet_iter = iterate_packets() + packet_iter.next() + while True: + #print "open: %s" % len(open_socks) + rds, _wrs, _ex = select(get_fds(), [], []) + results = packet_iter.send(rds) + for r in results: + yield r + return handle_input, select_based_iterator, get_fds + +def proxy(local_port, remote_addr): + return Proxy(local_port, remote_addr) + +def closeallsockets(): + for s in Socket.sockets: + try: + s.close() + except: + pass + +def example_main(): + import sys + target_host, target_port = sys.argv[-1].split(':') + local_port = int(sys.argv[-2]) + for src, dst, data in proxy(local_port , (target_host, int(target_port))): + print "%s->%s %s" % (src, dst, len(data)) + +def tests(): + import argparse + import sys + p = argparse.ArgumentParser(description="proxy multiple socket connections") + p.add_argument('--listen-port', default=11000) + p.add_argument('--remote-port', default=12000) + p.add_argument('--remote-host', default='127.0.0.1') + args = p.parse_args(sys.argv[1:]) + for x in proxy(args.listen_port, (args.remote_host, args.remote_port)): + print x + +if __name__ == '__main__': + tests() + |