filter ratings for movies index route

This commit is contained in:
ducklet 2021-08-04 17:13:52 +02:00
parent e56ff8a6e6
commit ff4f8fa246

View file

@ -1,4 +1,3 @@
import json
import logging
import re
from pathlib import Path
@ -19,8 +18,10 @@ from .models import (
optional_fields,
utcnow,
)
from .types import ULID
log = logging.getLogger(__name__)
T = TypeVar("T")
_shared_connection: Optional[Database] = None
@ -478,6 +479,39 @@ def demux(tp: Type[ModelType], row) -> ModelType:
return fromplain(tp, {n: row[k] for k, (_, n) in sql_fieldmap(tp).items()})
def sql_in(column: str, values: list[T]) -> tuple[str, dict[str, T]]:
c = column.replace(".", "___")
value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)}
placeholders = ",".join(":" + k for k in value_map)
return f"{column} IN ({placeholders})", value_map
async def ratings_for_movies(
movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = []
) -> Iterable[Rating]:
values: dict[str, str] = {}
conditions: list[str] = []
q, vm = sql_in("movie_id", [str(m) for m in movie_ids])
conditions.append(q)
values.update(vm)
if user_ids:
q, vm = sql_in("user_id", [str(m) for m in user_ids])
conditions.append(q)
values.update(vm)
query = f"""
SELECT {','.join(sql_fields(Rating))}
FROM {Rating._table}
WHERE {' AND '.join(f'({c})' for c in conditions) if conditions else '1=1'}
"""
rows = await shared_connection().fetch_all(query, values)
return (fromplain(Rating, row) for row in rows)
async def find_movies(
*,
title: str = None,
@ -529,7 +563,6 @@ async def find_movies(
if not include_unrated:
conditions.append(f"{Movie._table}.score NOTNULL")
if not user_ids:
query = f"""
SELECT {','.join(sql_fields(Movie))}
FROM {Movie._table}
@ -541,40 +574,17 @@ async def find_movies(
LIMIT :skip_rows, :limit_rows
"""
rows = await shared_connection().fetch_all(bindparams(query, values))
return ((fromplain(Movie, row), []) for row in rows)
# XXX add user_ids filtering
movies = [fromplain(Movie, row) for row in rows]
fields_ = mux(Movie, Rating)
query = f"""
WITH movie_ids AS (
SELECT id
FROM {Movie._table}
WHERE {(' AND '.join(conditions)) if conditions else '1=1'}
ORDER BY
length({Movie._table}.title) ASC,
{Movie._table}.imdb_score DESC,
{Movie._table}.release_year DESC
LIMIT :skip_rows, :limit_rows
)
SELECT {fields_}
FROM {Movie._table}
LEFT JOIN {Rating._table} ON {Rating._table}.movie_id={Movie._table}.id
WHERE {Movie._table}.id IN movie_ids
"""
if not user_ids:
return ((m, []) for m in movies)
rows = await shared_connection().fetch_all(bindparams(query, values))
ratings = await ratings_for_movies((m.id for m in movies), user_ids)
aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {}
for row in rows:
movie = demux(Movie, row)
_, ratings = aggreg.setdefault(movie.id, (movie, []))
try:
rating = demux(Rating, row)
except:
pass
else:
ratings.append(rating)
aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies}
for rating in ratings:
aggreg[rating.movie_id][1].append(rating)
return aggreg.values()