migrate db.find_ratings to SQLAlchemy

This commit is contained in:
ducklet 2023-03-28 21:50:14 +02:00
parent d4933bf1a6
commit 1a3528e096
3 changed files with 54 additions and 55 deletions

View file

@ -1,4 +1,5 @@
from datetime import datetime from datetime import datetime
import pytest import pytest
from starlette.testclient import TestClient from starlette.testclient import TestClient

View file

@ -18,7 +18,9 @@ from .models import (
asplain, asplain,
fields, fields,
fromplain, fromplain,
movies,
optional_fields, optional_fields,
ratings,
utcnow, utcnow,
) )
from .types import ULID from .types import ULID
@ -427,77 +429,72 @@ async def find_ratings(
limit_rows: int = 10, limit_rows: int = 10,
user_ids: Iterable[str] = [], user_ids: Iterable[str] = [],
): ):
values: dict[str, int | str] = {
"limit_rows": limit_rows,
}
conditions = [] conditions = []
if title: if title:
values["escape"] = "#" escape_char = "#"
escaped_title = sql_escape(title, char=values["escape"]) escaped_title = sql_escape(title, char=escape_char)
values["pattern"] = ( pattern = (
"_".join(escaped_title.split()) "_".join(escaped_title.split())
if exact if exact
else "%" + "%".join(escaped_title.split()) + "%" else "%" + "%".join(escaped_title.split()) + "%"
) )
conditions.append( conditions.append(
f""" sa.or_(
( movies.c.title.like(pattern, escape=escape_char),
{Movie._table}.title LIKE :pattern ESCAPE :escape movies.c.original_title.like(pattern, escape=escape_char),
OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape
) )
"""
) )
if yearcomp: match yearcomp:
op, year = yearcomp case ("<", year):
assert op in "<=>" conditions.append(movies.c.release_year < year)
values["year"] = year case ("=", year):
conditions.append(f"{Movie._table}.release_year{op}:year") conditions.append(movies.c.release_year == year)
case (">", year):
conditions.append(movies.c.release_year > year)
if media_type: if media_type is not None:
values["media_type"] = media_type conditions.append(movies.c.media_type == media_type)
conditions.append(f"{Movie._table}.media_type=:media_type")
if ignore_tv_episodes: if ignore_tv_episodes:
conditions.append(f"{Movie._table}.media_type!='TV Episode'") conditions.append(movies.c.media_type != "TV Episode")
user_condition = "1=1" user_condition = []
if user_ids: if user_ids:
uvs = {f"user_id_{i}": v for i, v in enumerate(user_ids, start=1)} user_condition.append(ratings.c.user_id.in_(user_ids))
values.update(uvs)
user_condition = f"{Rating._table}.user_id IN ({','.join(':'+n for n in uvs)})"
query = f""" query = (
SELECT DISTINCT {Rating._table}.movie_id sa.select(ratings.c.movie_id)
FROM {Rating._table} .distinct()
LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id .outerjoin_from(ratings, movies, movies.c.id == ratings.c.movie_id)
WHERE {user_condition}{(' AND ' + ' AND '.join(conditions)) if conditions else ''} .where(*conditions, *user_condition)
ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC, {Movie._table}.imdb_score DESC .order_by(
LIMIT :limit_rows sa.func.length(movies.c.title).asc(),
""" ratings.c.rating_date.desc(),
movies.c.imdb_score.desc(),
)
.limit(limit_rows)
)
async with locked_connection() as conn: async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values)) rows = conn.iterate(query)
movie_ids = tuple(r._mapping["movie_id"] for r in rows) movie_ids = [r.movie_id async for r in rows]
if include_unrated and len(movie_ids) < limit_rows: if include_unrated and len(movie_ids) < limit_rows:
sqlin, sqlin_vals = sql_in("id", movie_ids, not_=True) query = (
query = f""" sa.select(movies.c.id.label("movie_id"))
SELECT DISTINCT id AS movie_id .distinct()
FROM {Movie._table} .where(movies.c.id.not_in(movie_ids), *conditions)
WHERE {sqlin} .order_by(
{('AND ' + ' AND '.join(conditions)) if conditions else ''} sa.func.length(movies.c.title).asc(),
ORDER BY length(title) ASC, imdb_score DESC, release_year DESC movies.c.imdb_score.desc(),
LIMIT :limit_rows movies.c.release_year.desc(),
""" )
.limit(limit_rows - len(movie_ids))
)
async with locked_connection() as conn: async with locked_connection() as conn:
rows = await conn.fetch_all( rows = conn.iterate(query)
bindparams( movie_ids += [r.movie_id async for r in rows]
query,
{**values, **sqlin_vals, "limit_rows": limit_rows - len(movie_ids)},
)
)
movie_ids += tuple(r._mapping["movie_id"] for r in rows)
return await ratings_for_movie_ids(ids=movie_ids) return await ratings_for_movie_ids(ids=movie_ids)
@ -507,9 +504,6 @@ async def ratings_for_movie_ids(
) -> Iterable[dict[str, Any]]: ) -> Iterable[dict[str, Any]]:
conds = [] conds = []
ratings = Rating.__table__
movies = Movie.__table__
if ids: if ids:
conds.append(movies.c.id.in_([str(x) for x in ids])) conds.append(movies.c.id.in_([str(x) for x in ids]))

View file

@ -12,8 +12,8 @@ from typing import (
Literal, Literal,
Mapping, Mapping,
Type, Type,
TypeVar,
TypedDict, TypedDict,
TypeVar,
Union, Union,
get_args, get_args,
get_origin, get_origin,
@ -318,6 +318,8 @@ class Movie:
self._is_lazy = False self._is_lazy = False
movies = Movie.__table__
_RelationSentinel = object() _RelationSentinel = object()
"""Mark a model field as containing external data. """Mark a model field as containing external data.
@ -372,6 +374,8 @@ class Rating:
) )
ratings = Rating.__table__
Access = Literal[ Access = Literal[
"r", # read "r", # read
"i", # index "i", # index