add user session reclaiming

In its current state the implementation should allow a user to resume
their session if the websocket connection is reset, for whatever reason.
This could be expanded to allow session sharing (multiple agents logging
in to the same client), or manual session resume via some sort of
password (encode uid & key to some pass-phrase kinda thing, or QR code).
This commit is contained in:
ducklet 2021-01-31 00:19:35 +01:00
parent 476f3d7a49
commit 4908b1fc6e
10 changed files with 376 additions and 77 deletions

View file

@ -1,5 +1,16 @@
body {
font-family: sans-serif;
overflow: hidden;
}
#error {
background-color: #fee;
border: 3px solid red;
padding: 1em;
font-family: monospace;
display: none;
}
#error code {
white-space: pre;
}
input {
width: 20em;

View file

@ -1,6 +1,10 @@
<meta charset="utf-8" />
<link rel="stylesheet" type="text/css" href="buzzer.css" />
<body>
<div id="error">
Error:
<code></code>
</div>
<div id="info">
<label
>You:
@ -11,7 +15,9 @@
placeholder="Please put your name here ..."
/></label>
<h2>All Players</h2>
<ul></ul>
<ul>
<!-- players will be inserted here -->
</ul>
</div>
<div id="buzzbox">
<p id="active">BZZZZZ!</p>
@ -20,4 +26,8 @@
</div>
</body>
<template id="player">
<li data-cid="client ID">player name</li>
</template>
<script type="module" src="./buzzer.js"></script>

View file

@ -3,6 +3,7 @@
/* global document, window */
const crypto = window.crypto
const location = document.location
const performance = window.performance
const storage = window.sessionStorage
// TODOs
@ -18,17 +19,16 @@ const buzzer_key = {
const q = (selector, root) => (root || document).querySelector(selector)
const on = (event, cb) => document.addEventListener(event, cb)
function node(type, { appendTo, cls, text, data, ...attrs } = {}) {
let elem = document.createElement(type)
function node(type, { appendTo, cls, text, data, style, ...attrs } = {}) {
let elem = typeof type === "string" ? document.createElement(type) : type
if (cls) {
elem.className = cls
}
if (text) {
elem.textContent = text
}
for (const name in data ?? {}) {
elem.dataset[name] = data[name]
}
Object.assign(elem.dataset, data ?? {})
Object.assign(elem.style, style ?? {})
for (const name in attrs) {
elem.setAttribute(name, attrs[name])
}
@ -53,11 +53,13 @@ let socket,
session_key
function hide(e) {
q(`#${e}`).style.display = "none"
e = typeof e === "string" ? q(`#${e}`) : e
e.style.display = "none"
}
function show(e) {
q(`#${e}`).style.display = "block"
e = typeof e === "string" ? q(`#${e}`) : e
e.style.display = "block"
}
function session_id() {
@ -97,8 +99,9 @@ function redraw_clients(me, clients) {
return
}
clear(ul)
const player_tpl = q("template#player").content.firstElementChild
for (const c of clients) {
node("li", {
node(player_tpl.cloneNode(), {
text: c.name,
data: { cid: c.id },
appendTo: ul,
@ -180,18 +183,41 @@ function setup_ui() {
}
})
q("#username").addEventListener("change", (event) => {
send("name", event.target.value)
const username_el = q("#username")
if (storage["my_name"]) {
username_el.value = storage["my_name"]
}
username_el.addEventListener("change", (event) => {
set_name(event.target.value)
})
ul = q("#info ul")
}
function set_name(name) {
storage["my_name"] = name
send("name", name)
}
function session_url(sid) {
return `${config.wsurl}/${sid}`
}
function session_id_from_url(url) {
const wsurl = new URL(config.wsurl)
const match = RegExp(`${wsurl.pathname}/([^/]+)$`).exec(url)
return !match ? null : match[1]
}
function setup_ws() {
const sid = session_id()
socket = new WebSocket(`${config.wsurl}/quiz/${sid}`)
const credentials = { id: storage["my_uid"], key: storage["my_key"] }
socket = new WebSocket(`${session_url(sid)}`)
socket.addEventListener("open", function (event) {
send("name", q("#username").value)
if (sid === storage["my_sid"]) {
send("login", credentials)
}
set_name(q("#username").value)
})
socket.addEventListener("message", function (event) {
const msg = JSON.parse(event.data)
@ -199,9 +225,10 @@ function setup_ws() {
servertime = msg.value
toffset_ms = performance.now()
} else if (msg.type === "id") {
me = { id: msg.id, key: msg.key }
storage["my_id"] = me.id
me = { id: msg.id, key: msg.key, path: msg.path }
storage["my_uid"] = me.id
storage["my_key"] = me.key
storage["my_sid"] = session_id_from_url(me.path)
redraw_clients(me, clients)
} else if (msg.type === "session_key") {
session_key = { path: msg.path, key: msg.key }
@ -217,7 +244,7 @@ function setup_ws() {
clients = msg.value
redraw_clients(me, clients)
} else if (msg.type === "client") {
const client = msg.value
const client = { name: msg.name, id: msg.id, active: msg.active }
for (const c of clients) {
if (c.id === client.id) {
c.name = client.name
@ -227,6 +254,11 @@ function setup_ws() {
}
clients.push(client)
redraw_clients(me, clients)
} else if (msg.type === "error") {
console.error(`Error: ${msg.reason}`)
const errorbox = q("#error")
q("code", errorbox).textContent = JSON.stringify(msg, null, 2)
show(errorbox)
} else {
console.error(`Unknown message: ${event.data}`)
}

View file

@ -1,3 +1,3 @@
export default {
wsurl: "wss://quiz.dumpr.org:443",
wsurl: "wss://quiz.dumpr.org:443/quiz",
}

3
pyrightconfig.json Normal file
View file

@ -0,0 +1,3 @@
{
"stubPath":"./stubs"
}

View file

@ -1,19 +1,21 @@
import asyncio
import logging
from . import config
from .quiz import server
from . import config, quiz
log = logging.getLogger(__name__)
def main():
logging.basicConfig(
format="{asctime},{msecs:03.0f} [{name}:{process}] {levelname}: {message}",
format="{asctime} [{name}:{process}] {levelname}: {message}",
style="{",
datefmt="%Y-%m-%d %H:%M:%a",
level=logging.INFO,
)
asyncio.get_event_loop().run_until_complete(server(config.ws_host, config.ws_port))
server = quiz.server(config.ws_host, config.ws_port)
log.info(f"Starting server on ws://{config.ws_host}:{config.ws_port}/")
asyncio.get_event_loop().run_until_complete(server)
asyncio.get_event_loop().run_forever()

View file

@ -1,5 +1,9 @@
import os
max_clients_per_session = 10
max_sessions = 10
client_timeout_s = 600 # Seconds before a client times out
session_timeout_s = 0
path_prefix = "/quiz/"
ws_host = os.getenv("WS_HOST", "0.0.0.0")
ws_port = int(os.getenv("WS_PORT", 8765))

View file

@ -3,11 +3,11 @@ import logging
import unicodedata
from dataclasses import dataclass, field
from json import dumps, loads
from secrets import token_hex
from secrets import compare_digest, token_hex
from time import perf_counter_ns
from typing import * # pyright: reportWildcardImportFromLibrary=false
import websockets # type: ignore
import websockets
from . import config
@ -25,22 +25,30 @@ def token() -> Token:
@dataclass
class Client:
ws: Websocket
ws: Optional[Websocket]
path: Path
id: UserId = field(default_factory=token)
is_reclaimed: bool = False
name: str = ""
points: int = 0
secret: Token = field(default_factory=token)
session: Optional["Session"] = None
def __str__(self):
return f"{ascii(self.id)[1:-1]}:{ascii(self.name)[1:-1]}"
s = f"{ascii(self.id)[1:-1]}:{ascii(self.name)[1:-1]}"
if self.is_reclaimed:
return "<reclaimed>" + s
return s
@property
def is_admin(self):
return self.session is not None and self.session.admin == self.id
def others(self, include_self=True) -> Iterable["Client"]:
@property
def is_connected(self):
return self.ws and self.ws.open
def session_clients(self, include_self=True) -> Iterable["Client"]:
if self.session is None:
return []
clients = self.session.clients.values()
@ -48,6 +56,59 @@ class Client:
return clients
return [c for c in clients if c.id != self.id]
async def send(self, type_: str, **args):
if self.ws is not None:
await self.ws.send(msg(type_, **args))
@property
async def messages(self) -> AsyncIterable[str]:
if self.ws is None:
return
async for message in self.ws:
yield str(message)
@property
def is_active(self):
return self.ws is not None
@is_active.setter
def is_active(self, x: bool):
assert x is False
self.ws = None
@property
def info(self):
return {
"name": self.name or "<noname>",
"id": self.id,
"active": self.is_active,
}
def _reclaim(self, other: "Client"):
"""Take the place of another.
This allows a Client to inherit the state of another Client. It
is used to allow a user to resume their previous connection,
basically a log-in to an existing Client."""
# The only thing we truly need to keep is our websocket, because
# all other code assumes a Client's websocket won't ever change.
assert (
not other.is_reclaimed
and not other.is_active
and self.is_active
and self.path == other.path
and self.session is other.session
)
# Load all relevant info from other.
self.id = other.id
self.name = other.name
self.points = other.points
self.secret = other.secret
# Invalidate other.
other.is_reclaimed = True
@dataclass
class Session:
@ -63,8 +124,13 @@ class Session:
def get(cls, client: Client) -> "Session":
is_new = client.path not in cls.sessions
if is_new:
if len(cls.sessions) >= config.max_sessions:
raise LoginError("Too many sessions.")
log.info("[%s] new session: %a", client, client.path)
cls.sessions[client.path] = Session(client.path, client.id)
s = cls.sessions[client.path]
if len(s.clients) >= config.max_clients_per_session:
raise LoginError("Too many clients.")
s.clients[client.id] = client
client.session = s # Note: This creates a ref cycle.
return s
@ -85,24 +151,47 @@ class Session:
)
del self.clients[client.id]
def reclaim(self, client: Client, inactive: Client):
"""Replace an inactive client with another in the session.
This will transfer all state (including points and the
client's ID) from the inactive client to the claiming one.
"""
assert client.session is inactive.session is self
del self.clients[client.id] # Remove the successor from the client pool.
client._reclaim(inactive) # The successor now owns the ID & other state.
self.clients[client.id] = client # The ref to the inactive client is replaced.
def __str__(self):
return f"Session(path={self.path!a}, admin={self.admin!a})"
class LoginError(RuntimeError):
pass
def msg(type_: str, **args):
return dumps({"type": type_, **args})
async def send_time(client: Client):
await client.ws.send(msg("time", value=perf_counter_ns()))
async def send_time(target: Client):
"""Send the current server time to the target."""
await target.send("time", value=perf_counter_ns())
async def send_buzz(target, client: Client, time):
await target.ws.send(msg("buzz", client=client.id, time=time))
async def send_buzz(target: Client, client: Client, time: int):
"""Send target the timestamp of a buzz registered for the client."""
await target.send("buzz", client=client.id, time=time)
async def send_clients(client: Client):
if not client.session:
return
clients = [{"name": c.name or "<noname>", "id": c.id} for c in client.others()]
await client.ws.send(msg("clients", value=clients))
async def send_clients(target: Client):
"""Send info about all connected clients of a session to the target."""
await target.send("clients", value=[c.info for c in target.session_clients()])
async def send_client(target: Client, client: Client):
"""Send info about the client to the target."""
await target.send("client", **client.info)
async def wait(coros, **kwds):
@ -113,47 +202,48 @@ async def wait(coros, **kwds):
return await asyncio.wait(tasks, **kwds)
async def broadcast_client(client: Client):
if not client.session:
async def broadcast_client(target: Client):
"""Send info about the target to all clients of its session."""
await wait(send_client(c, target) for c in target.session_clients())
async def broadcast_clients(session: Session):
"""Send info about all clients of a session to all its clients.
The result is a full refresh/sync for all clients of the session.
Since it sends a full dataset of everything to everyone, the traffic
amount could be quite large, so use sparingly and with care."""
await wait(send_clients(c) for c in session.clients.values())
async def broadcast_buzz(client: Client, time: int):
"""Send all clients of a session the timestamp of a buzz registered
for the client."""
await wait(send_buzz(c, client, time) for c in client.session_clients())
async def send_credentials(target: Client):
"""Send their user credentials to a client."""
await target.send("id", id=target.id, key=target.secret, path=target.path)
async def send_keys_to_the_city(target: Client):
"""Send the session admin credentials to a client."""
if not target.session:
return
m = msg("client", value={"name": client.name or "<noname>", "id": client.id})
await wait(c.ws.send(m) for c in client.others())
await target.send("session_key", path=target.path, key=target.session.secret)
async def broadcast_clients(client: Client):
if not client.session:
return
await wait(send_clients(c) for c in client.others())
async def broadcast_buzz(client: Client, time):
if not client.session:
return
await wait(send_buzz(c, client, time) for c in client.others())
async def send_credentials(client: Client):
await client.ws.send(msg("id", id=client.id, key=client.secret))
async def send_keys_to_the_city(client: Client):
if not client.session:
return
await client.ws.send(
msg("session_key", path=client.path, key=client.session.secret)
)
async def send_hello(client: Client):
msgs = [send_time(client), send_credentials(client), send_clients(client)]
if client.is_admin:
msgs.append(send_keys_to_the_city(client))
async def send_hello(target: Client):
msgs = [send_time(target), send_credentials(target), send_clients(target)]
if target.is_admin:
msgs.append(send_keys_to_the_city(target))
await wait(msgs)
async def send_heartbeat(client: Client):
async def send_heartbeat(target: Client):
await asyncio.sleep(5.0)
await send_time(client)
await send_time(target)
def printable(s: str) -> str:
@ -162,7 +252,7 @@ def printable(s: str) -> str:
async def handle_messages(client: Client):
async for message in client.ws:
async for message in client.messages:
log.debug("[%s] got a message: %a", client, message)
mdata = loads(message)
if mdata["type"] == "buzz":
@ -175,12 +265,36 @@ async def handle_messages(client: Client):
log.info("[%s] new name: %a", client, name)
client.name = name
await broadcast_client(client)
elif mdata["type"] == "login":
assert client.session
target_id = UserId(mdata["value"]["id"])
existent = client.session.clients.get(target_id)
if existent is None:
log.info(
"[%s] tried to log in as non-existent user: %a", client, target_id
)
return
if existent.is_active:
log.info("[%s] cannot log in as active user: %a", client, target_id)
return
if existent.is_reclaimed:
log.info("[%s] client already reclaimed: %a", client, target_id)
return
if not compare_digest(mdata["value"]["key"], existent.secret):
log.info(
"[%s] failed to log in as existing user: %a", client, target_id
)
return
log.info("[%s] logging in as existent user: %s", client, existent)
client.session.reclaim(client, inactive=existent)
await send_hello(client)
await broadcast_clients(client.session)
else:
log.error("[%s] received borked message", client)
async def juggle(client: Client):
while client.ws.open:
while client.is_connected:
done, pending = await wait(
[send_heartbeat(client), handle_messages(client)],
return_when=asyncio.FIRST_COMPLETED,
@ -198,22 +312,36 @@ async def connected(ws: Websocket, path: str):
await ws.close()
return
path = printable(path)
client = Client(ws, path)
log.info("[%s] new client on %a", client, path)
try:
session = Session.get(client)
if len(session.clients) == 1:
log.info("[%s] new session on %a", client, path)
except LoginError as err:
log.error("[%s] Error logging in: %s", client, err)
await client.send("error", reason=str(err))
await ws.close()
return
try:
await send_hello(client)
await juggle(client)
finally:
log.info("[%s] client disconnected", client)
client.is_active = False
await broadcast_clients(session)
await asyncio.sleep(
config.client_timeout_s
) # Give the user the opportunity to log-in again as the same client.
if not client.is_reclaimed:
session.remove(client)
await broadcast_clients(client)
await broadcast_clients(session)
log.info("[%s] client gone", client)
if not session.is_alive:
session.destroy()
log.info("[%s] session gone: %s", client, session)
def server(host: str, port: int):

View file

@ -5,3 +5,4 @@ black "$RUN_DIR"
isort --profile black "$RUN_DIR"
prettier --write "$RUN_DIR"/public
shellcheck "$RUN_DIR"/scripts/*
MYPYPATH="$RUN_DIR"/stubs mypy "$RUN_DIR"/quiz

108
stubs/websockets.pyi Normal file
View file

@ -0,0 +1,108 @@
import asyncio
import http
from typing import *
# typing (but mypy doesn't know it ...)
class TracebackType: ...
# websockets/typing.py
Data = Union[str, bytes]
Origin = NewType("Origin", str)
Origin.__doc__ = """Value of a Origin header"""
Subprotocol = NewType("Subprotocol", str)
Subprotocol.__doc__ = """Subprotocol value in a Sec-WebSocket-Protocol header"""
# websockets/legacy/protocol.py
class WebSocketCommonProtocol(asyncio.Protocol):
@property
def host(self) -> Optional[str]: ...
@property
def port(self) -> Optional[int]: ...
@property
def secure(self) -> Optional[bool]: ...
@property
def open(self) -> bool: ...
@property
def closed(self) -> bool: ...
async def wait_closed(self) -> None: ...
async def __aiter__(self) -> AsyncIterator[Data]: ...
async def recv(self) -> Data: ...
async def send(
self, message: Union[Data, Iterable[Data], AsyncIterable[Data]]
) -> None: ...
async def close(self, code: int = 1000, reason: str = "") -> None: ...
async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: ...
# websockets/extensions/base.py
class ServerExtensionFactory: ...
# websockets/datastructures.py
class Headers(MutableMapping[str, str]):
def __iter__(self) -> Iterator[str]: ...
def __len__(self) -> int: ...
def __getitem__(self, key: str) -> str: ...
def __setitem__(self, key: str, value: str) -> None: ...
def __delitem__(self, key: str) -> None: ...
HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]]
# websockets/legacy/server.py
HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]]
HTTPResponse = Tuple[http.HTTPStatus, HeadersLike, bytes]
class WebSocketServerProtocol(WebSocketCommonProtocol):
async def process_request(
self, path: str, request_headers: Headers
) -> Optional[HTTPResponse]: ...
class WebSocketServer:
def wrap(self, server: asyncio.AbstractServer) -> None: ...
def close(self) -> None: ...
class Serve:
def __init__(
self,
ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]],
host: Optional[Union[str, Sequence[str]]] = None,
port: Optional[int] = None,
*,
path: Optional[str] = None,
create_protocol: Optional[Callable[[Any], WebSocketServerProtocol]] = None,
ping_interval: Optional[float] = 20,
ping_timeout: Optional[float] = 20,
close_timeout: Optional[float] = None,
max_size: Optional[int] = 2 ** 20,
max_queue: Optional[int] = 2 ** 5,
read_limit: int = 2 ** 16,
write_limit: int = 2 ** 16,
loop: Optional[asyncio.AbstractEventLoop] = None,
legacy_recv: bool = False,
klass: Optional[Type[WebSocketServerProtocol]] = None,
timeout: Optional[float] = None,
compression: Optional[str] = "deflate",
origins: Optional[Sequence[Optional[Origin]]] = None,
extensions: Optional[Sequence[ServerExtensionFactory]] = None,
subprotocols: Optional[Sequence[Subprotocol]] = None,
extra_headers: Optional[HeadersLikeOrCallable] = None,
process_request: Optional[
Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]]
] = None,
select_subprotocol: Optional[
Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol]
] = None,
unix: bool = False,
**kwargs: Any,
) -> None: ...
async def __aenter__(self) -> WebSocketServer: ...
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None: ...
def __await__(self) -> Generator[Any, None, WebSocketServer]: ...
__iter__ = __await__
serve = Serve
WebSocketServerProtocol