add basic session & credential dealing
This commit is contained in:
parent
f6bf544f54
commit
f4cf26a33e
7 changed files with 128 additions and 52 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from . import config
|
||||
from .quiz import server
|
||||
|
||||
|
||||
|
|
@ -12,7 +13,7 @@ def main():
|
|||
level=logging.INFO,
|
||||
)
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(server())
|
||||
asyncio.get_event_loop().run_until_complete(server(config.ws_host, config.ws_port))
|
||||
asyncio.get_event_loop().run_forever()
|
||||
|
||||
|
||||
|
|
|
|||
3
quiz/config.py
Normal file
3
quiz/config.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
path_prefix = "/quiz/"
|
||||
ws_host = "0.0.0.0"
|
||||
ws_port = 8765
|
||||
149
quiz/quiz.py
149
quiz/quiz.py
|
|
@ -1,7 +1,6 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import unicodedata
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from json import dumps, loads
|
||||
from secrets import token_hex
|
||||
|
|
@ -10,39 +9,70 @@ from typing import *
|
|||
|
||||
import websockets
|
||||
|
||||
from . import config
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
Path = str
|
||||
Token = NewType("Token", str)
|
||||
Websocket = websockets.WebSocketServerProtocol
|
||||
|
||||
|
||||
def token() -> Token:
|
||||
return Token(token_hex(8))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Client:
|
||||
ws: Websocket
|
||||
path: str
|
||||
id: str = field(default_factory=lambda: token_hex(8))
|
||||
path: Path
|
||||
id: Token = field(default_factory=token)
|
||||
name: str = ""
|
||||
points: int = 0
|
||||
secret: Token = field(default_factory=token)
|
||||
session: Optional["Session"] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{ascii(self.id)[1:-1]}:{ascii(self.name)[1:-1]}"
|
||||
|
||||
@property
|
||||
def is_admin(self):
|
||||
return self.session is not None and self.session.admin == self.id
|
||||
|
||||
sessions: dict[str, dict[str, Client]] = defaultdict(dict)
|
||||
|
||||
log = logging.getLogger("buzzer")
|
||||
def others(self, include_self=True) -> Iterable["Client"]:
|
||||
if self.session is None:
|
||||
return []
|
||||
clients = self.session.clients.values()
|
||||
if include_self:
|
||||
return clients
|
||||
return [c for c in clients if c.id != self.id]
|
||||
|
||||
|
||||
async def send_time(client):
|
||||
await client.ws.send(dumps({"type": "time", "value": perf_counter_ns()}))
|
||||
@dataclass
|
||||
class Session:
|
||||
path: Path
|
||||
admin: Token
|
||||
secret: Token = field(default_factory=token)
|
||||
clients: dict[Token, Client] = field(default_factory=dict)
|
||||
|
||||
|
||||
async def send_buzz(target, client, time):
|
||||
await target.ws.send(dumps({"type": "buzz", "client": client.id, "time": time}))
|
||||
def msg(type_: str, **args):
|
||||
return dumps({"type": type_, **args})
|
||||
|
||||
|
||||
async def send_clients(client):
|
||||
clients = [
|
||||
{"name": c.name or "<noname>", "id": c.id}
|
||||
for c in sessions[client.path].values()
|
||||
]
|
||||
await client.ws.send(dumps({"type": "clients", "value": clients}))
|
||||
async def send_time(client: Client):
|
||||
await client.ws.send(msg("time", value=perf_counter_ns()))
|
||||
|
||||
|
||||
async def send_buzz(target, client: Client, time):
|
||||
await target.ws.send(msg("buzz", client=client.id, time=time))
|
||||
|
||||
|
||||
async def send_clients(client: Client):
|
||||
if not client.session:
|
||||
return
|
||||
clients = [{"name": c.name or "<noname>", "id": c.id} for c in client.others()]
|
||||
await client.ws.send(msg("clients", value=clients))
|
||||
|
||||
|
||||
async def wait(coros, **kwds):
|
||||
|
|
@ -53,33 +83,45 @@ async def wait(coros, **kwds):
|
|||
return await asyncio.wait(tasks, **kwds)
|
||||
|
||||
|
||||
async def broadcast_client(client):
|
||||
msg = dumps(
|
||||
{
|
||||
"type": "client",
|
||||
"value": {"name": client.name or "<noname>", "id": client.id},
|
||||
}
|
||||
async def broadcast_client(client: Client):
|
||||
if not client.session:
|
||||
return
|
||||
m = msg("client", value={"name": client.name or "<noname>", "id": client.id})
|
||||
await wait(c.ws.send(m) for c in client.others())
|
||||
|
||||
|
||||
async def broadcast_clients(client: Client):
|
||||
if not client.session:
|
||||
return
|
||||
await wait(send_clients(c) for c in client.others())
|
||||
|
||||
|
||||
async def broadcast_buzz(client: Client, time):
|
||||
if not client.session:
|
||||
return
|
||||
await wait(send_buzz(c, client, time) for c in client.others())
|
||||
|
||||
|
||||
async def send_credentials(client: Client):
|
||||
await client.ws.send(msg("id", id=client.id, key=client.secret))
|
||||
|
||||
|
||||
async def send_keys_to_the_city(client: Client):
|
||||
if not client.session:
|
||||
return
|
||||
await client.ws.send(
|
||||
msg("session_key", path=client.path, key=client.session.secret)
|
||||
)
|
||||
await wait(c.ws.send(msg) for c in sessions[client.path].values())
|
||||
|
||||
|
||||
async def broadcast_clients(client):
|
||||
await wait(send_clients(c) for c in sessions[client.path].values())
|
||||
async def send_hello(client: Client):
|
||||
msgs = [send_time(client), send_credentials(client), send_clients(client)]
|
||||
if client.is_admin:
|
||||
msgs.append(send_keys_to_the_city(client))
|
||||
await wait(msgs)
|
||||
|
||||
|
||||
async def broadcast_buzz(client, time):
|
||||
await wait(send_buzz(c, client, time) for c in sessions[client.path].values())
|
||||
|
||||
|
||||
async def send_id(client):
|
||||
await client.ws.send(dumps({"type": "id", "value": client.id}))
|
||||
|
||||
|
||||
async def send_hello(client):
|
||||
await wait([send_time(client), send_id(client), send_clients(client)])
|
||||
|
||||
|
||||
async def send_heartbeat(client):
|
||||
async def send_heartbeat(client: Client):
|
||||
await asyncio.sleep(5.0)
|
||||
await send_time(client)
|
||||
|
||||
|
|
@ -89,7 +131,7 @@ def printable(s: str) -> str:
|
|||
return "".join(c for c in s if not unicodedata.category(c).startswith("C"))
|
||||
|
||||
|
||||
async def handle_messages(client):
|
||||
async def handle_messages(client: Client):
|
||||
async for message in client.ws:
|
||||
log.debug("[%s] got a message: %a", client, message)
|
||||
mdata = loads(message)
|
||||
|
|
@ -107,7 +149,7 @@ async def handle_messages(client):
|
|||
log.error("[%s] received borked message", client)
|
||||
|
||||
|
||||
async def juggle(client):
|
||||
async def juggle(client: Client):
|
||||
while client.ws.open:
|
||||
done, pending = await wait(
|
||||
[send_heartbeat(client), handle_messages(client)],
|
||||
|
|
@ -117,24 +159,41 @@ async def juggle(client):
|
|||
task.cancel()
|
||||
|
||||
|
||||
sessions: dict[Path, Session] = {}
|
||||
|
||||
|
||||
async def connected(ws: Websocket, path: str):
|
||||
if not path.startswith("/quiz/"):
|
||||
# We'll throw out anything not starting with a certain path prefix just to
|
||||
# get rid of internet spam - mass scans for security problems, etc.
|
||||
# No need to waste resources on this kinda crap.
|
||||
# Ideally the same rule should already be enforced by an upstream proxy.
|
||||
if not path.startswith(config.path_prefix):
|
||||
await ws.close()
|
||||
return
|
||||
|
||||
client = Client(ws, path)
|
||||
log.info("[%s] new client on %a", client, path)
|
||||
sessions[path][client.id] = client
|
||||
|
||||
if path not in sessions:
|
||||
sessions[path] = Session(path, admin=client.id)
|
||||
log.info("[%s] new session on %a", client, path)
|
||||
sessions[path].clients[client.id] = client
|
||||
client.session = sessions[path] # Note: This creates a ref cycle.
|
||||
|
||||
try:
|
||||
await send_hello(client)
|
||||
await juggle(client)
|
||||
finally:
|
||||
log.info("[%s] client disconnected", client)
|
||||
del sessions[path][client.id]
|
||||
client.session = (
|
||||
None # Not sure if this is necessary, but it breaks the ref cycle.
|
||||
)
|
||||
del sessions[path].clients[client.id]
|
||||
await broadcast_clients(client)
|
||||
# Clean up sessions map
|
||||
if not sessions[path]:
|
||||
if not sessions[path].clients:
|
||||
del sessions[path]
|
||||
|
||||
|
||||
def server(host="0.0.0.0", port=8765):
|
||||
def server(host: str, port: int):
|
||||
return websockets.serve(connected, host, port)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue