unwind/unwind/web.py

329 lines
8.3 KiB
Python

import logging
import secrets
from json.decoder import JSONDecodeError
from typing import Literal, Optional
from starlette.applications import Starlette
from starlette.authentication import (
AuthCredentials,
AuthenticationBackend,
AuthenticationError,
BaseUser,
SimpleUser,
requires,
)
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.responses import JSONResponse
from starlette.routing import Mount, Route
from . import config, db
from .db import close_connection_pool, find_ratings, open_connection_pool
from .middleware.responsetime import ResponseTimeMiddleware
from .models import Group, Movie, User, asplain
from .types import ULID
from .utils import b64encode, phc_compare, phc_scrypt
log = logging.getLogger(__name__)
class BearingUser(BaseUser):
def __init__(self, token):
self.token = token
class BearerAuthBackend(AuthenticationBackend):
def __init__(self, credentials: dict[str, str]):
self.admin_tokens = {v: k for k, v in credentials.items()}
async def authenticate(self, request):
if "Authorization" not in request.headers:
return
auth = request.headers["Authorization"]
try:
scheme, token = auth.split()
except ValueError:
raise AuthenticationError("Invalid auth credentials")
if scheme.lower() != "bearer":
return
roles = []
is_admin = token in self.admin_tokens
if is_admin:
user = SimpleUser(self.admin_tokens[token])
roles.append("admin")
else:
user = BearingUser(token)
return AuthCredentials(["authenticated", *roles]), user
def imdb_url(imdb_id: str):
return f"https://www.imdb.com/title/{imdb_id}/"
def truthy(s: str):
return bool(s) and s.lower() in {"1", "yes", "true"}
def yearcomp(s: str):
if not s:
return
comp: Literal["<", "=", ">"] = "="
if (prefix := s[0]) in "<=>":
comp = prefix # type: ignore
s = s[len(prefix) :]
return comp, int(s)
def as_int(x, *, max: int = None, min: Optional[int] = 1, default: int = None):
try:
if not isinstance(x, int):
x = int(x)
if min is not None and x < min:
return min
if max is not None and x > max:
return max
return x
except:
if default is None:
raise
return default
def as_ulid(s: str) -> ULID:
try:
return ULID(s)
except ValueError:
raise HTTPException(422, "Not a valid ULID.")
_routes = []
def route(path: str, *, methods: list[str] = None, **kwds):
def decorator(func):
r = Route(path, func, methods=methods, **kwds)
_routes.append(r)
return func
return decorator
route.registered = _routes
@route("/groups/{group_id}/ratings")
async def get_ratings_for_group(request):
group_id = as_ulid(request.path_params["group_id"])
group = await db.get(Group, id=str(group_id))
if not group:
return not_found()
user_ids = {u["id"] for u in group.users}
params = request.query_params
rows = await find_ratings(
title=params.get("title"),
media_type=params.get("media_type"),
exact=truthy(params.get("exact")),
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
include_unrated=truthy(params.get("include_unrated")),
yearcomp=yearcomp(params["year"]) if "year" in params else None,
limit_rows=as_int(params.get("per_page"), max=10, default=5),
user_ids=user_ids,
)
aggr = {}
for r in rows:
mov = aggr.setdefault(
r["movie_imdb_id"],
{
"canonical_title": r["canonical_title"],
"original_title": r["original_title"],
"year": r["release_year"],
"link": imdb_url(r["movie_imdb_id"]),
"user_scores": [],
"imdb_score": r["imdb_score"],
"media_type": r["media_type"],
},
)
if r["user_score"] is not None and r["user_id"] in user_ids:
mov["user_scores"].append(r["user_score"])
resp = tuple(aggr.values())
return JSONResponse(resp)
def unauthorized(reason: str = "Unauthorized"):
return JSONResponse({"error": reason}, status_code=401)
def forbidden(reason: str = "Forbidden"):
return JSONResponse({"error": reason}, status_code=403)
def not_found(reason: str = "Not Found"):
return JSONResponse({"error": reason}, status_code=404)
@route("/movies")
@requires(["private"])
async def get_movies(request):
imdb_id = request.query_params.get("imdb_id")
movie = await db.get(Movie, imdb_id=imdb_id)
resp = [asplain(movie)] if movie else []
return JSONResponse(resp)
@route("/movies", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_movie(request):
pass
@route("/users")
@requires(["authenticated", "admin"])
async def list_users(request):
users = await db.get_all(User)
return JSONResponse([asplain(u) for u in users])
@route("/users", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_user(request):
pass
@route("/users/{user_id}/ratings")
@requires(["private"])
async def ratings_for_user(request):
request.path_params["user_id"]
@route("/users/{user_id}/ratings", methods=["PUT"])
@requires("authenticated")
async def set_rating_for_user(request):
request.path_params["user_id"]
@route("/groups", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_group(request):
if not await request.body():
data = {}
else:
try:
data = await request.json()
except JSONDecodeError:
raise HTTPException(422, "Invalid JSON content.")
try:
name = data["name"]
except KeyError as err:
raise HTTPException(422, f"Missing data for key: {err.args[0]}")
# XXX restrict name
secret = secrets.token_bytes()
group = Group(name=name, secret=phc_scrypt(secret))
await db.add(group)
return JSONResponse(
{
"secret": b64encode(secret),
"group": asplain(group),
}
)
@route("/groups/{group_id}/users", methods=["POST"])
@requires(["authenticated"])
async def add_user_to_group(request):
group_id = as_ulid(request.path_params["group_id"])
group = await db.get(Group, id=str(group_id))
if not group:
return not_found()
is_allowed = "admin" in request.auth.scopes or phc_compare(
secret=request.user.token, phc_string=group.secret
)
if not is_allowed:
return forbidden()
if not await request.body():
data = {}
else:
try:
data = await request.json()
except JSONDecodeError:
raise HTTPException(422, "Invalid JSON content.")
try:
name = data["name"]
user_id = data["id"]
except KeyError as err:
raise HTTPException(422, f"Missing data for key: {err.args[0]}")
# XXX check if user exists
# XXX restrict name
if any(u["id"] == user_id for u in group.users):
pass
else:
group.users.append({"name": name, "id": user_id})
await db.update(group)
return JSONResponse(asplain(group))
async def http_exception(request, exc):
return JSONResponse({"error": exc.detail}, status_code=exc.status_code)
def auth_error(request, err):
return unauthorized(str(err))
def create_app():
if config.loglevel == "DEBUG":
logging.basicConfig(
format="%(asctime)s.%(msecs)03d [%(name)s:%(process)d] %(levelname)s: %(message)s",
datefmt="%H:%M:%S",
level=config.loglevel,
)
log.debug(f"Log level: {config.loglevel}")
return Starlette(
on_startup=[open_connection_pool],
on_shutdown=[close_connection_pool],
routes=[
Mount("/api/v1", routes=route.registered),
],
middleware=[
Middleware(ResponseTimeMiddleware, header_name="Unwind-Elapsed"),
Middleware(
AuthenticationMiddleware,
backend=BearerAuthBackend(config.api_credentials),
on_error=auth_error,
),
],
exception_handlers={HTTPException: http_exception},
)