chore: add more typing info
This commit is contained in:
parent
76a69b6340
commit
f0f69c1954
2 changed files with 47 additions and 36 deletions
|
|
@ -3,7 +3,7 @@ import contextlib
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
from json.decoder import JSONDecodeError
|
from json.decoder import JSONDecodeError
|
||||||
from typing import Literal, overload
|
from typing import Any, Literal, Never, TypeGuard, overload
|
||||||
|
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.authentication import (
|
from starlette.authentication import (
|
||||||
|
|
@ -20,14 +20,14 @@ from starlette.middleware import Middleware
|
||||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||||
from starlette.middleware.cors import CORSMiddleware
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
from starlette.middleware.gzip import GZipMiddleware
|
from starlette.middleware.gzip import GZipMiddleware
|
||||||
from starlette.requests import HTTPConnection
|
from starlette.requests import HTTPConnection, Request
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from starlette.routing import Mount, Route
|
from starlette.routing import Mount, Route
|
||||||
|
|
||||||
from . import config, db, imdb, imdb_import, web_models
|
from . import config, db, imdb, imdb_import, web_models
|
||||||
from .db import close_connection_pool, find_movies, find_ratings, open_connection_pool
|
from .db import close_connection_pool, find_movies, find_ratings, open_connection_pool
|
||||||
from .middleware.responsetime import ResponseTimeMiddleware
|
from .middleware.responsetime import ResponseTimeMiddleware
|
||||||
from .models import Group, Movie, User, asplain
|
from .models import JSON, Access, Group, Movie, User, asplain
|
||||||
from .types import ULID
|
from .types import ULID
|
||||||
from .utils import b64decode, b64encode, phc_compare, phc_scrypt
|
from .utils import b64decode, b64encode, phc_compare, phc_scrypt
|
||||||
|
|
||||||
|
|
@ -83,11 +83,11 @@ class BearerAuthBackend(AuthenticationBackend):
|
||||||
return AuthCredentials(["authenticated", *roles]), user
|
return AuthCredentials(["authenticated", *roles]), user
|
||||||
|
|
||||||
|
|
||||||
def truthy(s: str):
|
def truthy(s: str | None) -> bool:
|
||||||
return bool(s) and s.lower() in {"1", "yes", "true"}
|
return bool(s) and s.lower() in {"1", "yes", "true"}
|
||||||
|
|
||||||
|
|
||||||
_Yearcomp = Literal["<", "=", ">"]
|
type _Yearcomp = Literal["<", "=", ">"]
|
||||||
|
|
||||||
|
|
||||||
def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
|
def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
|
||||||
|
|
@ -103,7 +103,7 @@ def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
|
||||||
|
|
||||||
|
|
||||||
def as_int(
|
def as_int(
|
||||||
x, *, max: int | None = None, min: int | None = 1, default: int | None = None
|
x: Any, *, max: int | None = None, min: int | None = 1, default: int | None = None
|
||||||
) -> int:
|
) -> int:
|
||||||
try:
|
try:
|
||||||
if not isinstance(x, int):
|
if not isinstance(x, int):
|
||||||
|
|
@ -121,9 +121,9 @@ def as_int(
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
def as_ulid(s: str) -> ULID:
|
def as_ulid(s: Any) -> ULID:
|
||||||
try:
|
try:
|
||||||
if not s:
|
if not isinstance(s, str) or not s:
|
||||||
raise ValueError("Invalid ULID.")
|
raise ValueError("Invalid ULID.")
|
||||||
|
|
||||||
return ULID(s)
|
return ULID(s)
|
||||||
|
|
@ -133,14 +133,17 @@ def as_ulid(s: str) -> ULID:
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
async def json_from_body(request) -> dict: ...
|
async def json_from_body(request: Request) -> dict[str, JSON]: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
async def json_from_body(request, keys: list[str]) -> list: ...
|
async def json_from_body(request: Request, keys: list[str]) -> list[JSON]: ...
|
||||||
|
|
||||||
|
|
||||||
async def json_from_body(request, keys: list[str] | None = None):
|
async def json_from_body(
|
||||||
|
request: Request, keys: list[str] | None = None
|
||||||
|
) -> dict[str, JSON] | list[JSON]:
|
||||||
|
data: dict[str, JSON]
|
||||||
if not await request.body():
|
if not await request.body():
|
||||||
data = {}
|
data = {}
|
||||||
|
|
||||||
|
|
@ -150,6 +153,9 @@ async def json_from_body(request, keys: list[str] | None = None):
|
||||||
except JSONDecodeError as err:
|
except JSONDecodeError as err:
|
||||||
raise HTTPException(422, "Invalid JSON content.") from err
|
raise HTTPException(422, "Invalid JSON content.") from err
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise HTTPException(422, f"Invalid JSON type: {type(data)!a}")
|
||||||
|
|
||||||
if not keys:
|
if not keys:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
@ -159,11 +165,11 @@ async def json_from_body(request, keys: list[str] | None = None):
|
||||||
raise HTTPException(422, f"Missing data for key: {err.args[0]}") from err
|
raise HTTPException(422, f"Missing data for key: {err.args[0]}") from err
|
||||||
|
|
||||||
|
|
||||||
def is_admin(request):
|
def is_admin(request: Request) -> bool:
|
||||||
return "admin" in request.auth.scopes
|
return "admin" in request.auth.scopes
|
||||||
|
|
||||||
|
|
||||||
async def auth_user(request) -> User | None:
|
async def auth_user(request: Request) -> User | None:
|
||||||
if not isinstance(request.user, AuthedUser):
|
if not isinstance(request.user, AuthedUser):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -192,7 +198,7 @@ def route(path: str, *, methods: list[str] | None = None, **kwds):
|
||||||
|
|
||||||
|
|
||||||
@route("/groups/{group_id}/ratings")
|
@route("/groups/{group_id}/ratings")
|
||||||
async def get_ratings_for_group(request):
|
async def get_ratings_for_group(request: Request) -> JSONResponse:
|
||||||
group_id = as_ulid(request.path_params["group_id"])
|
group_id = as_ulid(request.path_params["group_id"])
|
||||||
|
|
||||||
async with db.new_connection() as conn:
|
async with db.new_connection() as conn:
|
||||||
|
|
@ -250,13 +256,13 @@ def not_found(reason: str = "Not Found"):
|
||||||
return JSONResponse({"error": reason}, status_code=404)
|
return JSONResponse({"error": reason}, status_code=404)
|
||||||
|
|
||||||
|
|
||||||
def not_implemented():
|
def not_implemented() -> Never:
|
||||||
raise HTTPException(404, "Not yet implemented.")
|
raise HTTPException(404, "Not yet implemented.")
|
||||||
|
|
||||||
|
|
||||||
@route("/movies")
|
@route("/movies")
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def list_movies(request):
|
async def list_movies(request: Request) -> JSONResponse:
|
||||||
params = request.query_params
|
params = request.query_params
|
||||||
|
|
||||||
user = await auth_user(request)
|
user = await auth_user(request)
|
||||||
|
|
@ -329,13 +335,13 @@ async def list_movies(request):
|
||||||
|
|
||||||
@route("/movies", methods=["POST"])
|
@route("/movies", methods=["POST"])
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def add_movie(request):
|
async def add_movie(request: Request) -> JSONResponse:
|
||||||
not_implemented()
|
not_implemented()
|
||||||
|
|
||||||
|
|
||||||
@route("/movies/_reload_imdb", methods=["GET"])
|
@route("/movies/_reload_imdb", methods=["GET"])
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def progress_for_load_imdb_movies(request):
|
async def progress_for_load_imdb_movies(request: Request) -> JSONResponse:
|
||||||
async with db.new_connection() as conn:
|
async with db.new_connection() as conn:
|
||||||
progress = await db.get_import_progress(conn)
|
progress = await db.get_import_progress(conn)
|
||||||
if not progress:
|
if not progress:
|
||||||
|
|
@ -371,7 +377,7 @@ _import_lock = asyncio.Lock()
|
||||||
|
|
||||||
@route("/movies/_reload_imdb", methods=["POST"])
|
@route("/movies/_reload_imdb", methods=["POST"])
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def load_imdb_movies(request):
|
async def load_imdb_movies(request: Request) -> JSONResponse:
|
||||||
params = request.query_params
|
params = request.query_params
|
||||||
force = truthy(params.get("force"))
|
force = truthy(params.get("force"))
|
||||||
|
|
||||||
|
|
@ -395,7 +401,7 @@ async def load_imdb_movies(request):
|
||||||
|
|
||||||
@route("/users")
|
@route("/users")
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def list_users(request):
|
async def list_users(request: Request) -> JSONResponse:
|
||||||
async with db.new_connection() as conn:
|
async with db.new_connection() as conn:
|
||||||
users = await db.get_all(conn, User)
|
users = await db.get_all(conn, User)
|
||||||
|
|
||||||
|
|
@ -404,7 +410,7 @@ async def list_users(request):
|
||||||
|
|
||||||
@route("/users", methods=["POST"])
|
@route("/users", methods=["POST"])
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def add_user(request):
|
async def add_user(request: Request) -> JSONResponse:
|
||||||
name, imdb_id = await json_from_body(request, ["name", "imdb_id"])
|
name, imdb_id = await json_from_body(request, ["name", "imdb_id"])
|
||||||
|
|
||||||
# XXX restrict name
|
# XXX restrict name
|
||||||
|
|
@ -426,7 +432,7 @@ async def add_user(request):
|
||||||
|
|
||||||
@route("/users/{user_id}")
|
@route("/users/{user_id}")
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def show_user(request):
|
async def show_user(request: Request) -> JSONResponse:
|
||||||
user_id = as_ulid(request.path_params["user_id"])
|
user_id = as_ulid(request.path_params["user_id"])
|
||||||
|
|
||||||
if is_admin(request):
|
if is_admin(request):
|
||||||
|
|
@ -455,7 +461,7 @@ async def show_user(request):
|
||||||
|
|
||||||
@route("/users/{user_id}", methods=["DELETE"])
|
@route("/users/{user_id}", methods=["DELETE"])
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def remove_user(request):
|
async def remove_user(request: Request) -> JSONResponse:
|
||||||
user_id = as_ulid(request.path_params["user_id"])
|
user_id = as_ulid(request.path_params["user_id"])
|
||||||
|
|
||||||
async with db.new_connection() as conn:
|
async with db.new_connection() as conn:
|
||||||
|
|
@ -473,7 +479,7 @@ async def remove_user(request):
|
||||||
|
|
||||||
@route("/users/{user_id}", methods=["PATCH"])
|
@route("/users/{user_id}", methods=["PATCH"])
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def modify_user(request):
|
async def modify_user(request: Request) -> JSONResponse:
|
||||||
user_id = as_ulid(request.path_params["user_id"])
|
user_id = as_ulid(request.path_params["user_id"])
|
||||||
|
|
||||||
if is_admin(request):
|
if is_admin(request):
|
||||||
|
|
@ -520,9 +526,13 @@ async def modify_user(request):
|
||||||
return JSONResponse(asplain(user))
|
return JSONResponse(asplain(user))
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_access(x: Any) -> TypeGuard[Access]:
|
||||||
|
return isinstance(x, str) and x in set("riw")
|
||||||
|
|
||||||
|
|
||||||
@route("/users/{user_id}/groups", methods=["POST"])
|
@route("/users/{user_id}/groups", methods=["POST"])
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def add_group_to_user(request):
|
async def add_group_to_user(request: Request) -> JSONResponse:
|
||||||
user_id = as_ulid(request.path_params["user_id"])
|
user_id = as_ulid(request.path_params["user_id"])
|
||||||
|
|
||||||
async with db.new_connection() as conn:
|
async with db.new_connection() as conn:
|
||||||
|
|
@ -537,7 +547,7 @@ async def add_group_to_user(request):
|
||||||
if not group:
|
if not group:
|
||||||
return not_found("Group not found")
|
return not_found("Group not found")
|
||||||
|
|
||||||
if access not in set("riw"):
|
if not is_valid_access(access):
|
||||||
raise HTTPException(422, "Invalid access level.")
|
raise HTTPException(422, "Invalid access level.")
|
||||||
|
|
||||||
user.set_access(group_id, access)
|
user.set_access(group_id, access)
|
||||||
|
|
@ -549,19 +559,19 @@ async def add_group_to_user(request):
|
||||||
|
|
||||||
@route("/users/{user_id}/ratings")
|
@route("/users/{user_id}/ratings")
|
||||||
@requires(["private"])
|
@requires(["private"])
|
||||||
async def ratings_for_user(request):
|
async def ratings_for_user(request: Request) -> JSONResponse:
|
||||||
not_implemented()
|
not_implemented()
|
||||||
|
|
||||||
|
|
||||||
@route("/users/{user_id}/ratings", methods=["PUT"])
|
@route("/users/{user_id}/ratings", methods=["PUT"])
|
||||||
@requires("authenticated")
|
@requires("authenticated")
|
||||||
async def set_rating_for_user(request):
|
async def set_rating_for_user(request: Request) -> JSONResponse:
|
||||||
not_implemented()
|
not_implemented()
|
||||||
|
|
||||||
|
|
||||||
@route("/users/_reload_ratings", methods=["POST"])
|
@route("/users/_reload_ratings", methods=["POST"])
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def load_imdb_user_ratings(request):
|
async def load_imdb_user_ratings(request: Request) -> JSONResponse:
|
||||||
ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()]
|
ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()]
|
||||||
|
|
||||||
return JSONResponse({"new_ratings": [asplain(r) for r in ratings]})
|
return JSONResponse({"new_ratings": [asplain(r) for r in ratings]})
|
||||||
|
|
@ -569,7 +579,7 @@ async def load_imdb_user_ratings(request):
|
||||||
|
|
||||||
@route("/groups")
|
@route("/groups")
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def list_groups(request):
|
async def list_groups(request: Request) -> JSONResponse:
|
||||||
async with db.new_connection() as conn:
|
async with db.new_connection() as conn:
|
||||||
groups = await db.get_all(conn, Group)
|
groups = await db.get_all(conn, Group)
|
||||||
|
|
||||||
|
|
@ -578,7 +588,7 @@ async def list_groups(request):
|
||||||
|
|
||||||
@route("/groups", methods=["POST"])
|
@route("/groups", methods=["POST"])
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def add_group(request):
|
async def add_group(request: Request) -> JSONResponse:
|
||||||
(name,) = await json_from_body(request, ["name"])
|
(name,) = await json_from_body(request, ["name"])
|
||||||
|
|
||||||
# XXX restrict name
|
# XXX restrict name
|
||||||
|
|
@ -592,7 +602,7 @@ async def add_group(request):
|
||||||
|
|
||||||
@route("/groups/{group_id}/users", methods=["POST"])
|
@route("/groups/{group_id}/users", methods=["POST"])
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def add_user_to_group(request):
|
async def add_user_to_group(request: Request) -> JSONResponse:
|
||||||
group_id = as_ulid(request.path_params["group_id"])
|
group_id = as_ulid(request.path_params["group_id"])
|
||||||
async with db.new_connection() as conn:
|
async with db.new_connection() as conn:
|
||||||
group = await db.get(conn, Group, id=str(group_id))
|
group = await db.get(conn, Group, id=str(group_id))
|
||||||
|
|
@ -628,11 +638,12 @@ async def add_user_to_group(request):
|
||||||
return JSONResponse(asplain(group))
|
return JSONResponse(asplain(group))
|
||||||
|
|
||||||
|
|
||||||
async def http_exception(request, exc):
|
async def http_exception(request: Request, exc: Exception) -> JSONResponse:
|
||||||
|
assert isinstance(exc, HTTPException)
|
||||||
return JSONResponse({"error": exc.detail}, status_code=exc.status_code)
|
return JSONResponse({"error": exc.detail}, status_code=exc.status_code)
|
||||||
|
|
||||||
|
|
||||||
def auth_error(request, err):
|
def auth_error(conn: HTTPConnection, err: Exception) -> JSONResponse:
|
||||||
return unauthorized(str(err))
|
return unauthorized(str(err))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,8 @@ from typing import Container, Iterable
|
||||||
|
|
||||||
from . import imdb, models
|
from . import imdb, models
|
||||||
|
|
||||||
URL = str
|
type URL = str
|
||||||
Score100 = int # [0, 100]
|
type Score100 = int # [0, 100]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue