diff --git a/unwind/db.py b/unwind/db.py index e5b43ab..1381e52 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -1,4 +1,3 @@ -import json import logging import re from pathlib import Path @@ -19,8 +18,10 @@ from .models import ( optional_fields, utcnow, ) +from .types import ULID log = logging.getLogger(__name__) +T = TypeVar("T") _shared_connection: Optional[Database] = None @@ -478,6 +479,39 @@ def demux(tp: Type[ModelType], row) -> ModelType: return fromplain(tp, {n: row[k] for k, (_, n) in sql_fieldmap(tp).items()}) +def sql_in(column: str, values: list[T]) -> tuple[str, dict[str, T]]: + c = column.replace(".", "___") + value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)} + placeholders = ",".join(":" + k for k in value_map) + return f"{column} IN ({placeholders})", value_map + + +async def ratings_for_movies( + movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = [] +) -> Iterable[Rating]: + values: dict[str, str] = {} + conditions: list[str] = [] + + q, vm = sql_in("movie_id", [str(m) for m in movie_ids]) + conditions.append(q) + values.update(vm) + + if user_ids: + q, vm = sql_in("user_id", [str(m) for m in user_ids]) + conditions.append(q) + values.update(vm) + + query = f""" + SELECT {','.join(sql_fields(Rating))} + FROM {Rating._table} + WHERE {' AND '.join(f'({c})' for c in conditions) if conditions else '1=1'} + """ + + rows = await shared_connection().fetch_all(query, values) + + return (fromplain(Rating, row) for row in rows) + + async def find_movies( *, title: str = None, @@ -529,52 +563,28 @@ async def find_movies( if not include_unrated: conditions.append(f"{Movie._table}.score NOTNULL") - if not user_ids: - query = f""" - SELECT {','.join(sql_fields(Movie))} - FROM {Movie._table} - WHERE {(' AND '.join(conditions)) if conditions else '1=1'} - ORDER BY - length({Movie._table}.title) ASC, - {Movie._table}.imdb_score DESC, - {Movie._table}.release_year DESC - LIMIT :skip_rows, :limit_rows - """ - rows = await shared_connection().fetch_all(bindparams(query, values)) - return ((fromplain(Movie, row), []) for row in rows) - - # XXX add user_ids filtering - - fields_ = mux(Movie, Rating) query = f""" - WITH movie_ids AS ( - SELECT id - FROM {Movie._table} + SELECT {','.join(sql_fields(Movie))} + FROM {Movie._table} WHERE {(' AND '.join(conditions)) if conditions else '1=1'} ORDER BY length({Movie._table}.title) ASC, {Movie._table}.imdb_score DESC, {Movie._table}.release_year DESC - LIMIT :skip_rows, :limit_rows - ) - SELECT {fields_} - FROM {Movie._table} - LEFT JOIN {Rating._table} ON {Rating._table}.movie_id={Movie._table}.id - WHERE {Movie._table}.id IN movie_ids + LIMIT :skip_rows, :limit_rows """ - rows = await shared_connection().fetch_all(bindparams(query, values)) - aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {} - for row in rows: - movie = demux(Movie, row) - _, ratings = aggreg.setdefault(movie.id, (movie, [])) - try: - rating = demux(Rating, row) - except: - pass - else: - ratings.append(rating) + movies = [fromplain(Movie, row) for row in rows] + + if not user_ids: + return ((m, []) for m in movies) + + ratings = await ratings_for_movies((m.id for m in movies), user_ids) + + aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies} + for rating in ratings: + aggreg[rating.movie_id][1].append(rating) return aggreg.values()