simplify SQL query a bit

With SQLite it's in general not a big issue to run multiple smaller
queries instead a big one, because the overhead per request is much
smaller than with separate DBMS.
This should make it easier to extend `find_ratings` in the future.
This commit is contained in:
ducklet 2021-12-19 19:35:53 +01:00
parent e1f35143df
commit adfead81fc

View file

@ -456,43 +456,39 @@ async def find_ratings(
values.update(uvs) values.update(uvs)
user_condition = f"{Rating._table}.user_id IN ({','.join(':'+n for n in uvs)})" user_condition = f"{Rating._table}.user_id IN ({','.join(':'+n for n in uvs)})"
source_table = "newest_movies" query = f"""
ctes = [
f"""{source_table} AS (
SELECT DISTINCT {Rating._table}.movie_id SELECT DISTINCT {Rating._table}.movie_id
FROM {Rating._table} FROM {Rating._table}
LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id
WHERE {user_condition}{(' AND ' + ' AND '.join(conditions)) if conditions else ''} WHERE {user_condition}{(' AND ' + ' AND '.join(conditions)) if conditions else ''}
ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC, {Movie._table}.imdb_score DESC ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC, {Movie._table}.imdb_score DESC
LIMIT :limit_rows LIMIT :limit_rows
)""" """
] async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values))
movie_ids = tuple(r["movie_id"] for r in rows)
if include_unrated: if include_unrated and len(movie_ids) < limit_rows:
source_table = "target_movies" sqlin, sqlin_vals = sql_in("id", movie_ids, not_=True)
ctes.extend( query = f"""
[
f"""unrated_movies AS (
SELECT DISTINCT id AS movie_id SELECT DISTINCT id AS movie_id
FROM {Movie._table} FROM {Movie._table}
WHERE id NOT IN newest_movies WHERE {sqlin}
{('AND ' + ' AND '.join(conditions)) if conditions else ''} {('AND ' + ' AND '.join(conditions)) if conditions else ''}
ORDER BY length(title) ASC, imdb_score DESC, release_year DESC ORDER BY length(title) ASC, imdb_score DESC, release_year DESC
LIMIT :limit_rows LIMIT :limit_rows
)""", """
f"""{source_table} AS ( async with locked_connection() as conn:
SELECT * FROM newest_movies rows = await conn.fetch_all(
UNION ALL -- using ALL here avoids the reordering of IDs bindparams(
SELECT * FROM unrated_movies query,
LIMIT :limit_rows {**values, **sqlin_vals, "limit_rows": limit_rows - len(movie_ids)},
)""",
]
) )
)
movie_ids += tuple(r["movie_id"] for r in rows)
sqlin, sqlin_vals = sql_in(f"{Movie._table}.id", movie_ids)
query = f""" query = f"""
WITH
{','.join(ctes)}
SELECT SELECT
{Rating._table}.score AS user_score, {Rating._table}.score AS user_score,
{Rating._table}.user_id AS user_id, {Rating._table}.user_id AS user_id,
@ -503,13 +499,13 @@ async def find_ratings(
{Movie._table}.title AS canonical_title, {Movie._table}.title AS canonical_title,
{Movie._table}.original_title AS original_title, {Movie._table}.original_title AS original_title,
{Movie._table}.release_year AS release_year {Movie._table}.release_year AS release_year
FROM {source_table} FROM {Movie._table}
LEFT JOIN {Rating._table} ON {Rating._table}.movie_id={source_table}.movie_id LEFT JOIN {Rating._table} ON {Movie._table}.id={Rating._table}.movie_id
LEFT JOIN {Movie._table} ON {Movie._table}.id={source_table}.movie_id WHERE {sqlin}
""" """
async with locked_connection() as conn: async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values)) rows = await conn.fetch_all(bindparams(query, {**values, **sqlin_vals}))
return tuple(dict(r) for r in rows) return tuple(dict(r) for r in rows)
@ -533,10 +529,12 @@ def demux(tp: Type[ModelType], row) -> ModelType:
return fromplain(tp, d, serialized=True) return fromplain(tp, d, serialized=True)
def sql_in(column: str, values: list[T]) -> tuple[str, dict[str, T]]: def sql_in(column: str, values: Iterable[T], not_=False) -> tuple[str, dict[str, T]]:
c = column.replace(".", "___") c = column.replace(".", "___")
value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)} value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)}
placeholders = ",".join(":" + k for k in value_map) placeholders = ",".join(":" + k for k in value_map)
if not_:
return f"{column} NOT IN ({placeholders})", value_map
return f"{column} IN ({placeholders})", value_map return f"{column} IN ({placeholders})", value_map