replace legacy ratings route with group ratings
This commit is contained in:
parent
a39a0e6442
commit
75391b1ca2
6 changed files with 207 additions and 16 deletions
|
|
@ -11,4 +11,7 @@ loglevel = os.getenv("UNWIND_LOGLEVEL") or ("DEBUG" if debug else "INFO")
|
|||
storage_path = os.getenv("UNWIND_STORAGE", "./data/db.sqlite")
|
||||
config_path = os.getenv("UNWIND_CONFIG", "./data/config.toml")
|
||||
|
||||
imdb = toml.load(config_path)["imdb"]
|
||||
_config = toml.load(config_path)
|
||||
|
||||
imdb = _config["imdb"]
|
||||
api_credentials = _config["api"]["credentials"]
|
||||
|
|
|
|||
16
unwind/db.py
16
unwind/db.py
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import re
|
||||
from dataclasses import fields
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Type, TypeVar, Union
|
||||
from typing import Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
|
||||
import sqlalchemy
|
||||
from databases import Database
|
||||
|
|
@ -281,6 +281,7 @@ async def find_ratings(
|
|||
include_unrated: bool = False,
|
||||
yearcomp: tuple[Literal["<", "=", ">"], int] = None,
|
||||
limit_rows: int = 10,
|
||||
user_ids: Iterable[str] = [],
|
||||
):
|
||||
values: dict[str, Union[int, str]] = {
|
||||
"limit_rows": limit_rows,
|
||||
|
|
@ -317,14 +318,20 @@ async def find_ratings(
|
|||
if ignore_tv_episodes:
|
||||
conditions.append(f"{Movie._table}.media_type!='TV Episode'")
|
||||
|
||||
user_condition = "1=1"
|
||||
if user_ids:
|
||||
uvs = {f"user_id_{i}": v for i, v in enumerate(user_ids, start=1)}
|
||||
values.update(uvs)
|
||||
user_condition = f"{Rating._table}.user_id IN ({','.join(':'+n for n in uvs)})"
|
||||
|
||||
source_table = "newest_movies"
|
||||
ctes = [
|
||||
f"""{source_table} AS (
|
||||
SELECT DISTINCT {Rating._table}.movie_id
|
||||
FROM {Rating._table}
|
||||
LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id
|
||||
{('WHERE ' + ' AND '.join(conditions)) if conditions else ''}
|
||||
ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC
|
||||
WHERE {user_condition}{(' AND ' + ' AND '.join(conditions)) if conditions else ''}
|
||||
ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC, {Movie._table}.score DESC
|
||||
LIMIT :limit_rows
|
||||
)"""
|
||||
]
|
||||
|
|
@ -338,7 +345,7 @@ async def find_ratings(
|
|||
FROM {Movie._table}
|
||||
WHERE id NOT IN newest_movies
|
||||
{('AND ' + ' AND '.join(conditions)) if conditions else ''}
|
||||
ORDER BY length(title) ASC, release_year DESC
|
||||
ORDER BY length(title) ASC, score DESC, release_year DESC
|
||||
LIMIT :limit_rows
|
||||
)""",
|
||||
f"""{source_table} AS (
|
||||
|
|
@ -356,6 +363,7 @@ async def find_ratings(
|
|||
|
||||
SELECT
|
||||
{Rating._table}.score AS user_score,
|
||||
{Rating._table}.user_id AS user_id,
|
||||
{Movie._table}.score AS imdb_score,
|
||||
{Movie._table}.imdb_id AS movie_imdb_id,
|
||||
{Movie._table}.media_type AS media_type,
|
||||
|
|
|
|||
|
|
@ -143,3 +143,13 @@ class User:
|
|||
id: ULID = field(default_factory=ULID)
|
||||
imdb_id: str = None
|
||||
name: str = None # canonical user name
|
||||
|
||||
|
||||
@dataclass
|
||||
class Group:
|
||||
_table: ClassVar[str] = "groups"
|
||||
|
||||
id: ULID = field(default_factory=ULID)
|
||||
name: str = None
|
||||
secret: str = None
|
||||
users: list[dict[str, str]] = field(default_factory=list)
|
||||
|
|
|
|||
8
unwind/sql/20210705-224139.sql
Normal file
8
unwind/sql/20210705-224139.sql
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
-- add groups table
|
||||
|
||||
CREATE TABLE groups (
|
||||
id TEXT PRIMARY KEY NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
secret TEXT NOT NULL,
|
||||
users TEXT NOT NULL -- JSON array
|
||||
);;
|
||||
78
unwind/utils.py
Normal file
78
unwind/utils.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import Literal
|
||||
|
||||
|
||||
def b64encode(b: bytes) -> str:
|
||||
return base64.b64encode(b).decode().rstrip("=")
|
||||
|
||||
|
||||
def b64decode(s: str) -> bytes:
|
||||
return base64.b64decode(b64padded(s))
|
||||
|
||||
|
||||
def b64padded(s: str) -> str:
|
||||
return s + "=" * (4 - len(s) % 4)
|
||||
|
||||
|
||||
def phc_scrypt(
|
||||
secret: bytes, *, salt: bytes = None, params: dict[Literal["n", "r", "p"], int] = {}
|
||||
) -> str:
|
||||
"""Return the scrypt expanded secret in PHC string format.
|
||||
|
||||
Uses somewhat sane defaults.
|
||||
|
||||
For more information on the PHC string format, see
|
||||
https://github.com/P-H-C/phc-string-format/blob/master/phc-sf-spec.md
|
||||
"""
|
||||
|
||||
if salt is None:
|
||||
salt = secrets.token_bytes(16)
|
||||
|
||||
n = params.get("n", 2 ** 14) # CPU/Memory cost factor
|
||||
r = params.get("r", 8) # block size
|
||||
p = params.get("p", 1) # parallelization factor
|
||||
# maxmem = 2 * 128 * n * r * p
|
||||
hashed_secret = hashlib.scrypt(secret, salt=salt, n=n, r=r, p=p)
|
||||
|
||||
encoded_params = ",".join(f"{k}={v}" for k, v in {"n": n, "r": r, "p": p}.items())
|
||||
phc = "".join(
|
||||
f"${x}"
|
||||
for x in ["scrypt", encoded_params, b64encode(salt), b64encode(hashed_secret)]
|
||||
)
|
||||
|
||||
return phc
|
||||
|
||||
|
||||
def phc_compare(*, secret: str, phc_string: str) -> bool:
|
||||
args = parse_phc(phc_string)
|
||||
|
||||
if args["id"] != "scrypt":
|
||||
raise ValueError(f"Algorithm not supported: {args['id']}")
|
||||
|
||||
assert type(args["params"]) is dict
|
||||
encoded = phc_scrypt(b64decode(secret), salt=args["salt"], params=args["params"])
|
||||
|
||||
return secrets.compare_digest(encoded, phc_string)
|
||||
|
||||
|
||||
def parse_phc(s: str):
|
||||
parts = dict.fromkeys(["id", "version", "params", "salt", "hash"])
|
||||
|
||||
_, parts["id"], *rest = s.split("$")
|
||||
|
||||
if rest and rest[0].startswith("v="):
|
||||
parts["version"] = rest.pop(0)
|
||||
if rest and "=" in rest[0]:
|
||||
parts["params"] = {
|
||||
kv[0]: int(kv[1])
|
||||
for p in rest.pop(0).split(",")
|
||||
if len(kv := p.split("=", 2)) == 2
|
||||
}
|
||||
if rest:
|
||||
parts["salt"] = b64decode(rest.pop(0))
|
||||
if rest:
|
||||
parts["hash"] = b64decode(rest.pop(0))
|
||||
|
||||
return parts
|
||||
106
unwind/web.py
106
unwind/web.py
|
|
@ -1,6 +1,6 @@
|
|||
import base64
|
||||
import binascii
|
||||
import logging
|
||||
import secrets
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Literal, Optional
|
||||
|
||||
from starlette.applications import Starlette
|
||||
|
|
@ -12,6 +12,7 @@ from starlette.authentication import (
|
|||
UnauthenticatedUser,
|
||||
requires,
|
||||
)
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
|
@ -20,7 +21,9 @@ from starlette.routing import Mount, Route
|
|||
from . import config, db
|
||||
from .db import close_connection_pool, find_ratings, open_connection_pool
|
||||
from .middleware.responsetime import ResponseTimeMiddleware
|
||||
from .models import Movie, asplain
|
||||
from .models import Group, Movie, asplain
|
||||
from .types import ULID
|
||||
from .utils import b64encode, phc_compare, phc_scrypt
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -81,7 +84,22 @@ def as_int(x, *, max: int = None, min: Optional[int] = 1, default: int = None):
|
|||
return default
|
||||
|
||||
|
||||
async def ratings(request):
|
||||
def as_ulid(s: str) -> ULID:
|
||||
try:
|
||||
return ULID(s)
|
||||
except ValueError:
|
||||
raise HTTPException(422, "Not a valid ULID.")
|
||||
|
||||
|
||||
async def get_ratings_for_group(request):
|
||||
group_id = as_ulid(request.path_params["group_id"])
|
||||
group = await db.get(Group, id=str(group_id))
|
||||
|
||||
if not group:
|
||||
return not_found()
|
||||
|
||||
user_ids = {u["id"] for u in group.users}
|
||||
|
||||
params = request.query_params
|
||||
rows = await find_ratings(
|
||||
title=params.get("title"),
|
||||
|
|
@ -91,6 +109,7 @@ async def ratings(request):
|
|||
include_unrated=truthy(params.get("include_unrated")),
|
||||
yearcomp=yearcomp(params["year"]) if "year" in params else None,
|
||||
limit_rows=as_int(params.get("per_page"), max=10, default=5),
|
||||
user_ids=user_ids,
|
||||
)
|
||||
|
||||
aggr = {}
|
||||
|
|
@ -107,7 +126,7 @@ async def ratings(request):
|
|||
"media_type": r["media_type"],
|
||||
},
|
||||
)
|
||||
if r["user_score"] is not None:
|
||||
if r["user_score"] is not None and r["user_id"] in user_ids:
|
||||
mov["user_scores"].append(r["user_score"])
|
||||
|
||||
resp = tuple(aggr.values())
|
||||
|
|
@ -115,7 +134,16 @@ async def ratings(request):
|
|||
return JSONResponse(resp)
|
||||
|
||||
|
||||
not_found = JSONResponse({"error": "Not Found"}, status_code=404)
|
||||
def unauthorized(reason: str = "Unauthorized"):
|
||||
return JSONResponse({"error": reason}, status_code=401)
|
||||
|
||||
|
||||
def forbidden(reason: str = "Forbidden"):
|
||||
return JSONResponse({"error": reason}, status_code=403)
|
||||
|
||||
|
||||
def not_found(reason: str = "Not Found"):
|
||||
return JSONResponse({"error": reason}, status_code=404)
|
||||
|
||||
|
||||
async def get_movies(request):
|
||||
|
|
@ -148,16 +176,73 @@ async def set_rating_for_user(request):
|
|||
|
||||
@requires(["authenticated", "admin"])
|
||||
async def add_group(request):
|
||||
pass
|
||||
if not await request.body():
|
||||
data = {}
|
||||
else:
|
||||
try:
|
||||
data = await request.json()
|
||||
except JSONDecodeError:
|
||||
raise HTTPException(422, "Invalid JSON content.")
|
||||
|
||||
try:
|
||||
name = data["name"]
|
||||
except KeyError as err:
|
||||
raise HTTPException(422, f"Missing data for key: {err.args[0]}")
|
||||
|
||||
# XXX restrict name
|
||||
|
||||
secret = secrets.token_bytes()
|
||||
|
||||
group = Group(name=name, secret=phc_scrypt(secret))
|
||||
await db.add(group)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"secret": b64encode(secret),
|
||||
"group": asplain(group),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@requires(["authenticated", "admin"])
|
||||
async def add_user_to_group(request):
|
||||
request.path_params["group_id"]
|
||||
group_id = as_ulid(request.path_params["group_id"])
|
||||
group = await db.get(Group, id=str(group_id))
|
||||
|
||||
if not group:
|
||||
return not_found()
|
||||
|
||||
async def get_ratings_for_group(request):
|
||||
request.path_params["group_id"]
|
||||
is_allowed = "admin" in request.auth.scopes or phc_compare(
|
||||
secret=request.user.token, phc_string=group.secret
|
||||
)
|
||||
if not is_allowed:
|
||||
return forbidden()
|
||||
|
||||
if not await request.body():
|
||||
data = {}
|
||||
else:
|
||||
try:
|
||||
data = await request.json()
|
||||
except JSONDecodeError:
|
||||
raise HTTPException(422, "Invalid JSON content.")
|
||||
|
||||
try:
|
||||
name = data["name"]
|
||||
user_id = data["id"]
|
||||
except KeyError as err:
|
||||
raise HTTPException(422, f"Missing data for key: {err.args[0]}")
|
||||
|
||||
# XXX check if user exists
|
||||
# XXX restrict name
|
||||
|
||||
if any(u["id"] == user_id for u in group.users):
|
||||
pass
|
||||
else:
|
||||
group.users.append({"name": name, "id": user_id})
|
||||
|
||||
await db.update(group)
|
||||
|
||||
return JSONResponse(asplain(group))
|
||||
|
||||
|
||||
def create_app():
|
||||
|
|
@ -176,7 +261,6 @@ def create_app():
|
|||
Mount(
|
||||
"/api/v1",
|
||||
routes=[
|
||||
Route("/ratings", ratings), # XXX legacy, remove.
|
||||
Route("/movies", get_movies),
|
||||
Route("/movies", add_movie, methods=["POST"]),
|
||||
Route("/users", add_user, methods=["POST"]),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue