refactor to make session handling more automatic

This commit is contained in:
ducklet 2021-01-30 15:14:23 +01:00
parent f4cf26a33e
commit 476f3d7a49
3 changed files with 48 additions and 23 deletions

View file

@ -1,3 +1,5 @@
import os
path_prefix = "/quiz/" path_prefix = "/quiz/"
ws_host = "0.0.0.0" ws_host = os.getenv("WS_HOST", "0.0.0.0")
ws_port = 8765 ws_port = int(os.getenv("WS_PORT", 8765))

View file

@ -5,9 +5,9 @@ from dataclasses import dataclass, field
from json import dumps, loads from json import dumps, loads
from secrets import token_hex from secrets import token_hex
from time import perf_counter_ns from time import perf_counter_ns
from typing import * from typing import * # pyright: reportWildcardImportFromLibrary=false
import websockets import websockets # type: ignore
from . import config from . import config
@ -15,6 +15,7 @@ log = logging.getLogger(__name__)
Path = str Path = str
Token = NewType("Token", str) Token = NewType("Token", str)
UserId = Token
Websocket = websockets.WebSocketServerProtocol Websocket = websockets.WebSocketServerProtocol
@ -26,7 +27,7 @@ def token() -> Token:
class Client: class Client:
ws: Websocket ws: Websocket
path: Path path: Path
id: Token = field(default_factory=token) id: UserId = field(default_factory=token)
name: str = "" name: str = ""
points: int = 0 points: int = 0
secret: Token = field(default_factory=token) secret: Token = field(default_factory=token)
@ -51,9 +52,38 @@ class Client:
@dataclass @dataclass
class Session: class Session:
path: Path path: Path
admin: Token admin: UserId
secret: Token = field(default_factory=token) secret: Token = field(default_factory=token)
clients: dict[Token, Client] = field(default_factory=dict) clients: dict[UserId, Client] = field(default_factory=dict)
sessions: ClassVar[dict[Path, "Session"]] = {}
@classmethod
def get(cls, client: Client) -> "Session":
is_new = client.path not in cls.sessions
if is_new:
cls.sessions[client.path] = Session(client.path, client.id)
s = cls.sessions[client.path]
s.clients[client.id] = client
client.session = s # Note: This creates a ref cycle.
return s
def destroy(self):
for c in list(self.clients.values()):
self.remove(c)
assert len(self.clients) == 0
del Session.sessions[self.path]
@property
def is_alive(self):
return bool(self.clients)
def remove(self, client: Client):
client.session = (
None # Not sure if this is necessary, but it breaks the ref cycle.
)
del self.clients[client.id]
def msg(type_: str, **args): def msg(type_: str, **args):
@ -154,14 +184,11 @@ async def juggle(client: Client):
done, pending = await wait( done, pending = await wait(
[send_heartbeat(client), handle_messages(client)], [send_heartbeat(client), handle_messages(client)],
return_when=asyncio.FIRST_COMPLETED, return_when=asyncio.FIRST_COMPLETED,
) ) # type: ignore # Pyright thinks wait may return None
for task in pending: for task in pending:
task.cancel() task.cancel()
sessions: dict[Path, Session] = {}
async def connected(ws: Websocket, path: str): async def connected(ws: Websocket, path: str):
# We'll throw out anything not starting with a certain path prefix just to # We'll throw out anything not starting with a certain path prefix just to
# get rid of internet spam - mass scans for security problems, etc. # get rid of internet spam - mass scans for security problems, etc.
@ -174,25 +201,19 @@ async def connected(ws: Websocket, path: str):
client = Client(ws, path) client = Client(ws, path)
log.info("[%s] new client on %a", client, path) log.info("[%s] new client on %a", client, path)
if path not in sessions: session = Session.get(client)
sessions[path] = Session(path, admin=client.id) if len(session.clients) == 1:
log.info("[%s] new session on %a", client, path) log.info("[%s] new session on %a", client, path)
sessions[path].clients[client.id] = client
client.session = sessions[path] # Note: This creates a ref cycle.
try: try:
await send_hello(client) await send_hello(client)
await juggle(client) await juggle(client)
finally: finally:
log.info("[%s] client disconnected", client) log.info("[%s] client disconnected", client)
client.session = ( session.remove(client)
None # Not sure if this is necessary, but it breaks the ref cycle.
)
del sessions[path].clients[client.id]
await broadcast_clients(client) await broadcast_clients(client)
# Clean up sessions map if not session.is_alive:
if not sessions[path].clients: session.destroy()
del sessions[path]
def server(host: str, port: int): def server(host: str, port: int):

View file

@ -3,12 +3,14 @@
image=$(cat "$RUN_DIR"/.dockerimage) image=$(cat "$RUN_DIR"/.dockerimage)
tag=latest tag=latest
pubport=8765
set -x set -x
exec docker run --init --name dumpr-quiz-ws \ exec docker run --init --name dumpr-quiz-ws \
--rm \ --rm \
--read-only \ --read-only \
--label org.dumpr.quiz.service=ws \ --label org.dumpr.quiz.service=ws \
-p 8765:8765 \ -p "$pubport":8765 \
-v "$RUN_DIR":/var/quiz:ro \ -v "$RUN_DIR":/var/quiz:ro \
"$image":"$tag" "$image":"$tag"