add user session reclaiming

In its current state the implementation should allow a user to resume
their session if the websocket connection is reset, for whatever reason.
This could be expanded to allow session sharing (multiple agents logging
in to the same client), or manual session resume via some sort of
password (encode uid & key to some pass-phrase kinda thing, or QR code).
This commit is contained in:
ducklet 2021-01-31 00:19:35 +01:00
parent 476f3d7a49
commit 4908b1fc6e
10 changed files with 376 additions and 77 deletions

View file

@ -1,5 +1,16 @@
body { body {
font-family: sans-serif; font-family: sans-serif;
overflow: hidden;
}
#error {
background-color: #fee;
border: 3px solid red;
padding: 1em;
font-family: monospace;
display: none;
}
#error code {
white-space: pre;
} }
input { input {
width: 20em; width: 20em;

View file

@ -1,6 +1,10 @@
<meta charset="utf-8" /> <meta charset="utf-8" />
<link rel="stylesheet" type="text/css" href="buzzer.css" /> <link rel="stylesheet" type="text/css" href="buzzer.css" />
<body> <body>
<div id="error">
Error:
<code></code>
</div>
<div id="info"> <div id="info">
<label <label
>You: >You:
@ -11,7 +15,9 @@
placeholder="Please put your name here ..." placeholder="Please put your name here ..."
/></label> /></label>
<h2>All Players</h2> <h2>All Players</h2>
<ul></ul> <ul>
<!-- players will be inserted here -->
</ul>
</div> </div>
<div id="buzzbox"> <div id="buzzbox">
<p id="active">BZZZZZ!</p> <p id="active">BZZZZZ!</p>
@ -20,4 +26,8 @@
</div> </div>
</body> </body>
<template id="player">
<li data-cid="client ID">player name</li>
</template>
<script type="module" src="./buzzer.js"></script> <script type="module" src="./buzzer.js"></script>

View file

@ -3,6 +3,7 @@
/* global document, window */ /* global document, window */
const crypto = window.crypto const crypto = window.crypto
const location = document.location const location = document.location
const performance = window.performance
const storage = window.sessionStorage const storage = window.sessionStorage
// TODOs // TODOs
@ -18,17 +19,16 @@ const buzzer_key = {
const q = (selector, root) => (root || document).querySelector(selector) const q = (selector, root) => (root || document).querySelector(selector)
const on = (event, cb) => document.addEventListener(event, cb) const on = (event, cb) => document.addEventListener(event, cb)
function node(type, { appendTo, cls, text, data, ...attrs } = {}) { function node(type, { appendTo, cls, text, data, style, ...attrs } = {}) {
let elem = document.createElement(type) let elem = typeof type === "string" ? document.createElement(type) : type
if (cls) { if (cls) {
elem.className = cls elem.className = cls
} }
if (text) { if (text) {
elem.textContent = text elem.textContent = text
} }
for (const name in data ?? {}) { Object.assign(elem.dataset, data ?? {})
elem.dataset[name] = data[name] Object.assign(elem.style, style ?? {})
}
for (const name in attrs) { for (const name in attrs) {
elem.setAttribute(name, attrs[name]) elem.setAttribute(name, attrs[name])
} }
@ -53,11 +53,13 @@ let socket,
session_key session_key
function hide(e) { function hide(e) {
q(`#${e}`).style.display = "none" e = typeof e === "string" ? q(`#${e}`) : e
e.style.display = "none"
} }
function show(e) { function show(e) {
q(`#${e}`).style.display = "block" e = typeof e === "string" ? q(`#${e}`) : e
e.style.display = "block"
} }
function session_id() { function session_id() {
@ -97,8 +99,9 @@ function redraw_clients(me, clients) {
return return
} }
clear(ul) clear(ul)
const player_tpl = q("template#player").content.firstElementChild
for (const c of clients) { for (const c of clients) {
node("li", { node(player_tpl.cloneNode(), {
text: c.name, text: c.name,
data: { cid: c.id }, data: { cid: c.id },
appendTo: ul, appendTo: ul,
@ -180,18 +183,41 @@ function setup_ui() {
} }
}) })
q("#username").addEventListener("change", (event) => { const username_el = q("#username")
send("name", event.target.value) if (storage["my_name"]) {
username_el.value = storage["my_name"]
}
username_el.addEventListener("change", (event) => {
set_name(event.target.value)
}) })
ul = q("#info ul") ul = q("#info ul")
} }
function set_name(name) {
storage["my_name"] = name
send("name", name)
}
function session_url(sid) {
return `${config.wsurl}/${sid}`
}
function session_id_from_url(url) {
const wsurl = new URL(config.wsurl)
const match = RegExp(`${wsurl.pathname}/([^/]+)$`).exec(url)
return !match ? null : match[1]
}
function setup_ws() { function setup_ws() {
const sid = session_id() const sid = session_id()
socket = new WebSocket(`${config.wsurl}/quiz/${sid}`) const credentials = { id: storage["my_uid"], key: storage["my_key"] }
socket = new WebSocket(`${session_url(sid)}`)
socket.addEventListener("open", function (event) { socket.addEventListener("open", function (event) {
send("name", q("#username").value) if (sid === storage["my_sid"]) {
send("login", credentials)
}
set_name(q("#username").value)
}) })
socket.addEventListener("message", function (event) { socket.addEventListener("message", function (event) {
const msg = JSON.parse(event.data) const msg = JSON.parse(event.data)
@ -199,9 +225,10 @@ function setup_ws() {
servertime = msg.value servertime = msg.value
toffset_ms = performance.now() toffset_ms = performance.now()
} else if (msg.type === "id") { } else if (msg.type === "id") {
me = { id: msg.id, key: msg.key } me = { id: msg.id, key: msg.key, path: msg.path }
storage["my_id"] = me.id storage["my_uid"] = me.id
storage["my_key"] = me.key storage["my_key"] = me.key
storage["my_sid"] = session_id_from_url(me.path)
redraw_clients(me, clients) redraw_clients(me, clients)
} else if (msg.type === "session_key") { } else if (msg.type === "session_key") {
session_key = { path: msg.path, key: msg.key } session_key = { path: msg.path, key: msg.key }
@ -217,7 +244,7 @@ function setup_ws() {
clients = msg.value clients = msg.value
redraw_clients(me, clients) redraw_clients(me, clients)
} else if (msg.type === "client") { } else if (msg.type === "client") {
const client = msg.value const client = { name: msg.name, id: msg.id, active: msg.active }
for (const c of clients) { for (const c of clients) {
if (c.id === client.id) { if (c.id === client.id) {
c.name = client.name c.name = client.name
@ -227,6 +254,11 @@ function setup_ws() {
} }
clients.push(client) clients.push(client)
redraw_clients(me, clients) redraw_clients(me, clients)
} else if (msg.type === "error") {
console.error(`Error: ${msg.reason}`)
const errorbox = q("#error")
q("code", errorbox).textContent = JSON.stringify(msg, null, 2)
show(errorbox)
} else { } else {
console.error(`Unknown message: ${event.data}`) console.error(`Unknown message: ${event.data}`)
} }

View file

@ -1,3 +1,3 @@
export default { export default {
wsurl: "wss://quiz.dumpr.org:443", wsurl: "wss://quiz.dumpr.org:443/quiz",
} }

3
pyrightconfig.json Normal file
View file

@ -0,0 +1,3 @@
{
"stubPath":"./stubs"
}

View file

@ -1,19 +1,21 @@
import asyncio import asyncio
import logging import logging
from . import config from . import config, quiz
from .quiz import server
log = logging.getLogger(__name__)
def main(): def main():
logging.basicConfig( logging.basicConfig(
format="{asctime},{msecs:03.0f} [{name}:{process}] {levelname}: {message}", format="{asctime} [{name}:{process}] {levelname}: {message}",
style="{", style="{",
datefmt="%Y-%m-%d %H:%M:%a",
level=logging.INFO, level=logging.INFO,
) )
asyncio.get_event_loop().run_until_complete(server(config.ws_host, config.ws_port)) server = quiz.server(config.ws_host, config.ws_port)
log.info(f"Starting server on ws://{config.ws_host}:{config.ws_port}/")
asyncio.get_event_loop().run_until_complete(server)
asyncio.get_event_loop().run_forever() asyncio.get_event_loop().run_forever()

View file

@ -1,5 +1,9 @@
import os import os
max_clients_per_session = 10
max_sessions = 10
client_timeout_s = 600 # Seconds before a client times out
session_timeout_s = 0
path_prefix = "/quiz/" path_prefix = "/quiz/"
ws_host = os.getenv("WS_HOST", "0.0.0.0") ws_host = os.getenv("WS_HOST", "0.0.0.0")
ws_port = int(os.getenv("WS_PORT", 8765)) ws_port = int(os.getenv("WS_PORT", 8765))

View file

@ -3,11 +3,11 @@ import logging
import unicodedata import unicodedata
from dataclasses import dataclass, field from dataclasses import dataclass, field
from json import dumps, loads from json import dumps, loads
from secrets import token_hex from secrets import compare_digest, token_hex
from time import perf_counter_ns from time import perf_counter_ns
from typing import * # pyright: reportWildcardImportFromLibrary=false from typing import * # pyright: reportWildcardImportFromLibrary=false
import websockets # type: ignore import websockets
from . import config from . import config
@ -25,22 +25,30 @@ def token() -> Token:
@dataclass @dataclass
class Client: class Client:
ws: Websocket ws: Optional[Websocket]
path: Path path: Path
id: UserId = field(default_factory=token) id: UserId = field(default_factory=token)
is_reclaimed: bool = False
name: str = "" name: str = ""
points: int = 0 points: int = 0
secret: Token = field(default_factory=token) secret: Token = field(default_factory=token)
session: Optional["Session"] = None session: Optional["Session"] = None
def __str__(self): def __str__(self):
return f"{ascii(self.id)[1:-1]}:{ascii(self.name)[1:-1]}" s = f"{ascii(self.id)[1:-1]}:{ascii(self.name)[1:-1]}"
if self.is_reclaimed:
return "<reclaimed>" + s
return s
@property @property
def is_admin(self): def is_admin(self):
return self.session is not None and self.session.admin == self.id return self.session is not None and self.session.admin == self.id
def others(self, include_self=True) -> Iterable["Client"]: @property
def is_connected(self):
return self.ws and self.ws.open
def session_clients(self, include_self=True) -> Iterable["Client"]:
if self.session is None: if self.session is None:
return [] return []
clients = self.session.clients.values() clients = self.session.clients.values()
@ -48,6 +56,59 @@ class Client:
return clients return clients
return [c for c in clients if c.id != self.id] return [c for c in clients if c.id != self.id]
async def send(self, type_: str, **args):
if self.ws is not None:
await self.ws.send(msg(type_, **args))
@property
async def messages(self) -> AsyncIterable[str]:
if self.ws is None:
return
async for message in self.ws:
yield str(message)
@property
def is_active(self):
return self.ws is not None
@is_active.setter
def is_active(self, x: bool):
assert x is False
self.ws = None
@property
def info(self):
return {
"name": self.name or "<noname>",
"id": self.id,
"active": self.is_active,
}
def _reclaim(self, other: "Client"):
"""Take the place of another.
This allows a Client to inherit the state of another Client. It
is used to allow a user to resume their previous connection,
basically a log-in to an existing Client."""
# The only thing we truly need to keep is our websocket, because
# all other code assumes a Client's websocket won't ever change.
assert (
not other.is_reclaimed
and not other.is_active
and self.is_active
and self.path == other.path
and self.session is other.session
)
# Load all relevant info from other.
self.id = other.id
self.name = other.name
self.points = other.points
self.secret = other.secret
# Invalidate other.
other.is_reclaimed = True
@dataclass @dataclass
class Session: class Session:
@ -63,8 +124,13 @@ class Session:
def get(cls, client: Client) -> "Session": def get(cls, client: Client) -> "Session":
is_new = client.path not in cls.sessions is_new = client.path not in cls.sessions
if is_new: if is_new:
if len(cls.sessions) >= config.max_sessions:
raise LoginError("Too many sessions.")
log.info("[%s] new session: %a", client, client.path)
cls.sessions[client.path] = Session(client.path, client.id) cls.sessions[client.path] = Session(client.path, client.id)
s = cls.sessions[client.path] s = cls.sessions[client.path]
if len(s.clients) >= config.max_clients_per_session:
raise LoginError("Too many clients.")
s.clients[client.id] = client s.clients[client.id] = client
client.session = s # Note: This creates a ref cycle. client.session = s # Note: This creates a ref cycle.
return s return s
@ -85,24 +151,47 @@ class Session:
) )
del self.clients[client.id] del self.clients[client.id]
def reclaim(self, client: Client, inactive: Client):
"""Replace an inactive client with another in the session.
This will transfer all state (including points and the
client's ID) from the inactive client to the claiming one.
"""
assert client.session is inactive.session is self
del self.clients[client.id] # Remove the successor from the client pool.
client._reclaim(inactive) # The successor now owns the ID & other state.
self.clients[client.id] = client # The ref to the inactive client is replaced.
def __str__(self):
return f"Session(path={self.path!a}, admin={self.admin!a})"
class LoginError(RuntimeError):
pass
def msg(type_: str, **args): def msg(type_: str, **args):
return dumps({"type": type_, **args}) return dumps({"type": type_, **args})
async def send_time(client: Client): async def send_time(target: Client):
await client.ws.send(msg("time", value=perf_counter_ns())) """Send the current server time to the target."""
await target.send("time", value=perf_counter_ns())
async def send_buzz(target, client: Client, time): async def send_buzz(target: Client, client: Client, time: int):
await target.ws.send(msg("buzz", client=client.id, time=time)) """Send target the timestamp of a buzz registered for the client."""
await target.send("buzz", client=client.id, time=time)
async def send_clients(client: Client): async def send_clients(target: Client):
if not client.session: """Send info about all connected clients of a session to the target."""
return await target.send("clients", value=[c.info for c in target.session_clients()])
clients = [{"name": c.name or "<noname>", "id": c.id} for c in client.others()]
await client.ws.send(msg("clients", value=clients))
async def send_client(target: Client, client: Client):
"""Send info about the client to the target."""
await target.send("client", **client.info)
async def wait(coros, **kwds): async def wait(coros, **kwds):
@ -113,47 +202,48 @@ async def wait(coros, **kwds):
return await asyncio.wait(tasks, **kwds) return await asyncio.wait(tasks, **kwds)
async def broadcast_client(client: Client): async def broadcast_client(target: Client):
if not client.session: """Send info about the target to all clients of its session."""
await wait(send_client(c, target) for c in target.session_clients())
async def broadcast_clients(session: Session):
"""Send info about all clients of a session to all its clients.
The result is a full refresh/sync for all clients of the session.
Since it sends a full dataset of everything to everyone, the traffic
amount could be quite large, so use sparingly and with care."""
await wait(send_clients(c) for c in session.clients.values())
async def broadcast_buzz(client: Client, time: int):
"""Send all clients of a session the timestamp of a buzz registered
for the client."""
await wait(send_buzz(c, client, time) for c in client.session_clients())
async def send_credentials(target: Client):
"""Send their user credentials to a client."""
await target.send("id", id=target.id, key=target.secret, path=target.path)
async def send_keys_to_the_city(target: Client):
"""Send the session admin credentials to a client."""
if not target.session:
return return
m = msg("client", value={"name": client.name or "<noname>", "id": client.id}) await target.send("session_key", path=target.path, key=target.session.secret)
await wait(c.ws.send(m) for c in client.others())
async def broadcast_clients(client: Client): async def send_hello(target: Client):
if not client.session: msgs = [send_time(target), send_credentials(target), send_clients(target)]
return if target.is_admin:
await wait(send_clients(c) for c in client.others()) msgs.append(send_keys_to_the_city(target))
async def broadcast_buzz(client: Client, time):
if not client.session:
return
await wait(send_buzz(c, client, time) for c in client.others())
async def send_credentials(client: Client):
await client.ws.send(msg("id", id=client.id, key=client.secret))
async def send_keys_to_the_city(client: Client):
if not client.session:
return
await client.ws.send(
msg("session_key", path=client.path, key=client.session.secret)
)
async def send_hello(client: Client):
msgs = [send_time(client), send_credentials(client), send_clients(client)]
if client.is_admin:
msgs.append(send_keys_to_the_city(client))
await wait(msgs) await wait(msgs)
async def send_heartbeat(client: Client): async def send_heartbeat(target: Client):
await asyncio.sleep(5.0) await asyncio.sleep(5.0)
await send_time(client) await send_time(target)
def printable(s: str) -> str: def printable(s: str) -> str:
@ -162,7 +252,7 @@ def printable(s: str) -> str:
async def handle_messages(client: Client): async def handle_messages(client: Client):
async for message in client.ws: async for message in client.messages:
log.debug("[%s] got a message: %a", client, message) log.debug("[%s] got a message: %a", client, message)
mdata = loads(message) mdata = loads(message)
if mdata["type"] == "buzz": if mdata["type"] == "buzz":
@ -175,12 +265,36 @@ async def handle_messages(client: Client):
log.info("[%s] new name: %a", client, name) log.info("[%s] new name: %a", client, name)
client.name = name client.name = name
await broadcast_client(client) await broadcast_client(client)
elif mdata["type"] == "login":
assert client.session
target_id = UserId(mdata["value"]["id"])
existent = client.session.clients.get(target_id)
if existent is None:
log.info(
"[%s] tried to log in as non-existent user: %a", client, target_id
)
return
if existent.is_active:
log.info("[%s] cannot log in as active user: %a", client, target_id)
return
if existent.is_reclaimed:
log.info("[%s] client already reclaimed: %a", client, target_id)
return
if not compare_digest(mdata["value"]["key"], existent.secret):
log.info(
"[%s] failed to log in as existing user: %a", client, target_id
)
return
log.info("[%s] logging in as existent user: %s", client, existent)
client.session.reclaim(client, inactive=existent)
await send_hello(client)
await broadcast_clients(client.session)
else: else:
log.error("[%s] received borked message", client) log.error("[%s] received borked message", client)
async def juggle(client: Client): async def juggle(client: Client):
while client.ws.open: while client.is_connected:
done, pending = await wait( done, pending = await wait(
[send_heartbeat(client), handle_messages(client)], [send_heartbeat(client), handle_messages(client)],
return_when=asyncio.FIRST_COMPLETED, return_when=asyncio.FIRST_COMPLETED,
@ -198,22 +312,36 @@ async def connected(ws: Websocket, path: str):
await ws.close() await ws.close()
return return
path = printable(path)
client = Client(ws, path) client = Client(ws, path)
log.info("[%s] new client on %a", client, path) log.info("[%s] new client on %a", client, path)
try:
session = Session.get(client) session = Session.get(client)
if len(session.clients) == 1: except LoginError as err:
log.info("[%s] new session on %a", client, path) log.error("[%s] Error logging in: %s", client, err)
await client.send("error", reason=str(err))
await ws.close()
return
try: try:
await send_hello(client) await send_hello(client)
await juggle(client) await juggle(client)
finally: finally:
log.info("[%s] client disconnected", client) log.info("[%s] client disconnected", client)
client.is_active = False
await broadcast_clients(session)
await asyncio.sleep(
config.client_timeout_s
) # Give the user the opportunity to log-in again as the same client.
if not client.is_reclaimed:
session.remove(client) session.remove(client)
await broadcast_clients(client) await broadcast_clients(session)
log.info("[%s] client gone", client)
if not session.is_alive: if not session.is_alive:
session.destroy() session.destroy()
log.info("[%s] session gone: %s", client, session)
def server(host: str, port: int): def server(host: str, port: int):

View file

@ -5,3 +5,4 @@ black "$RUN_DIR"
isort --profile black "$RUN_DIR" isort --profile black "$RUN_DIR"
prettier --write "$RUN_DIR"/public prettier --write "$RUN_DIR"/public
shellcheck "$RUN_DIR"/scripts/* shellcheck "$RUN_DIR"/scripts/*
MYPYPATH="$RUN_DIR"/stubs mypy "$RUN_DIR"/quiz

108
stubs/websockets.pyi Normal file
View file

@ -0,0 +1,108 @@
import asyncio
import http
from typing import *
# typing (but mypy doesn't know it ...)
class TracebackType: ...
# websockets/typing.py
Data = Union[str, bytes]
Origin = NewType("Origin", str)
Origin.__doc__ = """Value of a Origin header"""
Subprotocol = NewType("Subprotocol", str)
Subprotocol.__doc__ = """Subprotocol value in a Sec-WebSocket-Protocol header"""
# websockets/legacy/protocol.py
class WebSocketCommonProtocol(asyncio.Protocol):
@property
def host(self) -> Optional[str]: ...
@property
def port(self) -> Optional[int]: ...
@property
def secure(self) -> Optional[bool]: ...
@property
def open(self) -> bool: ...
@property
def closed(self) -> bool: ...
async def wait_closed(self) -> None: ...
async def __aiter__(self) -> AsyncIterator[Data]: ...
async def recv(self) -> Data: ...
async def send(
self, message: Union[Data, Iterable[Data], AsyncIterable[Data]]
) -> None: ...
async def close(self, code: int = 1000, reason: str = "") -> None: ...
async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: ...
# websockets/extensions/base.py
class ServerExtensionFactory: ...
# websockets/datastructures.py
class Headers(MutableMapping[str, str]):
def __iter__(self) -> Iterator[str]: ...
def __len__(self) -> int: ...
def __getitem__(self, key: str) -> str: ...
def __setitem__(self, key: str, value: str) -> None: ...
def __delitem__(self, key: str) -> None: ...
HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]]
# websockets/legacy/server.py
HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]]
HTTPResponse = Tuple[http.HTTPStatus, HeadersLike, bytes]
class WebSocketServerProtocol(WebSocketCommonProtocol):
async def process_request(
self, path: str, request_headers: Headers
) -> Optional[HTTPResponse]: ...
class WebSocketServer:
def wrap(self, server: asyncio.AbstractServer) -> None: ...
def close(self) -> None: ...
class Serve:
def __init__(
self,
ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]],
host: Optional[Union[str, Sequence[str]]] = None,
port: Optional[int] = None,
*,
path: Optional[str] = None,
create_protocol: Optional[Callable[[Any], WebSocketServerProtocol]] = None,
ping_interval: Optional[float] = 20,
ping_timeout: Optional[float] = 20,
close_timeout: Optional[float] = None,
max_size: Optional[int] = 2 ** 20,
max_queue: Optional[int] = 2 ** 5,
read_limit: int = 2 ** 16,
write_limit: int = 2 ** 16,
loop: Optional[asyncio.AbstractEventLoop] = None,
legacy_recv: bool = False,
klass: Optional[Type[WebSocketServerProtocol]] = None,
timeout: Optional[float] = None,
compression: Optional[str] = "deflate",
origins: Optional[Sequence[Optional[Origin]]] = None,
extensions: Optional[Sequence[ServerExtensionFactory]] = None,
subprotocols: Optional[Sequence[Subprotocol]] = None,
extra_headers: Optional[HeadersLikeOrCallable] = None,
process_request: Optional[
Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]]
] = None,
select_subprotocol: Optional[
Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol]
] = None,
unix: bool = False,
**kwargs: Any,
) -> None: ...
async def __aenter__(self) -> WebSocketServer: ...
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None: ...
def __await__(self) -> Generator[Any, None, WebSocketServer]: ...
__iter__ = __await__
serve = Serve
WebSocketServerProtocol