quiz/quiz/quiz.py

350 lines
11 KiB
Python

import asyncio
import logging
import unicodedata
from dataclasses import dataclass, field
from http import HTTPStatus
from json import dumps, loads
from secrets import compare_digest, token_hex
from time import perf_counter_ns
from typing import * # pyright: reportWildcardImportFromLibrary=false
import websockets
from . import config
log = logging.getLogger(__name__)
Path = str
Token = NewType("Token", str)
UserId = Token
Websocket = websockets.WebSocketServerProtocol
def token() -> Token:
return Token(token_hex(8))
@dataclass
class Client:
ws: Optional[Websocket]
path: Path
id: UserId = field(default_factory=token)
is_reclaimed: bool = False
name: str = ""
points: int = 0
secret: Token = field(default_factory=token)
session: Optional["Session"] = None
def __str__(self):
s = f"{ascii(self.id)[1:-1]}:{ascii(self.name)[1:-1]}"
if self.is_reclaimed:
return "<reclaimed>" + s
return s
@property
def is_admin(self):
return self.session is not None and self.session.admin == self.id
@property
def is_connected(self):
return self.ws and self.ws.open
def session_clients(self, include_self=True) -> Iterable["Client"]:
if self.session is None:
return []
clients = self.session.clients.values()
if include_self:
return clients
return [c for c in clients if c.id != self.id]
async def send(self, type_: str, **args):
if self.ws is not None:
await self.ws.send(msg(type_, **args))
@property
async def messages(self) -> AsyncIterable[str]:
if self.ws is None:
return
async for message in self.ws:
yield str(message)
@property
def is_active(self):
return self.ws is not None
@is_active.setter
def is_active(self, x: bool):
assert x is False
self.ws = None
@property
def info(self):
return {
"name": self.name or "<noname>",
"id": self.id,
"active": self.is_active,
}
def _reclaim(self, other: "Client"):
"""Take the place of another.
This allows a Client to inherit the state of another Client. It
is used to allow a user to resume their previous connection,
basically a log-in to an existing Client."""
# The only thing we truly need to keep is our websocket, because
# all other code assumes a Client's websocket won't ever change.
assert (
not other.is_reclaimed
and not other.is_active
and self.is_active
and self.path == other.path
and self.session is other.session
)
# Load all relevant info from other.
self.id = other.id
self.name = other.name
self.points = other.points
self.secret = other.secret
# Invalidate other.
other.is_reclaimed = True
@dataclass
class Session:
path: Path
admin: UserId
secret: Token = field(default_factory=token)
clients: dict[UserId, Client] = field(default_factory=dict)
sessions: ClassVar[dict[Path, "Session"]] = {}
@classmethod
def get(cls, client: Client) -> "Session":
is_new = client.path not in cls.sessions
if is_new:
if len(cls.sessions) >= config.max_sessions:
raise LoginError("Too many sessions.")
log.info("[%s] new session: %a", client, client.path)
cls.sessions[client.path] = Session(client.path, client.id)
s = cls.sessions[client.path]
if len(s.clients) >= config.max_clients_per_session:
raise LoginError("Too many clients.")
s.clients[client.id] = client
client.session = s # Note: This creates a ref cycle.
return s
def destroy(self):
for c in list(self.clients.values()):
self.remove(c)
assert len(self.clients) == 0
del Session.sessions[self.path]
@property
def is_alive(self):
return bool(self.clients)
def remove(self, client: Client):
client.session = (
None # Not sure if this is necessary, but it breaks the ref cycle.
)
del self.clients[client.id]
def reclaim(self, client: Client, inactive: Client):
"""Replace an inactive client with another in the session.
This will transfer all state (including points and the
client's ID) from the inactive client to the claiming one.
"""
assert client.session is inactive.session is self
del self.clients[client.id] # Remove the successor from the client pool.
client._reclaim(inactive) # The successor now owns the ID & other state.
self.clients[client.id] = client # The ref to the inactive client is replaced.
def __str__(self):
return f"Session(path={self.path!a}, admin={self.admin!a})"
class LoginError(RuntimeError):
pass
def msg(type_: str, **args):
return dumps({"type": type_, **args})
async def send_time(target: Client):
"""Send the current server time to the target."""
await target.send("time", value=perf_counter_ns())
async def send_buzz(target: Client, client: Client, time: int):
"""Send target the timestamp of a buzz registered for the client."""
await target.send("buzz", client=client.id, time=time)
async def send_clients(target: Client):
"""Send info about all connected clients of a session to the target."""
await target.send("clients", value=[c.info for c in target.session_clients()])
async def send_client(target: Client, client: Client):
"""Send info about the client to the target."""
await target.send("client", **client.info)
async def wait(coros, **kwds):
"""Schedule and wait for the given coroutines to complete."""
tasks = [asyncio.create_task(f) for f in coros]
if not tasks:
return
return await asyncio.wait(tasks, **kwds)
async def broadcast_client(target: Client):
"""Send info about the target to all clients of its session."""
await wait(send_client(c, target) for c in target.session_clients())
async def broadcast_clients(session: Session):
"""Send info about all clients of a session to all its clients.
The result is a full refresh/sync for all clients of the session.
Since it sends a full dataset of everything to everyone, the traffic
amount could be quite large, so use sparingly and with care."""
await wait(send_clients(c) for c in session.clients.values())
async def broadcast_buzz(client: Client, time: int):
"""Send all clients of a session the timestamp of a buzz registered
for the client."""
await wait(send_buzz(c, client, time) for c in client.session_clients())
async def send_credentials(target: Client):
"""Send their user credentials to a client."""
await target.send("id", id=target.id, key=target.secret, path=target.path)
async def send_keys_to_the_city(target: Client):
"""Send the session admin credentials to a client."""
if not target.session:
return
await target.send("session_key", path=target.path, key=target.session.secret)
async def send_hello(target: Client):
msgs = [send_time(target), send_credentials(target), send_clients(target)]
if target.is_admin:
msgs.append(send_keys_to_the_city(target))
await wait(msgs)
async def send_heartbeat(target: Client):
await asyncio.sleep(5.0)
await send_time(target)
def printable(s: str) -> str:
# See https://www.unicode.org/versions/Unicode13.0.0/ch04.pdf "Table 4-4."
return "".join(c for c in s if not unicodedata.category(c).startswith("C"))
async def handle_messages(client: Client):
async for message in client.messages:
log.debug("[%s] got a message: %a", client, message)
mdata = loads(message)
if mdata["type"] == "buzz":
time = mdata["value"]
log.info("[%s] buzz: %a", client, time)
# todo: check time against perf_counter_ns
await broadcast_buzz(client, time)
elif mdata["type"] == "name":
name = printable(mdata["value"])
log.info("[%s] new name: %a", client, name)
client.name = name
await broadcast_client(client)
elif mdata["type"] == "login":
assert client.session
target_id = UserId(mdata["value"]["id"])
existent = client.session.clients.get(target_id)
if existent is None:
log.info(
"[%s] tried to log in as non-existent user: %a", client, target_id
)
return
if existent.is_active:
log.info("[%s] cannot log in as active user: %a", client, target_id)
return
if existent.is_reclaimed:
log.info("[%s] client already reclaimed: %a", client, target_id)
return
if not compare_digest(mdata["value"]["key"], existent.secret):
log.info(
"[%s] failed to log in as existing user: %a", client, target_id
)
return
log.info("[%s] logging in as existent user: %s", client, existent)
client.session.reclaim(client, inactive=existent)
await send_hello(client)
await broadcast_clients(client.session)
else:
log.error("[%s] received borked message", client)
async def juggle(client: Client):
while client.is_connected:
done, pending = await wait(
[send_heartbeat(client), handle_messages(client)],
return_when=asyncio.FIRST_COMPLETED,
) # type: ignore # Pyright thinks wait may return None
for task in pending:
task.cancel()
async def connected(ws: Websocket, path: str):
path = printable(path)
client = Client(ws, path)
log.info("[%s] new client on %a", client, path)
try:
session = Session.get(client)
except LoginError as err:
log.error("[%s] Error logging in: %s", client, err)
await client.send("error", reason=str(err))
await ws.close()
return
try:
await send_hello(client)
await juggle(client)
finally:
log.info("[%s] client disconnected", client)
client.is_active = False
await broadcast_clients(session)
await asyncio.sleep(
config.client_timeout_s
) # Give the user the opportunity to log-in again as the same client.
if not client.is_reclaimed:
session.remove(client)
await broadcast_clients(session)
log.info("[%s] client gone", client)
if not session.is_alive:
session.destroy()
log.info("[%s] session gone: %s", client, session)
async def check_path(path: str, request_headers) -> Optional["websockets.HTTPResponse"]:
# We'll throw out anything not starting with a certain path prefix just to
# get rid of internet spam - mass scans for security problems, etc.
# No need to waste resources on this kinda crap.
# Ideally the same rule should already be enforced by an upstream proxy.
if not path.startswith(config.path_prefix):
return (HTTPStatus.FORBIDDEN, {}, b"")
def server(host: str, port: int):
return websockets.serve(connected, host, port, process_request=check_path)