refactor to make session handling more automatic

This commit is contained in:
ducklet 2021-01-30 15:14:23 +01:00
parent f4cf26a33e
commit 476f3d7a49
3 changed files with 48 additions and 23 deletions

View file

@ -1,3 +1,5 @@
import os
path_prefix = "/quiz/"
ws_host = "0.0.0.0"
ws_port = 8765
ws_host = os.getenv("WS_HOST", "0.0.0.0")
ws_port = int(os.getenv("WS_PORT", 8765))

View file

@ -5,9 +5,9 @@ from dataclasses import dataclass, field
from json import dumps, loads
from secrets import token_hex
from time import perf_counter_ns
from typing import *
from typing import * # pyright: reportWildcardImportFromLibrary=false
import websockets
import websockets # type: ignore
from . import config
@ -15,6 +15,7 @@ log = logging.getLogger(__name__)
Path = str
Token = NewType("Token", str)
UserId = Token
Websocket = websockets.WebSocketServerProtocol
@ -26,7 +27,7 @@ def token() -> Token:
class Client:
ws: Websocket
path: Path
id: Token = field(default_factory=token)
id: UserId = field(default_factory=token)
name: str = ""
points: int = 0
secret: Token = field(default_factory=token)
@ -51,9 +52,38 @@ class Client:
@dataclass
class Session:
path: Path
admin: Token
admin: UserId
secret: Token = field(default_factory=token)
clients: dict[Token, Client] = field(default_factory=dict)
clients: dict[UserId, Client] = field(default_factory=dict)
sessions: ClassVar[dict[Path, "Session"]] = {}
@classmethod
def get(cls, client: Client) -> "Session":
is_new = client.path not in cls.sessions
if is_new:
cls.sessions[client.path] = Session(client.path, client.id)
s = cls.sessions[client.path]
s.clients[client.id] = client
client.session = s # Note: This creates a ref cycle.
return s
def destroy(self):
for c in list(self.clients.values()):
self.remove(c)
assert len(self.clients) == 0
del Session.sessions[self.path]
@property
def is_alive(self):
return bool(self.clients)
def remove(self, client: Client):
client.session = (
None # Not sure if this is necessary, but it breaks the ref cycle.
)
del self.clients[client.id]
def msg(type_: str, **args):
@ -154,14 +184,11 @@ async def juggle(client: Client):
done, pending = await wait(
[send_heartbeat(client), handle_messages(client)],
return_when=asyncio.FIRST_COMPLETED,
)
) # type: ignore # Pyright thinks wait may return None
for task in pending:
task.cancel()
sessions: dict[Path, Session] = {}
async def connected(ws: Websocket, path: str):
# We'll throw out anything not starting with a certain path prefix just to
# get rid of internet spam - mass scans for security problems, etc.
@ -174,25 +201,19 @@ async def connected(ws: Websocket, path: str):
client = Client(ws, path)
log.info("[%s] new client on %a", client, path)
if path not in sessions:
sessions[path] = Session(path, admin=client.id)
session = Session.get(client)
if len(session.clients) == 1:
log.info("[%s] new session on %a", client, path)
sessions[path].clients[client.id] = client
client.session = sessions[path] # Note: This creates a ref cycle.
try:
await send_hello(client)
await juggle(client)
finally:
log.info("[%s] client disconnected", client)
client.session = (
None # Not sure if this is necessary, but it breaks the ref cycle.
)
del sessions[path].clients[client.id]
session.remove(client)
await broadcast_clients(client)
# Clean up sessions map
if not sessions[path].clients:
del sessions[path]
if not session.is_alive:
session.destroy()
def server(host: str, port: int):