refactor to make session handling more automatic
This commit is contained in:
parent
f4cf26a33e
commit
476f3d7a49
3 changed files with 48 additions and 23 deletions
|
|
@ -1,3 +1,5 @@
|
|||
import os
|
||||
|
||||
path_prefix = "/quiz/"
|
||||
ws_host = "0.0.0.0"
|
||||
ws_port = 8765
|
||||
ws_host = os.getenv("WS_HOST", "0.0.0.0")
|
||||
ws_port = int(os.getenv("WS_PORT", 8765))
|
||||
|
|
|
|||
61
quiz/quiz.py
61
quiz/quiz.py
|
|
@ -5,9 +5,9 @@ from dataclasses import dataclass, field
|
|||
from json import dumps, loads
|
||||
from secrets import token_hex
|
||||
from time import perf_counter_ns
|
||||
from typing import *
|
||||
from typing import * # pyright: reportWildcardImportFromLibrary=false
|
||||
|
||||
import websockets
|
||||
import websockets # type: ignore
|
||||
|
||||
from . import config
|
||||
|
||||
|
|
@ -15,6 +15,7 @@ log = logging.getLogger(__name__)
|
|||
|
||||
Path = str
|
||||
Token = NewType("Token", str)
|
||||
UserId = Token
|
||||
Websocket = websockets.WebSocketServerProtocol
|
||||
|
||||
|
||||
|
|
@ -26,7 +27,7 @@ def token() -> Token:
|
|||
class Client:
|
||||
ws: Websocket
|
||||
path: Path
|
||||
id: Token = field(default_factory=token)
|
||||
id: UserId = field(default_factory=token)
|
||||
name: str = ""
|
||||
points: int = 0
|
||||
secret: Token = field(default_factory=token)
|
||||
|
|
@ -51,9 +52,38 @@ class Client:
|
|||
@dataclass
|
||||
class Session:
|
||||
path: Path
|
||||
admin: Token
|
||||
admin: UserId
|
||||
|
||||
secret: Token = field(default_factory=token)
|
||||
clients: dict[Token, Client] = field(default_factory=dict)
|
||||
clients: dict[UserId, Client] = field(default_factory=dict)
|
||||
|
||||
sessions: ClassVar[dict[Path, "Session"]] = {}
|
||||
|
||||
@classmethod
|
||||
def get(cls, client: Client) -> "Session":
|
||||
is_new = client.path not in cls.sessions
|
||||
if is_new:
|
||||
cls.sessions[client.path] = Session(client.path, client.id)
|
||||
s = cls.sessions[client.path]
|
||||
s.clients[client.id] = client
|
||||
client.session = s # Note: This creates a ref cycle.
|
||||
return s
|
||||
|
||||
def destroy(self):
|
||||
for c in list(self.clients.values()):
|
||||
self.remove(c)
|
||||
assert len(self.clients) == 0
|
||||
del Session.sessions[self.path]
|
||||
|
||||
@property
|
||||
def is_alive(self):
|
||||
return bool(self.clients)
|
||||
|
||||
def remove(self, client: Client):
|
||||
client.session = (
|
||||
None # Not sure if this is necessary, but it breaks the ref cycle.
|
||||
)
|
||||
del self.clients[client.id]
|
||||
|
||||
|
||||
def msg(type_: str, **args):
|
||||
|
|
@ -154,14 +184,11 @@ async def juggle(client: Client):
|
|||
done, pending = await wait(
|
||||
[send_heartbeat(client), handle_messages(client)],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
) # type: ignore # Pyright thinks wait may return None
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
|
||||
sessions: dict[Path, Session] = {}
|
||||
|
||||
|
||||
async def connected(ws: Websocket, path: str):
|
||||
# We'll throw out anything not starting with a certain path prefix just to
|
||||
# get rid of internet spam - mass scans for security problems, etc.
|
||||
|
|
@ -174,25 +201,19 @@ async def connected(ws: Websocket, path: str):
|
|||
client = Client(ws, path)
|
||||
log.info("[%s] new client on %a", client, path)
|
||||
|
||||
if path not in sessions:
|
||||
sessions[path] = Session(path, admin=client.id)
|
||||
session = Session.get(client)
|
||||
if len(session.clients) == 1:
|
||||
log.info("[%s] new session on %a", client, path)
|
||||
sessions[path].clients[client.id] = client
|
||||
client.session = sessions[path] # Note: This creates a ref cycle.
|
||||
|
||||
try:
|
||||
await send_hello(client)
|
||||
await juggle(client)
|
||||
finally:
|
||||
log.info("[%s] client disconnected", client)
|
||||
client.session = (
|
||||
None # Not sure if this is necessary, but it breaks the ref cycle.
|
||||
)
|
||||
del sessions[path].clients[client.id]
|
||||
session.remove(client)
|
||||
await broadcast_clients(client)
|
||||
# Clean up sessions map
|
||||
if not sessions[path].clients:
|
||||
del sessions[path]
|
||||
if not session.is_alive:
|
||||
session.destroy()
|
||||
|
||||
|
||||
def server(host: str, port: int):
|
||||
|
|
|
|||
|
|
@ -3,12 +3,14 @@
|
|||
image=$(cat "$RUN_DIR"/.dockerimage)
|
||||
tag=latest
|
||||
|
||||
pubport=8765
|
||||
|
||||
set -x
|
||||
|
||||
exec docker run --init --name dumpr-quiz-ws \
|
||||
--rm \
|
||||
--read-only \
|
||||
--label org.dumpr.quiz.service=ws \
|
||||
-p 8765:8765 \
|
||||
-p "$pubport":8765 \
|
||||
-v "$RUN_DIR":/var/quiz:ro \
|
||||
"$image":"$tag"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue