From f0f69c1954b756c3347d92300875b3736756e3fa Mon Sep 17 00:00:00 2001 From: ducklet Date: Sun, 19 May 2024 11:10:08 +0200 Subject: [PATCH] chore: add more typing info --- unwind/web.py | 79 +++++++++++++++++++++++++------------------- unwind/web_models.py | 4 +-- 2 files changed, 47 insertions(+), 36 deletions(-) diff --git a/unwind/web.py b/unwind/web.py index b4ba575..6a2a0fc 100644 --- a/unwind/web.py +++ b/unwind/web.py @@ -3,7 +3,7 @@ import contextlib import logging import secrets from json.decoder import JSONDecodeError -from typing import Literal, overload +from typing import Any, Literal, Never, TypeGuard, overload from starlette.applications import Starlette from starlette.authentication import ( @@ -20,14 +20,14 @@ from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.cors import CORSMiddleware from starlette.middleware.gzip import GZipMiddleware -from starlette.requests import HTTPConnection +from starlette.requests import HTTPConnection, Request from starlette.responses import JSONResponse from starlette.routing import Mount, Route from . import config, db, imdb, imdb_import, web_models from .db import close_connection_pool, find_movies, find_ratings, open_connection_pool from .middleware.responsetime import ResponseTimeMiddleware -from .models import Group, Movie, User, asplain +from .models import JSON, Access, Group, Movie, User, asplain from .types import ULID from .utils import b64decode, b64encode, phc_compare, phc_scrypt @@ -83,11 +83,11 @@ class BearerAuthBackend(AuthenticationBackend): return AuthCredentials(["authenticated", *roles]), user -def truthy(s: str): +def truthy(s: str | None) -> bool: return bool(s) and s.lower() in {"1", "yes", "true"} -_Yearcomp = Literal["<", "=", ">"] +type _Yearcomp = Literal["<", "=", ">"] def yearcomp(s: str) -> tuple[_Yearcomp, int] | None: @@ -103,7 +103,7 @@ def yearcomp(s: str) -> tuple[_Yearcomp, int] | None: def as_int( - x, *, max: int | None = None, min: int | None = 1, default: int | None = None + x: Any, *, max: int | None = None, min: int | None = 1, default: int | None = None ) -> int: try: if not isinstance(x, int): @@ -121,9 +121,9 @@ def as_int( return default -def as_ulid(s: str) -> ULID: +def as_ulid(s: Any) -> ULID: try: - if not s: + if not isinstance(s, str) or not s: raise ValueError("Invalid ULID.") return ULID(s) @@ -133,14 +133,17 @@ def as_ulid(s: str) -> ULID: @overload -async def json_from_body(request) -> dict: ... +async def json_from_body(request: Request) -> dict[str, JSON]: ... @overload -async def json_from_body(request, keys: list[str]) -> list: ... +async def json_from_body(request: Request, keys: list[str]) -> list[JSON]: ... -async def json_from_body(request, keys: list[str] | None = None): +async def json_from_body( + request: Request, keys: list[str] | None = None +) -> dict[str, JSON] | list[JSON]: + data: dict[str, JSON] if not await request.body(): data = {} @@ -150,6 +153,9 @@ async def json_from_body(request, keys: list[str] | None = None): except JSONDecodeError as err: raise HTTPException(422, "Invalid JSON content.") from err + if not isinstance(data, dict): + raise HTTPException(422, f"Invalid JSON type: {type(data)!a}") + if not keys: return data @@ -159,11 +165,11 @@ async def json_from_body(request, keys: list[str] | None = None): raise HTTPException(422, f"Missing data for key: {err.args[0]}") from err -def is_admin(request): +def is_admin(request: Request) -> bool: return "admin" in request.auth.scopes -async def auth_user(request) -> User | None: +async def auth_user(request: Request) -> User | None: if not isinstance(request.user, AuthedUser): return @@ -192,7 +198,7 @@ def route(path: str, *, methods: list[str] | None = None, **kwds): @route("/groups/{group_id}/ratings") -async def get_ratings_for_group(request): +async def get_ratings_for_group(request: Request) -> JSONResponse: group_id = as_ulid(request.path_params["group_id"]) async with db.new_connection() as conn: @@ -250,13 +256,13 @@ def not_found(reason: str = "Not Found"): return JSONResponse({"error": reason}, status_code=404) -def not_implemented(): +def not_implemented() -> Never: raise HTTPException(404, "Not yet implemented.") @route("/movies") @requires(["authenticated"]) -async def list_movies(request): +async def list_movies(request: Request) -> JSONResponse: params = request.query_params user = await auth_user(request) @@ -329,13 +335,13 @@ async def list_movies(request): @route("/movies", methods=["POST"]) @requires(["authenticated", "admin"]) -async def add_movie(request): +async def add_movie(request: Request) -> JSONResponse: not_implemented() @route("/movies/_reload_imdb", methods=["GET"]) @requires(["authenticated", "admin"]) -async def progress_for_load_imdb_movies(request): +async def progress_for_load_imdb_movies(request: Request) -> JSONResponse: async with db.new_connection() as conn: progress = await db.get_import_progress(conn) if not progress: @@ -371,7 +377,7 @@ _import_lock = asyncio.Lock() @route("/movies/_reload_imdb", methods=["POST"]) @requires(["authenticated", "admin"]) -async def load_imdb_movies(request): +async def load_imdb_movies(request: Request) -> JSONResponse: params = request.query_params force = truthy(params.get("force")) @@ -395,7 +401,7 @@ async def load_imdb_movies(request): @route("/users") @requires(["authenticated", "admin"]) -async def list_users(request): +async def list_users(request: Request) -> JSONResponse: async with db.new_connection() as conn: users = await db.get_all(conn, User) @@ -404,7 +410,7 @@ async def list_users(request): @route("/users", methods=["POST"]) @requires(["authenticated", "admin"]) -async def add_user(request): +async def add_user(request: Request) -> JSONResponse: name, imdb_id = await json_from_body(request, ["name", "imdb_id"]) # XXX restrict name @@ -426,7 +432,7 @@ async def add_user(request): @route("/users/{user_id}") @requires(["authenticated"]) -async def show_user(request): +async def show_user(request: Request) -> JSONResponse: user_id = as_ulid(request.path_params["user_id"]) if is_admin(request): @@ -455,7 +461,7 @@ async def show_user(request): @route("/users/{user_id}", methods=["DELETE"]) @requires(["authenticated", "admin"]) -async def remove_user(request): +async def remove_user(request: Request) -> JSONResponse: user_id = as_ulid(request.path_params["user_id"]) async with db.new_connection() as conn: @@ -473,7 +479,7 @@ async def remove_user(request): @route("/users/{user_id}", methods=["PATCH"]) @requires(["authenticated"]) -async def modify_user(request): +async def modify_user(request: Request) -> JSONResponse: user_id = as_ulid(request.path_params["user_id"]) if is_admin(request): @@ -520,9 +526,13 @@ async def modify_user(request): return JSONResponse(asplain(user)) +def is_valid_access(x: Any) -> TypeGuard[Access]: + return isinstance(x, str) and x in set("riw") + + @route("/users/{user_id}/groups", methods=["POST"]) @requires(["authenticated", "admin"]) -async def add_group_to_user(request): +async def add_group_to_user(request: Request) -> JSONResponse: user_id = as_ulid(request.path_params["user_id"]) async with db.new_connection() as conn: @@ -537,7 +547,7 @@ async def add_group_to_user(request): if not group: return not_found("Group not found") - if access not in set("riw"): + if not is_valid_access(access): raise HTTPException(422, "Invalid access level.") user.set_access(group_id, access) @@ -549,19 +559,19 @@ async def add_group_to_user(request): @route("/users/{user_id}/ratings") @requires(["private"]) -async def ratings_for_user(request): +async def ratings_for_user(request: Request) -> JSONResponse: not_implemented() @route("/users/{user_id}/ratings", methods=["PUT"]) @requires("authenticated") -async def set_rating_for_user(request): +async def set_rating_for_user(request: Request) -> JSONResponse: not_implemented() @route("/users/_reload_ratings", methods=["POST"]) @requires(["authenticated", "admin"]) -async def load_imdb_user_ratings(request): +async def load_imdb_user_ratings(request: Request) -> JSONResponse: ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()] return JSONResponse({"new_ratings": [asplain(r) for r in ratings]}) @@ -569,7 +579,7 @@ async def load_imdb_user_ratings(request): @route("/groups") @requires(["authenticated", "admin"]) -async def list_groups(request): +async def list_groups(request: Request) -> JSONResponse: async with db.new_connection() as conn: groups = await db.get_all(conn, Group) @@ -578,7 +588,7 @@ async def list_groups(request): @route("/groups", methods=["POST"]) @requires(["authenticated", "admin"]) -async def add_group(request): +async def add_group(request: Request) -> JSONResponse: (name,) = await json_from_body(request, ["name"]) # XXX restrict name @@ -592,7 +602,7 @@ async def add_group(request): @route("/groups/{group_id}/users", methods=["POST"]) @requires(["authenticated"]) -async def add_user_to_group(request): +async def add_user_to_group(request: Request) -> JSONResponse: group_id = as_ulid(request.path_params["group_id"]) async with db.new_connection() as conn: group = await db.get(conn, Group, id=str(group_id)) @@ -628,11 +638,12 @@ async def add_user_to_group(request): return JSONResponse(asplain(group)) -async def http_exception(request, exc): +async def http_exception(request: Request, exc: Exception) -> JSONResponse: + assert isinstance(exc, HTTPException) return JSONResponse({"error": exc.detail}, status_code=exc.status_code) -def auth_error(request, err): +def auth_error(conn: HTTPConnection, err: Exception) -> JSONResponse: return unauthorized(str(err)) diff --git a/unwind/web_models.py b/unwind/web_models.py index 6e83e1d..6a2c331 100644 --- a/unwind/web_models.py +++ b/unwind/web_models.py @@ -3,8 +3,8 @@ from typing import Container, Iterable from . import imdb, models -URL = str -Score100 = int # [0, 100] +type URL = str +type Score100 = int # [0, 100] @dataclass