diff --git a/public/buzzer.css b/public/buzzer.css index 041e5ca..8b67fb6 100644 --- a/public/buzzer.css +++ b/public/buzzer.css @@ -1,5 +1,16 @@ body { font-family: sans-serif; + overflow: hidden; +} +#error { + background-color: #fee; + border: 3px solid red; + padding: 1em; + font-family: monospace; + display: none; +} +#error code { + white-space: pre; } input { width: 20em; diff --git a/public/buzzer.html b/public/buzzer.html index 110b12e..476d51c 100644 --- a/public/buzzer.html +++ b/public/buzzer.html @@ -1,6 +1,10 @@ +
+ Error: + +

All Players

- +

BZZZZZ!

@@ -20,4 +26,8 @@
+ + diff --git a/public/buzzer.js b/public/buzzer.js index c3593c4..be7d4ec 100644 --- a/public/buzzer.js +++ b/public/buzzer.js @@ -3,6 +3,7 @@ /* global document, window */ const crypto = window.crypto const location = document.location +const performance = window.performance const storage = window.sessionStorage // TODOs @@ -18,17 +19,16 @@ const buzzer_key = { const q = (selector, root) => (root || document).querySelector(selector) const on = (event, cb) => document.addEventListener(event, cb) -function node(type, { appendTo, cls, text, data, ...attrs } = {}) { - let elem = document.createElement(type) +function node(type, { appendTo, cls, text, data, style, ...attrs } = {}) { + let elem = typeof type === "string" ? document.createElement(type) : type if (cls) { elem.className = cls } if (text) { elem.textContent = text } - for (const name in data ?? {}) { - elem.dataset[name] = data[name] - } + Object.assign(elem.dataset, data ?? {}) + Object.assign(elem.style, style ?? {}) for (const name in attrs) { elem.setAttribute(name, attrs[name]) } @@ -53,11 +53,13 @@ let socket, session_key function hide(e) { - q(`#${e}`).style.display = "none" + e = typeof e === "string" ? q(`#${e}`) : e + e.style.display = "none" } function show(e) { - q(`#${e}`).style.display = "block" + e = typeof e === "string" ? q(`#${e}`) : e + e.style.display = "block" } function session_id() { @@ -97,8 +99,9 @@ function redraw_clients(me, clients) { return } clear(ul) + const player_tpl = q("template#player").content.firstElementChild for (const c of clients) { - node("li", { + node(player_tpl.cloneNode(), { text: c.name, data: { cid: c.id }, appendTo: ul, @@ -180,18 +183,41 @@ function setup_ui() { } }) - q("#username").addEventListener("change", (event) => { - send("name", event.target.value) + const username_el = q("#username") + if (storage["my_name"]) { + username_el.value = storage["my_name"] + } + username_el.addEventListener("change", (event) => { + set_name(event.target.value) }) ul = q("#info ul") } +function set_name(name) { + storage["my_name"] = name + send("name", name) +} + +function session_url(sid) { + return `${config.wsurl}/${sid}` +} + +function session_id_from_url(url) { + const wsurl = new URL(config.wsurl) + const match = RegExp(`${wsurl.pathname}/([^/]+)$`).exec(url) + return !match ? null : match[1] +} + function setup_ws() { const sid = session_id() - socket = new WebSocket(`${config.wsurl}/quiz/${sid}`) + const credentials = { id: storage["my_uid"], key: storage["my_key"] } + socket = new WebSocket(`${session_url(sid)}`) socket.addEventListener("open", function (event) { - send("name", q("#username").value) + if (sid === storage["my_sid"]) { + send("login", credentials) + } + set_name(q("#username").value) }) socket.addEventListener("message", function (event) { const msg = JSON.parse(event.data) @@ -199,9 +225,10 @@ function setup_ws() { servertime = msg.value toffset_ms = performance.now() } else if (msg.type === "id") { - me = { id: msg.id, key: msg.key } - storage["my_id"] = me.id + me = { id: msg.id, key: msg.key, path: msg.path } + storage["my_uid"] = me.id storage["my_key"] = me.key + storage["my_sid"] = session_id_from_url(me.path) redraw_clients(me, clients) } else if (msg.type === "session_key") { session_key = { path: msg.path, key: msg.key } @@ -217,7 +244,7 @@ function setup_ws() { clients = msg.value redraw_clients(me, clients) } else if (msg.type === "client") { - const client = msg.value + const client = { name: msg.name, id: msg.id, active: msg.active } for (const c of clients) { if (c.id === client.id) { c.name = client.name @@ -227,6 +254,11 @@ function setup_ws() { } clients.push(client) redraw_clients(me, clients) + } else if (msg.type === "error") { + console.error(`Error: ${msg.reason}`) + const errorbox = q("#error") + q("code", errorbox).textContent = JSON.stringify(msg, null, 2) + show(errorbox) } else { console.error(`Unknown message: ${event.data}`) } diff --git a/public/config.js b/public/config.js index f853b85..ec278ad 100644 --- a/public/config.js +++ b/public/config.js @@ -1,3 +1,3 @@ export default { - wsurl: "wss://quiz.dumpr.org:443", + wsurl: "wss://quiz.dumpr.org:443/quiz", } diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000..a9ff6d7 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,3 @@ +{ + "stubPath":"./stubs" +} diff --git a/quiz/__main__.py b/quiz/__main__.py index e2e00eb..f4e9c64 100644 --- a/quiz/__main__.py +++ b/quiz/__main__.py @@ -1,19 +1,21 @@ import asyncio import logging -from . import config -from .quiz import server +from . import config, quiz + +log = logging.getLogger(__name__) def main(): logging.basicConfig( - format="{asctime},{msecs:03.0f} [{name}:{process}] {levelname}: {message}", + format="{asctime} [{name}:{process}] {levelname}: {message}", style="{", - datefmt="%Y-%m-%d %H:%M:%a", level=logging.INFO, ) - asyncio.get_event_loop().run_until_complete(server(config.ws_host, config.ws_port)) + server = quiz.server(config.ws_host, config.ws_port) + log.info(f"Starting server on ws://{config.ws_host}:{config.ws_port}/") + asyncio.get_event_loop().run_until_complete(server) asyncio.get_event_loop().run_forever() diff --git a/quiz/config.py b/quiz/config.py index a2cd872..6692d49 100644 --- a/quiz/config.py +++ b/quiz/config.py @@ -1,5 +1,9 @@ import os +max_clients_per_session = 10 +max_sessions = 10 +client_timeout_s = 600 # Seconds before a client times out +session_timeout_s = 0 path_prefix = "/quiz/" ws_host = os.getenv("WS_HOST", "0.0.0.0") ws_port = int(os.getenv("WS_PORT", 8765)) diff --git a/quiz/quiz.py b/quiz/quiz.py index 1d6fdde..5180635 100644 --- a/quiz/quiz.py +++ b/quiz/quiz.py @@ -3,11 +3,11 @@ import logging import unicodedata from dataclasses import dataclass, field from json import dumps, loads -from secrets import token_hex +from secrets import compare_digest, token_hex from time import perf_counter_ns from typing import * # pyright: reportWildcardImportFromLibrary=false -import websockets # type: ignore +import websockets from . import config @@ -25,22 +25,30 @@ def token() -> Token: @dataclass class Client: - ws: Websocket + ws: Optional[Websocket] path: Path id: UserId = field(default_factory=token) + is_reclaimed: bool = False name: str = "" points: int = 0 secret: Token = field(default_factory=token) session: Optional["Session"] = None def __str__(self): - return f"{ascii(self.id)[1:-1]}:{ascii(self.name)[1:-1]}" + s = f"{ascii(self.id)[1:-1]}:{ascii(self.name)[1:-1]}" + if self.is_reclaimed: + return "" + s + return s @property def is_admin(self): return self.session is not None and self.session.admin == self.id - def others(self, include_self=True) -> Iterable["Client"]: + @property + def is_connected(self): + return self.ws and self.ws.open + + def session_clients(self, include_self=True) -> Iterable["Client"]: if self.session is None: return [] clients = self.session.clients.values() @@ -48,6 +56,59 @@ class Client: return clients return [c for c in clients if c.id != self.id] + async def send(self, type_: str, **args): + if self.ws is not None: + await self.ws.send(msg(type_, **args)) + + @property + async def messages(self) -> AsyncIterable[str]: + if self.ws is None: + return + async for message in self.ws: + yield str(message) + + @property + def is_active(self): + return self.ws is not None + + @is_active.setter + def is_active(self, x: bool): + assert x is False + self.ws = None + + @property + def info(self): + return { + "name": self.name or "", + "id": self.id, + "active": self.is_active, + } + + def _reclaim(self, other: "Client"): + """Take the place of another. + + This allows a Client to inherit the state of another Client. It + is used to allow a user to resume their previous connection, + basically a log-in to an existing Client.""" + # The only thing we truly need to keep is our websocket, because + # all other code assumes a Client's websocket won't ever change. + assert ( + not other.is_reclaimed + and not other.is_active + and self.is_active + and self.path == other.path + and self.session is other.session + ) + + # Load all relevant info from other. + self.id = other.id + self.name = other.name + self.points = other.points + self.secret = other.secret + + # Invalidate other. + other.is_reclaimed = True + @dataclass class Session: @@ -63,8 +124,13 @@ class Session: def get(cls, client: Client) -> "Session": is_new = client.path not in cls.sessions if is_new: + if len(cls.sessions) >= config.max_sessions: + raise LoginError("Too many sessions.") + log.info("[%s] new session: %a", client, client.path) cls.sessions[client.path] = Session(client.path, client.id) s = cls.sessions[client.path] + if len(s.clients) >= config.max_clients_per_session: + raise LoginError("Too many clients.") s.clients[client.id] = client client.session = s # Note: This creates a ref cycle. return s @@ -85,24 +151,47 @@ class Session: ) del self.clients[client.id] + def reclaim(self, client: Client, inactive: Client): + """Replace an inactive client with another in the session. + + This will transfer all state (including points and the + client's ID) from the inactive client to the claiming one. + """ + assert client.session is inactive.session is self + del self.clients[client.id] # Remove the successor from the client pool. + client._reclaim(inactive) # The successor now owns the ID & other state. + self.clients[client.id] = client # The ref to the inactive client is replaced. + + def __str__(self): + return f"Session(path={self.path!a}, admin={self.admin!a})" + + +class LoginError(RuntimeError): + pass + def msg(type_: str, **args): return dumps({"type": type_, **args}) -async def send_time(client: Client): - await client.ws.send(msg("time", value=perf_counter_ns())) +async def send_time(target: Client): + """Send the current server time to the target.""" + await target.send("time", value=perf_counter_ns()) -async def send_buzz(target, client: Client, time): - await target.ws.send(msg("buzz", client=client.id, time=time)) +async def send_buzz(target: Client, client: Client, time: int): + """Send target the timestamp of a buzz registered for the client.""" + await target.send("buzz", client=client.id, time=time) -async def send_clients(client: Client): - if not client.session: - return - clients = [{"name": c.name or "", "id": c.id} for c in client.others()] - await client.ws.send(msg("clients", value=clients)) +async def send_clients(target: Client): + """Send info about all connected clients of a session to the target.""" + await target.send("clients", value=[c.info for c in target.session_clients()]) + + +async def send_client(target: Client, client: Client): + """Send info about the client to the target.""" + await target.send("client", **client.info) async def wait(coros, **kwds): @@ -113,47 +202,48 @@ async def wait(coros, **kwds): return await asyncio.wait(tasks, **kwds) -async def broadcast_client(client: Client): - if not client.session: +async def broadcast_client(target: Client): + """Send info about the target to all clients of its session.""" + await wait(send_client(c, target) for c in target.session_clients()) + + +async def broadcast_clients(session: Session): + """Send info about all clients of a session to all its clients. + + The result is a full refresh/sync for all clients of the session. + Since it sends a full dataset of everything to everyone, the traffic + amount could be quite large, so use sparingly and with care.""" + await wait(send_clients(c) for c in session.clients.values()) + + +async def broadcast_buzz(client: Client, time: int): + """Send all clients of a session the timestamp of a buzz registered + for the client.""" + await wait(send_buzz(c, client, time) for c in client.session_clients()) + + +async def send_credentials(target: Client): + """Send their user credentials to a client.""" + await target.send("id", id=target.id, key=target.secret, path=target.path) + + +async def send_keys_to_the_city(target: Client): + """Send the session admin credentials to a client.""" + if not target.session: return - m = msg("client", value={"name": client.name or "", "id": client.id}) - await wait(c.ws.send(m) for c in client.others()) + await target.send("session_key", path=target.path, key=target.session.secret) -async def broadcast_clients(client: Client): - if not client.session: - return - await wait(send_clients(c) for c in client.others()) - - -async def broadcast_buzz(client: Client, time): - if not client.session: - return - await wait(send_buzz(c, client, time) for c in client.others()) - - -async def send_credentials(client: Client): - await client.ws.send(msg("id", id=client.id, key=client.secret)) - - -async def send_keys_to_the_city(client: Client): - if not client.session: - return - await client.ws.send( - msg("session_key", path=client.path, key=client.session.secret) - ) - - -async def send_hello(client: Client): - msgs = [send_time(client), send_credentials(client), send_clients(client)] - if client.is_admin: - msgs.append(send_keys_to_the_city(client)) +async def send_hello(target: Client): + msgs = [send_time(target), send_credentials(target), send_clients(target)] + if target.is_admin: + msgs.append(send_keys_to_the_city(target)) await wait(msgs) -async def send_heartbeat(client: Client): +async def send_heartbeat(target: Client): await asyncio.sleep(5.0) - await send_time(client) + await send_time(target) def printable(s: str) -> str: @@ -162,7 +252,7 @@ def printable(s: str) -> str: async def handle_messages(client: Client): - async for message in client.ws: + async for message in client.messages: log.debug("[%s] got a message: %a", client, message) mdata = loads(message) if mdata["type"] == "buzz": @@ -175,12 +265,36 @@ async def handle_messages(client: Client): log.info("[%s] new name: %a", client, name) client.name = name await broadcast_client(client) + elif mdata["type"] == "login": + assert client.session + target_id = UserId(mdata["value"]["id"]) + existent = client.session.clients.get(target_id) + if existent is None: + log.info( + "[%s] tried to log in as non-existent user: %a", client, target_id + ) + return + if existent.is_active: + log.info("[%s] cannot log in as active user: %a", client, target_id) + return + if existent.is_reclaimed: + log.info("[%s] client already reclaimed: %a", client, target_id) + return + if not compare_digest(mdata["value"]["key"], existent.secret): + log.info( + "[%s] failed to log in as existing user: %a", client, target_id + ) + return + log.info("[%s] logging in as existent user: %s", client, existent) + client.session.reclaim(client, inactive=existent) + await send_hello(client) + await broadcast_clients(client.session) else: log.error("[%s] received borked message", client) async def juggle(client: Client): - while client.ws.open: + while client.is_connected: done, pending = await wait( [send_heartbeat(client), handle_messages(client)], return_when=asyncio.FIRST_COMPLETED, @@ -198,22 +312,36 @@ async def connected(ws: Websocket, path: str): await ws.close() return + path = printable(path) + client = Client(ws, path) log.info("[%s] new client on %a", client, path) - session = Session.get(client) - if len(session.clients) == 1: - log.info("[%s] new session on %a", client, path) + try: + session = Session.get(client) + except LoginError as err: + log.error("[%s] Error logging in: %s", client, err) + await client.send("error", reason=str(err)) + await ws.close() + return try: await send_hello(client) await juggle(client) finally: log.info("[%s] client disconnected", client) - session.remove(client) - await broadcast_clients(client) + client.is_active = False + await broadcast_clients(session) + await asyncio.sleep( + config.client_timeout_s + ) # Give the user the opportunity to log-in again as the same client. + if not client.is_reclaimed: + session.remove(client) + await broadcast_clients(session) + log.info("[%s] client gone", client) if not session.is_alive: session.destroy() + log.info("[%s] session gone: %s", client, session) def server(host: str, port: int): diff --git a/scripts/lint b/scripts/lint index 742eb64..17ef541 100755 --- a/scripts/lint +++ b/scripts/lint @@ -5,3 +5,4 @@ black "$RUN_DIR" isort --profile black "$RUN_DIR" prettier --write "$RUN_DIR"/public shellcheck "$RUN_DIR"/scripts/* +MYPYPATH="$RUN_DIR"/stubs mypy "$RUN_DIR"/quiz diff --git a/stubs/websockets.pyi b/stubs/websockets.pyi new file mode 100644 index 0000000..3451014 --- /dev/null +++ b/stubs/websockets.pyi @@ -0,0 +1,108 @@ +import asyncio +import http +from typing import * + +# typing (but mypy doesn't know it ...) +class TracebackType: ... + +# websockets/typing.py +Data = Union[str, bytes] +Origin = NewType("Origin", str) +Origin.__doc__ = """Value of a Origin header""" +Subprotocol = NewType("Subprotocol", str) +Subprotocol.__doc__ = """Subprotocol value in a Sec-WebSocket-Protocol header""" + +# websockets/legacy/protocol.py +class WebSocketCommonProtocol(asyncio.Protocol): + @property + def host(self) -> Optional[str]: ... + @property + def port(self) -> Optional[int]: ... + @property + def secure(self) -> Optional[bool]: ... + @property + def open(self) -> bool: ... + @property + def closed(self) -> bool: ... + async def wait_closed(self) -> None: ... + async def __aiter__(self) -> AsyncIterator[Data]: ... + async def recv(self) -> Data: ... + async def send( + self, message: Union[Data, Iterable[Data], AsyncIterable[Data]] + ) -> None: ... + async def close(self, code: int = 1000, reason: str = "") -> None: ... + async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: ... + +# websockets/extensions/base.py +class ServerExtensionFactory: ... + +# websockets/datastructures.py +class Headers(MutableMapping[str, str]): + def __iter__(self) -> Iterator[str]: ... + def __len__(self) -> int: ... + def __getitem__(self, key: str) -> str: ... + def __setitem__(self, key: str, value: str) -> None: ... + def __delitem__(self, key: str) -> None: ... + +HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]] + +# websockets/legacy/server.py +HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] +HTTPResponse = Tuple[http.HTTPStatus, HeadersLike, bytes] + +class WebSocketServerProtocol(WebSocketCommonProtocol): + async def process_request( + self, path: str, request_headers: Headers + ) -> Optional[HTTPResponse]: ... + +class WebSocketServer: + def wrap(self, server: asyncio.AbstractServer) -> None: ... + def close(self) -> None: ... + +class Serve: + def __init__( + self, + ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + host: Optional[Union[str, Sequence[str]]] = None, + port: Optional[int] = None, + *, + path: Optional[str] = None, + create_protocol: Optional[Callable[[Any], WebSocketServerProtocol]] = None, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = None, + max_size: Optional[int] = 2 ** 20, + max_queue: Optional[int] = 2 ** 5, + read_limit: int = 2 ** 16, + write_limit: int = 2 ** 16, + loop: Optional[asyncio.AbstractEventLoop] = None, + legacy_recv: bool = False, + klass: Optional[Type[WebSocketServerProtocol]] = None, + timeout: Optional[float] = None, + compression: Optional[str] = "deflate", + origins: Optional[Sequence[Optional[Origin]]] = None, + extensions: Optional[Sequence[ServerExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + extra_headers: Optional[HeadersLikeOrCallable] = None, + process_request: Optional[ + Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] + ] = None, + select_subprotocol: Optional[ + Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] + ] = None, + unix: bool = False, + **kwargs: Any, + ) -> None: ... + async def __aenter__(self) -> WebSocketServer: ... + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: ... + def __await__(self) -> Generator[Any, None, WebSocketServer]: ... + __iter__ = __await__ + +serve = Serve + +WebSocketServerProtocol