chore: add more typing info
This commit is contained in:
parent
76a69b6340
commit
f0f69c1954
2 changed files with 47 additions and 36 deletions
|
|
@ -3,7 +3,7 @@ import contextlib
|
|||
import logging
|
||||
import secrets
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Literal, overload
|
||||
from typing import Any, Literal, Never, TypeGuard, overload
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.authentication import (
|
||||
|
|
@ -20,14 +20,14 @@ from starlette.middleware import Middleware
|
|||
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware.gzip import GZipMiddleware
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
from . import config, db, imdb, imdb_import, web_models
|
||||
from .db import close_connection_pool, find_movies, find_ratings, open_connection_pool
|
||||
from .middleware.responsetime import ResponseTimeMiddleware
|
||||
from .models import Group, Movie, User, asplain
|
||||
from .models import JSON, Access, Group, Movie, User, asplain
|
||||
from .types import ULID
|
||||
from .utils import b64decode, b64encode, phc_compare, phc_scrypt
|
||||
|
||||
|
|
@ -83,11 +83,11 @@ class BearerAuthBackend(AuthenticationBackend):
|
|||
return AuthCredentials(["authenticated", *roles]), user
|
||||
|
||||
|
||||
def truthy(s: str):
|
||||
def truthy(s: str | None) -> bool:
|
||||
return bool(s) and s.lower() in {"1", "yes", "true"}
|
||||
|
||||
|
||||
_Yearcomp = Literal["<", "=", ">"]
|
||||
type _Yearcomp = Literal["<", "=", ">"]
|
||||
|
||||
|
||||
def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
|
||||
|
|
@ -103,7 +103,7 @@ def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
|
|||
|
||||
|
||||
def as_int(
|
||||
x, *, max: int | None = None, min: int | None = 1, default: int | None = None
|
||||
x: Any, *, max: int | None = None, min: int | None = 1, default: int | None = None
|
||||
) -> int:
|
||||
try:
|
||||
if not isinstance(x, int):
|
||||
|
|
@ -121,9 +121,9 @@ def as_int(
|
|||
return default
|
||||
|
||||
|
||||
def as_ulid(s: str) -> ULID:
|
||||
def as_ulid(s: Any) -> ULID:
|
||||
try:
|
||||
if not s:
|
||||
if not isinstance(s, str) or not s:
|
||||
raise ValueError("Invalid ULID.")
|
||||
|
||||
return ULID(s)
|
||||
|
|
@ -133,14 +133,17 @@ def as_ulid(s: str) -> ULID:
|
|||
|
||||
|
||||
@overload
|
||||
async def json_from_body(request) -> dict: ...
|
||||
async def json_from_body(request: Request) -> dict[str, JSON]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def json_from_body(request, keys: list[str]) -> list: ...
|
||||
async def json_from_body(request: Request, keys: list[str]) -> list[JSON]: ...
|
||||
|
||||
|
||||
async def json_from_body(request, keys: list[str] | None = None):
|
||||
async def json_from_body(
|
||||
request: Request, keys: list[str] | None = None
|
||||
) -> dict[str, JSON] | list[JSON]:
|
||||
data: dict[str, JSON]
|
||||
if not await request.body():
|
||||
data = {}
|
||||
|
||||
|
|
@ -150,6 +153,9 @@ async def json_from_body(request, keys: list[str] | None = None):
|
|||
except JSONDecodeError as err:
|
||||
raise HTTPException(422, "Invalid JSON content.") from err
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise HTTPException(422, f"Invalid JSON type: {type(data)!a}")
|
||||
|
||||
if not keys:
|
||||
return data
|
||||
|
||||
|
|
@ -159,11 +165,11 @@ async def json_from_body(request, keys: list[str] | None = None):
|
|||
raise HTTPException(422, f"Missing data for key: {err.args[0]}") from err
|
||||
|
||||
|
||||
def is_admin(request):
|
||||
def is_admin(request: Request) -> bool:
|
||||
return "admin" in request.auth.scopes
|
||||
|
||||
|
||||
async def auth_user(request) -> User | None:
|
||||
async def auth_user(request: Request) -> User | None:
|
||||
if not isinstance(request.user, AuthedUser):
|
||||
return
|
||||
|
||||
|
|
@ -192,7 +198,7 @@ def route(path: str, *, methods: list[str] | None = None, **kwds):
|
|||
|
||||
|
||||
@route("/groups/{group_id}/ratings")
|
||||
async def get_ratings_for_group(request):
|
||||
async def get_ratings_for_group(request: Request) -> JSONResponse:
|
||||
group_id = as_ulid(request.path_params["group_id"])
|
||||
|
||||
async with db.new_connection() as conn:
|
||||
|
|
@ -250,13 +256,13 @@ def not_found(reason: str = "Not Found"):
|
|||
return JSONResponse({"error": reason}, status_code=404)
|
||||
|
||||
|
||||
def not_implemented():
|
||||
def not_implemented() -> Never:
|
||||
raise HTTPException(404, "Not yet implemented.")
|
||||
|
||||
|
||||
@route("/movies")
|
||||
@requires(["authenticated"])
|
||||
async def list_movies(request):
|
||||
async def list_movies(request: Request) -> JSONResponse:
|
||||
params = request.query_params
|
||||
|
||||
user = await auth_user(request)
|
||||
|
|
@ -329,13 +335,13 @@ async def list_movies(request):
|
|||
|
||||
@route("/movies", methods=["POST"])
|
||||
@requires(["authenticated", "admin"])
|
||||
async def add_movie(request):
|
||||
async def add_movie(request: Request) -> JSONResponse:
|
||||
not_implemented()
|
||||
|
||||
|
||||
@route("/movies/_reload_imdb", methods=["GET"])
|
||||
@requires(["authenticated", "admin"])
|
||||
async def progress_for_load_imdb_movies(request):
|
||||
async def progress_for_load_imdb_movies(request: Request) -> JSONResponse:
|
||||
async with db.new_connection() as conn:
|
||||
progress = await db.get_import_progress(conn)
|
||||
if not progress:
|
||||
|
|
@ -371,7 +377,7 @@ _import_lock = asyncio.Lock()
|
|||
|
||||
@route("/movies/_reload_imdb", methods=["POST"])
|
||||
@requires(["authenticated", "admin"])
|
||||
async def load_imdb_movies(request):
|
||||
async def load_imdb_movies(request: Request) -> JSONResponse:
|
||||
params = request.query_params
|
||||
force = truthy(params.get("force"))
|
||||
|
||||
|
|
@ -395,7 +401,7 @@ async def load_imdb_movies(request):
|
|||
|
||||
@route("/users")
|
||||
@requires(["authenticated", "admin"])
|
||||
async def list_users(request):
|
||||
async def list_users(request: Request) -> JSONResponse:
|
||||
async with db.new_connection() as conn:
|
||||
users = await db.get_all(conn, User)
|
||||
|
||||
|
|
@ -404,7 +410,7 @@ async def list_users(request):
|
|||
|
||||
@route("/users", methods=["POST"])
|
||||
@requires(["authenticated", "admin"])
|
||||
async def add_user(request):
|
||||
async def add_user(request: Request) -> JSONResponse:
|
||||
name, imdb_id = await json_from_body(request, ["name", "imdb_id"])
|
||||
|
||||
# XXX restrict name
|
||||
|
|
@ -426,7 +432,7 @@ async def add_user(request):
|
|||
|
||||
@route("/users/{user_id}")
|
||||
@requires(["authenticated"])
|
||||
async def show_user(request):
|
||||
async def show_user(request: Request) -> JSONResponse:
|
||||
user_id = as_ulid(request.path_params["user_id"])
|
||||
|
||||
if is_admin(request):
|
||||
|
|
@ -455,7 +461,7 @@ async def show_user(request):
|
|||
|
||||
@route("/users/{user_id}", methods=["DELETE"])
|
||||
@requires(["authenticated", "admin"])
|
||||
async def remove_user(request):
|
||||
async def remove_user(request: Request) -> JSONResponse:
|
||||
user_id = as_ulid(request.path_params["user_id"])
|
||||
|
||||
async with db.new_connection() as conn:
|
||||
|
|
@ -473,7 +479,7 @@ async def remove_user(request):
|
|||
|
||||
@route("/users/{user_id}", methods=["PATCH"])
|
||||
@requires(["authenticated"])
|
||||
async def modify_user(request):
|
||||
async def modify_user(request: Request) -> JSONResponse:
|
||||
user_id = as_ulid(request.path_params["user_id"])
|
||||
|
||||
if is_admin(request):
|
||||
|
|
@ -520,9 +526,13 @@ async def modify_user(request):
|
|||
return JSONResponse(asplain(user))
|
||||
|
||||
|
||||
def is_valid_access(x: Any) -> TypeGuard[Access]:
|
||||
return isinstance(x, str) and x in set("riw")
|
||||
|
||||
|
||||
@route("/users/{user_id}/groups", methods=["POST"])
|
||||
@requires(["authenticated", "admin"])
|
||||
async def add_group_to_user(request):
|
||||
async def add_group_to_user(request: Request) -> JSONResponse:
|
||||
user_id = as_ulid(request.path_params["user_id"])
|
||||
|
||||
async with db.new_connection() as conn:
|
||||
|
|
@ -537,7 +547,7 @@ async def add_group_to_user(request):
|
|||
if not group:
|
||||
return not_found("Group not found")
|
||||
|
||||
if access not in set("riw"):
|
||||
if not is_valid_access(access):
|
||||
raise HTTPException(422, "Invalid access level.")
|
||||
|
||||
user.set_access(group_id, access)
|
||||
|
|
@ -549,19 +559,19 @@ async def add_group_to_user(request):
|
|||
|
||||
@route("/users/{user_id}/ratings")
|
||||
@requires(["private"])
|
||||
async def ratings_for_user(request):
|
||||
async def ratings_for_user(request: Request) -> JSONResponse:
|
||||
not_implemented()
|
||||
|
||||
|
||||
@route("/users/{user_id}/ratings", methods=["PUT"])
|
||||
@requires("authenticated")
|
||||
async def set_rating_for_user(request):
|
||||
async def set_rating_for_user(request: Request) -> JSONResponse:
|
||||
not_implemented()
|
||||
|
||||
|
||||
@route("/users/_reload_ratings", methods=["POST"])
|
||||
@requires(["authenticated", "admin"])
|
||||
async def load_imdb_user_ratings(request):
|
||||
async def load_imdb_user_ratings(request: Request) -> JSONResponse:
|
||||
ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()]
|
||||
|
||||
return JSONResponse({"new_ratings": [asplain(r) for r in ratings]})
|
||||
|
|
@ -569,7 +579,7 @@ async def load_imdb_user_ratings(request):
|
|||
|
||||
@route("/groups")
|
||||
@requires(["authenticated", "admin"])
|
||||
async def list_groups(request):
|
||||
async def list_groups(request: Request) -> JSONResponse:
|
||||
async with db.new_connection() as conn:
|
||||
groups = await db.get_all(conn, Group)
|
||||
|
||||
|
|
@ -578,7 +588,7 @@ async def list_groups(request):
|
|||
|
||||
@route("/groups", methods=["POST"])
|
||||
@requires(["authenticated", "admin"])
|
||||
async def add_group(request):
|
||||
async def add_group(request: Request) -> JSONResponse:
|
||||
(name,) = await json_from_body(request, ["name"])
|
||||
|
||||
# XXX restrict name
|
||||
|
|
@ -592,7 +602,7 @@ async def add_group(request):
|
|||
|
||||
@route("/groups/{group_id}/users", methods=["POST"])
|
||||
@requires(["authenticated"])
|
||||
async def add_user_to_group(request):
|
||||
async def add_user_to_group(request: Request) -> JSONResponse:
|
||||
group_id = as_ulid(request.path_params["group_id"])
|
||||
async with db.new_connection() as conn:
|
||||
group = await db.get(conn, Group, id=str(group_id))
|
||||
|
|
@ -628,11 +638,12 @@ async def add_user_to_group(request):
|
|||
return JSONResponse(asplain(group))
|
||||
|
||||
|
||||
async def http_exception(request, exc):
|
||||
async def http_exception(request: Request, exc: Exception) -> JSONResponse:
|
||||
assert isinstance(exc, HTTPException)
|
||||
return JSONResponse({"error": exc.detail}, status_code=exc.status_code)
|
||||
|
||||
|
||||
def auth_error(request, err):
|
||||
def auth_error(conn: HTTPConnection, err: Exception) -> JSONResponse:
|
||||
return unauthorized(str(err))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ from typing import Container, Iterable
|
|||
|
||||
from . import imdb, models
|
||||
|
||||
URL = str
|
||||
Score100 = int # [0, 100]
|
||||
type URL = str
|
||||
type Score100 = int # [0, 100]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue