676 lines
18 KiB
Python
676 lines
18 KiB
Python
import asyncio
|
|
import contextlib
|
|
import logging
|
|
import secrets
|
|
from json.decoder import JSONDecodeError
|
|
from typing import Literal, overload
|
|
|
|
from starlette.applications import Starlette
|
|
from starlette.authentication import (
|
|
AuthCredentials,
|
|
AuthenticationBackend,
|
|
AuthenticationError,
|
|
BaseUser,
|
|
SimpleUser,
|
|
requires,
|
|
)
|
|
from starlette.background import BackgroundTask
|
|
from starlette.exceptions import HTTPException
|
|
from starlette.middleware import Middleware
|
|
from starlette.middleware.authentication import AuthenticationMiddleware
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
from starlette.middleware.gzip import GZipMiddleware
|
|
from starlette.responses import JSONResponse
|
|
from starlette.routing import Mount, Route
|
|
|
|
from . import config, db, imdb, imdb_import, web_models
|
|
from .db import close_connection_pool, find_movies, find_ratings, open_connection_pool
|
|
from .middleware.responsetime import ResponseTimeMiddleware
|
|
from .models import Group, Movie, User, asplain
|
|
from .types import ULID
|
|
from .utils import b64decode, b64encode, phc_compare, phc_scrypt
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# XXX we probably don't need a group secret anymore, if group access is managed
|
|
# on a user level; a group secret would be a separate user with full group
|
|
# access
|
|
|
|
|
|
class AuthedUser(BaseUser):
|
|
def __init__(self, user_id: str, secret: str):
|
|
self.user_id = user_id
|
|
self.secret = secret
|
|
|
|
|
|
class BearerAuthBackend(AuthenticationBackend):
|
|
def __init__(self, credentials: dict[str, str]):
|
|
self.admin_tokens = {v: k for k, v in credentials.items()}
|
|
|
|
async def authenticate(self, request):
|
|
if "Authorization" not in request.headers:
|
|
return
|
|
|
|
# XXX should we remove the auth header after reading, for security reasons?
|
|
|
|
auth = request.headers["Authorization"]
|
|
try:
|
|
scheme, credentials = auth.split()
|
|
except ValueError:
|
|
raise AuthenticationError("Invalid auth credentials")
|
|
|
|
roles = []
|
|
|
|
if scheme.lower() == "bearer":
|
|
is_admin = credentials in self.admin_tokens
|
|
if not is_admin:
|
|
return
|
|
name = self.admin_tokens[credentials]
|
|
user = SimpleUser(name)
|
|
roles.append("admin")
|
|
|
|
elif scheme.lower() == "basic":
|
|
try:
|
|
name, secret = b64decode(credentials).decode().split(":")
|
|
except:
|
|
raise AuthenticationError("Invalid auth credentials")
|
|
user = AuthedUser(name, secret)
|
|
|
|
else:
|
|
return
|
|
|
|
return AuthCredentials(["authenticated", *roles]), user
|
|
|
|
|
|
def truthy(s: str):
|
|
return bool(s) and s.lower() in {"1", "yes", "true"}
|
|
|
|
|
|
_Yearcomp = Literal["<", "=", ">"]
|
|
|
|
|
|
def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
|
|
if not s:
|
|
return
|
|
|
|
comp: _Yearcomp = "="
|
|
if (prefix := s[0]) in "<=>":
|
|
comp = prefix # type: ignore
|
|
s = s[len(prefix) :]
|
|
|
|
return comp, int(s)
|
|
|
|
|
|
def as_int(
|
|
x, *, max: int | None = None, min: int | None = 1, default: int | None = None
|
|
) -> int:
|
|
try:
|
|
if not isinstance(x, int):
|
|
x = int(x)
|
|
if min is not None and x < min:
|
|
return min
|
|
if max is not None and x > max:
|
|
return max
|
|
return x
|
|
|
|
except:
|
|
if default is None:
|
|
raise
|
|
|
|
return default
|
|
|
|
|
|
def as_ulid(s: str) -> ULID:
|
|
try:
|
|
if not s:
|
|
raise ValueError("Invalid ULID.")
|
|
|
|
return ULID(s)
|
|
|
|
except ValueError:
|
|
raise HTTPException(422, "Not a valid ULID.")
|
|
|
|
|
|
@overload
|
|
async def json_from_body(request) -> dict: ...
|
|
|
|
|
|
@overload
|
|
async def json_from_body(request, keys: list[str]) -> list: ...
|
|
|
|
|
|
async def json_from_body(request, keys: list[str] | None = None):
|
|
if not await request.body():
|
|
data = {}
|
|
|
|
else:
|
|
try:
|
|
data = await request.json()
|
|
except JSONDecodeError:
|
|
raise HTTPException(422, "Invalid JSON content.")
|
|
|
|
if not keys:
|
|
return data
|
|
|
|
try:
|
|
return [data[k] for k in keys]
|
|
except KeyError as err:
|
|
raise HTTPException(422, f"Missing data for key: {err.args[0]}")
|
|
|
|
|
|
def is_admin(request):
|
|
return "admin" in request.auth.scopes
|
|
|
|
|
|
async def auth_user(request) -> User | None:
|
|
if not isinstance(request.user, AuthedUser):
|
|
return
|
|
|
|
async with db.new_connection() as conn:
|
|
user = await db.get(conn, User, id=request.user.user_id)
|
|
if not user:
|
|
return
|
|
|
|
is_authed = phc_compare(secret=request.user.secret, phc_string=user.secret)
|
|
if not is_authed:
|
|
return
|
|
|
|
return user
|
|
|
|
|
|
_routes: list[Route] = []
|
|
|
|
|
|
def route(path: str, *, methods: list[str] | None = None, **kwds):
|
|
def decorator(func):
|
|
r = Route(path, func, methods=methods, **kwds)
|
|
_routes.append(r)
|
|
return func
|
|
|
|
return decorator
|
|
|
|
|
|
@route("/groups/{group_id}/ratings")
|
|
async def get_ratings_for_group(request):
|
|
group_id = as_ulid(request.path_params["group_id"])
|
|
|
|
async with db.new_connection() as conn:
|
|
if (group := await db.get(conn, Group, id=str(group_id))) is None:
|
|
return not_found()
|
|
|
|
user_ids = {u["id"] for u in group.users}
|
|
|
|
params = request.query_params
|
|
|
|
imdb_id: str | None = params.get("imdb_id")
|
|
unwind_id: str | None = params.get("unwind_id")
|
|
|
|
# if (imdb_id or unwind_id) and (movie := await db.get(Movie, id=unwind_id, imdb_id=imdb_id)):
|
|
if unwind_id:
|
|
async with db.new_connection() as conn:
|
|
rows = await db.ratings_for_movie_ids(conn, ids=[unwind_id])
|
|
|
|
elif imdb_id:
|
|
async with db.new_connection() as conn:
|
|
rows = await db.ratings_for_movie_ids(conn, imdb_ids=[imdb_id])
|
|
|
|
else:
|
|
async with db.new_connection() as conn:
|
|
rows = await find_ratings(
|
|
conn,
|
|
title=params.get("title"),
|
|
media_type=params.get("media_type"),
|
|
exact=truthy(params.get("exact")),
|
|
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
|
|
include_unrated=truthy(params.get("include_unrated")),
|
|
yearcomp=yearcomp(params["year"]) if "year" in params else None,
|
|
limit_rows=as_int(params.get("per_page"), max=10, default=5),
|
|
user_ids=user_ids,
|
|
)
|
|
|
|
ratings = (web_models.Rating(**r) for r in rows)
|
|
|
|
aggr = web_models.aggregate_ratings(ratings, user_ids)
|
|
|
|
resp = tuple(asplain(r) for r in aggr)
|
|
|
|
return JSONResponse(resp)
|
|
|
|
|
|
def unauthorized(reason: str = "Unauthorized"):
|
|
return JSONResponse({"error": reason}, status_code=401)
|
|
|
|
|
|
def forbidden(reason: str = "Forbidden"):
|
|
return JSONResponse({"error": reason}, status_code=403)
|
|
|
|
|
|
def not_found(reason: str = "Not Found"):
|
|
return JSONResponse({"error": reason}, status_code=404)
|
|
|
|
|
|
def not_implemented():
|
|
raise HTTPException(404, "Not yet implemented.")
|
|
|
|
|
|
@route("/movies")
|
|
@requires(["authenticated"])
|
|
async def list_movies(request):
|
|
params = request.query_params
|
|
|
|
user = await auth_user(request)
|
|
|
|
user_ids = set()
|
|
|
|
if group_id := params.get("group_id"):
|
|
group_id = as_ulid(group_id)
|
|
|
|
async with db.new_connection() as conn:
|
|
group = await db.get(conn, Group, id=str(group_id))
|
|
if not group:
|
|
return not_found("Group not found.")
|
|
|
|
is_allowed = is_admin(request) or user and user.has_access(group_id)
|
|
if not is_allowed:
|
|
return forbidden("No access to group.")
|
|
|
|
user_ids |= {ULID(u["id"]) for u in group.users}
|
|
|
|
if user_id := params.get("user_id"):
|
|
user_id = as_ulid(user_id)
|
|
|
|
# Currently a user may only directly access their own ratings.
|
|
is_allowed = is_admin(request) or user and user.id == user_id
|
|
if not is_allowed:
|
|
return forbidden("No access to user.")
|
|
|
|
user_ids |= {user_id}
|
|
|
|
imdb_id = params.get("imdb_id")
|
|
unwind_id = params.get("unwind_id")
|
|
|
|
if imdb_id or unwind_id:
|
|
# XXX missing support for user_ids and user_scores
|
|
async with db.new_connection() as conn:
|
|
movies = (
|
|
[m]
|
|
if (m := await db.get(conn, Movie, id=unwind_id, imdb_id=imdb_id))
|
|
else []
|
|
)
|
|
|
|
resp = [asplain(web_models.RatingAggregate.from_movie(m)) for m in movies]
|
|
|
|
else:
|
|
per_page = as_int(params.get("per_page"), max=1000, default=5)
|
|
page = as_int(params.get("page"), min=1, default=1)
|
|
async with db.new_connection() as conn:
|
|
movieratings = await find_movies(
|
|
conn,
|
|
title=params.get("title"),
|
|
media_type=params.get("media_type"),
|
|
exact=truthy(params.get("exact")),
|
|
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
|
|
include_unrated=truthy(params.get("include_unrated")),
|
|
yearcomp=yearcomp(params["year"]) if "year" in params else None,
|
|
limit_rows=per_page,
|
|
skip_rows=(page - 1) * per_page,
|
|
user_ids=list(user_ids),
|
|
)
|
|
|
|
resp = []
|
|
for movie, ratings in movieratings:
|
|
mov = asplain(movie)
|
|
mov["user_scores"] = [rating.score for rating in ratings]
|
|
resp.append(mov)
|
|
|
|
return JSONResponse(resp)
|
|
|
|
|
|
@route("/movies", methods=["POST"])
|
|
@requires(["authenticated", "admin"])
|
|
async def add_movie(request):
|
|
not_implemented()
|
|
|
|
|
|
@route("/movies/_reload_imdb", methods=["GET"])
|
|
@requires(["authenticated", "admin"])
|
|
async def progress_for_load_imdb_movies(request):
|
|
async with db.new_connection() as conn:
|
|
progress = await db.get_import_progress(conn)
|
|
if not progress:
|
|
return JSONResponse({"status": "No import exists."}, status_code=404)
|
|
|
|
p = asplain(progress)
|
|
percent = progress.percent
|
|
error = progress.error
|
|
|
|
status = None
|
|
if error:
|
|
status = "Error during import."
|
|
elif percent == 0.0 and progress.stopped:
|
|
status = "Import skipped."
|
|
elif percent < 100:
|
|
status = "Import is running."
|
|
else:
|
|
status = "Import finished."
|
|
|
|
resp = {
|
|
"status": status,
|
|
"progress": percent,
|
|
"error": error,
|
|
"started": p["started"],
|
|
"stopped": p["stopped"],
|
|
}
|
|
|
|
return JSONResponse(resp)
|
|
|
|
|
|
_import_lock = asyncio.Lock()
|
|
|
|
|
|
@route("/movies/_reload_imdb", methods=["POST"])
|
|
@requires(["authenticated", "admin"])
|
|
async def load_imdb_movies(request):
|
|
params = request.query_params
|
|
force = truthy(params.get("force"))
|
|
|
|
async with _import_lock:
|
|
async with db.new_connection() as conn:
|
|
progress = await db.get_import_progress(conn)
|
|
if progress and not progress.stopped:
|
|
return JSONResponse(
|
|
{"status": "Import is running.", "progress": progress.percent},
|
|
status_code=409,
|
|
)
|
|
|
|
async with db.transaction() as conn:
|
|
await db.set_import_progress(conn, 0)
|
|
|
|
task = BackgroundTask(imdb_import.load_from_web, force=force)
|
|
return JSONResponse(
|
|
{"status": "Import started.", "progress": 0.0}, background=task, status_code=202
|
|
)
|
|
|
|
|
|
@route("/users")
|
|
@requires(["authenticated", "admin"])
|
|
async def list_users(request):
|
|
async with db.new_connection() as conn:
|
|
users = await db.get_all(conn, User)
|
|
|
|
return JSONResponse([asplain(u) for u in users])
|
|
|
|
|
|
@route("/users", methods=["POST"])
|
|
@requires(["authenticated", "admin"])
|
|
async def add_user(request):
|
|
name, imdb_id = await json_from_body(request, ["name", "imdb_id"])
|
|
|
|
# XXX restrict name
|
|
# XXX check if imdb_id is well-formed
|
|
|
|
secret = secrets.token_bytes()
|
|
|
|
user = User(name=name, imdb_id=imdb_id, secret=phc_scrypt(secret))
|
|
async with db.transaction() as conn:
|
|
await db.add(conn, user)
|
|
|
|
return JSONResponse(
|
|
{
|
|
"secret": b64encode(secret),
|
|
"user": asplain(user),
|
|
}
|
|
)
|
|
|
|
|
|
@route("/users/{user_id}")
|
|
@requires(["authenticated"])
|
|
async def show_user(request):
|
|
user_id = as_ulid(request.path_params["user_id"])
|
|
|
|
if is_admin(request):
|
|
async with db.new_connection() as conn:
|
|
user = await db.get(conn, User, id=str(user_id))
|
|
|
|
else:
|
|
user = await auth_user(request)
|
|
|
|
if not user:
|
|
return not_found()
|
|
|
|
is_allowed = user.id == user_id
|
|
if not is_allowed:
|
|
return forbidden()
|
|
|
|
# Redact `secret`
|
|
resp = asplain(user)
|
|
resp["secret"] = None
|
|
|
|
# Fix `groups`
|
|
resp["groups"] = user.groups
|
|
|
|
return JSONResponse(resp)
|
|
|
|
|
|
@route("/users/{user_id}", methods=["DELETE"])
|
|
@requires(["authenticated", "admin"])
|
|
async def remove_user(request):
|
|
user_id = as_ulid(request.path_params["user_id"])
|
|
|
|
async with db.new_connection() as conn:
|
|
user = await db.get(conn, User, id=str(user_id))
|
|
if not user:
|
|
return not_found()
|
|
|
|
async with db.transaction() as conn:
|
|
# XXX remove user refs from groups and ratings
|
|
|
|
await db.remove(conn, user)
|
|
|
|
return JSONResponse(asplain(user))
|
|
|
|
|
|
@route("/users/{user_id}", methods=["PATCH"])
|
|
@requires(["authenticated"])
|
|
async def modify_user(request):
|
|
user_id = as_ulid(request.path_params["user_id"])
|
|
|
|
if is_admin(request):
|
|
async with db.new_connection() as conn:
|
|
user = await db.get(conn, User, id=str(user_id))
|
|
|
|
else:
|
|
user = await auth_user(request)
|
|
|
|
if not user:
|
|
return not_found()
|
|
|
|
is_allowed = user.id == user_id
|
|
if not is_allowed:
|
|
return forbidden()
|
|
|
|
data = await json_from_body(request)
|
|
|
|
if "name" in data:
|
|
if not is_admin(request):
|
|
return forbidden("Changing user name is not allowed.")
|
|
|
|
# XXX restrict name
|
|
user.name = data["name"]
|
|
|
|
if "imdb_id" in data:
|
|
if not is_admin(request):
|
|
return forbidden("Changing IMDb ID is not allowed.")
|
|
|
|
# XXX check if imdb_id is well-formed
|
|
user.imdb_id = data["imdb_id"]
|
|
|
|
if "secret" in data:
|
|
try:
|
|
secret = b64decode(data["secret"])
|
|
except:
|
|
raise HTTPException(422, "Invalid secret.")
|
|
|
|
user.secret = phc_scrypt(secret)
|
|
|
|
async with db.transaction() as conn:
|
|
await db.update(conn, user)
|
|
|
|
return JSONResponse(asplain(user))
|
|
|
|
|
|
@route("/users/{user_id}/groups", methods=["POST"])
|
|
@requires(["authenticated", "admin"])
|
|
async def add_group_to_user(request):
|
|
user_id = as_ulid(request.path_params["user_id"])
|
|
|
|
async with db.new_connection() as conn:
|
|
user = await db.get(conn, User, id=str(user_id))
|
|
if not user:
|
|
return not_found("User not found")
|
|
|
|
(group_id, access) = await json_from_body(request, ["group", "access"])
|
|
|
|
async with db.new_connection() as conn:
|
|
group = await db.get(conn, Group, id=str(group_id))
|
|
if not group:
|
|
return not_found("Group not found")
|
|
|
|
if access not in set("riw"):
|
|
raise HTTPException(422, "Invalid access level.")
|
|
|
|
user.set_access(group_id, access)
|
|
async with db.transaction() as conn:
|
|
await db.update(conn, user)
|
|
|
|
return JSONResponse(asplain(user))
|
|
|
|
|
|
@route("/users/{user_id}/ratings")
|
|
@requires(["private"])
|
|
async def ratings_for_user(request):
|
|
not_implemented()
|
|
|
|
|
|
@route("/users/{user_id}/ratings", methods=["PUT"])
|
|
@requires("authenticated")
|
|
async def set_rating_for_user(request):
|
|
not_implemented()
|
|
|
|
|
|
@route("/users/_reload_ratings", methods=["POST"])
|
|
@requires(["authenticated", "admin"])
|
|
async def load_imdb_user_ratings(request):
|
|
ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()]
|
|
|
|
return JSONResponse({"new_ratings": [asplain(r) for r in ratings]})
|
|
|
|
|
|
@route("/groups")
|
|
@requires(["authenticated", "admin"])
|
|
async def list_groups(request):
|
|
async with db.new_connection() as conn:
|
|
groups = await db.get_all(conn, Group)
|
|
|
|
return JSONResponse([asplain(g) for g in groups])
|
|
|
|
|
|
@route("/groups", methods=["POST"])
|
|
@requires(["authenticated", "admin"])
|
|
async def add_group(request):
|
|
(name,) = await json_from_body(request, ["name"])
|
|
|
|
# XXX restrict name
|
|
|
|
group = Group(name=name)
|
|
async with db.transaction() as conn:
|
|
await db.add(conn, group)
|
|
|
|
return JSONResponse(asplain(group))
|
|
|
|
|
|
@route("/groups/{group_id}/users", methods=["POST"])
|
|
@requires(["authenticated"])
|
|
async def add_user_to_group(request):
|
|
group_id = as_ulid(request.path_params["group_id"])
|
|
async with db.new_connection() as conn:
|
|
group = await db.get(conn, Group, id=str(group_id))
|
|
|
|
if not group:
|
|
return not_found()
|
|
|
|
is_allowed = is_admin(request)
|
|
|
|
if not is_allowed:
|
|
user = await auth_user(request)
|
|
if not user:
|
|
return not_found("User not found.")
|
|
|
|
is_allowed = user.has_access(group_id, "w")
|
|
|
|
if not is_allowed:
|
|
return forbidden()
|
|
|
|
name, user_id = await json_from_body(request, ["name", "id"])
|
|
|
|
# XXX check if user exists
|
|
# XXX restrict name
|
|
|
|
if any(u["id"] == user_id for u in group.users):
|
|
pass
|
|
else:
|
|
group.users.append({"name": name, "id": user_id})
|
|
|
|
async with db.transaction() as conn:
|
|
await db.update(conn, group)
|
|
|
|
return JSONResponse(asplain(group))
|
|
|
|
|
|
async def http_exception(request, exc):
|
|
return JSONResponse({"error": exc.detail}, status_code=exc.status_code)
|
|
|
|
|
|
def auth_error(request, err):
|
|
return unauthorized(str(err))
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def lifespan(app: Starlette):
|
|
await open_connection_pool()
|
|
yield
|
|
await close_connection_pool()
|
|
|
|
|
|
def create_app():
|
|
if config.loglevel == "DEBUG":
|
|
logging.basicConfig(
|
|
format="%(asctime)s.%(msecs)03d [%(name)s:%(process)d] %(levelname)s: %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
level=config.loglevel,
|
|
)
|
|
log.debug(f"Log level: {config.loglevel}")
|
|
|
|
return Starlette(
|
|
lifespan=lifespan,
|
|
routes=[
|
|
Mount(f"{config.api_base}v1", routes=_routes),
|
|
],
|
|
middleware=[
|
|
Middleware(ResponseTimeMiddleware, header_name="Unwind-Elapsed"),
|
|
Middleware(
|
|
AuthenticationMiddleware,
|
|
backend=BearerAuthBackend(config.api_credentials),
|
|
on_error=auth_error,
|
|
),
|
|
Middleware(
|
|
CORSMiddleware,
|
|
allow_origins=[config.api_cors],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
),
|
|
Middleware(GZipMiddleware),
|
|
],
|
|
exception_handlers={HTTPException: http_exception},
|
|
)
|