350 lines
11 KiB
Python
350 lines
11 KiB
Python
import asyncio
|
|
import logging
|
|
import unicodedata
|
|
from dataclasses import dataclass, field
|
|
from http import HTTPStatus
|
|
from json import dumps, loads
|
|
from secrets import compare_digest, token_hex
|
|
from time import perf_counter_ns
|
|
from typing import * # pyright: reportWildcardImportFromLibrary=false
|
|
|
|
import websockets
|
|
|
|
from . import config
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
Path = str
|
|
Token = NewType("Token", str)
|
|
UserId = Token
|
|
Websocket = websockets.WebSocketServerProtocol
|
|
|
|
|
|
def token() -> Token:
|
|
return Token(token_hex(8))
|
|
|
|
|
|
@dataclass
|
|
class Client:
|
|
ws: Optional[Websocket]
|
|
path: Path
|
|
id: UserId = field(default_factory=token)
|
|
is_reclaimed: bool = False
|
|
name: str = ""
|
|
points: int = 0
|
|
secret: Token = field(default_factory=token)
|
|
session: Optional["Session"] = None
|
|
|
|
def __str__(self):
|
|
s = f"{ascii(self.id)[1:-1]}:{ascii(self.name)[1:-1]}"
|
|
if self.is_reclaimed:
|
|
return "<reclaimed>" + s
|
|
return s
|
|
|
|
@property
|
|
def is_admin(self):
|
|
return self.session is not None and self.session.admin == self.id
|
|
|
|
@property
|
|
def is_connected(self):
|
|
return self.ws and self.ws.open
|
|
|
|
def session_clients(self, include_self=True) -> Iterable["Client"]:
|
|
if self.session is None:
|
|
return []
|
|
clients = self.session.clients.values()
|
|
if include_self:
|
|
return clients
|
|
return [c for c in clients if c.id != self.id]
|
|
|
|
async def send(self, type_: str, **args):
|
|
if self.ws is not None:
|
|
await self.ws.send(msg(type_, **args))
|
|
|
|
@property
|
|
async def messages(self) -> AsyncIterable[str]:
|
|
if self.ws is None:
|
|
return
|
|
async for message in self.ws:
|
|
yield str(message)
|
|
|
|
@property
|
|
def is_active(self):
|
|
return self.ws is not None
|
|
|
|
@is_active.setter
|
|
def is_active(self, x: bool):
|
|
assert x is False
|
|
self.ws = None
|
|
|
|
@property
|
|
def info(self):
|
|
return {
|
|
"name": self.name or "<noname>",
|
|
"id": self.id,
|
|
"active": self.is_active,
|
|
}
|
|
|
|
def _reclaim(self, other: "Client"):
|
|
"""Take the place of another.
|
|
|
|
This allows a Client to inherit the state of another Client. It
|
|
is used to allow a user to resume their previous connection,
|
|
basically a log-in to an existing Client."""
|
|
# The only thing we truly need to keep is our websocket, because
|
|
# all other code assumes a Client's websocket won't ever change.
|
|
assert (
|
|
not other.is_reclaimed
|
|
and not other.is_active
|
|
and self.is_active
|
|
and self.path == other.path
|
|
and self.session is other.session
|
|
)
|
|
|
|
# Load all relevant info from other.
|
|
self.id = other.id
|
|
self.name = other.name
|
|
self.points = other.points
|
|
self.secret = other.secret
|
|
|
|
# Invalidate other.
|
|
other.is_reclaimed = True
|
|
|
|
|
|
@dataclass
|
|
class Session:
|
|
path: Path
|
|
admin: UserId
|
|
|
|
secret: Token = field(default_factory=token)
|
|
clients: dict[UserId, Client] = field(default_factory=dict)
|
|
|
|
sessions: ClassVar[dict[Path, "Session"]] = {}
|
|
|
|
@classmethod
|
|
def get(cls, client: Client) -> "Session":
|
|
is_new = client.path not in cls.sessions
|
|
if is_new:
|
|
if len(cls.sessions) >= config.max_sessions:
|
|
raise LoginError("Too many sessions.")
|
|
log.info("[%s] new session: %a", client, client.path)
|
|
cls.sessions[client.path] = Session(client.path, client.id)
|
|
s = cls.sessions[client.path]
|
|
if len(s.clients) >= config.max_clients_per_session:
|
|
raise LoginError("Too many clients.")
|
|
s.clients[client.id] = client
|
|
client.session = s # Note: This creates a ref cycle.
|
|
return s
|
|
|
|
def destroy(self):
|
|
for c in list(self.clients.values()):
|
|
self.remove(c)
|
|
assert len(self.clients) == 0
|
|
del Session.sessions[self.path]
|
|
|
|
@property
|
|
def is_alive(self):
|
|
return bool(self.clients)
|
|
|
|
def remove(self, client: Client):
|
|
client.session = (
|
|
None # Not sure if this is necessary, but it breaks the ref cycle.
|
|
)
|
|
del self.clients[client.id]
|
|
|
|
def reclaim(self, client: Client, inactive: Client):
|
|
"""Replace an inactive client with another in the session.
|
|
|
|
This will transfer all state (including points and the
|
|
client's ID) from the inactive client to the claiming one.
|
|
"""
|
|
assert client.session is inactive.session is self
|
|
del self.clients[client.id] # Remove the successor from the client pool.
|
|
client._reclaim(inactive) # The successor now owns the ID & other state.
|
|
self.clients[client.id] = client # The ref to the inactive client is replaced.
|
|
|
|
def __str__(self):
|
|
return f"Session(path={self.path!a}, admin={self.admin!a})"
|
|
|
|
|
|
class LoginError(RuntimeError):
|
|
pass
|
|
|
|
|
|
def msg(type_: str, **args):
|
|
return dumps({"type": type_, **args})
|
|
|
|
|
|
async def send_time(target: Client):
|
|
"""Send the current server time to the target."""
|
|
await target.send("time", value=perf_counter_ns())
|
|
|
|
|
|
async def send_buzz(target: Client, client: Client, time: int):
|
|
"""Send target the timestamp of a buzz registered for the client."""
|
|
await target.send("buzz", client=client.id, time=time)
|
|
|
|
|
|
async def send_clients(target: Client):
|
|
"""Send info about all connected clients of a session to the target."""
|
|
await target.send("clients", value=[c.info for c in target.session_clients()])
|
|
|
|
|
|
async def send_client(target: Client, client: Client):
|
|
"""Send info about the client to the target."""
|
|
await target.send("client", **client.info)
|
|
|
|
|
|
async def wait(coros, **kwds):
|
|
"""Schedule and wait for the given coroutines to complete."""
|
|
tasks = [asyncio.create_task(f) for f in coros]
|
|
if not tasks:
|
|
return
|
|
return await asyncio.wait(tasks, **kwds)
|
|
|
|
|
|
async def broadcast_client(target: Client):
|
|
"""Send info about the target to all clients of its session."""
|
|
await wait(send_client(c, target) for c in target.session_clients())
|
|
|
|
|
|
async def broadcast_clients(session: Session):
|
|
"""Send info about all clients of a session to all its clients.
|
|
|
|
The result is a full refresh/sync for all clients of the session.
|
|
Since it sends a full dataset of everything to everyone, the traffic
|
|
amount could be quite large, so use sparingly and with care."""
|
|
await wait(send_clients(c) for c in session.clients.values())
|
|
|
|
|
|
async def broadcast_buzz(client: Client, time: int):
|
|
"""Send all clients of a session the timestamp of a buzz registered
|
|
for the client."""
|
|
await wait(send_buzz(c, client, time) for c in client.session_clients())
|
|
|
|
|
|
async def send_credentials(target: Client):
|
|
"""Send their user credentials to a client."""
|
|
await target.send("id", id=target.id, key=target.secret, path=target.path)
|
|
|
|
|
|
async def send_keys_to_the_city(target: Client):
|
|
"""Send the session admin credentials to a client."""
|
|
if not target.session:
|
|
return
|
|
await target.send("session_key", path=target.path, key=target.session.secret)
|
|
|
|
|
|
async def send_hello(target: Client):
|
|
msgs = [send_time(target), send_credentials(target), send_clients(target)]
|
|
if target.is_admin:
|
|
msgs.append(send_keys_to_the_city(target))
|
|
await wait(msgs)
|
|
|
|
|
|
async def send_heartbeat(target: Client):
|
|
await asyncio.sleep(5.0)
|
|
await send_time(target)
|
|
|
|
|
|
def printable(s: str) -> str:
|
|
# See https://www.unicode.org/versions/Unicode13.0.0/ch04.pdf "Table 4-4."
|
|
return "".join(c for c in s if not unicodedata.category(c).startswith("C"))
|
|
|
|
|
|
async def handle_messages(client: Client):
|
|
async for message in client.messages:
|
|
log.debug("[%s] got a message: %a", client, message)
|
|
mdata = loads(message)
|
|
if mdata["type"] == "buzz":
|
|
time = mdata["value"]
|
|
log.info("[%s] buzz: %a", client, time)
|
|
# todo: check time against perf_counter_ns
|
|
await broadcast_buzz(client, time)
|
|
elif mdata["type"] == "name":
|
|
name = printable(mdata["value"])
|
|
log.info("[%s] new name: %a", client, name)
|
|
client.name = name
|
|
await broadcast_client(client)
|
|
elif mdata["type"] == "login":
|
|
assert client.session
|
|
target_id = UserId(mdata["value"]["id"])
|
|
existent = client.session.clients.get(target_id)
|
|
if existent is None:
|
|
log.info(
|
|
"[%s] tried to log in as non-existent user: %a", client, target_id
|
|
)
|
|
return
|
|
if existent.is_active:
|
|
log.info("[%s] cannot log in as active user: %a", client, target_id)
|
|
return
|
|
if existent.is_reclaimed:
|
|
log.info("[%s] client already reclaimed: %a", client, target_id)
|
|
return
|
|
if not compare_digest(mdata["value"]["key"], existent.secret):
|
|
log.info(
|
|
"[%s] failed to log in as existing user: %a", client, target_id
|
|
)
|
|
return
|
|
log.info("[%s] logging in as existent user: %s", client, existent)
|
|
client.session.reclaim(client, inactive=existent)
|
|
await send_hello(client)
|
|
await broadcast_clients(client.session)
|
|
else:
|
|
log.error("[%s] received borked message", client)
|
|
|
|
|
|
async def juggle(client: Client):
|
|
while client.is_connected:
|
|
done, pending = await wait(
|
|
[send_heartbeat(client), handle_messages(client)],
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
) # type: ignore # Pyright thinks wait may return None
|
|
for task in pending:
|
|
task.cancel()
|
|
|
|
|
|
async def connected(ws: Websocket, path: str):
|
|
path = printable(path)
|
|
|
|
client = Client(ws, path)
|
|
log.info("[%s] new client on %a", client, path)
|
|
|
|
try:
|
|
session = Session.get(client)
|
|
except LoginError as err:
|
|
log.error("[%s] Error logging in: %s", client, err)
|
|
await client.send("error", reason=str(err))
|
|
await ws.close()
|
|
return
|
|
|
|
try:
|
|
await send_hello(client)
|
|
await juggle(client)
|
|
finally:
|
|
log.info("[%s] client disconnected", client)
|
|
client.is_active = False
|
|
await broadcast_clients(session)
|
|
await asyncio.sleep(
|
|
config.client_timeout_s
|
|
) # Give the user the opportunity to log-in again as the same client.
|
|
if not client.is_reclaimed:
|
|
session.remove(client)
|
|
await broadcast_clients(session)
|
|
log.info("[%s] client gone", client)
|
|
if not session.is_alive:
|
|
session.destroy()
|
|
log.info("[%s] session gone: %s", client, session)
|
|
|
|
|
|
async def check_path(path: str, request_headers) -> Optional["websockets.HTTPResponse"]:
|
|
# We'll throw out anything not starting with a certain path prefix just to
|
|
# get rid of internet spam - mass scans for security problems, etc.
|
|
# No need to waste resources on this kinda crap.
|
|
# Ideally the same rule should already be enforced by an upstream proxy.
|
|
if not path.startswith(config.path_prefix):
|
|
return (HTTPStatus.FORBIDDEN, {}, b"")
|
|
|
|
|
|
def server(host: str, port: int):
|
|
return websockets.serve(connected, host, port, process_request=check_path)
|