replace legacy ratings route with group ratings

This commit is contained in:
ducklet 2021-07-08 09:48:54 +02:00
parent a39a0e6442
commit 75391b1ca2
6 changed files with 207 additions and 16 deletions

View file

@ -11,4 +11,7 @@ loglevel = os.getenv("UNWIND_LOGLEVEL") or ("DEBUG" if debug else "INFO")
storage_path = os.getenv("UNWIND_STORAGE", "./data/db.sqlite")
config_path = os.getenv("UNWIND_CONFIG", "./data/config.toml")
imdb = toml.load(config_path)["imdb"]
_config = toml.load(config_path)
imdb = _config["imdb"]
api_credentials = _config["api"]["credentials"]

View file

@ -2,7 +2,7 @@ import logging
import re
from dataclasses import fields
from pathlib import Path
from typing import Literal, Optional, Type, TypeVar, Union
from typing import Iterable, Literal, Optional, Type, TypeVar, Union
import sqlalchemy
from databases import Database
@ -281,6 +281,7 @@ async def find_ratings(
include_unrated: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] = None,
limit_rows: int = 10,
user_ids: Iterable[str] = [],
):
values: dict[str, Union[int, str]] = {
"limit_rows": limit_rows,
@ -317,14 +318,20 @@ async def find_ratings(
if ignore_tv_episodes:
conditions.append(f"{Movie._table}.media_type!='TV Episode'")
user_condition = "1=1"
if user_ids:
uvs = {f"user_id_{i}": v for i, v in enumerate(user_ids, start=1)}
values.update(uvs)
user_condition = f"{Rating._table}.user_id IN ({','.join(':'+n for n in uvs)})"
source_table = "newest_movies"
ctes = [
f"""{source_table} AS (
SELECT DISTINCT {Rating._table}.movie_id
FROM {Rating._table}
LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id
{('WHERE ' + ' AND '.join(conditions)) if conditions else ''}
ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC
WHERE {user_condition}{(' AND ' + ' AND '.join(conditions)) if conditions else ''}
ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC, {Movie._table}.score DESC
LIMIT :limit_rows
)"""
]
@ -338,7 +345,7 @@ async def find_ratings(
FROM {Movie._table}
WHERE id NOT IN newest_movies
{('AND ' + ' AND '.join(conditions)) if conditions else ''}
ORDER BY length(title) ASC, release_year DESC
ORDER BY length(title) ASC, score DESC, release_year DESC
LIMIT :limit_rows
)""",
f"""{source_table} AS (
@ -356,6 +363,7 @@ async def find_ratings(
SELECT
{Rating._table}.score AS user_score,
{Rating._table}.user_id AS user_id,
{Movie._table}.score AS imdb_score,
{Movie._table}.imdb_id AS movie_imdb_id,
{Movie._table}.media_type AS media_type,

View file

@ -143,3 +143,13 @@ class User:
id: ULID = field(default_factory=ULID)
imdb_id: str = None
name: str = None # canonical user name
@dataclass
class Group:
_table: ClassVar[str] = "groups"
id: ULID = field(default_factory=ULID)
name: str = None
secret: str = None
users: list[dict[str, str]] = field(default_factory=list)

View file

@ -0,0 +1,8 @@
-- add groups table
CREATE TABLE groups (
id TEXT PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
secret TEXT NOT NULL,
users TEXT NOT NULL -- JSON array
);;

78
unwind/utils.py Normal file
View file

@ -0,0 +1,78 @@
import base64
import hashlib
import secrets
from typing import Literal
def b64encode(b: bytes) -> str:
return base64.b64encode(b).decode().rstrip("=")
def b64decode(s: str) -> bytes:
return base64.b64decode(b64padded(s))
def b64padded(s: str) -> str:
return s + "=" * (4 - len(s) % 4)
def phc_scrypt(
secret: bytes, *, salt: bytes = None, params: dict[Literal["n", "r", "p"], int] = {}
) -> str:
"""Return the scrypt expanded secret in PHC string format.
Uses somewhat sane defaults.
For more information on the PHC string format, see
https://github.com/P-H-C/phc-string-format/blob/master/phc-sf-spec.md
"""
if salt is None:
salt = secrets.token_bytes(16)
n = params.get("n", 2 ** 14) # CPU/Memory cost factor
r = params.get("r", 8) # block size
p = params.get("p", 1) # parallelization factor
# maxmem = 2 * 128 * n * r * p
hashed_secret = hashlib.scrypt(secret, salt=salt, n=n, r=r, p=p)
encoded_params = ",".join(f"{k}={v}" for k, v in {"n": n, "r": r, "p": p}.items())
phc = "".join(
f"${x}"
for x in ["scrypt", encoded_params, b64encode(salt), b64encode(hashed_secret)]
)
return phc
def phc_compare(*, secret: str, phc_string: str) -> bool:
args = parse_phc(phc_string)
if args["id"] != "scrypt":
raise ValueError(f"Algorithm not supported: {args['id']}")
assert type(args["params"]) is dict
encoded = phc_scrypt(b64decode(secret), salt=args["salt"], params=args["params"])
return secrets.compare_digest(encoded, phc_string)
def parse_phc(s: str):
parts = dict.fromkeys(["id", "version", "params", "salt", "hash"])
_, parts["id"], *rest = s.split("$")
if rest and rest[0].startswith("v="):
parts["version"] = rest.pop(0)
if rest and "=" in rest[0]:
parts["params"] = {
kv[0]: int(kv[1])
for p in rest.pop(0).split(",")
if len(kv := p.split("=", 2)) == 2
}
if rest:
parts["salt"] = b64decode(rest.pop(0))
if rest:
parts["hash"] = b64decode(rest.pop(0))
return parts

View file

@ -1,6 +1,6 @@
import base64
import binascii
import logging
import secrets
from json.decoder import JSONDecodeError
from typing import Literal, Optional
from starlette.applications import Starlette
@ -12,6 +12,7 @@ from starlette.authentication import (
UnauthenticatedUser,
requires,
)
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.responses import JSONResponse
@ -20,7 +21,9 @@ from starlette.routing import Mount, Route
from . import config, db
from .db import close_connection_pool, find_ratings, open_connection_pool
from .middleware.responsetime import ResponseTimeMiddleware
from .models import Movie, asplain
from .models import Group, Movie, asplain
from .types import ULID
from .utils import b64encode, phc_compare, phc_scrypt
log = logging.getLogger(__name__)
@ -81,7 +84,22 @@ def as_int(x, *, max: int = None, min: Optional[int] = 1, default: int = None):
return default
async def ratings(request):
def as_ulid(s: str) -> ULID:
try:
return ULID(s)
except ValueError:
raise HTTPException(422, "Not a valid ULID.")
async def get_ratings_for_group(request):
group_id = as_ulid(request.path_params["group_id"])
group = await db.get(Group, id=str(group_id))
if not group:
return not_found()
user_ids = {u["id"] for u in group.users}
params = request.query_params
rows = await find_ratings(
title=params.get("title"),
@ -91,6 +109,7 @@ async def ratings(request):
include_unrated=truthy(params.get("include_unrated")),
yearcomp=yearcomp(params["year"]) if "year" in params else None,
limit_rows=as_int(params.get("per_page"), max=10, default=5),
user_ids=user_ids,
)
aggr = {}
@ -107,7 +126,7 @@ async def ratings(request):
"media_type": r["media_type"],
},
)
if r["user_score"] is not None:
if r["user_score"] is not None and r["user_id"] in user_ids:
mov["user_scores"].append(r["user_score"])
resp = tuple(aggr.values())
@ -115,7 +134,16 @@ async def ratings(request):
return JSONResponse(resp)
not_found = JSONResponse({"error": "Not Found"}, status_code=404)
def unauthorized(reason: str = "Unauthorized"):
return JSONResponse({"error": reason}, status_code=401)
def forbidden(reason: str = "Forbidden"):
return JSONResponse({"error": reason}, status_code=403)
def not_found(reason: str = "Not Found"):
return JSONResponse({"error": reason}, status_code=404)
async def get_movies(request):
@ -148,16 +176,73 @@ async def set_rating_for_user(request):
@requires(["authenticated", "admin"])
async def add_group(request):
pass
if not await request.body():
data = {}
else:
try:
data = await request.json()
except JSONDecodeError:
raise HTTPException(422, "Invalid JSON content.")
try:
name = data["name"]
except KeyError as err:
raise HTTPException(422, f"Missing data for key: {err.args[0]}")
# XXX restrict name
secret = secrets.token_bytes()
group = Group(name=name, secret=phc_scrypt(secret))
await db.add(group)
return JSONResponse(
{
"secret": b64encode(secret),
"group": asplain(group),
}
)
@requires(["authenticated", "admin"])
async def add_user_to_group(request):
request.path_params["group_id"]
group_id = as_ulid(request.path_params["group_id"])
group = await db.get(Group, id=str(group_id))
if not group:
return not_found()
async def get_ratings_for_group(request):
request.path_params["group_id"]
is_allowed = "admin" in request.auth.scopes or phc_compare(
secret=request.user.token, phc_string=group.secret
)
if not is_allowed:
return forbidden()
if not await request.body():
data = {}
else:
try:
data = await request.json()
except JSONDecodeError:
raise HTTPException(422, "Invalid JSON content.")
try:
name = data["name"]
user_id = data["id"]
except KeyError as err:
raise HTTPException(422, f"Missing data for key: {err.args[0]}")
# XXX check if user exists
# XXX restrict name
if any(u["id"] == user_id for u in group.users):
pass
else:
group.users.append({"name": name, "id": user_id})
await db.update(group)
return JSONResponse(asplain(group))
def create_app():
@ -176,7 +261,6 @@ def create_app():
Mount(
"/api/v1",
routes=[
Route("/ratings", ratings), # XXX legacy, remove.
Route("/movies", get_movies),
Route("/movies", add_movie, methods=["POST"]),
Route("/users", add_user, methods=["POST"]),