diff --git a/unwind/db.py b/unwind/db.py index 4b71aff..e54754b 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -456,43 +456,39 @@ async def find_ratings( values.update(uvs) user_condition = f"{Rating._table}.user_id IN ({','.join(':'+n for n in uvs)})" - source_table = "newest_movies" - ctes = [ - f"""{source_table} AS ( - SELECT DISTINCT {Rating._table}.movie_id - FROM {Rating._table} - LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id - WHERE {user_condition}{(' AND ' + ' AND '.join(conditions)) if conditions else ''} - ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC, {Movie._table}.imdb_score DESC - LIMIT :limit_rows - )""" - ] - - if include_unrated: - source_table = "target_movies" - ctes.extend( - [ - f"""unrated_movies AS ( - SELECT DISTINCT id AS movie_id - FROM {Movie._table} - WHERE id NOT IN newest_movies - {('AND ' + ' AND '.join(conditions)) if conditions else ''} - ORDER BY length(title) ASC, imdb_score DESC, release_year DESC - LIMIT :limit_rows - )""", - f"""{source_table} AS ( - SELECT * FROM newest_movies - UNION ALL -- using ALL here avoids the reordering of IDs - SELECT * FROM unrated_movies - LIMIT :limit_rows - )""", - ] - ) - query = f""" - WITH - {','.join(ctes)} + SELECT DISTINCT {Rating._table}.movie_id + FROM {Rating._table} + LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id + WHERE {user_condition}{(' AND ' + ' AND '.join(conditions)) if conditions else ''} + ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC, {Movie._table}.imdb_score DESC + LIMIT :limit_rows + """ + async with locked_connection() as conn: + rows = await conn.fetch_all(bindparams(query, values)) + movie_ids = tuple(r["movie_id"] for r in rows) + if include_unrated and len(movie_ids) < limit_rows: + sqlin, sqlin_vals = sql_in("id", movie_ids, not_=True) + query = f""" + SELECT DISTINCT id AS movie_id + FROM {Movie._table} + WHERE {sqlin} + {('AND ' + ' AND '.join(conditions)) if conditions else ''} + ORDER BY length(title) ASC, imdb_score DESC, release_year DESC + LIMIT :limit_rows + """ + async with locked_connection() as conn: + rows = await conn.fetch_all( + bindparams( + query, + {**values, **sqlin_vals, "limit_rows": limit_rows - len(movie_ids)}, + ) + ) + movie_ids += tuple(r["movie_id"] for r in rows) + + sqlin, sqlin_vals = sql_in(f"{Movie._table}.id", movie_ids) + query = f""" SELECT {Rating._table}.score AS user_score, {Rating._table}.user_id AS user_id, @@ -503,13 +499,13 @@ async def find_ratings( {Movie._table}.title AS canonical_title, {Movie._table}.original_title AS original_title, {Movie._table}.release_year AS release_year - FROM {source_table} - LEFT JOIN {Rating._table} ON {Rating._table}.movie_id={source_table}.movie_id - LEFT JOIN {Movie._table} ON {Movie._table}.id={source_table}.movie_id + FROM {Movie._table} + LEFT JOIN {Rating._table} ON {Movie._table}.id={Rating._table}.movie_id + WHERE {sqlin} """ async with locked_connection() as conn: - rows = await conn.fetch_all(bindparams(query, values)) + rows = await conn.fetch_all(bindparams(query, {**values, **sqlin_vals})) return tuple(dict(r) for r in rows) @@ -533,10 +529,12 @@ def demux(tp: Type[ModelType], row) -> ModelType: return fromplain(tp, d, serialized=True) -def sql_in(column: str, values: list[T]) -> tuple[str, dict[str, T]]: +def sql_in(column: str, values: Iterable[T], not_=False) -> tuple[str, dict[str, T]]: c = column.replace(".", "___") value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)} placeholders = ",".join(":" + k for k in value_map) + if not_: + return f"{column} NOT IN ({placeholders})", value_map return f"{column} IN ({placeholders})", value_map