diff --git a/public/buzzer.css b/public/buzzer.css
index 041e5ca..8b67fb6 100644
--- a/public/buzzer.css
+++ b/public/buzzer.css
@@ -1,5 +1,16 @@
body {
font-family: sans-serif;
+ overflow: hidden;
+}
+#error {
+ background-color: #fee;
+ border: 3px solid red;
+ padding: 1em;
+ font-family: monospace;
+ display: none;
+}
+#error code {
+ white-space: pre;
}
input {
width: 20em;
diff --git a/public/buzzer.html b/public/buzzer.html
index 110b12e..476d51c 100644
--- a/public/buzzer.html
+++ b/public/buzzer.html
@@ -1,6 +1,10 @@
+
+ Error:
+
+
All Players
-
+
BZZZZZ!
@@ -20,4 +26,8 @@
+
+ player name
+
+
diff --git a/public/buzzer.js b/public/buzzer.js
index c3593c4..be7d4ec 100644
--- a/public/buzzer.js
+++ b/public/buzzer.js
@@ -3,6 +3,7 @@
/* global document, window */
const crypto = window.crypto
const location = document.location
+const performance = window.performance
const storage = window.sessionStorage
// TODOs
@@ -18,17 +19,16 @@ const buzzer_key = {
const q = (selector, root) => (root || document).querySelector(selector)
const on = (event, cb) => document.addEventListener(event, cb)
-function node(type, { appendTo, cls, text, data, ...attrs } = {}) {
- let elem = document.createElement(type)
+function node(type, { appendTo, cls, text, data, style, ...attrs } = {}) {
+ let elem = typeof type === "string" ? document.createElement(type) : type
if (cls) {
elem.className = cls
}
if (text) {
elem.textContent = text
}
- for (const name in data ?? {}) {
- elem.dataset[name] = data[name]
- }
+ Object.assign(elem.dataset, data ?? {})
+ Object.assign(elem.style, style ?? {})
for (const name in attrs) {
elem.setAttribute(name, attrs[name])
}
@@ -53,11 +53,13 @@ let socket,
session_key
function hide(e) {
- q(`#${e}`).style.display = "none"
+ e = typeof e === "string" ? q(`#${e}`) : e
+ e.style.display = "none"
}
function show(e) {
- q(`#${e}`).style.display = "block"
+ e = typeof e === "string" ? q(`#${e}`) : e
+ e.style.display = "block"
}
function session_id() {
@@ -97,8 +99,9 @@ function redraw_clients(me, clients) {
return
}
clear(ul)
+ const player_tpl = q("template#player").content.firstElementChild
for (const c of clients) {
- node("li", {
+ node(player_tpl.cloneNode(), {
text: c.name,
data: { cid: c.id },
appendTo: ul,
@@ -180,18 +183,41 @@ function setup_ui() {
}
})
- q("#username").addEventListener("change", (event) => {
- send("name", event.target.value)
+ const username_el = q("#username")
+ if (storage["my_name"]) {
+ username_el.value = storage["my_name"]
+ }
+ username_el.addEventListener("change", (event) => {
+ set_name(event.target.value)
})
ul = q("#info ul")
}
+function set_name(name) {
+ storage["my_name"] = name
+ send("name", name)
+}
+
+function session_url(sid) {
+ return `${config.wsurl}/${sid}`
+}
+
+function session_id_from_url(url) {
+ const wsurl = new URL(config.wsurl)
+ const match = RegExp(`${wsurl.pathname}/([^/]+)$`).exec(url)
+ return !match ? null : match[1]
+}
+
function setup_ws() {
const sid = session_id()
- socket = new WebSocket(`${config.wsurl}/quiz/${sid}`)
+ const credentials = { id: storage["my_uid"], key: storage["my_key"] }
+ socket = new WebSocket(`${session_url(sid)}`)
socket.addEventListener("open", function (event) {
- send("name", q("#username").value)
+ if (sid === storage["my_sid"]) {
+ send("login", credentials)
+ }
+ set_name(q("#username").value)
})
socket.addEventListener("message", function (event) {
const msg = JSON.parse(event.data)
@@ -199,9 +225,10 @@ function setup_ws() {
servertime = msg.value
toffset_ms = performance.now()
} else if (msg.type === "id") {
- me = { id: msg.id, key: msg.key }
- storage["my_id"] = me.id
+ me = { id: msg.id, key: msg.key, path: msg.path }
+ storage["my_uid"] = me.id
storage["my_key"] = me.key
+ storage["my_sid"] = session_id_from_url(me.path)
redraw_clients(me, clients)
} else if (msg.type === "session_key") {
session_key = { path: msg.path, key: msg.key }
@@ -217,7 +244,7 @@ function setup_ws() {
clients = msg.value
redraw_clients(me, clients)
} else if (msg.type === "client") {
- const client = msg.value
+ const client = { name: msg.name, id: msg.id, active: msg.active }
for (const c of clients) {
if (c.id === client.id) {
c.name = client.name
@@ -227,6 +254,11 @@ function setup_ws() {
}
clients.push(client)
redraw_clients(me, clients)
+ } else if (msg.type === "error") {
+ console.error(`Error: ${msg.reason}`)
+ const errorbox = q("#error")
+ q("code", errorbox).textContent = JSON.stringify(msg, null, 2)
+ show(errorbox)
} else {
console.error(`Unknown message: ${event.data}`)
}
diff --git a/public/config.js b/public/config.js
index f853b85..ec278ad 100644
--- a/public/config.js
+++ b/public/config.js
@@ -1,3 +1,3 @@
export default {
- wsurl: "wss://quiz.dumpr.org:443",
+ wsurl: "wss://quiz.dumpr.org:443/quiz",
}
diff --git a/pyrightconfig.json b/pyrightconfig.json
new file mode 100644
index 0000000..a9ff6d7
--- /dev/null
+++ b/pyrightconfig.json
@@ -0,0 +1,3 @@
+{
+ "stubPath":"./stubs"
+}
diff --git a/quiz/__main__.py b/quiz/__main__.py
index e2e00eb..f4e9c64 100644
--- a/quiz/__main__.py
+++ b/quiz/__main__.py
@@ -1,19 +1,21 @@
import asyncio
import logging
-from . import config
-from .quiz import server
+from . import config, quiz
+
+log = logging.getLogger(__name__)
def main():
logging.basicConfig(
- format="{asctime},{msecs:03.0f} [{name}:{process}] {levelname}: {message}",
+ format="{asctime} [{name}:{process}] {levelname}: {message}",
style="{",
- datefmt="%Y-%m-%d %H:%M:%a",
level=logging.INFO,
)
- asyncio.get_event_loop().run_until_complete(server(config.ws_host, config.ws_port))
+ server = quiz.server(config.ws_host, config.ws_port)
+ log.info(f"Starting server on ws://{config.ws_host}:{config.ws_port}/")
+ asyncio.get_event_loop().run_until_complete(server)
asyncio.get_event_loop().run_forever()
diff --git a/quiz/config.py b/quiz/config.py
index a2cd872..6692d49 100644
--- a/quiz/config.py
+++ b/quiz/config.py
@@ -1,5 +1,9 @@
import os
+max_clients_per_session = 10
+max_sessions = 10
+client_timeout_s = 600 # Seconds before a client times out
+session_timeout_s = 0
path_prefix = "/quiz/"
ws_host = os.getenv("WS_HOST", "0.0.0.0")
ws_port = int(os.getenv("WS_PORT", 8765))
diff --git a/quiz/quiz.py b/quiz/quiz.py
index 1d6fdde..5180635 100644
--- a/quiz/quiz.py
+++ b/quiz/quiz.py
@@ -3,11 +3,11 @@ import logging
import unicodedata
from dataclasses import dataclass, field
from json import dumps, loads
-from secrets import token_hex
+from secrets import compare_digest, token_hex
from time import perf_counter_ns
from typing import * # pyright: reportWildcardImportFromLibrary=false
-import websockets # type: ignore
+import websockets
from . import config
@@ -25,22 +25,30 @@ def token() -> Token:
@dataclass
class Client:
- ws: Websocket
+ ws: Optional[Websocket]
path: Path
id: UserId = field(default_factory=token)
+ is_reclaimed: bool = False
name: str = ""
points: int = 0
secret: Token = field(default_factory=token)
session: Optional["Session"] = None
def __str__(self):
- return f"{ascii(self.id)[1:-1]}:{ascii(self.name)[1:-1]}"
+ s = f"{ascii(self.id)[1:-1]}:{ascii(self.name)[1:-1]}"
+ if self.is_reclaimed:
+ return "" + s
+ return s
@property
def is_admin(self):
return self.session is not None and self.session.admin == self.id
- def others(self, include_self=True) -> Iterable["Client"]:
+ @property
+ def is_connected(self):
+ return self.ws and self.ws.open
+
+ def session_clients(self, include_self=True) -> Iterable["Client"]:
if self.session is None:
return []
clients = self.session.clients.values()
@@ -48,6 +56,59 @@ class Client:
return clients
return [c for c in clients if c.id != self.id]
+ async def send(self, type_: str, **args):
+ if self.ws is not None:
+ await self.ws.send(msg(type_, **args))
+
+ @property
+ async def messages(self) -> AsyncIterable[str]:
+ if self.ws is None:
+ return
+ async for message in self.ws:
+ yield str(message)
+
+ @property
+ def is_active(self):
+ return self.ws is not None
+
+ @is_active.setter
+ def is_active(self, x: bool):
+ assert x is False
+ self.ws = None
+
+ @property
+ def info(self):
+ return {
+ "name": self.name or "",
+ "id": self.id,
+ "active": self.is_active,
+ }
+
+ def _reclaim(self, other: "Client"):
+ """Take the place of another.
+
+ This allows a Client to inherit the state of another Client. It
+ is used to allow a user to resume their previous connection,
+ basically a log-in to an existing Client."""
+ # The only thing we truly need to keep is our websocket, because
+ # all other code assumes a Client's websocket won't ever change.
+ assert (
+ not other.is_reclaimed
+ and not other.is_active
+ and self.is_active
+ and self.path == other.path
+ and self.session is other.session
+ )
+
+ # Load all relevant info from other.
+ self.id = other.id
+ self.name = other.name
+ self.points = other.points
+ self.secret = other.secret
+
+ # Invalidate other.
+ other.is_reclaimed = True
+
@dataclass
class Session:
@@ -63,8 +124,13 @@ class Session:
def get(cls, client: Client) -> "Session":
is_new = client.path not in cls.sessions
if is_new:
+ if len(cls.sessions) >= config.max_sessions:
+ raise LoginError("Too many sessions.")
+ log.info("[%s] new session: %a", client, client.path)
cls.sessions[client.path] = Session(client.path, client.id)
s = cls.sessions[client.path]
+ if len(s.clients) >= config.max_clients_per_session:
+ raise LoginError("Too many clients.")
s.clients[client.id] = client
client.session = s # Note: This creates a ref cycle.
return s
@@ -85,24 +151,47 @@ class Session:
)
del self.clients[client.id]
+ def reclaim(self, client: Client, inactive: Client):
+ """Replace an inactive client with another in the session.
+
+ This will transfer all state (including points and the
+ client's ID) from the inactive client to the claiming one.
+ """
+ assert client.session is inactive.session is self
+ del self.clients[client.id] # Remove the successor from the client pool.
+ client._reclaim(inactive) # The successor now owns the ID & other state.
+ self.clients[client.id] = client # The ref to the inactive client is replaced.
+
+ def __str__(self):
+ return f"Session(path={self.path!a}, admin={self.admin!a})"
+
+
+class LoginError(RuntimeError):
+ pass
+
def msg(type_: str, **args):
return dumps({"type": type_, **args})
-async def send_time(client: Client):
- await client.ws.send(msg("time", value=perf_counter_ns()))
+async def send_time(target: Client):
+ """Send the current server time to the target."""
+ await target.send("time", value=perf_counter_ns())
-async def send_buzz(target, client: Client, time):
- await target.ws.send(msg("buzz", client=client.id, time=time))
+async def send_buzz(target: Client, client: Client, time: int):
+ """Send target the timestamp of a buzz registered for the client."""
+ await target.send("buzz", client=client.id, time=time)
-async def send_clients(client: Client):
- if not client.session:
- return
- clients = [{"name": c.name or "", "id": c.id} for c in client.others()]
- await client.ws.send(msg("clients", value=clients))
+async def send_clients(target: Client):
+ """Send info about all connected clients of a session to the target."""
+ await target.send("clients", value=[c.info for c in target.session_clients()])
+
+
+async def send_client(target: Client, client: Client):
+ """Send info about the client to the target."""
+ await target.send("client", **client.info)
async def wait(coros, **kwds):
@@ -113,47 +202,48 @@ async def wait(coros, **kwds):
return await asyncio.wait(tasks, **kwds)
-async def broadcast_client(client: Client):
- if not client.session:
+async def broadcast_client(target: Client):
+ """Send info about the target to all clients of its session."""
+ await wait(send_client(c, target) for c in target.session_clients())
+
+
+async def broadcast_clients(session: Session):
+ """Send info about all clients of a session to all its clients.
+
+ The result is a full refresh/sync for all clients of the session.
+ Since it sends a full dataset of everything to everyone, the traffic
+ amount could be quite large, so use sparingly and with care."""
+ await wait(send_clients(c) for c in session.clients.values())
+
+
+async def broadcast_buzz(client: Client, time: int):
+ """Send all clients of a session the timestamp of a buzz registered
+ for the client."""
+ await wait(send_buzz(c, client, time) for c in client.session_clients())
+
+
+async def send_credentials(target: Client):
+ """Send their user credentials to a client."""
+ await target.send("id", id=target.id, key=target.secret, path=target.path)
+
+
+async def send_keys_to_the_city(target: Client):
+ """Send the session admin credentials to a client."""
+ if not target.session:
return
- m = msg("client", value={"name": client.name or "", "id": client.id})
- await wait(c.ws.send(m) for c in client.others())
+ await target.send("session_key", path=target.path, key=target.session.secret)
-async def broadcast_clients(client: Client):
- if not client.session:
- return
- await wait(send_clients(c) for c in client.others())
-
-
-async def broadcast_buzz(client: Client, time):
- if not client.session:
- return
- await wait(send_buzz(c, client, time) for c in client.others())
-
-
-async def send_credentials(client: Client):
- await client.ws.send(msg("id", id=client.id, key=client.secret))
-
-
-async def send_keys_to_the_city(client: Client):
- if not client.session:
- return
- await client.ws.send(
- msg("session_key", path=client.path, key=client.session.secret)
- )
-
-
-async def send_hello(client: Client):
- msgs = [send_time(client), send_credentials(client), send_clients(client)]
- if client.is_admin:
- msgs.append(send_keys_to_the_city(client))
+async def send_hello(target: Client):
+ msgs = [send_time(target), send_credentials(target), send_clients(target)]
+ if target.is_admin:
+ msgs.append(send_keys_to_the_city(target))
await wait(msgs)
-async def send_heartbeat(client: Client):
+async def send_heartbeat(target: Client):
await asyncio.sleep(5.0)
- await send_time(client)
+ await send_time(target)
def printable(s: str) -> str:
@@ -162,7 +252,7 @@ def printable(s: str) -> str:
async def handle_messages(client: Client):
- async for message in client.ws:
+ async for message in client.messages:
log.debug("[%s] got a message: %a", client, message)
mdata = loads(message)
if mdata["type"] == "buzz":
@@ -175,12 +265,36 @@ async def handle_messages(client: Client):
log.info("[%s] new name: %a", client, name)
client.name = name
await broadcast_client(client)
+ elif mdata["type"] == "login":
+ assert client.session
+ target_id = UserId(mdata["value"]["id"])
+ existent = client.session.clients.get(target_id)
+ if existent is None:
+ log.info(
+ "[%s] tried to log in as non-existent user: %a", client, target_id
+ )
+ return
+ if existent.is_active:
+ log.info("[%s] cannot log in as active user: %a", client, target_id)
+ return
+ if existent.is_reclaimed:
+ log.info("[%s] client already reclaimed: %a", client, target_id)
+ return
+ if not compare_digest(mdata["value"]["key"], existent.secret):
+ log.info(
+ "[%s] failed to log in as existing user: %a", client, target_id
+ )
+ return
+ log.info("[%s] logging in as existent user: %s", client, existent)
+ client.session.reclaim(client, inactive=existent)
+ await send_hello(client)
+ await broadcast_clients(client.session)
else:
log.error("[%s] received borked message", client)
async def juggle(client: Client):
- while client.ws.open:
+ while client.is_connected:
done, pending = await wait(
[send_heartbeat(client), handle_messages(client)],
return_when=asyncio.FIRST_COMPLETED,
@@ -198,22 +312,36 @@ async def connected(ws: Websocket, path: str):
await ws.close()
return
+ path = printable(path)
+
client = Client(ws, path)
log.info("[%s] new client on %a", client, path)
- session = Session.get(client)
- if len(session.clients) == 1:
- log.info("[%s] new session on %a", client, path)
+ try:
+ session = Session.get(client)
+ except LoginError as err:
+ log.error("[%s] Error logging in: %s", client, err)
+ await client.send("error", reason=str(err))
+ await ws.close()
+ return
try:
await send_hello(client)
await juggle(client)
finally:
log.info("[%s] client disconnected", client)
- session.remove(client)
- await broadcast_clients(client)
+ client.is_active = False
+ await broadcast_clients(session)
+ await asyncio.sleep(
+ config.client_timeout_s
+ ) # Give the user the opportunity to log-in again as the same client.
+ if not client.is_reclaimed:
+ session.remove(client)
+ await broadcast_clients(session)
+ log.info("[%s] client gone", client)
if not session.is_alive:
session.destroy()
+ log.info("[%s] session gone: %s", client, session)
def server(host: str, port: int):
diff --git a/scripts/lint b/scripts/lint
index 742eb64..17ef541 100755
--- a/scripts/lint
+++ b/scripts/lint
@@ -5,3 +5,4 @@ black "$RUN_DIR"
isort --profile black "$RUN_DIR"
prettier --write "$RUN_DIR"/public
shellcheck "$RUN_DIR"/scripts/*
+MYPYPATH="$RUN_DIR"/stubs mypy "$RUN_DIR"/quiz
diff --git a/stubs/websockets.pyi b/stubs/websockets.pyi
new file mode 100644
index 0000000..3451014
--- /dev/null
+++ b/stubs/websockets.pyi
@@ -0,0 +1,108 @@
+import asyncio
+import http
+from typing import *
+
+# typing (but mypy doesn't know it ...)
+class TracebackType: ...
+
+# websockets/typing.py
+Data = Union[str, bytes]
+Origin = NewType("Origin", str)
+Origin.__doc__ = """Value of a Origin header"""
+Subprotocol = NewType("Subprotocol", str)
+Subprotocol.__doc__ = """Subprotocol value in a Sec-WebSocket-Protocol header"""
+
+# websockets/legacy/protocol.py
+class WebSocketCommonProtocol(asyncio.Protocol):
+ @property
+ def host(self) -> Optional[str]: ...
+ @property
+ def port(self) -> Optional[int]: ...
+ @property
+ def secure(self) -> Optional[bool]: ...
+ @property
+ def open(self) -> bool: ...
+ @property
+ def closed(self) -> bool: ...
+ async def wait_closed(self) -> None: ...
+ async def __aiter__(self) -> AsyncIterator[Data]: ...
+ async def recv(self) -> Data: ...
+ async def send(
+ self, message: Union[Data, Iterable[Data], AsyncIterable[Data]]
+ ) -> None: ...
+ async def close(self, code: int = 1000, reason: str = "") -> None: ...
+ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: ...
+
+# websockets/extensions/base.py
+class ServerExtensionFactory: ...
+
+# websockets/datastructures.py
+class Headers(MutableMapping[str, str]):
+ def __iter__(self) -> Iterator[str]: ...
+ def __len__(self) -> int: ...
+ def __getitem__(self, key: str) -> str: ...
+ def __setitem__(self, key: str, value: str) -> None: ...
+ def __delitem__(self, key: str) -> None: ...
+
+HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]]
+
+# websockets/legacy/server.py
+HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]]
+HTTPResponse = Tuple[http.HTTPStatus, HeadersLike, bytes]
+
+class WebSocketServerProtocol(WebSocketCommonProtocol):
+ async def process_request(
+ self, path: str, request_headers: Headers
+ ) -> Optional[HTTPResponse]: ...
+
+class WebSocketServer:
+ def wrap(self, server: asyncio.AbstractServer) -> None: ...
+ def close(self) -> None: ...
+
+class Serve:
+ def __init__(
+ self,
+ ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]],
+ host: Optional[Union[str, Sequence[str]]] = None,
+ port: Optional[int] = None,
+ *,
+ path: Optional[str] = None,
+ create_protocol: Optional[Callable[[Any], WebSocketServerProtocol]] = None,
+ ping_interval: Optional[float] = 20,
+ ping_timeout: Optional[float] = 20,
+ close_timeout: Optional[float] = None,
+ max_size: Optional[int] = 2 ** 20,
+ max_queue: Optional[int] = 2 ** 5,
+ read_limit: int = 2 ** 16,
+ write_limit: int = 2 ** 16,
+ loop: Optional[asyncio.AbstractEventLoop] = None,
+ legacy_recv: bool = False,
+ klass: Optional[Type[WebSocketServerProtocol]] = None,
+ timeout: Optional[float] = None,
+ compression: Optional[str] = "deflate",
+ origins: Optional[Sequence[Optional[Origin]]] = None,
+ extensions: Optional[Sequence[ServerExtensionFactory]] = None,
+ subprotocols: Optional[Sequence[Subprotocol]] = None,
+ extra_headers: Optional[HeadersLikeOrCallable] = None,
+ process_request: Optional[
+ Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]]
+ ] = None,
+ select_subprotocol: Optional[
+ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol]
+ ] = None,
+ unix: bool = False,
+ **kwargs: Any,
+ ) -> None: ...
+ async def __aenter__(self) -> WebSocketServer: ...
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> None: ...
+ def __await__(self) -> Generator[Any, None, WebSocketServer]: ...
+ __iter__ = __await__
+
+serve = Serve
+
+WebSocketServerProtocol