summaryrefslogtreecommitdiff
path: root/proxy.py
blob: 227d88c5dc16efa1041b5ca9a22aebb04c94c149 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#!/usr/bin/python
from socket import socket, AF_INET, SOCK_STREAM
from select import select

MAX_PACKET_SIZE=65536

class Socket(socket):
    sockets = []
    def __init__(self, *args, **kw):
        super(Socket, self).__init__(*args, **kw)
        self.sockets.append(self)

class twowaydict(dict):
    def __init__(self):
        self.other = dict()
    def __delitem__(self, k):
        if super(twowaydict,self).__contains__(k):
            other_k = self[k]
        else:
            other_k = k
            k = self.other[k]
        del self.other[other_k]
        super(twowaydict,self).__delitem__(k)
    def __setitem__(self, k, v):
        super(twowaydict, self).__setitem__(k, v)
        self.other[v] = k
    def __contains__(self, k):
        return super(twowaydict,self).__contains__(k) or k in self.other
    def __getitem__(self, k):
        if super(twowaydict,self).__contains__(k):
            return super(twowaydict,self).__getitem__(k)
        return self.other[k]
    def getpair(self, k):
        if k in self:
            return k, self[k]
        elif k in self.other:
            return k, self.other[k]
    def allkeys(self):
        return self.keys() + self.other.keys()

BROKEN_PIPE_ERRNO, TRANSPORT_NOT_CONNECTED_ERRNO = 32, 107

def make_accepter(port, host='127.0.0.1'):
    accepter = Socket(AF_INET, SOCK_STREAM)
    accepter.bind((host, port))
    accepter.listen(1)
    return accepter

def connect(addr):
    s = Socket(AF_INET, SOCK_STREAM)
    s.connect(addr)
    s.setblocking(False)
    return s

class Proxy(object):
    def __init__(self, local_port, remote_addr, host='127.0.0.1'):
        self._drop_next = False
        self._proxy = _proxy(self, local_port, remote_addr, host)
    def drop_next(self):
        self._drop_next = True
    def check_drop_next(self):
        dn = self._drop_next
        self._drop_next = False
        return dn
    def __iter__(self):
        for x in self._proxy:
            yield x

def _proxy(proxy, local_port, remote_addr, host = '127.0.0.1'):
    print "proxying from %s to %s" % (local_port, remote_addr)
    accepter = make_accepter(local_port, host)
    open_socks = twowaydict()
    close_errnos = set([BROKEN_PIPE_ERRNO, TRANSPORT_NOT_CONNECTED_ERRNO])
    while True:
        #print "open: %s" % len(open_socks)
        rds, _wrs, _ex = select(
            [accepter]+open_socks.allkeys(), [], [])
        for s in rds:
            if s is accepter:
                s, _addr = accepter.accept()
                open_socks[s] = connect(remote_addr)
            else:
                other = open_socks[s]
                src_dst_socks = [s, other]
                src_dst = [None, None]
                dont_recv = False
                for i in xrange(len(src_dst)):
                    try:
                        src_dst[i] = src_dst_socks[i].getpeername()[1]
                    except Exception, e:
                        if e.errno in close_errnos:
                            src_dst_socks[1-i].close()
                            if s in open_socks:
                                del open_socks[s]
                            dont_recv = True
                if dont_recv: continue
                src, dst = src_dst
                data = s.recv(MAX_PACKET_SIZE)
                if len(data) == 0:
                    other.close()
                    del open_socks[s]
                    continue
                yield src, dst, data
                if not proxy.check_drop_next():
                    try:
                        other.send(data)
                    except Exception, e:
                        if e.errno in close_errnos:
                            n = len(open_socks)
                            s.close()
                            open_socks[s].close()
                            del open_socks[s]
                            assert(len(open_socks) == n - 1)
                        else:
                            import pdb; pdb.set_trace()

def proxy(local_port, remote_addr):
    return Proxy(local_port, remote_addr)

def closeallsockets():
    for s in Socket.sockets:
        try:
            s.close()
        except:
            pass

def example_main():
    import sys
    target_host, target_port = sys.argv[-1].split(':')
    local_port = int(sys.argv[-2])
    for src, dst, data in proxy(local_port , (target_host, int(target_port))):
        print "%s->%s %s" % (src, dst, len(data))

def tests():
    port_num = 8000
    from_port, to_port = port_num, port_num+1000
    proxy(port_num, ('localhost', port_num+1000))

def main():
    import argparse
    import sys
    parser = argparse.ArgumentParser()
    parser.add_argument('-l', '--local-port', type=int, required=True, help='set proxy local port')
    parser.add_argument('-H', '--remote-host', default='localhost', help='set proxy remote address')
    parser.add_argument('-p', '--remote-port', type=int, required=True, help='set proxy remote address')
    parser.add_argument('-v', '--verbose', dest='verbose', action='count', help='verbosity', default=0)
    opts, rest = parser.parse_known_args(sys.argv[1:])
    local_port = opts.local_port
    remote_addr = (opts.remote_host, opts.remote_port)
    p = proxy(local_port=local_port, remote_addr=remote_addr)
    for ret in p:
        if opts.verbose:
            print repr(ret)

if __name__ == '__main__':
    main()