filter ratings for movies index route
This commit is contained in:
parent
e56ff8a6e6
commit
ff4f8fa246
1 changed files with 48 additions and 38 deletions
74
unwind/db.py
74
unwind/db.py
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
|
@ -19,8 +18,10 @@ from .models import (
|
|||
optional_fields,
|
||||
utcnow,
|
||||
)
|
||||
from .types import ULID
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
|
||||
_shared_connection: Optional[Database] = None
|
||||
|
||||
|
|
@ -478,6 +479,39 @@ def demux(tp: Type[ModelType], row) -> ModelType:
|
|||
return fromplain(tp, {n: row[k] for k, (_, n) in sql_fieldmap(tp).items()})
|
||||
|
||||
|
||||
def sql_in(column: str, values: list[T]) -> tuple[str, dict[str, T]]:
|
||||
c = column.replace(".", "___")
|
||||
value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)}
|
||||
placeholders = ",".join(":" + k for k in value_map)
|
||||
return f"{column} IN ({placeholders})", value_map
|
||||
|
||||
|
||||
async def ratings_for_movies(
|
||||
movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = []
|
||||
) -> Iterable[Rating]:
|
||||
values: dict[str, str] = {}
|
||||
conditions: list[str] = []
|
||||
|
||||
q, vm = sql_in("movie_id", [str(m) for m in movie_ids])
|
||||
conditions.append(q)
|
||||
values.update(vm)
|
||||
|
||||
if user_ids:
|
||||
q, vm = sql_in("user_id", [str(m) for m in user_ids])
|
||||
conditions.append(q)
|
||||
values.update(vm)
|
||||
|
||||
query = f"""
|
||||
SELECT {','.join(sql_fields(Rating))}
|
||||
FROM {Rating._table}
|
||||
WHERE {' AND '.join(f'({c})' for c in conditions) if conditions else '1=1'}
|
||||
"""
|
||||
|
||||
rows = await shared_connection().fetch_all(query, values)
|
||||
|
||||
return (fromplain(Rating, row) for row in rows)
|
||||
|
||||
|
||||
async def find_movies(
|
||||
*,
|
||||
title: str = None,
|
||||
|
|
@ -529,7 +563,6 @@ async def find_movies(
|
|||
if not include_unrated:
|
||||
conditions.append(f"{Movie._table}.score NOTNULL")
|
||||
|
||||
if not user_ids:
|
||||
query = f"""
|
||||
SELECT {','.join(sql_fields(Movie))}
|
||||
FROM {Movie._table}
|
||||
|
|
@ -541,40 +574,17 @@ async def find_movies(
|
|||
LIMIT :skip_rows, :limit_rows
|
||||
"""
|
||||
rows = await shared_connection().fetch_all(bindparams(query, values))
|
||||
return ((fromplain(Movie, row), []) for row in rows)
|
||||
|
||||
# XXX add user_ids filtering
|
||||
movies = [fromplain(Movie, row) for row in rows]
|
||||
|
||||
fields_ = mux(Movie, Rating)
|
||||
query = f"""
|
||||
WITH movie_ids AS (
|
||||
SELECT id
|
||||
FROM {Movie._table}
|
||||
WHERE {(' AND '.join(conditions)) if conditions else '1=1'}
|
||||
ORDER BY
|
||||
length({Movie._table}.title) ASC,
|
||||
{Movie._table}.imdb_score DESC,
|
||||
{Movie._table}.release_year DESC
|
||||
LIMIT :skip_rows, :limit_rows
|
||||
)
|
||||
SELECT {fields_}
|
||||
FROM {Movie._table}
|
||||
LEFT JOIN {Rating._table} ON {Rating._table}.movie_id={Movie._table}.id
|
||||
WHERE {Movie._table}.id IN movie_ids
|
||||
"""
|
||||
if not user_ids:
|
||||
return ((m, []) for m in movies)
|
||||
|
||||
rows = await shared_connection().fetch_all(bindparams(query, values))
|
||||
ratings = await ratings_for_movies((m.id for m in movies), user_ids)
|
||||
|
||||
aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {}
|
||||
for row in rows:
|
||||
movie = demux(Movie, row)
|
||||
_, ratings = aggreg.setdefault(movie.id, (movie, []))
|
||||
try:
|
||||
rating = demux(Rating, row)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
ratings.append(rating)
|
||||
aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies}
|
||||
for rating in ratings:
|
||||
aggreg[rating.movie_id][1].append(rating)
|
||||
|
||||
return aggreg.values()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue