summaryrefslogtreecommitdiff
path: root/bandwidthmon
blob: 4f35ce9170ef263aead46715ef4895ae9795d85b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
#!/usr/bin/env python
import proxy
import argparse
import sys
import glib
import gtk
import time
import socket

debug = True
def dprint(x):
    global debug
    if not debug:
        return
    print x

class Bandwidth(object):
    """ Compute average bandwidth over last X seconds
    """

    def __init__(self, window_size, average_callback=None, limit=None):
        self.window_size = window_size
        self.limit = limit
        self.average_callback = average_callback
        self.reset()

    def __str__(self):
        return "<Bandwidth window:%d sec>" % (self.window_size)

    def add_packet(self, now, data):
        num_bytes = len(data)
        # TODO - other stuff then just length
        self.total += num_bytes
        # add new datapoint
        self.window.append((now, num_bytes))
        self.remove_old(now)
        glib.timeout_add_seconds(self.window_size, self.remove_old,
                                 now + self.window_size + 0.0001)
        return self.average()

    def reset(self):
        self.total = 0
        self.start_time = time.time()
        self.last_time = self.start_time
        self.window = []

    def remove_old(self, now):
        indices = []
        # remove too old data points
        for i, (t, x) in enumerate(self.window):
            if t + self.window_size < now:
                indices.append(i)
        for i in reversed(indices):
            del self.window[i]
        self.on_average_updated()

    def on_average_updated(self):
        if self.average_callback:
            self.average_callback(self)

    def bytes_in_window(self):
        # suboptimal - can update per packet
        return sum(x for t, x in self.window)

    def average(self):
        # suboptimal
        return float(self.bytes_in_window()) / self.window_size

    def global_average(self):
        if self.start_time == self.last_time:
            return 0
        return float(self.total) / (self.last_time - self.start_time)

    def bytes_sendable_in_limit(self, now, count):
        """ Rate limiting model:
        proxy wakes up and wants to send some bytes. It asks bandwidth how
        many bytes it can send. Calculation assumes we want to send as many
        bytes as we can and still hold to a Average_over_last_T_seconds max.
        """
        if not self.limit:
            return count
        self.remove_old(now)
        return min(count, self.limit * self.window_size - self.bytes_in_window())

def rate_to_string(rate):
    if rate == 0:
        return '0'
    if rate < 1024:
        return "< 1k"
    if rate < 1024 * 1024:
        return '%3.2f KiB' % (float(rate) / 1024)
    return '%3.2f MiB' % (float(rate) / 1024 / 1024)

class UI(object):
    def __init__(self, pairs, bandwidth_limit):
        """
             |        hbox
        vbox | in_label out_label
             ; reset_button
        """
        self.pairs = pairs
        self.bandwidth_limit = bandwidth_limit
        self.window = gtk.Window()
        self.vbox = gtk.VBox()
        self.window.add(self.vbox)
        self.port_in_bw = {}
        self.port_out_bw = {}
        self.bw_to_ui = {}
        for src, dst in self.pairs:
            self.add_in_out_pair(src, dst, self.vbox)
        self.bw_reset_button = gtk.Button('reset')
        self.bw_reset_button.connect("clicked", self.reset, None)
        self.vbox.add(self.bw_reset_button)
        self.window.set_size_request(200, 80)
        self.window.show_all()
    
    def add_in_out_pair(self, src, dst, vbox):
        #import pdb; pdb.set_trace()
        hbox1 = gtk.HBox()
        hbox2 = gtk.HBox()
        vbox.add(hbox1)
        vbox.add(hbox2)
        label_avg_in = gtk.Label('0')
        label_avg_out = gtk.Label('0')
        label_avg_in_from_reset = gtk.Label('0')
        label_avg_out_from_reset = gtk.Label('0')
        hbox1.add(label_avg_in)
        hbox1.add(label_avg_out)
        hbox2.add(label_avg_in_from_reset)
        hbox2.add(label_avg_out_from_reset)
        in_bw = Bandwidth(1, self.on_bw_average, limit=self.bandwidth_limit)
        out_bw = Bandwidth(1, self.on_bw_average, limit=self.bandwidth_limit)
        self.port_in_bw[src] = in_bw
        self.port_out_bw[src] = out_bw
        self.port_in_bw[dst] = in_bw
        self.port_out_bw[dst] = out_bw
        # TODO - use a class
        self.bw_to_ui[in_bw] = dict(label_avg=label_avg_in,
            label_avg_from_reset=label_avg_in_from_reset,
            hbox1=hbox1, hbox2=hbox2)
        self.bw_to_ui[out_bw] = dict(label_avg=label_avg_out,
            label_avg_from_reset=label_avg_out_from_reset,
            hbox1=hbox1, hbox2=hbox2)

    def reset(self, widget, data=None):
        for bw in self.bw_to_ui.keys():
            bw.reset()

    def on_bw_average(self, bw):
        ui = self.bw_to_ui[bw]
        label_avg = ui['label_avg']
        label_avg.set_label(rate_to_string(bw.average()))
        label_avg_from_reset = ui['label_avg_from_reset']
        label_avg_from_reset.set_label(rate_to_string(bw.global_average()))
        #for u in [label_avg, label_avg_from_reset, ui['hbox1'], ui['hbox2'], self.vbox]:
        #    u.queue_draw()
        self.vbox.queue_draw()

    def add_packet(self, now, src, dst, data):
        if src in self.port_in_bw:
            bw = self.port_in_bw[src]
        elif dst in self.port_out_bw:
            bw = self.port_out_bw[dst]
        else:
            print "got packet for unmonitored src port %d (%d->%d %d#)" % (
                src, src, dst, len(data))
            return
        bw.add_packet(now, data)

    def bandwidth(self, src, dst):
        if src in self.port_in_bw:
            bw = self.port_in_bw[src]
        elif dst in self.port_out_bw:
            bw = self.port_out_bw[dst]
        else:
            raise Exception("%s: non existant src and dst" % (self, src, dst))
        return bw

# How much time to wait before retrying the send of data that
# was over the limit
LIMIT_RETRY_TIMEOUT_MS = 10
def setup_proxy(ui, listen_port, listen_host, remote_addr, debug_proxy,
                bandwidth_limit):
    iterate_packets, handle_input, select_based_iterator, get_fds = proxy.make_proxy(
        listen_port, remote_addr, listen_host, debug=debug_proxy)
    assert(len(get_fds()) == 1) # only the accepting socket
    accepter = get_fds()[0]
    added_fds = set()
    # queued data due to bandwidth limiting. We live in a single threaded world.
    queued = {}
    def on_new_fd(fd):
        dprint("on_new_fd %s" % fd.fileno())
        # Don't add glib.IO_OUT unless you mean it.
        # corollary: when I want to really implement this (to implement non
        # blocking writes) I'll have to io_add_watch(OUT) and return False and
        # io_add_watch(~OUT)
        glib.io_add_watch(fd.fileno(),
                          glib.IO_IN | glib.IO_HUP | glib.IO_ERR,
                          on_read, fd)
    def resend((target, src, dst, now)):
        key = (src, dst)
        if key not in queued:
            return
        data = queued[key]
        del queued[key]
        if not now:
            now = time.time()
        sendable = ui.bandwidth(src, dst
                        ).bytes_sendable_in_limit(now, len(data))
        if sendable == len(data):
            target.send(data)
            return
        target.send(data[:sendable])
        data = data[sendable:]
        queued[key] = data
        glib.timeout_add(LIMIT_RETRY_TIMEOUT_MS, resend,
                         (target, src, dst, None))
        
    def on_read(glib_fd, condition, fd):
        # lame check to find out the fd is closed
        try:
            fd.fileno()
        except socket.error:
            dprint("removing socket %s" % fd)
            added_fds.remove(fd)
            return False
        dprint("called back %s, condition %s" % (fd.fileno(), condition))
        if not condition & glib.IO_IN:
            dprint("not reading from %d" % fd.fileno())
            if condition & glib.IO_HUP:
                return False
            return True
        result = handle_input(fd)
        update_fds()
        if not result:
            # accepter port returns nothing, don't close it.
            dprint("%s: result %s" % (fd.fileno(), repr(result)))
            return fd == accepter
        (src, dst, data, other, completer) = result
        now = time.time()
        if bandwidth_limit:
            # ugly - why access ui to get bw object?
            sendable = ui.bandwidth(src, dst
                            ).bytes_sendable_in_limit(now, len(data))
        else:
            sendable = len(data)
        if sendable == len(data):
            completer()
            dprint("%s: result %d->%d #%d" % (fd.fileno(), src, dst, len(data)))
            ui.add_packet(now, src, dst, data)
        else:
            # policy time: do we split packets or not. Let's cut.
            dprint("%s: %s->%s over bandwidth limit, sending %d/%d" % (
                   fd.fileno(), src, dst, sendable, len(data)))
            key = (src, dst)
            if key in queued:
                dprint("%s: %s data increased from %d to %d" %
                       (fd.fileno(), key, len(queued[key]),
                        len(queued[key]) + len(data)))
                # NB If this ever becomes a problem use an array (scatter/gather)
                queued[key] = queued[key] + data
            else:
                queued[key] = data
            resend((other, src, dst, now))
        return True
    def update_fds():
        fds = get_fds()
        dprint("update_fds: %s" % (','.join(str(f.fileno()) for f in fds)))
        new_fds = []
        for fd in set(fds) - added_fds:
            on_new_fd(fd)
            added_fds.add(fd)
            new_fds.append(fd)
    update_fds()

def main():
    debug_proxy = False
    p = argparse.ArgumentParser(description="proxy multiple socket connections")
    p.add_argument('--listen-port', required=True, type=int, action='append')
    p.add_argument('--remote-port', required=True, type=int, action='append')
    # only allow same host for all ports - easier to implement.
    p.add_argument('--remote-host', default='127.0.0.1')
    p.add_argument('--listen-host', default='127.0.0.1')
    p.add_argument('--bandwidth-limit', default=None, type=int)
    p.add_argument('--debug')
    args = p.parse_args(sys.argv[1:])
    if len(args.listen_port) != len(args.remote_port):
        print "must supply same amount of listening and remote ports"
        sys.exit(1)
    global debug
    debug = not not args.debug
    ui = UI(pairs=zip(args.listen_port, args.remote_port),
            bandwidth_limit=args.bandwidth_limit)
    for listen_port, remote_port in zip(args.listen_port, args.remote_port):
        setup_proxy(ui=ui, listen_port=listen_port, listen_host=args.listen_host,
                    remote_addr=(args.remote_host, remote_port),
                    debug_proxy=debug_proxy,
                    bandwidth_limit=args.bandwidth_limit)
    gtk.main()

if __name__ == '__main__':
    main()