194 lines
5.4 KiB
Python
194 lines
5.4 KiB
Python
import base64
|
|
import binascii
|
|
import logging
|
|
from typing import Literal, Optional
|
|
|
|
from starlette.applications import Starlette
|
|
from starlette.authentication import (
|
|
AuthCredentials,
|
|
AuthenticationBackend,
|
|
AuthenticationError,
|
|
SimpleUser,
|
|
UnauthenticatedUser,
|
|
requires,
|
|
)
|
|
from starlette.middleware import Middleware
|
|
from starlette.middleware.authentication import AuthenticationMiddleware
|
|
from starlette.responses import JSONResponse
|
|
from starlette.routing import Mount, Route
|
|
|
|
from . import config, db
|
|
from .db import close_connection_pool, find_ratings, open_connection_pool
|
|
from .models import Movie, asplain
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class BasicAuthBackend(AuthenticationBackend):
|
|
async def authenticate(self, request):
|
|
if "Authorization" not in request.headers:
|
|
return
|
|
|
|
auth = request.headers["Authorization"]
|
|
try:
|
|
scheme, credentials = auth.split()
|
|
if scheme.lower() != "basic":
|
|
return
|
|
decoded = base64.b64decode(credentials).decode("ascii")
|
|
except (ValueError, UnicodeDecodeError, binascii.Error) as exc:
|
|
raise AuthenticationError("Invalid basic auth credentials")
|
|
|
|
username, _, password = decoded.partition(":")
|
|
# TODO: You'd want to verify the username and password here.
|
|
return AuthCredentials(["authenticated"]), SimpleUser(username)
|
|
|
|
|
|
def imdb_url(imdb_id: str):
|
|
return f"https://www.imdb.com/title/{imdb_id}/"
|
|
|
|
|
|
def truthy(s: str):
|
|
return bool(s) and s.lower() in {"1", "yes", "true"}
|
|
|
|
|
|
def yearcomp(s: str):
|
|
if not s:
|
|
return
|
|
|
|
comp: Literal["<", "=", ">"] = "="
|
|
if (prefix := s[0]) in "<=>":
|
|
comp = prefix # type: ignore
|
|
s = s[len(prefix) :]
|
|
|
|
return comp, int(s)
|
|
|
|
|
|
def as_int(x, *, max: int = None, min: Optional[int] = 1, default: int = None):
|
|
try:
|
|
if not isinstance(x, int):
|
|
x = int(x)
|
|
if min is not None and x < min:
|
|
return min
|
|
if max is not None and x > max:
|
|
return max
|
|
return x
|
|
|
|
except:
|
|
if default is None:
|
|
raise
|
|
|
|
return default
|
|
|
|
|
|
async def ratings(request):
|
|
params = request.query_params
|
|
rows = await find_ratings(
|
|
title=params.get("title"),
|
|
media_type=params.get("media_type"),
|
|
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
|
|
include_unrated=truthy(params.get("include_unrated")),
|
|
yearcomp=yearcomp(params["year"]) if "year" in params else None,
|
|
limit_rows=as_int(params.get("per_page"), max=10, default=5),
|
|
)
|
|
|
|
aggr = {}
|
|
for r in rows:
|
|
mov = aggr.setdefault(
|
|
r["movie_imdb_id"],
|
|
{
|
|
"canonical_title": r["canonical_title"],
|
|
"original_title": r["original_title"],
|
|
"year": r["release_year"],
|
|
"link": imdb_url(r["movie_imdb_id"]),
|
|
"user_scores": [],
|
|
"imdb_score": r["imdb_score"],
|
|
"media_type": r["media_type"],
|
|
},
|
|
)
|
|
if r["user_score"] is not None:
|
|
mov["user_scores"].append(r["user_score"])
|
|
|
|
resp = tuple(aggr.values())
|
|
|
|
return JSONResponse(resp)
|
|
|
|
|
|
not_found = JSONResponse({"error": "Not Found"}, status_code=404)
|
|
|
|
|
|
async def get_movies(request):
|
|
imdb_id = request.query_params.get("imdb_id")
|
|
|
|
movie = await db.get(Movie, imdb_id=imdb_id)
|
|
|
|
resp = [asplain(movie)] if movie else []
|
|
return JSONResponse(resp)
|
|
|
|
|
|
@requires(["authenticated", "admin"])
|
|
async def add_movie(request):
|
|
pass
|
|
|
|
|
|
@requires(["authenticated", "admin"])
|
|
async def add_user(request):
|
|
pass
|
|
|
|
|
|
async def ratings_for_user(request):
|
|
request.path_params["user_id"]
|
|
|
|
|
|
@requires("authenticated")
|
|
async def set_rating_for_user(request):
|
|
request.path_params["user_id"]
|
|
|
|
|
|
@requires(["authenticated", "admin"])
|
|
async def add_group(request):
|
|
pass
|
|
|
|
|
|
@requires(["authenticated", "admin"])
|
|
async def add_user_to_group(request):
|
|
request.path_params["group_id"]
|
|
|
|
|
|
async def get_ratings_for_group(request):
|
|
request.path_params["group_id"]
|
|
|
|
|
|
def create_app():
|
|
if config.loglevel == "DEBUG":
|
|
logging.basicConfig(
|
|
format="%(asctime)s.%(msecs)03d [%(name)s:%(process)d] %(levelname)s: %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
level=config.loglevel,
|
|
)
|
|
log.debug(f"Log level: {config.loglevel}")
|
|
|
|
return Starlette(
|
|
on_startup=[open_connection_pool],
|
|
on_shutdown=[close_connection_pool],
|
|
routes=[
|
|
Mount(
|
|
"/api/v1",
|
|
routes=[
|
|
Route("/ratings", ratings), # XXX legacy, remove.
|
|
Route("/movies", get_movies),
|
|
Route("/movies", add_movie, methods=["POST"]),
|
|
Route("/users", add_user, methods=["POST"]),
|
|
Route("/users/{user_id}/ratings", ratings_for_user),
|
|
Route(
|
|
"/users/{user_id}/ratings", set_rating_for_user, methods=["PUT"]
|
|
),
|
|
Route("/groups", add_group, methods=["POST"]),
|
|
Route(
|
|
"/groups/{group_id}/users", add_user_to_group, methods=["POST"]
|
|
),
|
|
Route("/groups/{group_id}/ratings", get_ratings_for_group),
|
|
],
|
|
),
|
|
],
|
|
middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuthBackend())],
|
|
)
|