unwind/unwind/web.py
ducklet 32bbfe881b change /movies output format for id filtered listing
The new format is much closer to the format used by /groups/ratings.
Also allows to filter based on Unwind's ID.
2021-12-08 00:13:05 +01:00

671 lines
17 KiB
Python

import asyncio
import logging
import secrets
from json.decoder import JSONDecodeError
from typing import Literal, Optional, overload
from starlette.applications import Starlette
from starlette.authentication import (
AuthCredentials,
AuthenticationBackend,
AuthenticationError,
BaseUser,
SimpleUser,
requires,
)
from starlette.background import BackgroundTask
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.responses import JSONResponse
from starlette.routing import Mount, Route
from . import config, db, imdb, imdb_import
from .db import close_connection_pool, find_movies, find_ratings, open_connection_pool
from .middleware.responsetime import ResponseTimeMiddleware
from .models import Group, Movie, User, asplain
from .types import ULID
from .utils import b64decode, b64encode, phc_compare, phc_scrypt
log = logging.getLogger(__name__)
# XXX we probably don't need a group secret anymore, if group access is managed
# on a user level; a group secret would be a separate user with full group
# access
class AuthedUser(BaseUser):
def __init__(self, user_id: str, secret: str):
self.user_id = user_id
self.secret = secret
class BearerAuthBackend(AuthenticationBackend):
def __init__(self, credentials: dict[str, str]):
self.admin_tokens = {v: k for k, v in credentials.items()}
async def authenticate(self, request):
if "Authorization" not in request.headers:
return
# XXX should we remove the auth header after reading, for security reasons?
auth = request.headers["Authorization"]
try:
scheme, credentials = auth.split()
except ValueError:
raise AuthenticationError("Invalid auth credentials")
roles = []
if scheme.lower() == "bearer":
is_admin = credentials in self.admin_tokens
if not is_admin:
return
name = self.admin_tokens[credentials]
user = SimpleUser(name)
roles.append("admin")
elif scheme.lower() == "basic":
try:
name, secret = b64decode(credentials).decode().split(":")
except:
raise AuthenticationError("Invalid auth credentials")
user = AuthedUser(name, secret)
else:
return
return AuthCredentials(["authenticated", *roles]), user
def truthy(s: str):
return bool(s) and s.lower() in {"1", "yes", "true"}
def yearcomp(s: str):
if not s:
return
comp: Literal["<", "=", ">"] = "="
if (prefix := s[0]) in "<=>":
comp = prefix # type: ignore
s = s[len(prefix) :]
return comp, int(s)
def as_int(x, *, max: int = None, min: Optional[int] = 1, default: int = None):
try:
if not isinstance(x, int):
x = int(x)
if min is not None and x < min:
return min
if max is not None and x > max:
return max
return x
except:
if default is None:
raise
return default
def as_ulid(s: str) -> ULID:
try:
if not s:
raise ValueError("Invalid ULID.")
return ULID(s)
except ValueError:
raise HTTPException(422, "Not a valid ULID.")
@overload
async def json_from_body(request) -> dict:
...
@overload
async def json_from_body(request, keys: list[str]) -> list:
...
async def json_from_body(request, keys: list[str] = None):
if not await request.body():
data = {}
else:
try:
data = await request.json()
except JSONDecodeError:
raise HTTPException(422, "Invalid JSON content.")
if not keys:
return data
try:
return [data[k] for k in keys]
except KeyError as err:
raise HTTPException(422, f"Missing data for key: {err.args[0]}")
def is_admin(request):
return "admin" in request.auth.scopes
async def auth_user(request) -> Optional[User]:
if not isinstance(request.user, AuthedUser):
return
user = await db.get(User, id=request.user.user_id)
if not user:
return
is_authed = phc_compare(secret=request.user.secret, phc_string=user.secret)
if not is_authed:
return
return user
_routes = []
def route(path: str, *, methods: list[str] = None, **kwds):
def decorator(func):
r = Route(path, func, methods=methods, **kwds)
_routes.append(r)
return func
return decorator
route.registered = _routes
@route("/groups/{group_id}/ratings")
async def get_ratings_for_group(request):
group_id = as_ulid(request.path_params["group_id"])
group = await db.get(Group, id=str(group_id))
if not group:
return not_found()
user_ids = {u["id"] for u in group.users}
params = request.query_params
rows = await find_ratings(
title=params.get("title"),
media_type=params.get("media_type"),
exact=truthy(params.get("exact")),
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
include_unrated=truthy(params.get("include_unrated")),
yearcomp=yearcomp(params["year"]) if "year" in params else None,
limit_rows=as_int(params.get("per_page"), max=10, default=5),
user_ids=user_ids,
)
aggr = {}
for r in rows:
mov = aggr.setdefault(
r["movie_imdb_id"],
{
"canonical_title": r["canonical_title"],
"original_title": r["original_title"],
"year": r["release_year"],
"link": imdb.movie_url(r["movie_imdb_id"]),
"user_scores": [],
"imdb_score": r["imdb_score"],
"imdb_votes": r["imdb_votes"],
"media_type": r["media_type"],
},
)
if r["user_score"] is not None and r["user_id"] in user_ids:
mov["user_scores"].append(r["user_score"])
resp = tuple(aggr.values())
return JSONResponse(resp)
def unauthorized(reason: str = "Unauthorized"):
return JSONResponse({"error": reason}, status_code=401)
def forbidden(reason: str = "Forbidden"):
return JSONResponse({"error": reason}, status_code=403)
def not_found(reason: str = "Not Found"):
return JSONResponse({"error": reason}, status_code=404)
def not_implemented():
raise HTTPException(404, "Not yet implemented.")
@route("/movies")
@requires(["authenticated"])
async def list_movies(request):
params = request.query_params
user = await auth_user(request)
user_ids = set()
if group_id := params.get("group_id"):
group_id = as_ulid(group_id)
group = await db.get(Group, id=str(group_id))
if not group:
return not_found("Group not found.")
is_allowed = is_admin(request) or user and user.has_access(group_id)
if not is_allowed:
return forbidden("No access to group.")
user_ids |= {ULID(u["id"]) for u in group.users}
if user_id := params.get("user_id"):
user_id = as_ulid(user_id)
# Currently a user may only directly access their own ratings.
is_allowed = is_admin(request) or user and user.id == user_id
if not is_allowed:
return forbidden("No access to user.")
user_ids |= {user_id}
imdb_id = params.get("imdb_id")
unwind_id = params.get("unwind_id")
if imdb_id or unwind_id:
# XXX missing support for user_ids and user_scores
movies = [await db.get(Movie, id=unwind_id, imdb_id=imdb_id)]
resp = [
{
"unwind_id": m["id"],
"canonical_title": m["title"],
"imdb_score": m["imdb_score"],
"imdb_votes": m["imdb_votes"],
"link": imdb.movie_url(m["imdb_id"]),
"media_type": m["media_type"],
"original_title": m["original_title"],
"user_scores": [],
"year": m["release_year"],
}
for m in map(asplain, movies)
]
else:
per_page = as_int(params.get("per_page"), max=1000, default=5)
page = as_int(params.get("page"), min=1, default=1)
movieratings = await find_movies(
title=params.get("title"),
media_type=params.get("media_type"),
exact=truthy(params.get("exact")),
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
include_unrated=truthy(params.get("include_unrated")),
yearcomp=yearcomp(params["year"]) if "year" in params else None,
limit_rows=per_page,
skip_rows=(page - 1) * per_page,
user_ids=list(user_ids),
)
resp = []
for movie, ratings in movieratings:
mov = asplain(movie)
mov["user_scores"] = [rating.score for rating in ratings]
resp.append(mov)
return JSONResponse(resp)
@route("/movies", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_movie(request):
not_implemented()
@route("/movies/_reload_imdb", methods=["GET"])
@requires(["authenticated", "admin"])
async def progress_for_load_imdb_movies(request):
progress = await db.get_import_progress()
if not progress:
return JSONResponse({"status": "No import exists."}, status_code=404)
p = asplain(progress)
percent = progress.percent
error = progress.error
status = None
if error:
status = "Error during import."
elif percent == 0.0 and progress.stopped:
status = "Import skipped."
elif percent < 100:
status = "Import is running."
else:
status = "Import finished."
resp = {
"status": status,
"progress": percent,
"error": error,
"started": p["started"],
"stopped": p["stopped"],
}
return JSONResponse(resp)
_import_lock = asyncio.Lock()
@route("/movies/_reload_imdb", methods=["POST"])
@requires(["authenticated", "admin"])
async def load_imdb_movies(request):
params = request.query_params
force = truthy(params.get("force"))
async with _import_lock:
progress = await db.get_import_progress()
if progress and not progress.stopped:
return JSONResponse(
{"status": "Import is running.", "progress": progress.percent},
status_code=409,
)
await db.set_import_progress(0)
task = BackgroundTask(imdb_import.load_from_web, force=force)
return JSONResponse(
{"status": "Import started.", "progress": 0.0}, background=task, status_code=202
)
@route("/users")
@requires(["authenticated", "admin"])
async def list_users(request):
users = await db.get_all(User)
return JSONResponse([asplain(u) for u in users])
@route("/users", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_user(request):
name, imdb_id = await json_from_body(request, ["name", "imdb_id"])
# XXX restrict name
# XXX check if imdb_id is well-formed
secret = secrets.token_bytes()
user = User(name=name, imdb_id=imdb_id, secret=phc_scrypt(secret))
await db.add(user)
return JSONResponse(
{
"secret": b64encode(secret),
"user": asplain(user),
}
)
@route("/users/{user_id}")
@requires(["authenticated"])
async def show_user(request):
user_id = as_ulid(request.path_params["user_id"])
if is_admin(request):
user = await db.get(User, id=str(user_id))
else:
user = await auth_user(request)
if not user:
return not_found()
is_allowed = user.id == user_id
if not is_allowed:
return forbidden()
# Redact `secret`
resp = asplain(user)
resp["secret"] = None
# Fix `groups`
resp["groups"] = user.groups
return JSONResponse(resp)
@route("/users/{user_id}", methods=["DELETE"])
@requires(["authenticated", "admin"])
async def remove_user(request):
user_id = as_ulid(request.path_params["user_id"])
user = await db.get(User, id=str(user_id))
if not user:
return not_found()
async with db.shared_connection().transaction():
# XXX remove user refs from groups and ratings
await db.remove(user)
return JSONResponse(asplain(user))
@route("/users/{user_id}", methods=["PATCH"])
@requires(["authenticated"])
async def modify_user(request):
user_id = as_ulid(request.path_params["user_id"])
if is_admin(request):
user = await db.get(User, id=str(user_id))
else:
user = await auth_user(request)
if not user:
return not_found()
is_allowed = user.id == user_id
if not is_allowed:
return forbidden()
data = await json_from_body(request)
if "name" in data:
if not is_admin(request):
return forbidden("Changing user name is not allowed.")
# XXX restrict name
user.name = data["name"]
if "imdb_id" in data:
if not is_admin(request):
return forbidden("Changing IMDb ID is not allowed.")
# XXX check if imdb_id is well-formed
user.imdb_id = data["imdb_id"]
if "secret" in data:
try:
secret = b64decode(data["secret"])
except:
raise HTTPException(422, f"Invalid secret.")
user.secret = phc_scrypt(secret)
await db.update(user)
return JSONResponse(asplain(user))
@route("/users/{user_id}/groups", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_group_to_user(request):
user_id = as_ulid(request.path_params["user_id"])
user = await db.get(User, id=str(user_id))
if not user:
return not_found("User not found")
(group_id, access) = await json_from_body(request, ["group", "access"])
group = await db.get(Group, id=str(group_id))
if not group:
return not_found("Group not found")
if access not in set("riw"):
raise HTTPException(422, f"Invalid access level.")
user.set_access(group_id, access)
await db.update(user)
return JSONResponse(asplain(user))
@route("/users/{user_id}/ratings")
@requires(["private"])
async def ratings_for_user(request):
not_implemented()
@route("/users/{user_id}/ratings", methods=["PUT"])
@requires("authenticated")
async def set_rating_for_user(request):
not_implemented()
@route("/users/_reload_ratings", methods=["POST"])
@requires(["authenticated", "admin"])
async def load_imdb_user_ratings(request):
ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()]
return JSONResponse({"new_ratings": [asplain(r) for r in ratings]})
@route("/groups")
@requires(["authenticated", "admin"])
async def list_groups(request):
groups = await db.get_all(Group)
return JSONResponse([asplain(g) for g in groups])
@route("/groups", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_group(request):
(name,) = await json_from_body(request, ["name"])
# XXX restrict name
group = Group(name=name)
await db.add(group)
return JSONResponse(asplain(group))
@route("/groups/{group_id}/users", methods=["POST"])
@requires(["authenticated"])
async def add_user_to_group(request):
group_id = as_ulid(request.path_params["group_id"])
group = await db.get(Group, id=str(group_id))
if not group:
return not_found()
is_allowed = is_admin(request)
if not is_allowed:
user = await auth_user(request)
if not user:
return not_found("User not found.")
is_allowed = user.has_access(group_id, "w")
if not is_allowed:
return forbidden()
name, user_id = await json_from_body(request, ["name", "id"])
# XXX check if user exists
# XXX restrict name
if any(u["id"] == user_id for u in group.users):
pass
else:
group.users.append({"name": name, "id": user_id})
await db.update(group)
return JSONResponse(asplain(group))
async def http_exception(request, exc):
return JSONResponse({"error": exc.detail}, status_code=exc.status_code)
def auth_error(request, err):
return unauthorized(str(err))
def create_app():
if config.loglevel == "DEBUG":
logging.basicConfig(
format="%(asctime)s.%(msecs)03d [%(name)s:%(process)d] %(levelname)s: %(message)s",
datefmt="%H:%M:%S",
level=config.loglevel,
)
log.debug(f"Log level: {config.loglevel}")
return Starlette(
on_startup=[open_connection_pool],
on_shutdown=[close_connection_pool],
routes=[
Mount(f"{config.api_base}v1", routes=route.registered),
],
middleware=[
Middleware(ResponseTimeMiddleware, header_name="Unwind-Elapsed"),
Middleware(
AuthenticationMiddleware,
backend=BearerAuthBackend(config.api_credentials),
on_error=auth_error,
),
Middleware(
CORSMiddleware,
allow_origins=[config.api_cors],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
),
Middleware(GZipMiddleware),
],
exception_handlers={HTTPException: http_exception},
)