#!/usr/bin/env python import proxy import argparse import sys import glib import gtk import time import socket debug = True def dprint(x): global debug if not debug: return print x class Bandwidth(object): """ Compute average bandwidth over last X seconds """ def __init__(self, window_size, average_callback=None, limit=None): self.window_size = window_size self.limit = limit self.average_callback = average_callback self.reset() def __str__(self): return "" % (self.window_size) def add_packet(self, now, data): num_bytes = len(data) # TODO - other stuff then just length self.total += num_bytes # add new datapoint self.window.append((now, num_bytes)) self.remove_old(now) glib.timeout_add_seconds(self.window_size, self.remove_old, now + self.window_size + 0.0001) return self.average() def reset(self): self.total = 0 self.start_time = time.time() self.last_time = self.start_time self.window = [] def remove_old(self, now): indices = [] # remove too old data points for i, (t, x) in enumerate(self.window): if t + self.window_size < now: indices.append(i) for i in reversed(indices): del self.window[i] self.on_average_updated() def on_average_updated(self): if self.average_callback: self.average_callback(self) def bytes_in_window(self): # suboptimal - can update per packet return sum(x for t, x in self.window) def average(self): # suboptimal return float(self.bytes_in_window()) / self.window_size def global_average(self): if self.start_time == self.last_time: return 0 return float(self.total) / (self.last_time - self.start_time) def bytes_sendable_in_limit(self, now, count): """ Rate limiting model: proxy wakes up and wants to send some bytes. It asks bandwidth how many bytes it can send. Calculation assumes we want to send as many bytes as we can and still hold to a Average_over_last_T_seconds max. """ if not self.limit: return count self.remove_old(now) return min(count, self.limit * self.window_size - self.bytes_in_window()) def rate_to_string(rate): if rate == 0: return '0' if rate < 1024: return "< 1k" if rate < 1024 * 1024: return '%3.2f KiB' % (float(rate) / 1024) return '%3.2f MiB' % (float(rate) / 1024 / 1024) class UI(object): def __init__(self, pairs, bandwidth_limit): """ | hbox vbox | in_label out_label ; reset_button """ self.pairs = pairs self.bandwidth_limit = bandwidth_limit self.window = gtk.Window() self.vbox = gtk.VBox() self.window.add(self.vbox) self.port_in_bw = {} self.port_out_bw = {} self.bw_to_ui = {} for src, dst in self.pairs: self.add_in_out_pair(src, dst, self.vbox) self.bw_reset_button = gtk.Button('reset') self.bw_reset_button.connect("clicked", self.reset, None) self.vbox.add(self.bw_reset_button) self.window.set_size_request(200, 80) self.window.show_all() def add_in_out_pair(self, src, dst, vbox): #import pdb; pdb.set_trace() hbox1 = gtk.HBox() hbox2 = gtk.HBox() vbox.add(hbox1) vbox.add(hbox2) label_avg_in = gtk.Label('0') label_avg_out = gtk.Label('0') label_avg_in_from_reset = gtk.Label('0') label_avg_out_from_reset = gtk.Label('0') hbox1.add(label_avg_in) hbox1.add(label_avg_out) hbox2.add(label_avg_in_from_reset) hbox2.add(label_avg_out_from_reset) in_bw = Bandwidth(1, self.on_bw_average, limit=self.bandwidth_limit) out_bw = Bandwidth(1, self.on_bw_average, limit=self.bandwidth_limit) self.port_in_bw[src] = in_bw self.port_out_bw[src] = out_bw self.port_in_bw[dst] = in_bw self.port_out_bw[dst] = out_bw # TODO - use a class self.bw_to_ui[in_bw] = dict(label_avg=label_avg_in, label_avg_from_reset=label_avg_in_from_reset, hbox1=hbox1, hbox2=hbox2) self.bw_to_ui[out_bw] = dict(label_avg=label_avg_out, label_avg_from_reset=label_avg_out_from_reset, hbox1=hbox1, hbox2=hbox2) def reset(self, widget, data=None): for bw in self.bw_to_ui.keys(): bw.reset() def on_bw_average(self, bw): ui = self.bw_to_ui[bw] label_avg = ui['label_avg'] label_avg.set_label(rate_to_string(bw.average())) label_avg_from_reset = ui['label_avg_from_reset'] label_avg_from_reset.set_label(rate_to_string(bw.global_average())) #for u in [label_avg, label_avg_from_reset, ui['hbox1'], ui['hbox2'], self.vbox]: # u.queue_draw() self.vbox.queue_draw() def add_packet(self, now, src, dst, data): if src in self.port_in_bw: bw = self.port_in_bw[src] elif dst in self.port_out_bw: bw = self.port_out_bw[dst] else: print "got packet for unmonitored src port %d (%d->%d %d#)" % ( src, src, dst, len(data)) return bw.add_packet(now, data) def bandwidth(self, src, dst): if src in self.port_in_bw: bw = self.port_in_bw[src] elif dst in self.port_out_bw: bw = self.port_out_bw[dst] else: raise Exception("%s: non existant src and dst" % (self, src, dst)) return bw # How much time to wait before retrying the send of data that # was over the limit LIMIT_RETRY_TIMEOUT_MS = 10 def setup_proxy(ui, listen_port, listen_host, remote_addr, debug_proxy, bandwidth_limit): iterate_packets, handle_input, select_based_iterator, get_fds = proxy.make_proxy( listen_port, remote_addr, listen_host, debug=debug_proxy) assert(len(get_fds()) == 1) # only the accepting socket accepter = get_fds()[0] added_fds = set() # queued data due to bandwidth limiting. We live in a single threaded world. queued = {} def on_new_fd(fd): dprint("on_new_fd %s" % fd.fileno()) # Don't add glib.IO_OUT unless you mean it. # corollary: when I want to really implement this (to implement non # blocking writes) I'll have to io_add_watch(OUT) and return False and # io_add_watch(~OUT) glib.io_add_watch(fd.fileno(), glib.IO_IN | glib.IO_HUP | glib.IO_ERR, on_read, fd) def resend((target, src, dst, now)): key = (src, dst) if key not in queued: return data = queued[key] del queued[key] if not now: now = time.time() sendable = ui.bandwidth(src, dst ).bytes_sendable_in_limit(now, len(data)) if sendable == len(data): target.send(data) return target.send(data[:sendable]) data = data[sendable:] queued[key] = data glib.timeout_add(LIMIT_RETRY_TIMEOUT_MS, resend, (target, src, dst, None)) def on_read(glib_fd, condition, fd): # lame check to find out the fd is closed try: fd.fileno() except socket.error: dprint("removing socket %s" % fd) added_fds.remove(fd) return False dprint("called back %s, condition %s" % (fd.fileno(), condition)) if not condition & glib.IO_IN: dprint("not reading from %d" % fd.fileno()) if condition & glib.IO_HUP: return False return True result = handle_input(fd) update_fds() if not result: # accepter port returns nothing, don't close it. dprint("%s: result %s" % (fd.fileno(), repr(result))) return fd == accepter (src, dst, data, other, completer) = result now = time.time() if bandwidth_limit: # ugly - why access ui to get bw object? sendable = ui.bandwidth(src, dst ).bytes_sendable_in_limit(now, len(data)) else: sendable = len(data) if sendable == len(data): completer() dprint("%s: result %d->%d #%d" % (fd.fileno(), src, dst, len(data))) ui.add_packet(now, src, dst, data) else: # policy time: do we split packets or not. Let's cut. dprint("%s: %s->%s over bandwidth limit, sending %d/%d" % ( fd.fileno(), src, dst, sendable, len(data))) key = (src, dst) if key in queued: dprint("%s: %s data increased from %d to %d" % (fd.fileno(), key, len(queued[key]), len(queued[key]) + len(data))) # NB If this ever becomes a problem use an array (scatter/gather) queued[key] = queued[key] + data else: queued[key] = data resend((other, src, dst, now)) return True def update_fds(): fds = get_fds() dprint("update_fds: %s" % (','.join(str(f.fileno()) for f in fds))) new_fds = [] for fd in set(fds) - added_fds: on_new_fd(fd) added_fds.add(fd) new_fds.append(fd) update_fds() def main(): debug_proxy = False p = argparse.ArgumentParser(description="proxy multiple socket connections") p.add_argument('--listen-port', required=True, type=int, action='append') p.add_argument('--remote-port', required=True, type=int, action='append') # only allow same host for all ports - easier to implement. p.add_argument('--remote-host', default='127.0.0.1') p.add_argument('--listen-host', default='127.0.0.1') p.add_argument('--bandwidth-limit', default=None, type=int) p.add_argument('--debug') args = p.parse_args(sys.argv[1:]) if len(args.listen_port) != len(args.remote_port): print "must supply same amount of listening and remote ports" sys.exit(1) global debug debug = not not args.debug ui = UI(pairs=zip(args.listen_port, args.remote_port), bandwidth_limit=args.bandwidth_limit) for listen_port, remote_port in zip(args.listen_port, args.remote_port): setup_proxy(ui=ui, listen_port=listen_port, listen_host=args.listen_host, remote_addr=(args.remote_host, remote_port), debug_proxy=debug_proxy, bandwidth_limit=args.bandwidth_limit) gtk.main() if __name__ == '__main__': main()