chore: add more typing info

This commit is contained in:
ducklet 2024-05-19 11:10:08 +02:00
parent 76a69b6340
commit f0f69c1954
2 changed files with 47 additions and 36 deletions

View file

@ -3,7 +3,7 @@ import contextlib
import logging
import secrets
from json.decoder import JSONDecodeError
from typing import Literal, overload
from typing import Any, Literal, Never, TypeGuard, overload
from starlette.applications import Starlette
from starlette.authentication import (
@ -20,14 +20,14 @@ from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.requests import HTTPConnection
from starlette.requests import HTTPConnection, Request
from starlette.responses import JSONResponse
from starlette.routing import Mount, Route
from . import config, db, imdb, imdb_import, web_models
from .db import close_connection_pool, find_movies, find_ratings, open_connection_pool
from .middleware.responsetime import ResponseTimeMiddleware
from .models import Group, Movie, User, asplain
from .models import JSON, Access, Group, Movie, User, asplain
from .types import ULID
from .utils import b64decode, b64encode, phc_compare, phc_scrypt
@ -83,11 +83,11 @@ class BearerAuthBackend(AuthenticationBackend):
return AuthCredentials(["authenticated", *roles]), user
def truthy(s: str):
def truthy(s: str | None) -> bool:
return bool(s) and s.lower() in {"1", "yes", "true"}
_Yearcomp = Literal["<", "=", ">"]
type _Yearcomp = Literal["<", "=", ">"]
def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
@ -103,7 +103,7 @@ def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
def as_int(
x, *, max: int | None = None, min: int | None = 1, default: int | None = None
x: Any, *, max: int | None = None, min: int | None = 1, default: int | None = None
) -> int:
try:
if not isinstance(x, int):
@ -121,9 +121,9 @@ def as_int(
return default
def as_ulid(s: str) -> ULID:
def as_ulid(s: Any) -> ULID:
try:
if not s:
if not isinstance(s, str) or not s:
raise ValueError("Invalid ULID.")
return ULID(s)
@ -133,14 +133,17 @@ def as_ulid(s: str) -> ULID:
@overload
async def json_from_body(request) -> dict: ...
async def json_from_body(request: Request) -> dict[str, JSON]: ...
@overload
async def json_from_body(request, keys: list[str]) -> list: ...
async def json_from_body(request: Request, keys: list[str]) -> list[JSON]: ...
async def json_from_body(request, keys: list[str] | None = None):
async def json_from_body(
request: Request, keys: list[str] | None = None
) -> dict[str, JSON] | list[JSON]:
data: dict[str, JSON]
if not await request.body():
data = {}
@ -150,6 +153,9 @@ async def json_from_body(request, keys: list[str] | None = None):
except JSONDecodeError as err:
raise HTTPException(422, "Invalid JSON content.") from err
if not isinstance(data, dict):
raise HTTPException(422, f"Invalid JSON type: {type(data)!a}")
if not keys:
return data
@ -159,11 +165,11 @@ async def json_from_body(request, keys: list[str] | None = None):
raise HTTPException(422, f"Missing data for key: {err.args[0]}") from err
def is_admin(request):
def is_admin(request: Request) -> bool:
return "admin" in request.auth.scopes
async def auth_user(request) -> User | None:
async def auth_user(request: Request) -> User | None:
if not isinstance(request.user, AuthedUser):
return
@ -192,7 +198,7 @@ def route(path: str, *, methods: list[str] | None = None, **kwds):
@route("/groups/{group_id}/ratings")
async def get_ratings_for_group(request):
async def get_ratings_for_group(request: Request) -> JSONResponse:
group_id = as_ulid(request.path_params["group_id"])
async with db.new_connection() as conn:
@ -250,13 +256,13 @@ def not_found(reason: str = "Not Found"):
return JSONResponse({"error": reason}, status_code=404)
def not_implemented():
def not_implemented() -> Never:
raise HTTPException(404, "Not yet implemented.")
@route("/movies")
@requires(["authenticated"])
async def list_movies(request):
async def list_movies(request: Request) -> JSONResponse:
params = request.query_params
user = await auth_user(request)
@ -329,13 +335,13 @@ async def list_movies(request):
@route("/movies", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_movie(request):
async def add_movie(request: Request) -> JSONResponse:
not_implemented()
@route("/movies/_reload_imdb", methods=["GET"])
@requires(["authenticated", "admin"])
async def progress_for_load_imdb_movies(request):
async def progress_for_load_imdb_movies(request: Request) -> JSONResponse:
async with db.new_connection() as conn:
progress = await db.get_import_progress(conn)
if not progress:
@ -371,7 +377,7 @@ _import_lock = asyncio.Lock()
@route("/movies/_reload_imdb", methods=["POST"])
@requires(["authenticated", "admin"])
async def load_imdb_movies(request):
async def load_imdb_movies(request: Request) -> JSONResponse:
params = request.query_params
force = truthy(params.get("force"))
@ -395,7 +401,7 @@ async def load_imdb_movies(request):
@route("/users")
@requires(["authenticated", "admin"])
async def list_users(request):
async def list_users(request: Request) -> JSONResponse:
async with db.new_connection() as conn:
users = await db.get_all(conn, User)
@ -404,7 +410,7 @@ async def list_users(request):
@route("/users", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_user(request):
async def add_user(request: Request) -> JSONResponse:
name, imdb_id = await json_from_body(request, ["name", "imdb_id"])
# XXX restrict name
@ -426,7 +432,7 @@ async def add_user(request):
@route("/users/{user_id}")
@requires(["authenticated"])
async def show_user(request):
async def show_user(request: Request) -> JSONResponse:
user_id = as_ulid(request.path_params["user_id"])
if is_admin(request):
@ -455,7 +461,7 @@ async def show_user(request):
@route("/users/{user_id}", methods=["DELETE"])
@requires(["authenticated", "admin"])
async def remove_user(request):
async def remove_user(request: Request) -> JSONResponse:
user_id = as_ulid(request.path_params["user_id"])
async with db.new_connection() as conn:
@ -473,7 +479,7 @@ async def remove_user(request):
@route("/users/{user_id}", methods=["PATCH"])
@requires(["authenticated"])
async def modify_user(request):
async def modify_user(request: Request) -> JSONResponse:
user_id = as_ulid(request.path_params["user_id"])
if is_admin(request):
@ -520,9 +526,13 @@ async def modify_user(request):
return JSONResponse(asplain(user))
def is_valid_access(x: Any) -> TypeGuard[Access]:
return isinstance(x, str) and x in set("riw")
@route("/users/{user_id}/groups", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_group_to_user(request):
async def add_group_to_user(request: Request) -> JSONResponse:
user_id = as_ulid(request.path_params["user_id"])
async with db.new_connection() as conn:
@ -537,7 +547,7 @@ async def add_group_to_user(request):
if not group:
return not_found("Group not found")
if access not in set("riw"):
if not is_valid_access(access):
raise HTTPException(422, "Invalid access level.")
user.set_access(group_id, access)
@ -549,19 +559,19 @@ async def add_group_to_user(request):
@route("/users/{user_id}/ratings")
@requires(["private"])
async def ratings_for_user(request):
async def ratings_for_user(request: Request) -> JSONResponse:
not_implemented()
@route("/users/{user_id}/ratings", methods=["PUT"])
@requires("authenticated")
async def set_rating_for_user(request):
async def set_rating_for_user(request: Request) -> JSONResponse:
not_implemented()
@route("/users/_reload_ratings", methods=["POST"])
@requires(["authenticated", "admin"])
async def load_imdb_user_ratings(request):
async def load_imdb_user_ratings(request: Request) -> JSONResponse:
ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()]
return JSONResponse({"new_ratings": [asplain(r) for r in ratings]})
@ -569,7 +579,7 @@ async def load_imdb_user_ratings(request):
@route("/groups")
@requires(["authenticated", "admin"])
async def list_groups(request):
async def list_groups(request: Request) -> JSONResponse:
async with db.new_connection() as conn:
groups = await db.get_all(conn, Group)
@ -578,7 +588,7 @@ async def list_groups(request):
@route("/groups", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_group(request):
async def add_group(request: Request) -> JSONResponse:
(name,) = await json_from_body(request, ["name"])
# XXX restrict name
@ -592,7 +602,7 @@ async def add_group(request):
@route("/groups/{group_id}/users", methods=["POST"])
@requires(["authenticated"])
async def add_user_to_group(request):
async def add_user_to_group(request: Request) -> JSONResponse:
group_id = as_ulid(request.path_params["group_id"])
async with db.new_connection() as conn:
group = await db.get(conn, Group, id=str(group_id))
@ -628,11 +638,12 @@ async def add_user_to_group(request):
return JSONResponse(asplain(group))
async def http_exception(request, exc):
async def http_exception(request: Request, exc: Exception) -> JSONResponse:
assert isinstance(exc, HTTPException)
return JSONResponse({"error": exc.detail}, status_code=exc.status_code)
def auth_error(request, err):
def auth_error(conn: HTTPConnection, err: Exception) -> JSONResponse:
return unauthorized(str(err))

View file

@ -3,8 +3,8 @@ from typing import Container, Iterable
from . import imdb, models
URL = str
Score100 = int # [0, 100]
type URL = str
type Score100 = int # [0, 100]
@dataclass