replace legacy ratings route with group ratings

This commit is contained in:
ducklet 2021-07-08 09:48:54 +02:00
parent a39a0e6442
commit 75391b1ca2
6 changed files with 207 additions and 16 deletions

View file

@ -11,4 +11,7 @@ loglevel = os.getenv("UNWIND_LOGLEVEL") or ("DEBUG" if debug else "INFO")
storage_path = os.getenv("UNWIND_STORAGE", "./data/db.sqlite") storage_path = os.getenv("UNWIND_STORAGE", "./data/db.sqlite")
config_path = os.getenv("UNWIND_CONFIG", "./data/config.toml") config_path = os.getenv("UNWIND_CONFIG", "./data/config.toml")
imdb = toml.load(config_path)["imdb"] _config = toml.load(config_path)
imdb = _config["imdb"]
api_credentials = _config["api"]["credentials"]

View file

@ -2,7 +2,7 @@ import logging
import re import re
from dataclasses import fields from dataclasses import fields
from pathlib import Path from pathlib import Path
from typing import Literal, Optional, Type, TypeVar, Union from typing import Iterable, Literal, Optional, Type, TypeVar, Union
import sqlalchemy import sqlalchemy
from databases import Database from databases import Database
@ -281,6 +281,7 @@ async def find_ratings(
include_unrated: bool = False, include_unrated: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] = None, yearcomp: tuple[Literal["<", "=", ">"], int] = None,
limit_rows: int = 10, limit_rows: int = 10,
user_ids: Iterable[str] = [],
): ):
values: dict[str, Union[int, str]] = { values: dict[str, Union[int, str]] = {
"limit_rows": limit_rows, "limit_rows": limit_rows,
@ -317,14 +318,20 @@ async def find_ratings(
if ignore_tv_episodes: if ignore_tv_episodes:
conditions.append(f"{Movie._table}.media_type!='TV Episode'") conditions.append(f"{Movie._table}.media_type!='TV Episode'")
user_condition = "1=1"
if user_ids:
uvs = {f"user_id_{i}": v for i, v in enumerate(user_ids, start=1)}
values.update(uvs)
user_condition = f"{Rating._table}.user_id IN ({','.join(':'+n for n in uvs)})"
source_table = "newest_movies" source_table = "newest_movies"
ctes = [ ctes = [
f"""{source_table} AS ( f"""{source_table} AS (
SELECT DISTINCT {Rating._table}.movie_id SELECT DISTINCT {Rating._table}.movie_id
FROM {Rating._table} FROM {Rating._table}
LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id
{('WHERE ' + ' AND '.join(conditions)) if conditions else ''} WHERE {user_condition}{(' AND ' + ' AND '.join(conditions)) if conditions else ''}
ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC, {Movie._table}.score DESC
LIMIT :limit_rows LIMIT :limit_rows
)""" )"""
] ]
@ -338,7 +345,7 @@ async def find_ratings(
FROM {Movie._table} FROM {Movie._table}
WHERE id NOT IN newest_movies WHERE id NOT IN newest_movies
{('AND ' + ' AND '.join(conditions)) if conditions else ''} {('AND ' + ' AND '.join(conditions)) if conditions else ''}
ORDER BY length(title) ASC, release_year DESC ORDER BY length(title) ASC, score DESC, release_year DESC
LIMIT :limit_rows LIMIT :limit_rows
)""", )""",
f"""{source_table} AS ( f"""{source_table} AS (
@ -356,6 +363,7 @@ async def find_ratings(
SELECT SELECT
{Rating._table}.score AS user_score, {Rating._table}.score AS user_score,
{Rating._table}.user_id AS user_id,
{Movie._table}.score AS imdb_score, {Movie._table}.score AS imdb_score,
{Movie._table}.imdb_id AS movie_imdb_id, {Movie._table}.imdb_id AS movie_imdb_id,
{Movie._table}.media_type AS media_type, {Movie._table}.media_type AS media_type,

View file

@ -143,3 +143,13 @@ class User:
id: ULID = field(default_factory=ULID) id: ULID = field(default_factory=ULID)
imdb_id: str = None imdb_id: str = None
name: str = None # canonical user name name: str = None # canonical user name
@dataclass
class Group:
_table: ClassVar[str] = "groups"
id: ULID = field(default_factory=ULID)
name: str = None
secret: str = None
users: list[dict[str, str]] = field(default_factory=list)

View file

@ -0,0 +1,8 @@
-- add groups table
CREATE TABLE groups (
id TEXT PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
secret TEXT NOT NULL,
users TEXT NOT NULL -- JSON array
);;

78
unwind/utils.py Normal file
View file

@ -0,0 +1,78 @@
import base64
import hashlib
import secrets
from typing import Literal
def b64encode(b: bytes) -> str:
return base64.b64encode(b).decode().rstrip("=")
def b64decode(s: str) -> bytes:
return base64.b64decode(b64padded(s))
def b64padded(s: str) -> str:
return s + "=" * (4 - len(s) % 4)
def phc_scrypt(
secret: bytes, *, salt: bytes = None, params: dict[Literal["n", "r", "p"], int] = {}
) -> str:
"""Return the scrypt expanded secret in PHC string format.
Uses somewhat sane defaults.
For more information on the PHC string format, see
https://github.com/P-H-C/phc-string-format/blob/master/phc-sf-spec.md
"""
if salt is None:
salt = secrets.token_bytes(16)
n = params.get("n", 2 ** 14) # CPU/Memory cost factor
r = params.get("r", 8) # block size
p = params.get("p", 1) # parallelization factor
# maxmem = 2 * 128 * n * r * p
hashed_secret = hashlib.scrypt(secret, salt=salt, n=n, r=r, p=p)
encoded_params = ",".join(f"{k}={v}" for k, v in {"n": n, "r": r, "p": p}.items())
phc = "".join(
f"${x}"
for x in ["scrypt", encoded_params, b64encode(salt), b64encode(hashed_secret)]
)
return phc
def phc_compare(*, secret: str, phc_string: str) -> bool:
args = parse_phc(phc_string)
if args["id"] != "scrypt":
raise ValueError(f"Algorithm not supported: {args['id']}")
assert type(args["params"]) is dict
encoded = phc_scrypt(b64decode(secret), salt=args["salt"], params=args["params"])
return secrets.compare_digest(encoded, phc_string)
def parse_phc(s: str):
parts = dict.fromkeys(["id", "version", "params", "salt", "hash"])
_, parts["id"], *rest = s.split("$")
if rest and rest[0].startswith("v="):
parts["version"] = rest.pop(0)
if rest and "=" in rest[0]:
parts["params"] = {
kv[0]: int(kv[1])
for p in rest.pop(0).split(",")
if len(kv := p.split("=", 2)) == 2
}
if rest:
parts["salt"] = b64decode(rest.pop(0))
if rest:
parts["hash"] = b64decode(rest.pop(0))
return parts

View file

@ -1,6 +1,6 @@
import base64
import binascii
import logging import logging
import secrets
from json.decoder import JSONDecodeError
from typing import Literal, Optional from typing import Literal, Optional
from starlette.applications import Starlette from starlette.applications import Starlette
@ -12,6 +12,7 @@ from starlette.authentication import (
UnauthenticatedUser, UnauthenticatedUser,
requires, requires,
) )
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
@ -20,7 +21,9 @@ from starlette.routing import Mount, Route
from . import config, db from . import config, db
from .db import close_connection_pool, find_ratings, open_connection_pool from .db import close_connection_pool, find_ratings, open_connection_pool
from .middleware.responsetime import ResponseTimeMiddleware from .middleware.responsetime import ResponseTimeMiddleware
from .models import Movie, asplain from .models import Group, Movie, asplain
from .types import ULID
from .utils import b64encode, phc_compare, phc_scrypt
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -81,7 +84,22 @@ def as_int(x, *, max: int = None, min: Optional[int] = 1, default: int = None):
return default return default
async def ratings(request): def as_ulid(s: str) -> ULID:
try:
return ULID(s)
except ValueError:
raise HTTPException(422, "Not a valid ULID.")
async def get_ratings_for_group(request):
group_id = as_ulid(request.path_params["group_id"])
group = await db.get(Group, id=str(group_id))
if not group:
return not_found()
user_ids = {u["id"] for u in group.users}
params = request.query_params params = request.query_params
rows = await find_ratings( rows = await find_ratings(
title=params.get("title"), title=params.get("title"),
@ -91,6 +109,7 @@ async def ratings(request):
include_unrated=truthy(params.get("include_unrated")), include_unrated=truthy(params.get("include_unrated")),
yearcomp=yearcomp(params["year"]) if "year" in params else None, yearcomp=yearcomp(params["year"]) if "year" in params else None,
limit_rows=as_int(params.get("per_page"), max=10, default=5), limit_rows=as_int(params.get("per_page"), max=10, default=5),
user_ids=user_ids,
) )
aggr = {} aggr = {}
@ -107,7 +126,7 @@ async def ratings(request):
"media_type": r["media_type"], "media_type": r["media_type"],
}, },
) )
if r["user_score"] is not None: if r["user_score"] is not None and r["user_id"] in user_ids:
mov["user_scores"].append(r["user_score"]) mov["user_scores"].append(r["user_score"])
resp = tuple(aggr.values()) resp = tuple(aggr.values())
@ -115,7 +134,16 @@ async def ratings(request):
return JSONResponse(resp) return JSONResponse(resp)
not_found = JSONResponse({"error": "Not Found"}, status_code=404) def unauthorized(reason: str = "Unauthorized"):
return JSONResponse({"error": reason}, status_code=401)
def forbidden(reason: str = "Forbidden"):
return JSONResponse({"error": reason}, status_code=403)
def not_found(reason: str = "Not Found"):
return JSONResponse({"error": reason}, status_code=404)
async def get_movies(request): async def get_movies(request):
@ -148,16 +176,73 @@ async def set_rating_for_user(request):
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def add_group(request): async def add_group(request):
pass if not await request.body():
data = {}
else:
try:
data = await request.json()
except JSONDecodeError:
raise HTTPException(422, "Invalid JSON content.")
try:
name = data["name"]
except KeyError as err:
raise HTTPException(422, f"Missing data for key: {err.args[0]}")
# XXX restrict name
secret = secrets.token_bytes()
group = Group(name=name, secret=phc_scrypt(secret))
await db.add(group)
return JSONResponse(
{
"secret": b64encode(secret),
"group": asplain(group),
}
)
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def add_user_to_group(request): async def add_user_to_group(request):
request.path_params["group_id"] group_id = as_ulid(request.path_params["group_id"])
group = await db.get(Group, id=str(group_id))
if not group:
return not_found()
async def get_ratings_for_group(request): is_allowed = "admin" in request.auth.scopes or phc_compare(
request.path_params["group_id"] secret=request.user.token, phc_string=group.secret
)
if not is_allowed:
return forbidden()
if not await request.body():
data = {}
else:
try:
data = await request.json()
except JSONDecodeError:
raise HTTPException(422, "Invalid JSON content.")
try:
name = data["name"]
user_id = data["id"]
except KeyError as err:
raise HTTPException(422, f"Missing data for key: {err.args[0]}")
# XXX check if user exists
# XXX restrict name
if any(u["id"] == user_id for u in group.users):
pass
else:
group.users.append({"name": name, "id": user_id})
await db.update(group)
return JSONResponse(asplain(group))
def create_app(): def create_app():
@ -176,7 +261,6 @@ def create_app():
Mount( Mount(
"/api/v1", "/api/v1",
routes=[ routes=[
Route("/ratings", ratings), # XXX legacy, remove.
Route("/movies", get_movies), Route("/movies", get_movies),
Route("/movies", add_movie, methods=["POST"]), Route("/movies", add_movie, methods=["POST"]),
Route("/users", add_user, methods=["POST"]), Route("/users", add_user, methods=["POST"]),