diff options
author | Matthew Waters <matthew@centricular.com> | 2018-09-10 18:08:15 +1000 |
---|---|---|
committer | Matthew Waters <ystreet00@gmail.com> | 2020-05-06 06:01:57 +0000 |
commit | bc821a85d4b418d2c4ff670f76e203fc0bfc3290 (patch) | |
tree | 2937312caedeb8026d2a1043754d3f8dc2177d28 | |
parent | 37cf0dffb54be3a7ba319fb17a49aa265c5e92b5 (diff) |
tests: first pass at some basic browser tests
-rw-r--r-- | webrtc/meson.build | 11 | ||||
-rw-r--r-- | webrtc/sendrecv/gst/meson.build | 4 | ||||
-rw-r--r-- | webrtc/sendrecv/gst/tests/basic.py | 144 | ||||
-rw-r--r-- | webrtc/sendrecv/gst/tests/meson.build | 15 | ||||
-rw-r--r-- | webrtc/sendrecv/gst/webrtc_sendrecv.py (renamed from webrtc/sendrecv/gst/webrtc-sendrecv.py) | 28 | ||||
-rwxr-xr-x | webrtc/signalling/generate_cert.sh | 11 | ||||
-rw-r--r-- | webrtc/signalling/meson.build | 8 | ||||
-rwxr-xr-x | webrtc/signalling/simple-server.py | 286 | ||||
-rwxr-xr-x | webrtc/signalling/simple_server.py | 291 |
9 files changed, 505 insertions, 293 deletions
diff --git a/webrtc/meson.build b/webrtc/meson.build index 482b824..3ee35f7 100644 --- a/webrtc/meson.build +++ b/webrtc/meson.build @@ -1,5 +1,6 @@ project('gstwebrtc-demo', 'c', meson_version : '>= 0.48', + license: 'BSD-2-Clause', default_options : [ 'warning_level=1', 'buildtype=debug' ]) @@ -24,5 +25,15 @@ libsoup_dep = dependency('libsoup-2.4', version : '>=2.48', json_glib_dep = dependency('json-glib-1.0', fallback : ['json-glib', 'json_glib_dep']) + +py3_mod = import('python3') +py3 = py3_mod.find_python() + +py3_version = py3_mod.language_version() +if py3_version.version_compare('< 3.6') + error('Could not find a sufficient python version required: 3.6, found {}'.format(py3_version)) +endif + subdir('multiparty-sendrecv') +subdir('signalling') subdir('sendrecv') diff --git a/webrtc/sendrecv/gst/meson.build b/webrtc/sendrecv/gst/meson.build index 7150e42..85950ad 100644 --- a/webrtc/sendrecv/gst/meson.build +++ b/webrtc/sendrecv/gst/meson.build @@ -1,3 +1,7 @@ executable('webrtc-sendrecv', 'webrtc-sendrecv.c', dependencies : [gst_dep, gstsdp_dep, gstwebrtc_dep, libsoup_dep, json_glib_dep ]) + +webrtc_py = files('webrtc_sendrecv.py') + +subdir('tests') diff --git a/webrtc/sendrecv/gst/tests/basic.py b/webrtc/sendrecv/gst/tests/basic.py new file mode 100644 index 0000000..84ce81d --- /dev/null +++ b/webrtc/sendrecv/gst/tests/basic.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 + +import os +import unittest +from selenium import webdriver +from selenium.webdriver.support.wait import WebDriverWait +from selenium.webdriver.firefox.firefox_profile import FirefoxProfile +from selenium.webdriver.chrome.options import Options as COptions +import webrtc_sendrecv as webrtc +import simple_server as sserver +import asyncio +import threading +import signal + +import gi +gi.require_version('Gst', '1.0') +from gi.repository import Gst + +thread = None +stop = None +server = None + +class AsyncIOThread(threading.Thread): + def __init__ (self, loop): + threading.Thread.__init__(self) + self.loop = loop + + def run(self): + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + self.loop.close() + print ("closed loop") + + def stop_thread(self): + self.loop.call_soon_threadsafe(self.loop.stop) + +async def run_until(server, stop_token): + async with server: + await stop_token + print ("run_until done") + +def setUpModule(): + global thread, server + Gst.init(None) + cacerts_path = os.environ.get('TEST_CA_CERT_PATH') + loop = asyncio.new_event_loop() + + thread = AsyncIOThread(loop) + thread.start() + server = sserver.WebRTCSimpleServer('127.0.0.1', 8443, 20, False, cacerts_path) + def f(): + global stop + stop = asyncio.ensure_future(server.run()) + loop.call_soon_threadsafe(f) + +def tearDownModule(): + global thread, stop + stop.cancel() + thread.stop_thread() + thread.join() + print("thread joined") + +def valid_int(n): + if isinstance(n, int): + return True + if isinstance(n, str): + try: + num = int(n) + return True + except: + return False + return False + +def create_firefox_driver(): + capabilities = webdriver.DesiredCapabilities().FIREFOX.copy() + capabilities['acceptSslCerts'] = True + capabilities['acceptInsecureCerts'] = True + profile = FirefoxProfile() + profile.set_preference ('media.navigator.streams.fake', True) + profile.set_preference ('media.navigator.permission.disabled', True) + + return webdriver.Firefox(firefox_profile=profile, capabilities=capabilities) + +def create_chrome_driver(): + capabilities = webdriver.DesiredCapabilities().CHROME.copy() + capabilities['acceptSslCerts'] = True + capabilities['acceptInsecureCerts'] = True + copts = COptions() + copts.add_argument('--allow-file-access-from-files') + copts.add_argument('--use-fake-ui-for-media-stream') + copts.add_argument('--use-fake-device-for-media-stream') + copts.add_argument('--enable-blink-features=RTCUnifiedPlanByDefault') + + return webdriver.Chrome(options=copts, desired_capabilities=capabilities) + +class ServerConnectionTestCase(unittest.TestCase): + def setUp(self): + self.browser = create_firefox_driver() +# self.browser = create_chrome_driver() + self.addCleanup(self.browser.quit) + self.html_source = os.environ.get('TEST_HTML_SOURCE') + self.assertIsNot(self.html_source, None) + self.assertNotEqual(self.html_source, '') + self.html_source = 'file://' + self.html_source + '/index.html' + + def get_peer_id(self): + self.browser.get(self.html_source) + peer_id = WebDriverWait(self.browser, 5).until( + lambda x: x.find_element_by_id('peer-id'), + message='Peer-id element was never seen' + ) + WebDriverWait (self.browser, 5).until( + lambda x: valid_int(peer_id.text), + message='Peer-id never became a number' + ) + return int(peer_id.text) + + def testPeerID(self): + self.get_peer_id() + + def testPerformCall(self): + loop = asyncio.new_event_loop() + thread = AsyncIOThread(loop) + thread.start() + peer_id = self.get_peer_id() + client = webrtc.WebRTCClient(peer_id + 1, peer_id, 'wss://127.0.0.1:8443') + + async def do_things(): + await client.connect() + async def stop_after(client, delay): + await asyncio.sleep(delay) + await client.stop() + future = asyncio.ensure_future (stop_after (client, 5)) + res = await client.loop() + thread.stop_thread() + return res + + res = asyncio.run_coroutine_threadsafe(do_things(), loop).result() + thread.join() + print ("client thread joined") + self.assertEqual(res, 0) + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/webrtc/sendrecv/gst/tests/meson.build b/webrtc/sendrecv/gst/tests/meson.build new file mode 100644 index 0000000..dfca6eb --- /dev/null +++ b/webrtc/sendrecv/gst/tests/meson.build @@ -0,0 +1,15 @@ +tests = [ + ['basic', 'basic.py'], +] + +test_deps = [certs] + +foreach elem : tests + test(elem.get(0), + py3, + depends: test_deps, + args : files(elem.get(1)), + env : ['PYTHONPATH=' + join_paths(meson.source_root(), 'sendrecv', 'gst') + ':' + join_paths(meson.source_root(), 'signalling'), + 'TEST_HTML_SOURCE=' + join_paths(meson.source_root(), 'sendrecv', 'js'), + 'TEST_CA_CERT_PATH=' + join_paths(meson.build_root(), 'signalling')]) +endforeach diff --git a/webrtc/sendrecv/gst/webrtc-sendrecv.py b/webrtc/sendrecv/gst/webrtc_sendrecv.py index b19bc3f..b101e8c 100644 --- a/webrtc/sendrecv/gst/webrtc-sendrecv.py +++ b/webrtc/sendrecv/gst/webrtc_sendrecv.py @@ -23,6 +23,8 @@ webrtcbin name=sendrecv bundle-policy=max-bundle stun-server=stun://stun.l.googl queue ! application/x-rtp,media=audio,encoding-name=OPUS,payload=96 ! sendrecv. ''' +from websockets.version import version as wsv + class WebRTCClient: def __init__(self, id_, peer_id, server): self.id_ = id_ @@ -32,10 +34,11 @@ class WebRTCClient: self.peer_id = peer_id self.server = server or 'wss://webrtc.nirbheek.in:8443' + async def connect(self): sslctx = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH) self.conn = await websockets.connect(self.server, ssl=sslctx) - await self.conn.send('HELLO %d' % our_id) + await self.conn.send('HELLO %d' % self.id_) async def setup_call(self): await self.conn.send('SESSION {}'.format(self.peer_id)) @@ -46,6 +49,7 @@ class WebRTCClient: msg = json.dumps({'sdp': {'type': 'offer', 'sdp': text}}) loop = asyncio.new_event_loop() loop.run_until_complete(self.conn.send(msg)) + loop.close() def on_offer_created(self, promise, _, __): promise.wait() @@ -64,6 +68,7 @@ class WebRTCClient: icemsg = json.dumps({'ice': {'candidate': candidate, 'sdpMLineIndex': mlineindex}}) loop = asyncio.new_event_loop() loop.run_until_complete(self.conn.send(icemsg)) + loop.close() def on_incoming_decodebin_stream(self, _, pad): if not pad.has_current_caps(): @@ -113,7 +118,7 @@ class WebRTCClient: self.webrtc.connect('pad-added', self.on_incoming_stream) self.pipe.set_state(Gst.State.PLAYING) - async def handle_sdp(self, message): + def handle_sdp(self, message): assert (self.webrtc) msg = json.loads(message) if 'sdp' in msg: @@ -133,6 +138,11 @@ class WebRTCClient: sdpmlineindex = ice['sdpMLineIndex'] self.webrtc.emit('add-ice-candidate', sdpmlineindex, candidate) + def close_pipeline(self): + self.pipe.set_state(Gst.State.NULL) + self.pipe = None + self.webrtc = None + async def loop(self): assert self.conn async for message in self.conn: @@ -142,11 +152,18 @@ class WebRTCClient: self.start_pipeline() elif message.startswith('ERROR'): print (message) + self.close_pipeline() return 1 else: - await self.handle_sdp(message) + self.handle_sdp(message) + self.close_pipeline() return 0 + async def stop(self): + if self.conn: + await self.conn.close() + self.conn = None + def check_plugins(): needed = ["opus", "vpx", "nice", "webrtc", "dtls", "srtp", "rtp", @@ -168,6 +185,7 @@ if __name__=='__main__': args = parser.parse_args() our_id = random.randrange(10, 10000) c = WebRTCClient(our_id, args.peerid, args.server) - asyncio.get_event_loop().run_until_complete(c.connect()) - res = asyncio.get_event_loop().run_until_complete(c.loop()) + loop = asyncio.get_event_loop() + loop.run_until_complete(c.connect()) + res = loop.run_until_complete(c.loop()) sys.exit(res) diff --git a/webrtc/signalling/generate_cert.sh b/webrtc/signalling/generate_cert.sh index 68a4b96..7f4084f 100755 --- a/webrtc/signalling/generate_cert.sh +++ b/webrtc/signalling/generate_cert.sh @@ -1,3 +1,10 @@ -#! /bin/bash +#! /bin/sh -openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes +BASE_DIR=$(dirname $0) + +OUTDIR="" +if [ $# -eq 1 ]; then + OUTDIR=$1/ +fi + +openssl req -x509 -newkey rsa:4096 -keyout ${OUTDIR}key.pem -out ${OUTDIR}cert.pem -days 365 -nodes -subj "/CN=example.com" diff --git a/webrtc/signalling/meson.build b/webrtc/signalling/meson.build new file mode 100644 index 0000000..43a53bf --- /dev/null +++ b/webrtc/signalling/meson.build @@ -0,0 +1,8 @@ +generate_certs = find_program('generate_cert.sh') +certs = custom_target( + 'generate-certs', + command: [generate_certs, '@OUTDIR@'], + output : ['key.pem', 'cert.pem'] +) + +simple_server = files('simple_server.py') diff --git a/webrtc/signalling/simple-server.py b/webrtc/signalling/simple-server.py deleted file mode 100755 index b337eae..0000000 --- a/webrtc/signalling/simple-server.py +++ /dev/null @@ -1,286 +0,0 @@ -#!/usr/bin/env python3 -# -# Example 1-1 call signalling server -# -# Copyright (C) 2017 Centricular Ltd. -# -# Author: Nirbheek Chauhan <nirbheek@centricular.com> -# - -import os -import sys -import ssl -import logging -import asyncio -import websockets -import argparse -import http - -from concurrent.futures._base import TimeoutError - -parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) -# See: host, port in https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_server -parser.add_argument('--addr', default='', help='Address to listen on (default: all interfaces, both ipv4 and ipv6)') -parser.add_argument('--port', default=8443, type=int, help='Port to listen on') -parser.add_argument('--keepalive-timeout', dest='keepalive_timeout', default=30, type=int, help='Timeout for keepalive (in seconds)') -parser.add_argument('--cert-path', default=os.path.dirname(__file__)) -parser.add_argument('--disable-ssl', default=False, help='Disable ssl', action='store_true') -parser.add_argument('--health', default='/health', help='Health check route') - -options = parser.parse_args(sys.argv[1:]) - -ADDR_PORT = (options.addr, options.port) -KEEPALIVE_TIMEOUT = options.keepalive_timeout - -############### Global data ############### - -# Format: {uid: (Peer WebSocketServerProtocol, -# remote_address, -# <'session'|room_id|None>)} -peers = dict() -# Format: {caller_uid: callee_uid, -# callee_uid: caller_uid} -# Bidirectional mapping between the two peers -sessions = dict() -# Format: {room_id: {peer1_id, peer2_id, peer3_id, ...}} -# Room dict with a set of peers in each room -rooms = dict() - -############### Helper functions ############### - -async def health_check(path, request_headers): - if path == options.health: - return http.HTTPStatus.OK, [], b"OK\n" - -async def recv_msg_ping(ws, raddr): - ''' - Wait for a message forever, and send a regular ping to prevent bad routers - from closing the connection. - ''' - msg = None - while msg is None: - try: - msg = await asyncio.wait_for(ws.recv(), KEEPALIVE_TIMEOUT) - except TimeoutError: - print('Sending keepalive ping to {!r} in recv'.format(raddr)) - await ws.ping() - return msg - -async def disconnect(ws, peer_id): - ''' - Remove @peer_id from the list of sessions and close our connection to it. - This informs the peer that the session and all calls have ended, and it - must reconnect. - ''' - global sessions - if peer_id in sessions: - del sessions[peer_id] - # Close connection - if ws and ws.open: - # Don't care about errors - asyncio.ensure_future(ws.close(reason='hangup')) - -async def cleanup_session(uid): - if uid in sessions: - other_id = sessions[uid] - del sessions[uid] - print("Cleaned up {} session".format(uid)) - if other_id in sessions: - del sessions[other_id] - print("Also cleaned up {} session".format(other_id)) - # If there was a session with this peer, also - # close the connection to reset its state. - if other_id in peers: - print("Closing connection to {}".format(other_id)) - wso, oaddr, _ = peers[other_id] - del peers[other_id] - await wso.close() - -async def cleanup_room(uid, room_id): - room_peers = rooms[room_id] - if uid not in room_peers: - return - room_peers.remove(uid) - for pid in room_peers: - wsp, paddr, _ = peers[pid] - msg = 'ROOM_PEER_LEFT {}'.format(uid) - print('room {}: {} -> {}: {}'.format(room_id, uid, pid, msg)) - await wsp.send(msg) - -async def remove_peer(uid): - await cleanup_session(uid) - if uid in peers: - ws, raddr, status = peers[uid] - if status and status != 'session': - await cleanup_room(uid, status) - del peers[uid] - await ws.close() - print("Disconnected from peer {!r} at {!r}".format(uid, raddr)) - -############### Handler functions ############### - -async def connection_handler(ws, uid): - global peers, sessions, rooms - raddr = ws.remote_address - peer_status = None - peers[uid] = [ws, raddr, peer_status] - print("Registered peer {!r} at {!r}".format(uid, raddr)) - while True: - # Receive command, wait forever if necessary - msg = await recv_msg_ping(ws, raddr) - # Update current status - peer_status = peers[uid][2] - # We are in a session or a room, messages must be relayed - if peer_status is not None: - # We're in a session, route message to connected peer - if peer_status == 'session': - other_id = sessions[uid] - wso, oaddr, status = peers[other_id] - assert(status == 'session') - print("{} -> {}: {}".format(uid, other_id, msg)) - await wso.send(msg) - # We're in a room, accept room-specific commands - elif peer_status: - # ROOM_PEER_MSG peer_id MSG - if msg.startswith('ROOM_PEER_MSG'): - _, other_id, msg = msg.split(maxsplit=2) - if other_id not in peers: - await ws.send('ERROR peer {!r} not found' - ''.format(other_id)) - continue - wso, oaddr, status = peers[other_id] - if status != room_id: - await ws.send('ERROR peer {!r} is not in the room' - ''.format(other_id)) - continue - msg = 'ROOM_PEER_MSG {} {}'.format(uid, msg) - print('room {}: {} -> {}: {}'.format(room_id, uid, other_id, msg)) - await wso.send(msg) - elif msg == 'ROOM_PEER_LIST': - room_id = peers[peer_id][2] - room_peers = ' '.join([pid for pid in rooms[room_id] if pid != peer_id]) - msg = 'ROOM_PEER_LIST {}'.format(room_peers) - print('room {}: -> {}: {}'.format(room_id, uid, msg)) - await ws.send(msg) - else: - await ws.send('ERROR invalid msg, already in room') - continue - else: - raise AssertionError('Unknown peer status {!r}'.format(peer_status)) - # Requested a session with a specific peer - elif msg.startswith('SESSION'): - print("{!r} command {!r}".format(uid, msg)) - _, callee_id = msg.split(maxsplit=1) - if callee_id not in peers: - await ws.send('ERROR peer {!r} not found'.format(callee_id)) - continue - if peer_status is not None: - await ws.send('ERROR peer {!r} busy'.format(callee_id)) - continue - await ws.send('SESSION_OK') - wsc = peers[callee_id][0] - print('Session from {!r} ({!r}) to {!r} ({!r})' - ''.format(uid, raddr, callee_id, wsc.remote_address)) - # Register session - peers[uid][2] = peer_status = 'session' - sessions[uid] = callee_id - peers[callee_id][2] = 'session' - sessions[callee_id] = uid - # Requested joining or creation of a room - elif msg.startswith('ROOM'): - print('{!r} command {!r}'.format(uid, msg)) - _, room_id = msg.split(maxsplit=1) - # Room name cannot be 'session', empty, or contain whitespace - if room_id == 'session' or room_id.split() != [room_id]: - await ws.send('ERROR invalid room id {!r}'.format(room_id)) - continue - if room_id in rooms: - if uid in rooms[room_id]: - raise AssertionError('How did we accept a ROOM command ' - 'despite already being in a room?') - else: - # Create room if required - rooms[room_id] = set() - room_peers = ' '.join([pid for pid in rooms[room_id]]) - await ws.send('ROOM_OK {}'.format(room_peers)) - # Enter room - peers[uid][2] = peer_status = room_id - rooms[room_id].add(uid) - for pid in rooms[room_id]: - if pid == uid: - continue - wsp, paddr, _ = peers[pid] - msg = 'ROOM_PEER_JOINED {}'.format(uid) - print('room {}: {} -> {}: {}'.format(room_id, uid, pid, msg)) - await wsp.send(msg) - else: - print('Ignoring unknown message {!r} from {!r}'.format(msg, uid)) - -async def hello_peer(ws): - ''' - Exchange hello, register peer - ''' - raddr = ws.remote_address - hello = await ws.recv() - hello, uid = hello.split(maxsplit=1) - if hello != 'HELLO': - await ws.close(code=1002, reason='invalid protocol') - raise Exception("Invalid hello from {!r}".format(raddr)) - if not uid or uid in peers or uid.split() != [uid]: # no whitespace - await ws.close(code=1002, reason='invalid peer uid') - raise Exception("Invalid uid {!r} from {!r}".format(uid, raddr)) - # Send back a HELLO - await ws.send('HELLO') - return uid - -async def handler(ws, path): - ''' - All incoming messages are handled here. @path is unused. - ''' - raddr = ws.remote_address - print("Connected to {!r}".format(raddr)) - peer_id = await hello_peer(ws) - try: - await connection_handler(ws, peer_id) - except websockets.ConnectionClosed: - print("Connection to peer {!r} closed, exiting handler".format(raddr)) - finally: - await remove_peer(peer_id) - -sslctx = None -if not options.disable_ssl: - # Create an SSL context to be used by the websocket server - certpath = options.cert_path - print('Using TLS with keys in {!r}'.format(certpath)) - if 'letsencrypt' in certpath: - chain_pem = os.path.join(certpath, 'fullchain.pem') - key_pem = os.path.join(certpath, 'privkey.pem') - else: - chain_pem = os.path.join(certpath, 'cert.pem') - key_pem = os.path.join(certpath, 'key.pem') - - sslctx = ssl.create_default_context() - try: - sslctx.load_cert_chain(chain_pem, keyfile=key_pem) - except FileNotFoundError: - print("Certificates not found, did you run generate_cert.sh?") - sys.exit(1) - # FIXME - sslctx.check_hostname = False - sslctx.verify_mode = ssl.CERT_NONE - -print("Listening on https://{}:{}".format(*ADDR_PORT)) -# Websocket server -wsd = websockets.serve(handler, *ADDR_PORT, ssl=sslctx, process_request=health_check, - # Maximum number of messages that websockets will pop - # off the asyncio and OS buffers per connection. See: - # https://websockets.readthedocs.io/en/stable/api.html#websockets.protocol.WebSocketCommonProtocol - max_queue=16) - -logger = logging.getLogger('websockets.server') - -logger.setLevel(logging.ERROR) -logger.addHandler(logging.StreamHandler()) - -asyncio.get_event_loop().run_until_complete(wsd) -asyncio.get_event_loop().run_forever() diff --git a/webrtc/signalling/simple_server.py b/webrtc/signalling/simple_server.py new file mode 100755 index 0000000..ead3034 --- /dev/null +++ b/webrtc/signalling/simple_server.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python3 +# +# Example 1-1 call signalling server +# +# Copyright (C) 2017 Centricular Ltd. +# +# Author: Nirbheek Chauhan <nirbheek@centricular.com> +# + +import os +import sys +import ssl +import logging +import asyncio +import websockets +import argparse +import http + +from concurrent.futures._base import TimeoutError + +class WebRTCSimpleServer(object): + + def __init__(self, addr, port, keepalive_timeout, disable_ssl, certpath, health_path=None): + ############### Global data ############### + + # Format: {uid: (Peer WebSocketServerProtocol, + # remote_address, + # <'session'|room_id|None>)} + self.peers = dict() + # Format: {caller_uid: callee_uid, + # callee_uid: caller_uid} + # Bidirectional mapping between the two peers + self.sessions = dict() + # Format: {room_id: {peer1_id, peer2_id, peer3_id, ...}} + # Room dict with a set of peers in each room + self.rooms = dict() + + self.keepalive_timeout = keepalive_timeout + self.addr = addr + self.port = port + self.disable_ssl = disable_ssl + self.certpath = certpath + self.health_path = health_path + + ############### Helper functions ############### + + async def health_check(self, path, request_headers): + if path == self.health_part: + return http.HTTPStatus.OK, [], b"OK\n" + return None + + async def recv_msg_ping(self, ws, raddr): + ''' + Wait for a message forever, and send a regular ping to prevent bad routers + from closing the connection. + ''' + msg = None + while msg is None: + try: + msg = await asyncio.wait_for(ws.recv(), self.keepalive_timeout) + except TimeoutError: + print('Sending keepalive ping to {!r} in recv'.format(raddr)) + await ws.ping() + return msg + + async def cleanup_session(self, uid): + if uid in self.sessions: + other_id = self.sessions[uid] + del self.sessions[uid] + print("Cleaned up {} session".format(uid)) + if other_id in self.sessions: + del self.sessions[other_id] + print("Also cleaned up {} session".format(other_id)) + # If there was a session with this peer, also + # close the connection to reset its state. + if other_id in self.peers: + print("Closing connection to {}".format(other_id)) + wso, oaddr, _ = self.peers[other_id] + del self.peers[other_id] + await wso.close() + + async def cleanup_room(self, uid, room_id): + room_peers = self.rooms[room_id] + if uid not in room_peers: + return + room_peers.remove(uid) + for pid in room_peers: + wsp, paddr, _ = self.peers[pid] + msg = 'ROOM_PEER_LEFT {}'.format(uid) + print('room {}: {} -> {}: {}'.format(room_id, uid, pid, msg)) + await wsp.send(msg) + + async def remove_peer(self, uid): + await self.cleanup_session(uid) + if uid in self.peers: + ws, raddr, status = self.peers[uid] + if status and status != 'session': + await self.cleanup_room(uid, status) + del self.peers[uid] + await ws.close() + print("Disconnected from peer {!r} at {!r}".format(uid, raddr)) + + ############### Handler functions ############### + + + async def connection_handler(self, ws, uid): + raddr = ws.remote_address + peer_status = None + self.peers[uid] = [ws, raddr, peer_status] + print("Registered peer {!r} at {!r}".format(uid, raddr)) + while True: + # Receive command, wait forever if necessary + msg = await self.recv_msg_ping(ws, raddr) + # Update current status + peer_status = self.peers[uid][2] + # We are in a session or a room, messages must be relayed + if peer_status is not None: + # We're in a session, route message to connected peer + if peer_status == 'session': + other_id = self.sessions[uid] + wso, oaddr, status = self.peers[other_id] + assert(status == 'session') + print("{} -> {}: {}".format(uid, other_id, msg)) + await wso.send(msg) + # We're in a room, accept room-specific commands + elif peer_status: + # ROOM_PEER_MSG peer_id MSG + if msg.startswith('ROOM_PEER_MSG'): + _, other_id, msg = msg.split(maxsplit=2) + if other_id not in self.peers: + await ws.send('ERROR peer {!r} not found' + ''.format(other_id)) + continue + wso, oaddr, status = self.peers[other_id] + if status != room_id: + await ws.send('ERROR peer {!r} is not in the room' + ''.format(other_id)) + continue + msg = 'ROOM_PEER_MSG {} {}'.format(uid, msg) + print('room {}: {} -> {}: {}'.format(room_id, uid, other_id, msg)) + await wso.send(msg) + elif msg == 'ROOM_PEER_LIST': + room_id = self.peers[peer_id][2] + room_peers = ' '.join([pid for pid in self.rooms[room_id] if pid != peer_id]) + msg = 'ROOM_PEER_LIST {}'.format(room_peers) + print('room {}: -> {}: {}'.format(room_id, uid, msg)) + await ws.send(msg) + else: + await ws.send('ERROR invalid msg, already in room') + continue + else: + raise AssertionError('Unknown peer status {!r}'.format(peer_status)) + # Requested a session with a specific peer + elif msg.startswith('SESSION'): + print("{!r} command {!r}".format(uid, msg)) + _, callee_id = msg.split(maxsplit=1) + if callee_id not in self.peers: + await ws.send('ERROR peer {!r} not found'.format(callee_id)) + continue + if peer_status is not None: + await ws.send('ERROR peer {!r} busy'.format(callee_id)) + continue + await ws.send('SESSION_OK') + wsc = self.peers[callee_id][0] + print('Session from {!r} ({!r}) to {!r} ({!r})' + ''.format(uid, raddr, callee_id, wsc.remote_address)) + # Register session + self.peers[uid][2] = peer_status = 'session' + self.sessions[uid] = callee_id + self.peers[callee_id][2] = 'session' + self.sessions[callee_id] = uid + # Requested joining or creation of a room + elif msg.startswith('ROOM'): + print('{!r} command {!r}'.format(uid, msg)) + _, room_id = msg.split(maxsplit=1) + # Room name cannot be 'session', empty, or contain whitespace + if room_id == 'session' or room_id.split() != [room_id]: + await ws.send('ERROR invalid room id {!r}'.format(room_id)) + continue + if room_id in self.rooms: + if uid in self.rooms[room_id]: + raise AssertionError('How did we accept a ROOM command ' + 'despite already being in a room?') + else: + # Create room if required + self.rooms[room_id] = set() + room_peers = ' '.join([pid for pid in self.rooms[room_id]]) + await ws.send('ROOM_OK {}'.format(room_peers)) + # Enter room + self.peers[uid][2] = peer_status = room_id + self.rooms[room_id].add(uid) + for pid in self.rooms[room_id]: + if pid == uid: + continue + wsp, paddr, _ = self.peers[pid] + msg = 'ROOM_PEER_JOINED {}'.format(uid) + print('room {}: {} -> {}: {}'.format(room_id, uid, pid, msg)) + await wsp.send(msg) + else: + print('Ignoring unknown message {!r} from {!r}'.format(msg, uid)) + + async def hello_peer(self, ws): + ''' + Exchange hello, register peer + ''' + raddr = ws.remote_address + hello = await ws.recv() + hello, uid = hello.split(maxsplit=1) + if hello != 'HELLO': + await ws.close(code=1002, reason='invalid protocol') + raise Exception("Invalid hello from {!r}".format(raddr)) + if not uid or uid in self.peers or uid.split() != [uid]: # no whitespace + await ws.close(code=1002, reason='invalid peer uid') + raise Exception("Invalid uid {!r} from {!r}".format(uid, raddr)) + # Send back a HELLO + await ws.send('HELLO') + return uid + + def run(self): + sslctx = None + if not self.disable_ssl: + # Create an SSL context to be used by the websocket server + print('Using TLS with keys in {!r}'.format(self.certpath)) + if 'letsencrypt' in self.certpath: + chain_pem = os.path.join(self.certpath, 'fullchain.pem') + key_pem = os.path.join(self.certpath, 'privkey.pem') + else: + chain_pem = os.path.join(self.certpath, 'cert.pem') + key_pem = os.path.join(self.certpath, 'key.pem') + + sslctx = ssl.create_default_context() + try: + sslctx.load_cert_chain(chain_pem, keyfile=key_pem) + except FileNotFoundError: + print("Certificates not found, did you run generate_cert.sh?") + sys.exit(1) + # FIXME + sslctx.check_hostname = False + sslctx.verify_mode = ssl.CERT_NONE + + async def handler(ws, path): + ''' + All incoming messages are handled here. @path is unused. + ''' + raddr = ws.remote_address + print("Connected to {!r}".format(raddr)) + peer_id = await self.hello_peer(ws) + try: + await self.connection_handler(ws, peer_id) + except websockets.ConnectionClosed: + print("Connection to peer {!r} closed, exiting handler".format(raddr)) + finally: + await self.remove_peer(peer_id) + + print("Listening on https://{}:{}".format(self.addr, self.port)) + # Websocket server + wsd = websockets.serve(handler, self.addr, self.port, ssl=sslctx, process_request=self.health_check if self.health_path else None, + # Maximum number of messages that websockets will pop + # off the asyncio and OS buffers per connection. See: + # https://websockets.readthedocs.io/en/stable/api.html#websockets.protocol.WebSocketCommonProtocol + max_queue=16) + + logger = logging.getLogger('websockets.server') + + logger.setLevel(logging.ERROR) + logger.addHandler(logging.StreamHandler()) + + return wsd + +def main(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + # See: host, port in https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_server + parser.add_argument('--addr', default='', help='Address to listen on (default: all interfaces, both ipv4 and ipv6)') + parser.add_argument('--port', default=8443, type=int, help='Port to listen on') + parser.add_argument('--keepalive-timeout', dest='keepalive_timeout', default=30, type=int, help='Timeout for keepalive (in seconds)') + parser.add_argument('--cert-path', default=os.path.dirname(__file__)) + parser.add_argument('--disable-ssl', default=False, help='Disable ssl', action='store_true') + parser.add_argument('--health', default='/health', help='Health check route') + + options = parser.parse_args(sys.argv[1:]) + + loop = asyncio.get_event_loop() + + r = WebRTCSimpleServer(options.addr, options.port, options.keepalive_timeout, options.disable_ssl, options.cert_path) + + loop.run_until_complete (r.run()) + loop.run_forever () + print ("Goodbye!") + +if __name__ == "__main__": + main() |