#!/usr/bin/python import socket from socket import AF_INET, SOCK_STREAM from select import select MAX_PACKET_SIZE=65536 class Socket(socket.socket): sockets = [] def __init__(self, *args, **kw): super(Socket, self).__init__(*args, **kw) self.sockets.append(self) class twowaydict(dict): def __init__(self): self.other = dict() def __delitem__(self, k): if super(twowaydict,self).__contains__(k): other_k = self[k] else: other_k = k k = self.other[k] del self.other[other_k] super(twowaydict,self).__delitem__(k) def __setitem__(self, k, v): super(twowaydict, self).__setitem__(k, v) self.other[v] = k def __contains__(self, k): return super(twowaydict,self).__contains__(k) or k in self.other def __getitem__(self, k): if super(twowaydict,self).__contains__(k): return super(twowaydict,self).__getitem__(k) return self.other[k] def getpair(self, k): if k in self: return k, self[k] elif k in self.other: return k, self.other[k] def allkeys(self): return self.keys() + self.other.keys() BROKEN_PIPE_ERRNO, TRANSPORT_NOT_CONNECTED_ERRNO = 32, 107 def make_accepter(port, host='127.0.0.1'): accepter = Socket(AF_INET, SOCK_STREAM) accepter.bind((host, port)) accepter.listen(1) return accepter def connect(addr): s = Socket(AF_INET, SOCK_STREAM) try: s.connect(addr) except socket.error: return None s.setblocking(False) return s class Proxy(object): def __init__(self, local_port, remote_addr, host='127.0.0.1'): self._drop_next = False iterate_packets, handle_input, select_based_iterator, get_fds = make_proxy( local_port, remote_addr, host, check_drop_next=self.check_drop_next) self.select_based_iterator = select_based_iterator() def drop_next(self): self._drop_next = True def check_drop_next(self): dn = self._drop_next self._drop_next = False return dn def __iter__(self): for x in self.select_based_iterator: yield x def make_proxy(local_port, remote_addr, host = '127.0.0.1', check_drop_next=lambda: False, debug=False): print "proxying from %s to %s" % (local_port, remote_addr) accepter = make_accepter(local_port, host) open_socks = twowaydict() close_errnos = set([BROKEN_PIPE_ERRNO, TRANSPORT_NOT_CONNECTED_ERRNO]) def dprint(s): if debug: print s def get_fds(): fds = [accepter] + open_socks.allkeys() dprint("make_proxy: %s" % repr(fds)) return fds def iterate_packets(): """ An iterator with inputs, designed to work with some external event loop. Supply readable sockets list via send, see below example. """ readable = yield [] results = [] while True: for s in readable: result = handle_input(s) if result: src, dst, data, completer = result results.append((src, dst, data)) # TODO - do the condition here, have it provided as # a callback. The mixed oop/fp style is bad. completer() readable = yield results results = [] def handle_input(s): """ returns None if a socket was closed (nothing to do), or a quadtuple of (src, dst, data, completer) The completer will send the packet to the destination. Conversely if you don't call it the packet is simply dropped. """ if s is accepter: s, _addr = accepter.accept() s_or_none = connect(remote_addr) if s_or_none: open_socks[s] = s_or_none else: print "connection to remote %s failed" % repr(remote_addr) else: other = open_socks[s] src_dst_socks = [s, other] src_dst = [None, None] dont_recv = False for i in xrange(len(src_dst)): try: src_dst[i] = src_dst_socks[i].getpeername()[1] except Exception, e: if e.errno in close_errnos: src_dst_socks[1-i].close() if s in open_socks: del open_socks[s] dont_recv = True if dont_recv: return None src, dst = src_dst try: dprint("recv from %d" % s.fileno()) data = s.recv(MAX_PACKET_SIZE) except socket.error: print "handle_input: bad event loop - socket invalid %s" % s.fileno() return None if len(data) == 0: other.close() del open_socks[s] return def completer(): """ default completion of proxy action. sends data to destination port. """ if check_drop_next(): return try: other.send(data) except Exception, e: if e.errno in close_errnos: n = len(open_socks) s.close() open_socks[s].close() del open_socks[s] assert(len(open_socks) == n - 1) else: print "Caught exception writing: %s" % repr(e) #import pdb; pdb.set_trace() return src, dst, data, other, completer def select_based_iterator(): packet_iter = iterate_packets() packet_iter.next() while True: #print "open: %s" % len(open_socks) rds, _wrs, _ex = select(get_fds(), [], []) results = packet_iter.send(rds) for r in results: yield r return iterate_packets, handle_input, select_based_iterator, get_fds def proxy(local_port, remote_addr): return Proxy(local_port, remote_addr) def closeallsockets(): for s in Socket.sockets: try: s.close() except: pass def tests(): import argparse import sys p = argparse.ArgumentParser(description="proxy multiple socket connections") p.add_argument('--listen-port', default=11000, type=int) p.add_argument('--listen-host', default='127.0.0.1') p.add_argument('--remote-port', default=12000, type=int) p.add_argument('--remote-host', default='127.0.0.1') args = p.parse_args(sys.argv[1:]) iterate_packets, handle_input, select_based_iterator, get_fds = make_proxy( args.listen_port, (args.remote_host, args.remote_port), args.listen_host, debug=True) for src, dst, data in select_based_iterator(): print "%s->%s %s" % (src, dst, len(data)) if __name__ == '__main__': tests()