diff --git a/unwind/db.py b/unwind/db.py index e54754b..13217e9 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -4,7 +4,7 @@ import logging import re import threading from pathlib import Path -from typing import Iterable, Literal, Optional, Type, TypeVar, Union +from typing import Any, Iterable, Literal, Optional, Type, TypeVar, Union import sqlalchemy from databases import Database @@ -487,7 +487,28 @@ async def find_ratings( ) movie_ids += tuple(r["movie_id"] for r in rows) - sqlin, sqlin_vals = sql_in(f"{Movie._table}.id", movie_ids) + return await ratings_for_movie_ids(ids=movie_ids) + + +async def ratings_for_movie_ids( + ids: Iterable[ULID | str] = [], imdb_ids: Iterable[str] = [] +) -> Iterable[dict[str, Any]]: + conds: list[str] = [] + vals: dict[str, str] = {} + + if ids: + sqlin, sqlin_vals = sql_in(f"{Movie._table}.id", (str(x) for x in ids)) + conds.append(sqlin) + vals.update(sqlin_vals) + + if imdb_ids: + sqlin, sqlin_vals = sql_in(f"{Movie._table}.imdb_id", imdb_ids) + conds.append(sqlin) + vals.update(sqlin_vals) + + if not conds: + return [] + query = f""" SELECT {Rating._table}.score AS user_score, @@ -501,11 +522,11 @@ async def find_ratings( {Movie._table}.release_year AS release_year FROM {Movie._table} LEFT JOIN {Rating._table} ON {Movie._table}.id={Rating._table}.movie_id - WHERE {sqlin} + WHERE {(' OR '.join(conds))} """ async with locked_connection() as conn: - rows = await conn.fetch_all(bindparams(query, {**values, **sqlin_vals})) + rows = await conn.fetch_all(bindparams(query, vals)) return tuple(dict(r) for r in rows) diff --git a/unwind/web.py b/unwind/web.py index 1a4a279..e194c10 100644 --- a/unwind/web.py +++ b/unwind/web.py @@ -200,16 +200,29 @@ async def get_ratings_for_group(request): user_ids = {u["id"] for u in group.users} params = request.query_params - rows = await find_ratings( - title=params.get("title"), - media_type=params.get("media_type"), - exact=truthy(params.get("exact")), - ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")), - include_unrated=truthy(params.get("include_unrated")), - yearcomp=yearcomp(params["year"]) if "year" in params else None, - limit_rows=as_int(params.get("per_page"), max=10, default=5), - user_ids=user_ids, - ) + + imdb_id: str | None = params.get("imdb_id") + unwind_id: str | None = params.get("unwind_id") + + # if (imdb_id or unwind_id) and (movie := await db.get(Movie, id=unwind_id, imdb_id=imdb_id)): + if unwind_id: + rows = await db.ratings_for_movie_ids(ids=[unwind_id]) + + elif imdb_id: + rows = await db.ratings_for_movie_ids(imdb_ids=[imdb_id]) + + else: + rows = await find_ratings( + title=params.get("title"), + media_type=params.get("media_type"), + exact=truthy(params.get("exact")), + ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")), + include_unrated=truthy(params.get("include_unrated")), + yearcomp=yearcomp(params["year"]) if "year" in params else None, + limit_rows=as_int(params.get("per_page"), max=10, default=5), + user_ids=user_ids, + ) + ratings = (web_models.Rating(**r) for r in rows) aggr = web_models.aggregate_ratings(ratings, user_ids)