329 lines
8.3 KiB
Python
329 lines
8.3 KiB
Python
import logging
|
|
import secrets
|
|
from json.decoder import JSONDecodeError
|
|
from typing import Literal, Optional
|
|
|
|
from starlette.applications import Starlette
|
|
from starlette.authentication import (
|
|
AuthCredentials,
|
|
AuthenticationBackend,
|
|
AuthenticationError,
|
|
BaseUser,
|
|
SimpleUser,
|
|
requires,
|
|
)
|
|
from starlette.exceptions import HTTPException
|
|
from starlette.middleware import Middleware
|
|
from starlette.middleware.authentication import AuthenticationMiddleware
|
|
from starlette.responses import JSONResponse
|
|
from starlette.routing import Mount, Route
|
|
|
|
from . import config, db
|
|
from .db import close_connection_pool, find_ratings, open_connection_pool
|
|
from .middleware.responsetime import ResponseTimeMiddleware
|
|
from .models import Group, Movie, User, asplain
|
|
from .types import ULID
|
|
from .utils import b64encode, phc_compare, phc_scrypt
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class BearingUser(BaseUser):
|
|
def __init__(self, token):
|
|
self.token = token
|
|
|
|
|
|
class BearerAuthBackend(AuthenticationBackend):
|
|
def __init__(self, credentials: dict[str, str]):
|
|
self.admin_tokens = {v: k for k, v in credentials.items()}
|
|
|
|
async def authenticate(self, request):
|
|
if "Authorization" not in request.headers:
|
|
return
|
|
|
|
auth = request.headers["Authorization"]
|
|
try:
|
|
scheme, token = auth.split()
|
|
except ValueError:
|
|
raise AuthenticationError("Invalid auth credentials")
|
|
|
|
if scheme.lower() != "bearer":
|
|
return
|
|
|
|
roles = []
|
|
|
|
is_admin = token in self.admin_tokens
|
|
|
|
if is_admin:
|
|
user = SimpleUser(self.admin_tokens[token])
|
|
roles.append("admin")
|
|
|
|
else:
|
|
user = BearingUser(token)
|
|
|
|
return AuthCredentials(["authenticated", *roles]), user
|
|
|
|
|
|
def imdb_url(imdb_id: str):
|
|
return f"https://www.imdb.com/title/{imdb_id}/"
|
|
|
|
|
|
def truthy(s: str):
|
|
return bool(s) and s.lower() in {"1", "yes", "true"}
|
|
|
|
|
|
def yearcomp(s: str):
|
|
if not s:
|
|
return
|
|
|
|
comp: Literal["<", "=", ">"] = "="
|
|
if (prefix := s[0]) in "<=>":
|
|
comp = prefix # type: ignore
|
|
s = s[len(prefix) :]
|
|
|
|
return comp, int(s)
|
|
|
|
|
|
def as_int(x, *, max: int = None, min: Optional[int] = 1, default: int = None):
|
|
try:
|
|
if not isinstance(x, int):
|
|
x = int(x)
|
|
if min is not None and x < min:
|
|
return min
|
|
if max is not None and x > max:
|
|
return max
|
|
return x
|
|
|
|
except:
|
|
if default is None:
|
|
raise
|
|
|
|
return default
|
|
|
|
|
|
def as_ulid(s: str) -> ULID:
|
|
try:
|
|
return ULID(s)
|
|
except ValueError:
|
|
raise HTTPException(422, "Not a valid ULID.")
|
|
|
|
|
|
_routes = []
|
|
|
|
|
|
def route(path: str, *, methods: list[str] = None, **kwds):
|
|
def decorator(func):
|
|
r = Route(path, func, methods=methods, **kwds)
|
|
_routes.append(r)
|
|
return func
|
|
|
|
return decorator
|
|
|
|
|
|
route.registered = _routes
|
|
|
|
|
|
@route("/groups/{group_id}/ratings")
|
|
async def get_ratings_for_group(request):
|
|
group_id = as_ulid(request.path_params["group_id"])
|
|
group = await db.get(Group, id=str(group_id))
|
|
|
|
if not group:
|
|
return not_found()
|
|
|
|
user_ids = {u["id"] for u in group.users}
|
|
|
|
params = request.query_params
|
|
rows = await find_ratings(
|
|
title=params.get("title"),
|
|
media_type=params.get("media_type"),
|
|
exact=truthy(params.get("exact")),
|
|
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
|
|
include_unrated=truthy(params.get("include_unrated")),
|
|
yearcomp=yearcomp(params["year"]) if "year" in params else None,
|
|
limit_rows=as_int(params.get("per_page"), max=10, default=5),
|
|
user_ids=user_ids,
|
|
)
|
|
|
|
aggr = {}
|
|
for r in rows:
|
|
mov = aggr.setdefault(
|
|
r["movie_imdb_id"],
|
|
{
|
|
"canonical_title": r["canonical_title"],
|
|
"original_title": r["original_title"],
|
|
"year": r["release_year"],
|
|
"link": imdb_url(r["movie_imdb_id"]),
|
|
"user_scores": [],
|
|
"imdb_score": r["imdb_score"],
|
|
"media_type": r["media_type"],
|
|
},
|
|
)
|
|
if r["user_score"] is not None and r["user_id"] in user_ids:
|
|
mov["user_scores"].append(r["user_score"])
|
|
|
|
resp = tuple(aggr.values())
|
|
|
|
return JSONResponse(resp)
|
|
|
|
|
|
def unauthorized(reason: str = "Unauthorized"):
|
|
return JSONResponse({"error": reason}, status_code=401)
|
|
|
|
|
|
def forbidden(reason: str = "Forbidden"):
|
|
return JSONResponse({"error": reason}, status_code=403)
|
|
|
|
|
|
def not_found(reason: str = "Not Found"):
|
|
return JSONResponse({"error": reason}, status_code=404)
|
|
|
|
|
|
@route("/movies")
|
|
@requires(["private"])
|
|
async def get_movies(request):
|
|
imdb_id = request.query_params.get("imdb_id")
|
|
|
|
movie = await db.get(Movie, imdb_id=imdb_id)
|
|
|
|
resp = [asplain(movie)] if movie else []
|
|
return JSONResponse(resp)
|
|
|
|
|
|
@route("/movies", methods=["POST"])
|
|
@requires(["authenticated", "admin"])
|
|
async def add_movie(request):
|
|
pass
|
|
|
|
|
|
@route("/users")
|
|
@requires(["authenticated", "admin"])
|
|
async def list_users(request):
|
|
users = await db.get_all(User)
|
|
return JSONResponse([asplain(u) for u in users])
|
|
|
|
|
|
@route("/users", methods=["POST"])
|
|
@requires(["authenticated", "admin"])
|
|
async def add_user(request):
|
|
pass
|
|
|
|
|
|
@route("/users/{user_id}/ratings")
|
|
@requires(["private"])
|
|
async def ratings_for_user(request):
|
|
request.path_params["user_id"]
|
|
|
|
|
|
@route("/users/{user_id}/ratings", methods=["PUT"])
|
|
@requires("authenticated")
|
|
async def set_rating_for_user(request):
|
|
request.path_params["user_id"]
|
|
|
|
|
|
@route("/groups", methods=["POST"])
|
|
@requires(["authenticated", "admin"])
|
|
async def add_group(request):
|
|
if not await request.body():
|
|
data = {}
|
|
else:
|
|
try:
|
|
data = await request.json()
|
|
except JSONDecodeError:
|
|
raise HTTPException(422, "Invalid JSON content.")
|
|
|
|
try:
|
|
name = data["name"]
|
|
except KeyError as err:
|
|
raise HTTPException(422, f"Missing data for key: {err.args[0]}")
|
|
|
|
# XXX restrict name
|
|
|
|
secret = secrets.token_bytes()
|
|
|
|
group = Group(name=name, secret=phc_scrypt(secret))
|
|
await db.add(group)
|
|
|
|
return JSONResponse(
|
|
{
|
|
"secret": b64encode(secret),
|
|
"group": asplain(group),
|
|
}
|
|
)
|
|
|
|
|
|
@route("/groups/{group_id}/users", methods=["POST"])
|
|
@requires(["authenticated"])
|
|
async def add_user_to_group(request):
|
|
group_id = as_ulid(request.path_params["group_id"])
|
|
group = await db.get(Group, id=str(group_id))
|
|
|
|
if not group:
|
|
return not_found()
|
|
|
|
is_allowed = "admin" in request.auth.scopes or phc_compare(
|
|
secret=request.user.token, phc_string=group.secret
|
|
)
|
|
if not is_allowed:
|
|
return forbidden()
|
|
|
|
if not await request.body():
|
|
data = {}
|
|
else:
|
|
try:
|
|
data = await request.json()
|
|
except JSONDecodeError:
|
|
raise HTTPException(422, "Invalid JSON content.")
|
|
|
|
try:
|
|
name = data["name"]
|
|
user_id = data["id"]
|
|
except KeyError as err:
|
|
raise HTTPException(422, f"Missing data for key: {err.args[0]}")
|
|
|
|
# XXX check if user exists
|
|
# XXX restrict name
|
|
|
|
if any(u["id"] == user_id for u in group.users):
|
|
pass
|
|
else:
|
|
group.users.append({"name": name, "id": user_id})
|
|
|
|
await db.update(group)
|
|
|
|
return JSONResponse(asplain(group))
|
|
|
|
|
|
async def http_exception(request, exc):
|
|
return JSONResponse({"error": exc.detail}, status_code=exc.status_code)
|
|
|
|
|
|
def auth_error(request, err):
|
|
return unauthorized(str(err))
|
|
|
|
|
|
def create_app():
|
|
if config.loglevel == "DEBUG":
|
|
logging.basicConfig(
|
|
format="%(asctime)s.%(msecs)03d [%(name)s:%(process)d] %(levelname)s: %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
level=config.loglevel,
|
|
)
|
|
log.debug(f"Log level: {config.loglevel}")
|
|
|
|
return Starlette(
|
|
on_startup=[open_connection_pool],
|
|
on_shutdown=[close_connection_pool],
|
|
routes=[
|
|
Mount("/api/v1", routes=route.registered),
|
|
],
|
|
middleware=[
|
|
Middleware(ResponseTimeMiddleware, header_name="Unwind-Elapsed"),
|
|
Middleware(
|
|
AuthenticationMiddleware,
|
|
backend=BearerAuthBackend(config.api_credentials),
|
|
on_error=auth_error,
|
|
),
|
|
],
|
|
exception_handlers={HTTPException: http_exception},
|
|
)
|