unwind/unwind/web.py

691 lines
20 KiB
Python

import asyncio
import contextlib
import logging
import secrets
from json.decoder import JSONDecodeError
from typing import Any, Literal, Never, TypeGuard, overload
from starlette.applications import Starlette
from starlette.authentication import (
AuthCredentials,
AuthenticationBackend,
AuthenticationError,
BaseUser,
SimpleUser,
requires,
)
from starlette.background import BackgroundTask
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.requests import HTTPConnection, Request
from starlette.responses import JSONResponse
from starlette.routing import Mount, Route
from . import config, db, imdb, imdb_import, web_models
from .db import close_connection_pool, find_movies, find_ratings, open_connection_pool
from .middleware.responsetime import ResponseTimeMiddleware
from .models import Access, Group, Movie, User, asplain
from .types import JSON, ULID
from .utils import b64decode, b64encode, phc_compare, phc_scrypt
log = logging.getLogger(__name__)
# XXX we probably don't need a group secret anymore, if group access is managed
# on a user level; a group secret would be a separate user with full group
# access
class AuthedUser(BaseUser):
def __init__(self, user_id: str, secret: str):
self.user_id = user_id
self.secret = secret
class BearerAuthBackend(AuthenticationBackend):
def __init__(self, credentials: dict[str, str]):
self.admin_tokens = {v: k for k, v in credentials.items()}
async def authenticate(self, conn: HTTPConnection):
if "Authorization" not in conn.headers:
return
# XXX should we remove the auth header after reading, for security reasons?
auth = conn.headers["Authorization"]
try:
scheme, credentials = auth.split()
except ValueError as err:
raise AuthenticationError("Invalid auth credentials") from err
roles = []
if scheme.lower() == "bearer":
is_admin = credentials in self.admin_tokens
if not is_admin:
return
name = self.admin_tokens[credentials]
user = SimpleUser(name)
roles.append("admin")
elif scheme.lower() == "basic":
try:
name, secret = b64decode(credentials).decode().split(":")
except Exception as err:
raise AuthenticationError("Invalid auth credentials") from err
user = AuthedUser(name, secret)
else:
return
return AuthCredentials(["authenticated", *roles]), user
def truthy(s: str | None) -> bool:
return bool(s) and s.lower() in {"1", "yes", "true"}
type _Yearcomp = Literal["<", "=", ">"]
def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
if not s:
return
comp: _Yearcomp = "="
if (prefix := s[0]) in "<=>":
comp = prefix # type: ignore
s = s[len(prefix) :]
return comp, int(s)
def as_int(
x: Any, *, max: int | None = None, min: int | None = 1, default: int | None = None
) -> int:
try:
if not isinstance(x, int):
x = int(x)
if min is not None and x < min:
return min
if max is not None and x > max:
return max
return x
except Exception:
if default is None:
raise
return default
def as_ulid(s: Any) -> ULID:
try:
if not isinstance(s, str) or not s:
raise ValueError("Invalid ULID.")
return ULID(s)
except ValueError as err:
raise HTTPException(422, "Not a valid ULID.") from err
@overload
async def json_from_body(request: Request) -> dict[str, JSON]: ...
@overload
async def json_from_body(request: Request, keys: list[str]) -> list[JSON]: ...
async def json_from_body(
request: Request, keys: list[str] | None = None
) -> dict[str, JSON] | list[JSON]:
data: dict[str, JSON]
if not await request.body():
data = {}
else:
try:
data = await request.json()
except JSONDecodeError as err:
raise HTTPException(422, "Invalid JSON content.") from err
if not isinstance(data, dict):
raise HTTPException(422, f"Invalid JSON type: {type(data)!a}")
if not keys:
return data
try:
return [data[k] for k in keys]
except KeyError as err:
raise HTTPException(422, f"Missing data for key: {err.args[0]}") from err
def is_admin(request: Request) -> bool:
return "admin" in request.auth.scopes
async def auth_user(request: Request) -> User | None:
if not isinstance(request.user, AuthedUser):
return
async with db.new_connection() as conn:
user = await db.get(conn, User, id=request.user.user_id)
if not user:
return
is_authed = phc_compare(secret=request.user.secret, phc_string=user.secret)
if not is_authed:
return
return user
_routes: list[Route] = []
def route(path: str, *, methods: list[str] | None = None, **kwds):
def decorator(func):
r = Route(path, func, methods=methods, **kwds)
_routes.append(r)
return func
return decorator
@route("/groups/{group_id}/ratings")
async def get_ratings_for_group(request: Request) -> JSONResponse:
group_id = as_ulid(request.path_params["group_id"])
async with db.new_connection() as conn:
if (group := await db.get(conn, Group, id=str(group_id))) is None:
return not_found()
user_ids = {u["id"] for u in group.users}
params = request.query_params
imdb_id: str | None = params.get("imdb_id")
unwind_id: str | None = params.get("unwind_id")
# if (imdb_id or unwind_id) and (movie := await db.get(Movie, id=unwind_id, imdb_id=imdb_id)):
if unwind_id:
async with db.new_connection() as conn:
rows = await db.ratings_for_movie_ids(conn, ids=[unwind_id])
elif imdb_id:
async with db.new_connection() as conn:
rows = await db.ratings_for_movie_ids(conn, imdb_ids=[imdb_id])
else:
async with db.new_connection() as conn:
rows = await find_ratings(
conn,
title=params.get("title"),
media_type=params.get("media_type"),
exact=truthy(params.get("exact")),
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
include_unrated=truthy(params.get("include_unrated")),
yearcomp=yearcomp(params["year"]) if "year" in params else None,
limit_rows=as_int(params.get("per_page"), max=10, default=5),
user_ids=user_ids,
)
ratings = [web_models.Rating(**r) for r in rows]
async with db.new_connection() as conn:
awards = await db.get_awards(conn, imdb_ids=[r.movie_imdb_id for r in ratings])
aggrs = web_models.aggregate_ratings(ratings, user_ids, awards_dict=awards)
resp = tuple(asplain(r) for r in aggrs)
return JSONResponse(resp)
def unauthorized(reason: str = "Unauthorized"):
return JSONResponse({"error": reason}, status_code=401)
def forbidden(reason: str = "Forbidden"):
return JSONResponse({"error": reason}, status_code=403)
def not_found(reason: str = "Not Found"):
return JSONResponse({"error": reason}, status_code=404)
def not_implemented() -> Never:
raise HTTPException(404, "Not yet implemented.")
@route("/movies")
@requires(["authenticated"])
async def list_movies(request: Request) -> JSONResponse:
params = request.query_params
user = await auth_user(request)
user_ids = set()
if group_id := params.get("group_id"):
group_id = as_ulid(group_id)
async with db.new_connection() as conn:
group = await db.get(conn, Group, id=str(group_id))
if not group:
return not_found("Group not found.")
is_allowed = is_admin(request) or user and user.has_access(group_id)
if not is_allowed:
return forbidden("No access to group.")
user_ids |= {ULID(u["id"]) for u in group.users}
if user_id := params.get("user_id"):
user_id = as_ulid(user_id)
# Currently a user may only directly access their own ratings.
is_allowed = is_admin(request) or user and user.id == user_id
if not is_allowed:
return forbidden("No access to user.")
user_ids |= {user_id}
imdb_id = params.get("imdb_id")
unwind_id = params.get("unwind_id")
if imdb_id or unwind_id:
# XXX missing support for user_ids and user_scores
async with db.new_connection() as conn:
movies = (
[m]
if (m := await db.get(conn, Movie, id=unwind_id, imdb_id=imdb_id))
else []
)
resp = [asplain(web_models.RatingAggregate.from_movie(m)) for m in movies]
else:
per_page = as_int(params.get("per_page"), max=1000, default=5)
page = as_int(params.get("page"), min=1, default=1)
async with db.new_connection() as conn:
movieratings = await find_movies(
conn,
title=params.get("title"),
media_type=params.get("media_type"),
exact=truthy(params.get("exact")),
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
include_unrated=truthy(params.get("include_unrated")),
yearcomp=yearcomp(params["year"]) if "year" in params else None,
limit_rows=per_page,
skip_rows=(page - 1) * per_page,
user_ids=list(user_ids),
)
resp = []
for movie, ratings in movieratings:
mov = asplain(movie)
mov["user_scores"] = [rating.score for rating in ratings]
resp.append(mov)
return JSONResponse(resp)
@route("/movies", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_movie(request: Request) -> JSONResponse:
not_implemented()
@route("/movies/_reload_imdb", methods=["GET"])
@requires(["authenticated", "admin"])
async def progress_for_load_imdb_movies(request: Request) -> JSONResponse:
async with db.new_connection() as conn:
progress = await db.get_import_progress(conn)
if not progress:
return JSONResponse({"status": "No import exists."}, status_code=404)
p = asplain(progress)
percent = progress.percent
error = progress.error
status = None
if error:
status = "Error during import."
elif percent == 0.0 and progress.stopped:
status = "Import skipped."
elif percent < 100:
status = "Import is running."
else:
status = "Import finished."
resp = {
"status": status,
"progress": percent,
"error": error,
"started": p["started"],
"stopped": p["stopped"],
}
return JSONResponse(resp)
_import_lock = asyncio.Lock()
@route("/movies/_reload_imdb", methods=["POST"])
@requires(["authenticated", "admin"])
async def load_imdb_movies(request: Request) -> JSONResponse:
params = request.query_params
force = truthy(params.get("force"))
async with _import_lock:
async with db.new_connection() as conn:
progress = await db.get_import_progress(conn)
if progress and not progress.stopped:
return JSONResponse(
{"status": "Import is running.", "progress": progress.percent},
status_code=409,
)
async with db.transaction() as conn:
await db.set_import_progress(conn, 0)
task = BackgroundTask(imdb_import.load_from_web, force=force)
return JSONResponse(
{"status": "Import started.", "progress": 0.0}, background=task, status_code=202
)
@route("/users")
@requires(["authenticated", "admin"])
async def list_users(request: Request) -> JSONResponse:
async with db.new_connection() as conn:
users = await db.get_all(conn, User)
return JSONResponse([asplain(u) for u in users])
@route("/users", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_user(request: Request) -> JSONResponse:
name, imdb_id = await json_from_body(request, ["name", "imdb_id"])
# XXX restrict name
# XXX check if imdb_id is well-formed
secret = secrets.token_bytes()
user = User(name=name, imdb_id=imdb_id, secret=phc_scrypt(secret))
async with db.transaction() as conn:
await db.add(conn, user)
return JSONResponse(
{
"secret": b64encode(secret),
"user": asplain(user),
}
)
@route("/users/{user_id}")
@requires(["authenticated"])
async def show_user(request: Request) -> JSONResponse:
user_id = as_ulid(request.path_params["user_id"])
if is_admin(request):
async with db.new_connection() as conn:
user = await db.get(conn, User, id=str(user_id))
else:
user = await auth_user(request)
if not user:
return not_found()
is_allowed = user.id == user_id
if not is_allowed:
return forbidden()
# Redact `secret`
resp = asplain(user)
resp["secret"] = None
# Fix `groups`
resp["groups"] = user.groups
return JSONResponse(resp)
@route("/users/{user_id}", methods=["DELETE"])
@requires(["authenticated", "admin"])
async def remove_user(request: Request) -> JSONResponse:
user_id = as_ulid(request.path_params["user_id"])
async with db.new_connection() as conn:
user = await db.get(conn, User, id=str(user_id))
if not user:
return not_found()
async with db.transaction() as conn:
# XXX remove user refs from groups and ratings
await db.remove(conn, user)
return JSONResponse(asplain(user))
@route("/users/{user_id}", methods=["PATCH"])
@requires(["authenticated"])
async def modify_user(request: Request) -> JSONResponse:
user_id = as_ulid(request.path_params["user_id"])
if is_admin(request):
async with db.new_connection() as conn:
user = await db.get(conn, User, id=str(user_id))
else:
user = await auth_user(request)
if not user:
return not_found()
is_allowed = user.id == user_id
if not is_allowed:
return forbidden()
data = await json_from_body(request)
if "name" in data:
if not is_admin(request):
return forbidden("Changing user name is not allowed.")
# XXX restrict name
user.name = data["name"]
if "imdb_id" in data:
if not is_admin(request):
return forbidden("Changing IMDb ID is not allowed.")
# XXX check if imdb_id is well-formed
user.imdb_id = data["imdb_id"]
if "secret" in data:
try:
secret = b64decode(data["secret"])
except Exception as err:
raise HTTPException(422, "Invalid secret.") from err
user.secret = phc_scrypt(secret)
async with db.transaction() as conn:
await db.update(conn, user)
return JSONResponse(asplain(user))
def is_valid_access(x: Any) -> TypeGuard[Access]:
return isinstance(x, str) and x in set("riw")
@route("/users/{user_id}/groups", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_group_to_user(request: Request) -> JSONResponse:
user_id = as_ulid(request.path_params["user_id"])
async with db.new_connection() as conn:
user = await db.get(conn, User, id=str(user_id))
if not user:
return not_found("User not found")
(group_id, access) = await json_from_body(request, ["group", "access"])
async with db.new_connection() as conn:
group = await db.get(conn, Group, id=str(group_id))
if not group:
return not_found("Group not found")
if not is_valid_access(access):
raise HTTPException(422, "Invalid access level.")
user.set_access(group_id, access)
async with db.transaction() as conn:
await db.update(conn, user)
return JSONResponse(asplain(user))
@route("/users/{user_id}/ratings")
@requires(["private"])
async def ratings_for_user(request: Request) -> JSONResponse:
not_implemented()
@route("/users/{user_id}/ratings", methods=["PUT"])
@requires("authenticated")
async def set_rating_for_user(request: Request) -> JSONResponse:
not_implemented()
@route("/users/_reload_ratings", methods=["POST"])
@requires(["authenticated", "admin"])
async def load_imdb_user_ratings(request: Request) -> JSONResponse:
ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()]
return JSONResponse({"new_ratings": [asplain(r) for r in ratings]})
@route("/groups")
@requires(["authenticated", "admin"])
async def list_groups(request: Request) -> JSONResponse:
async with db.new_connection() as conn:
groups = await db.get_all(conn, Group)
return JSONResponse([asplain(g) for g in groups])
@route("/groups", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_group(request: Request) -> JSONResponse:
(name,) = await json_from_body(request, ["name"])
# XXX restrict name
group = Group(name=name)
async with db.transaction() as conn:
await db.add(conn, group)
return JSONResponse(asplain(group))
@route("/groups/{group_id}/users", methods=["POST"])
@requires(["authenticated"])
async def add_user_to_group(request: Request) -> JSONResponse:
group_id = as_ulid(request.path_params["group_id"])
async with db.new_connection() as conn:
group = await db.get(conn, Group, id=str(group_id))
if not group:
return not_found()
is_allowed = is_admin(request)
if not is_allowed:
user = await auth_user(request)
if not user:
return not_found("User not found.")
is_allowed = user.has_access(group_id, "w")
if not is_allowed:
return forbidden()
name, user_id = await json_from_body(request, ["name", "id"])
# XXX check if user exists
# XXX restrict name
if any(u["id"] == user_id for u in group.users):
pass
else:
group.users.append({"name": name, "id": user_id})
async with db.transaction() as conn:
await db.update(conn, group)
return JSONResponse(asplain(group))
async def http_exception(request: Request, exc: Exception) -> JSONResponse:
assert isinstance(exc, HTTPException)
return JSONResponse({"error": exc.detail}, status_code=exc.status_code)
def auth_error(conn: HTTPConnection, err: Exception) -> JSONResponse:
return unauthorized(str(err))
@contextlib.asynccontextmanager
async def lifespan(app: Starlette):
await open_connection_pool()
yield
await close_connection_pool()
def create_app():
if config.loglevel == "DEBUG":
logging.basicConfig(
format="%(asctime)s.%(msecs)03d [%(name)s:%(process)d] %(levelname)s: %(message)s",
datefmt="%H:%M:%S",
level=config.loglevel,
)
log.debug(f"Log level: {config.loglevel}")
return Starlette(
lifespan=lifespan,
routes=[
Mount(f"{config.api_base}v1", routes=_routes),
],
middleware=[
Middleware(ResponseTimeMiddleware, header_name="Unwind-Elapsed"),
Middleware(
AuthenticationMiddleware,
backend=BearerAuthBackend(config.api_credentials),
on_error=auth_error,
),
Middleware(
CORSMiddleware,
allow_origins=[config.api_cors],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
),
Middleware(GZipMiddleware),
],
exception_handlers={HTTPException: http_exception},
)