#!/usr/bin/python from socket import socket, AF_INET, SOCK_STREAM from select import select MAX_PACKET_SIZE=65536 class Socket(socket): sockets = [] def __init__(self, *args, **kw): super(Socket, self).__init__(*args, **kw) self.sockets.append(self) class twowaydict(dict): def __init__(self): self.other = dict() def __delitem__(self, k): if super(twowaydict,self).__contains__(k): other_k = self[k] else: other_k = k k = self.other[k] del self.other[other_k] super(twowaydict,self).__delitem__(k) def __setitem__(self, k, v): super(twowaydict, self).__setitem__(k, v) self.other[v] = k def __contains__(self, k): return super(twowaydict,self).__contains__(k) or k in self.other def __getitem__(self, k): if super(twowaydict,self).__contains__(k): return super(twowaydict,self).__getitem__(k) return self.other[k] def getpair(self, k): if k in self: return k, self[k] elif k in self.other: return k, self.other[k] def allkeys(self): return self.keys() + self.other.keys() BROKEN_PIPE_ERRNO, TRANSPORT_NOT_CONNECTED_ERRNO = 32, 107 def make_accepter(port, host='127.0.0.1'): accepter = Socket(AF_INET, SOCK_STREAM) accepter.bind((host, port)) accepter.listen(1) return accepter def connect(addr): s = Socket(AF_INET, SOCK_STREAM) s.connect(addr) s.setblocking(False) return s class Proxy(object): def __init__(self, local_port, remote_addr, host='127.0.0.1'): self._drop_next = False self._proxy = _proxy(self, local_port, remote_addr, host) def drop_next(self): self._drop_next = True def check_drop_next(self): dn = self._drop_next self._drop_next = False return dn def __iter__(self): for x in self._proxy: yield x def _proxy(proxy, local_port, remote_addr, host = '127.0.0.1'): print "proxying from %s to %s" % (local_port, remote_addr) accepter = make_accepter(local_port, host) open_socks = twowaydict() close_errnos = set([BROKEN_PIPE_ERRNO, TRANSPORT_NOT_CONNECTED_ERRNO]) while True: #print "open: %s" % len(open_socks) rds, _wrs, _ex = select( [accepter]+open_socks.allkeys(), [], []) for s in rds: if s is accepter: s, _addr = accepter.accept() open_socks[s] = connect(remote_addr) else: other = open_socks[s] src_dst_socks = [s, other] src_dst = [None, None] dont_recv = False for i in xrange(len(src_dst)): try: src_dst[i] = src_dst_socks[i].getpeername()[1] except Exception, e: if e.errno in close_errnos: src_dst_socks[1-i].close() if s in open_socks: del open_socks[s] dont_recv = True if dont_recv: continue src, dst = src_dst data = s.recv(MAX_PACKET_SIZE) if len(data) == 0: other.close() del open_socks[s] continue yield src, dst, data if not proxy.check_drop_next(): try: other.send(data) except Exception, e: if e.errno in close_errnos: n = len(open_socks) s.close() open_socks[s].close() del open_socks[s] assert(len(open_socks) == n - 1) else: import pdb; pdb.set_trace() def proxy(local_port, remote_addr): return Proxy(local_port, remote_addr) def closeallsockets(): for s in Socket.sockets: try: s.close() except: pass def example_main(): import sys target_host, target_port = sys.argv[-1].split(':') local_port = int(sys.argv[-2]) for src, dst, data in proxy(local_port , (target_host, int(target_port))): print "%s->%s %s" % (src, dst, len(data)) def tests(): port_num = 8000 from_port, to_port = port_num, port_num+1000 proxy(port_num, ('localhost', port_num+1000)) def main(): import argparse import sys parser = argparse.ArgumentParser() parser.add_argument('-l', '--local-port', type=int, required=True, help='set proxy local port') parser.add_argument('-H', '--remote-host', default='localhost', help='set proxy remote address') parser.add_argument('-p', '--remote-port', type=int, required=True, help='set proxy remote address') parser.add_argument('-v', '--verbose', dest='verbose', action='count', help='verbosity', default=0) opts, rest = parser.parse_known_args(sys.argv[1:]) local_port = opts.local_port remote_addr = (opts.remote_host, opts.remote_port) p = proxy(local_port=local_port, remote_addr=remote_addr) for ret in p: if opts.verbose: print repr(ret) if __name__ == '__main__': main()