diff --git a/unwind/config.py b/unwind/config.py index d13f59b..b97ebcb 100644 --- a/unwind/config.py +++ b/unwind/config.py @@ -11,4 +11,7 @@ loglevel = os.getenv("UNWIND_LOGLEVEL") or ("DEBUG" if debug else "INFO") storage_path = os.getenv("UNWIND_STORAGE", "./data/db.sqlite") config_path = os.getenv("UNWIND_CONFIG", "./data/config.toml") -imdb = toml.load(config_path)["imdb"] +_config = toml.load(config_path) + +imdb = _config["imdb"] +api_credentials = _config["api"]["credentials"] diff --git a/unwind/db.py b/unwind/db.py index 176db9b..22ea0b0 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -2,7 +2,7 @@ import logging import re from dataclasses import fields from pathlib import Path -from typing import Literal, Optional, Type, TypeVar, Union +from typing import Iterable, Literal, Optional, Type, TypeVar, Union import sqlalchemy from databases import Database @@ -281,6 +281,7 @@ async def find_ratings( include_unrated: bool = False, yearcomp: tuple[Literal["<", "=", ">"], int] = None, limit_rows: int = 10, + user_ids: Iterable[str] = [], ): values: dict[str, Union[int, str]] = { "limit_rows": limit_rows, @@ -317,14 +318,20 @@ async def find_ratings( if ignore_tv_episodes: conditions.append(f"{Movie._table}.media_type!='TV Episode'") + user_condition = "1=1" + if user_ids: + uvs = {f"user_id_{i}": v for i, v in enumerate(user_ids, start=1)} + values.update(uvs) + user_condition = f"{Rating._table}.user_id IN ({','.join(':'+n for n in uvs)})" + source_table = "newest_movies" ctes = [ f"""{source_table} AS ( SELECT DISTINCT {Rating._table}.movie_id FROM {Rating._table} LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id - {('WHERE ' + ' AND '.join(conditions)) if conditions else ''} - ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC + WHERE {user_condition}{(' AND ' + ' AND '.join(conditions)) if conditions else ''} + ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC, {Movie._table}.score DESC LIMIT :limit_rows )""" ] @@ -338,7 +345,7 @@ async def find_ratings( FROM {Movie._table} WHERE id NOT IN newest_movies {('AND ' + ' AND '.join(conditions)) if conditions else ''} - ORDER BY length(title) ASC, release_year DESC + ORDER BY length(title) ASC, score DESC, release_year DESC LIMIT :limit_rows )""", f"""{source_table} AS ( @@ -356,6 +363,7 @@ async def find_ratings( SELECT {Rating._table}.score AS user_score, + {Rating._table}.user_id AS user_id, {Movie._table}.score AS imdb_score, {Movie._table}.imdb_id AS movie_imdb_id, {Movie._table}.media_type AS media_type, diff --git a/unwind/models.py b/unwind/models.py index 05fb5ae..2a60071 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -143,3 +143,13 @@ class User: id: ULID = field(default_factory=ULID) imdb_id: str = None name: str = None # canonical user name + + +@dataclass +class Group: + _table: ClassVar[str] = "groups" + + id: ULID = field(default_factory=ULID) + name: str = None + secret: str = None + users: list[dict[str, str]] = field(default_factory=list) diff --git a/unwind/sql/20210705-224139.sql b/unwind/sql/20210705-224139.sql new file mode 100644 index 0000000..e714b4e --- /dev/null +++ b/unwind/sql/20210705-224139.sql @@ -0,0 +1,8 @@ +-- add groups table + +CREATE TABLE groups ( + id TEXT PRIMARY KEY NOT NULL, + name TEXT NOT NULL, + secret TEXT NOT NULL, + users TEXT NOT NULL -- JSON array +);; diff --git a/unwind/utils.py b/unwind/utils.py new file mode 100644 index 0000000..012d1fb --- /dev/null +++ b/unwind/utils.py @@ -0,0 +1,78 @@ +import base64 +import hashlib +import secrets +from typing import Literal + + +def b64encode(b: bytes) -> str: + return base64.b64encode(b).decode().rstrip("=") + + +def b64decode(s: str) -> bytes: + return base64.b64decode(b64padded(s)) + + +def b64padded(s: str) -> str: + return s + "=" * (4 - len(s) % 4) + + +def phc_scrypt( + secret: bytes, *, salt: bytes = None, params: dict[Literal["n", "r", "p"], int] = {} +) -> str: + """Return the scrypt expanded secret in PHC string format. + + Uses somewhat sane defaults. + + For more information on the PHC string format, see + https://github.com/P-H-C/phc-string-format/blob/master/phc-sf-spec.md + """ + + if salt is None: + salt = secrets.token_bytes(16) + + n = params.get("n", 2 ** 14) # CPU/Memory cost factor + r = params.get("r", 8) # block size + p = params.get("p", 1) # parallelization factor + # maxmem = 2 * 128 * n * r * p + hashed_secret = hashlib.scrypt(secret, salt=salt, n=n, r=r, p=p) + + encoded_params = ",".join(f"{k}={v}" for k, v in {"n": n, "r": r, "p": p}.items()) + phc = "".join( + f"${x}" + for x in ["scrypt", encoded_params, b64encode(salt), b64encode(hashed_secret)] + ) + + return phc + + +def phc_compare(*, secret: str, phc_string: str) -> bool: + args = parse_phc(phc_string) + + if args["id"] != "scrypt": + raise ValueError(f"Algorithm not supported: {args['id']}") + + assert type(args["params"]) is dict + encoded = phc_scrypt(b64decode(secret), salt=args["salt"], params=args["params"]) + + return secrets.compare_digest(encoded, phc_string) + + +def parse_phc(s: str): + parts = dict.fromkeys(["id", "version", "params", "salt", "hash"]) + + _, parts["id"], *rest = s.split("$") + + if rest and rest[0].startswith("v="): + parts["version"] = rest.pop(0) + if rest and "=" in rest[0]: + parts["params"] = { + kv[0]: int(kv[1]) + for p in rest.pop(0).split(",") + if len(kv := p.split("=", 2)) == 2 + } + if rest: + parts["salt"] = b64decode(rest.pop(0)) + if rest: + parts["hash"] = b64decode(rest.pop(0)) + + return parts diff --git a/unwind/web.py b/unwind/web.py index 66fd8a9..0974f13 100644 --- a/unwind/web.py +++ b/unwind/web.py @@ -1,6 +1,6 @@ -import base64 -import binascii import logging +import secrets +from json.decoder import JSONDecodeError from typing import Literal, Optional from starlette.applications import Starlette @@ -12,6 +12,7 @@ from starlette.authentication import ( UnauthenticatedUser, requires, ) +from starlette.exceptions import HTTPException from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.responses import JSONResponse @@ -20,7 +21,9 @@ from starlette.routing import Mount, Route from . import config, db from .db import close_connection_pool, find_ratings, open_connection_pool from .middleware.responsetime import ResponseTimeMiddleware -from .models import Movie, asplain +from .models import Group, Movie, asplain +from .types import ULID +from .utils import b64encode, phc_compare, phc_scrypt log = logging.getLogger(__name__) @@ -81,7 +84,22 @@ def as_int(x, *, max: int = None, min: Optional[int] = 1, default: int = None): return default -async def ratings(request): +def as_ulid(s: str) -> ULID: + try: + return ULID(s) + except ValueError: + raise HTTPException(422, "Not a valid ULID.") + + +async def get_ratings_for_group(request): + group_id = as_ulid(request.path_params["group_id"]) + group = await db.get(Group, id=str(group_id)) + + if not group: + return not_found() + + user_ids = {u["id"] for u in group.users} + params = request.query_params rows = await find_ratings( title=params.get("title"), @@ -91,6 +109,7 @@ async def ratings(request): include_unrated=truthy(params.get("include_unrated")), yearcomp=yearcomp(params["year"]) if "year" in params else None, limit_rows=as_int(params.get("per_page"), max=10, default=5), + user_ids=user_ids, ) aggr = {} @@ -107,7 +126,7 @@ async def ratings(request): "media_type": r["media_type"], }, ) - if r["user_score"] is not None: + if r["user_score"] is not None and r["user_id"] in user_ids: mov["user_scores"].append(r["user_score"]) resp = tuple(aggr.values()) @@ -115,7 +134,16 @@ async def ratings(request): return JSONResponse(resp) -not_found = JSONResponse({"error": "Not Found"}, status_code=404) +def unauthorized(reason: str = "Unauthorized"): + return JSONResponse({"error": reason}, status_code=401) + + +def forbidden(reason: str = "Forbidden"): + return JSONResponse({"error": reason}, status_code=403) + + +def not_found(reason: str = "Not Found"): + return JSONResponse({"error": reason}, status_code=404) async def get_movies(request): @@ -148,16 +176,73 @@ async def set_rating_for_user(request): @requires(["authenticated", "admin"]) async def add_group(request): - pass + if not await request.body(): + data = {} + else: + try: + data = await request.json() + except JSONDecodeError: + raise HTTPException(422, "Invalid JSON content.") + + try: + name = data["name"] + except KeyError as err: + raise HTTPException(422, f"Missing data for key: {err.args[0]}") + + # XXX restrict name + + secret = secrets.token_bytes() + + group = Group(name=name, secret=phc_scrypt(secret)) + await db.add(group) + + return JSONResponse( + { + "secret": b64encode(secret), + "group": asplain(group), + } + ) @requires(["authenticated", "admin"]) async def add_user_to_group(request): - request.path_params["group_id"] + group_id = as_ulid(request.path_params["group_id"]) + group = await db.get(Group, id=str(group_id)) + if not group: + return not_found() -async def get_ratings_for_group(request): - request.path_params["group_id"] + is_allowed = "admin" in request.auth.scopes or phc_compare( + secret=request.user.token, phc_string=group.secret + ) + if not is_allowed: + return forbidden() + + if not await request.body(): + data = {} + else: + try: + data = await request.json() + except JSONDecodeError: + raise HTTPException(422, "Invalid JSON content.") + + try: + name = data["name"] + user_id = data["id"] + except KeyError as err: + raise HTTPException(422, f"Missing data for key: {err.args[0]}") + + # XXX check if user exists + # XXX restrict name + + if any(u["id"] == user_id for u in group.users): + pass + else: + group.users.append({"name": name, "id": user_id}) + + await db.update(group) + + return JSONResponse(asplain(group)) def create_app(): @@ -176,7 +261,6 @@ def create_app(): Mount( "/api/v1", routes=[ - Route("/ratings", ratings), # XXX legacy, remove. Route("/movies", get_movies), Route("/movies", add_movie, methods=["POST"]), Route("/users", add_user, methods=["POST"]),