chore: add more typing info

This commit is contained in:
ducklet 2024-05-19 11:10:08 +02:00
parent 76a69b6340
commit f0f69c1954
2 changed files with 47 additions and 36 deletions

View file

@ -3,7 +3,7 @@ import contextlib
import logging import logging
import secrets import secrets
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from typing import Literal, overload from typing import Any, Literal, Never, TypeGuard, overload
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.authentication import ( from starlette.authentication import (
@ -20,14 +20,14 @@ from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.gzip import GZipMiddleware from starlette.middleware.gzip import GZipMiddleware
from starlette.requests import HTTPConnection from starlette.requests import HTTPConnection, Request
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from starlette.routing import Mount, Route from starlette.routing import Mount, Route
from . import config, db, imdb, imdb_import, web_models from . import config, db, imdb, imdb_import, web_models
from .db import close_connection_pool, find_movies, find_ratings, open_connection_pool from .db import close_connection_pool, find_movies, find_ratings, open_connection_pool
from .middleware.responsetime import ResponseTimeMiddleware from .middleware.responsetime import ResponseTimeMiddleware
from .models import Group, Movie, User, asplain from .models import JSON, Access, Group, Movie, User, asplain
from .types import ULID from .types import ULID
from .utils import b64decode, b64encode, phc_compare, phc_scrypt from .utils import b64decode, b64encode, phc_compare, phc_scrypt
@ -83,11 +83,11 @@ class BearerAuthBackend(AuthenticationBackend):
return AuthCredentials(["authenticated", *roles]), user return AuthCredentials(["authenticated", *roles]), user
def truthy(s: str): def truthy(s: str | None) -> bool:
return bool(s) and s.lower() in {"1", "yes", "true"} return bool(s) and s.lower() in {"1", "yes", "true"}
_Yearcomp = Literal["<", "=", ">"] type _Yearcomp = Literal["<", "=", ">"]
def yearcomp(s: str) -> tuple[_Yearcomp, int] | None: def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
@ -103,7 +103,7 @@ def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
def as_int( def as_int(
x, *, max: int | None = None, min: int | None = 1, default: int | None = None x: Any, *, max: int | None = None, min: int | None = 1, default: int | None = None
) -> int: ) -> int:
try: try:
if not isinstance(x, int): if not isinstance(x, int):
@ -121,9 +121,9 @@ def as_int(
return default return default
def as_ulid(s: str) -> ULID: def as_ulid(s: Any) -> ULID:
try: try:
if not s: if not isinstance(s, str) or not s:
raise ValueError("Invalid ULID.") raise ValueError("Invalid ULID.")
return ULID(s) return ULID(s)
@ -133,14 +133,17 @@ def as_ulid(s: str) -> ULID:
@overload @overload
async def json_from_body(request) -> dict: ... async def json_from_body(request: Request) -> dict[str, JSON]: ...
@overload @overload
async def json_from_body(request, keys: list[str]) -> list: ... async def json_from_body(request: Request, keys: list[str]) -> list[JSON]: ...
async def json_from_body(request, keys: list[str] | None = None): async def json_from_body(
request: Request, keys: list[str] | None = None
) -> dict[str, JSON] | list[JSON]:
data: dict[str, JSON]
if not await request.body(): if not await request.body():
data = {} data = {}
@ -150,6 +153,9 @@ async def json_from_body(request, keys: list[str] | None = None):
except JSONDecodeError as err: except JSONDecodeError as err:
raise HTTPException(422, "Invalid JSON content.") from err raise HTTPException(422, "Invalid JSON content.") from err
if not isinstance(data, dict):
raise HTTPException(422, f"Invalid JSON type: {type(data)!a}")
if not keys: if not keys:
return data return data
@ -159,11 +165,11 @@ async def json_from_body(request, keys: list[str] | None = None):
raise HTTPException(422, f"Missing data for key: {err.args[0]}") from err raise HTTPException(422, f"Missing data for key: {err.args[0]}") from err
def is_admin(request): def is_admin(request: Request) -> bool:
return "admin" in request.auth.scopes return "admin" in request.auth.scopes
async def auth_user(request) -> User | None: async def auth_user(request: Request) -> User | None:
if not isinstance(request.user, AuthedUser): if not isinstance(request.user, AuthedUser):
return return
@ -192,7 +198,7 @@ def route(path: str, *, methods: list[str] | None = None, **kwds):
@route("/groups/{group_id}/ratings") @route("/groups/{group_id}/ratings")
async def get_ratings_for_group(request): async def get_ratings_for_group(request: Request) -> JSONResponse:
group_id = as_ulid(request.path_params["group_id"]) group_id = as_ulid(request.path_params["group_id"])
async with db.new_connection() as conn: async with db.new_connection() as conn:
@ -250,13 +256,13 @@ def not_found(reason: str = "Not Found"):
return JSONResponse({"error": reason}, status_code=404) return JSONResponse({"error": reason}, status_code=404)
def not_implemented(): def not_implemented() -> Never:
raise HTTPException(404, "Not yet implemented.") raise HTTPException(404, "Not yet implemented.")
@route("/movies") @route("/movies")
@requires(["authenticated"]) @requires(["authenticated"])
async def list_movies(request): async def list_movies(request: Request) -> JSONResponse:
params = request.query_params params = request.query_params
user = await auth_user(request) user = await auth_user(request)
@ -329,13 +335,13 @@ async def list_movies(request):
@route("/movies", methods=["POST"]) @route("/movies", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def add_movie(request): async def add_movie(request: Request) -> JSONResponse:
not_implemented() not_implemented()
@route("/movies/_reload_imdb", methods=["GET"]) @route("/movies/_reload_imdb", methods=["GET"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def progress_for_load_imdb_movies(request): async def progress_for_load_imdb_movies(request: Request) -> JSONResponse:
async with db.new_connection() as conn: async with db.new_connection() as conn:
progress = await db.get_import_progress(conn) progress = await db.get_import_progress(conn)
if not progress: if not progress:
@ -371,7 +377,7 @@ _import_lock = asyncio.Lock()
@route("/movies/_reload_imdb", methods=["POST"]) @route("/movies/_reload_imdb", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def load_imdb_movies(request): async def load_imdb_movies(request: Request) -> JSONResponse:
params = request.query_params params = request.query_params
force = truthy(params.get("force")) force = truthy(params.get("force"))
@ -395,7 +401,7 @@ async def load_imdb_movies(request):
@route("/users") @route("/users")
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def list_users(request): async def list_users(request: Request) -> JSONResponse:
async with db.new_connection() as conn: async with db.new_connection() as conn:
users = await db.get_all(conn, User) users = await db.get_all(conn, User)
@ -404,7 +410,7 @@ async def list_users(request):
@route("/users", methods=["POST"]) @route("/users", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def add_user(request): async def add_user(request: Request) -> JSONResponse:
name, imdb_id = await json_from_body(request, ["name", "imdb_id"]) name, imdb_id = await json_from_body(request, ["name", "imdb_id"])
# XXX restrict name # XXX restrict name
@ -426,7 +432,7 @@ async def add_user(request):
@route("/users/{user_id}") @route("/users/{user_id}")
@requires(["authenticated"]) @requires(["authenticated"])
async def show_user(request): async def show_user(request: Request) -> JSONResponse:
user_id = as_ulid(request.path_params["user_id"]) user_id = as_ulid(request.path_params["user_id"])
if is_admin(request): if is_admin(request):
@ -455,7 +461,7 @@ async def show_user(request):
@route("/users/{user_id}", methods=["DELETE"]) @route("/users/{user_id}", methods=["DELETE"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def remove_user(request): async def remove_user(request: Request) -> JSONResponse:
user_id = as_ulid(request.path_params["user_id"]) user_id = as_ulid(request.path_params["user_id"])
async with db.new_connection() as conn: async with db.new_connection() as conn:
@ -473,7 +479,7 @@ async def remove_user(request):
@route("/users/{user_id}", methods=["PATCH"]) @route("/users/{user_id}", methods=["PATCH"])
@requires(["authenticated"]) @requires(["authenticated"])
async def modify_user(request): async def modify_user(request: Request) -> JSONResponse:
user_id = as_ulid(request.path_params["user_id"]) user_id = as_ulid(request.path_params["user_id"])
if is_admin(request): if is_admin(request):
@ -520,9 +526,13 @@ async def modify_user(request):
return JSONResponse(asplain(user)) return JSONResponse(asplain(user))
def is_valid_access(x: Any) -> TypeGuard[Access]:
return isinstance(x, str) and x in set("riw")
@route("/users/{user_id}/groups", methods=["POST"]) @route("/users/{user_id}/groups", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def add_group_to_user(request): async def add_group_to_user(request: Request) -> JSONResponse:
user_id = as_ulid(request.path_params["user_id"]) user_id = as_ulid(request.path_params["user_id"])
async with db.new_connection() as conn: async with db.new_connection() as conn:
@ -537,7 +547,7 @@ async def add_group_to_user(request):
if not group: if not group:
return not_found("Group not found") return not_found("Group not found")
if access not in set("riw"): if not is_valid_access(access):
raise HTTPException(422, "Invalid access level.") raise HTTPException(422, "Invalid access level.")
user.set_access(group_id, access) user.set_access(group_id, access)
@ -549,19 +559,19 @@ async def add_group_to_user(request):
@route("/users/{user_id}/ratings") @route("/users/{user_id}/ratings")
@requires(["private"]) @requires(["private"])
async def ratings_for_user(request): async def ratings_for_user(request: Request) -> JSONResponse:
not_implemented() not_implemented()
@route("/users/{user_id}/ratings", methods=["PUT"]) @route("/users/{user_id}/ratings", methods=["PUT"])
@requires("authenticated") @requires("authenticated")
async def set_rating_for_user(request): async def set_rating_for_user(request: Request) -> JSONResponse:
not_implemented() not_implemented()
@route("/users/_reload_ratings", methods=["POST"]) @route("/users/_reload_ratings", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def load_imdb_user_ratings(request): async def load_imdb_user_ratings(request: Request) -> JSONResponse:
ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()] ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()]
return JSONResponse({"new_ratings": [asplain(r) for r in ratings]}) return JSONResponse({"new_ratings": [asplain(r) for r in ratings]})
@ -569,7 +579,7 @@ async def load_imdb_user_ratings(request):
@route("/groups") @route("/groups")
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def list_groups(request): async def list_groups(request: Request) -> JSONResponse:
async with db.new_connection() as conn: async with db.new_connection() as conn:
groups = await db.get_all(conn, Group) groups = await db.get_all(conn, Group)
@ -578,7 +588,7 @@ async def list_groups(request):
@route("/groups", methods=["POST"]) @route("/groups", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def add_group(request): async def add_group(request: Request) -> JSONResponse:
(name,) = await json_from_body(request, ["name"]) (name,) = await json_from_body(request, ["name"])
# XXX restrict name # XXX restrict name
@ -592,7 +602,7 @@ async def add_group(request):
@route("/groups/{group_id}/users", methods=["POST"]) @route("/groups/{group_id}/users", methods=["POST"])
@requires(["authenticated"]) @requires(["authenticated"])
async def add_user_to_group(request): async def add_user_to_group(request: Request) -> JSONResponse:
group_id = as_ulid(request.path_params["group_id"]) group_id = as_ulid(request.path_params["group_id"])
async with db.new_connection() as conn: async with db.new_connection() as conn:
group = await db.get(conn, Group, id=str(group_id)) group = await db.get(conn, Group, id=str(group_id))
@ -628,11 +638,12 @@ async def add_user_to_group(request):
return JSONResponse(asplain(group)) return JSONResponse(asplain(group))
async def http_exception(request, exc): async def http_exception(request: Request, exc: Exception) -> JSONResponse:
assert isinstance(exc, HTTPException)
return JSONResponse({"error": exc.detail}, status_code=exc.status_code) return JSONResponse({"error": exc.detail}, status_code=exc.status_code)
def auth_error(request, err): def auth_error(conn: HTTPConnection, err: Exception) -> JSONResponse:
return unauthorized(str(err)) return unauthorized(str(err))

View file

@ -3,8 +3,8 @@ from typing import Container, Iterable
from . import imdb, models from . import imdb, models
URL = str type URL = str
Score100 = int # [0, 100] type Score100 = int # [0, 100]
@dataclass @dataclass