migrate db.find_movies to SQLAlchemy

This commit is contained in:
ducklet 2023-03-28 21:49:02 +02:00
parent 1fd7e730b3
commit 84bbe331ee
2 changed files with 118 additions and 38 deletions

View file

@ -316,3 +316,86 @@ async def test_ratings_for_movies(shared_conn: db.Database):
assert (r1,) == tuple( assert (r1,) == tuple(
await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids) await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids)
) )
@pytest.mark.asyncio
async def test_find_movies(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True):
m1 = a_movie(title="movie one")
await db.add(m1)
m2 = a_movie(title="movie two", imdb_score=33, release_year=m1.release_year + 1)
await db.add(m2)
u1 = models.User(
imdb_id="u00001",
name="User1",
secret="secret1",
)
await db.add(u1)
u2 = models.User(
imdb_id="u00002",
name="User2",
secret="secret2",
)
await db.add(u2)
r1 = models.Rating(
movie_id=m2.id,
movie=m2,
user_id=u1.id,
user=u1,
score=66,
rating_date=datetime.now(),
)
await db.add(r1)
# ---
assert () == tuple(await db.find_movies(title=m1.title, include_unrated=False))
assert ((m1, []),) == tuple(
await db.find_movies(title=m1.title, include_unrated=True)
)
assert ((m1, []),) == tuple(
await db.find_movies(title="mo on", exact=False, include_unrated=True)
)
assert ((m1, []),) == tuple(
await db.find_movies(title="movie one", exact=True, include_unrated=True)
)
assert () == tuple(
await db.find_movies(title="mo on", exact=True, include_unrated=True)
)
assert ((m2, []),) == tuple(
await db.find_movies(title="movie", exact=False, include_unrated=False)
)
assert ((m2, []), (m1, [])) == tuple(
await db.find_movies(title="movie", exact=False, include_unrated=True)
)
assert ((m1, []),) == tuple(
await db.find_movies(include_unrated=True, yearcomp=("=", m1.release_year))
)
assert ((m2, []),) == tuple(
await db.find_movies(include_unrated=True, yearcomp=("=", m2.release_year))
)
assert ((m1, []),) == tuple(
await db.find_movies(include_unrated=True, yearcomp=("<", m2.release_year))
)
assert ((m2, []),) == tuple(
await db.find_movies(include_unrated=True, yearcomp=(">", m1.release_year))
)
assert ((m2, []), (m1, [])) == tuple(await db.find_movies(include_unrated=True))
assert ((m2, []),) == tuple(
await db.find_movies(include_unrated=True, limit_rows=1)
)
assert ((m1, []),) == tuple(
await db.find_movies(include_unrated=True, skip_rows=1)
)
assert ((m2, [r1]), (m1, [])) == tuple(
await db.find_movies(include_unrated=True, user_ids=[u1.id, u2.id])
)

View file

@ -565,66 +565,63 @@ async def find_movies(
include_unrated: bool = False, include_unrated: bool = False,
user_ids: list[ULID] = [], user_ids: list[ULID] = [],
) -> Iterable[tuple[Movie, list[Rating]]]: ) -> Iterable[tuple[Movie, list[Rating]]]:
values: dict[str, int | str] = {
"limit_rows": limit_rows,
"skip_rows": skip_rows,
}
conditions = [] conditions = []
if title: if title:
values["escape"] = "#" escape_char = "#"
escaped_title = sql_escape(title, char=values["escape"]) escaped_title = sql_escape(title, char=escape_char)
values["pattern"] = ( pattern = (
"_".join(escaped_title.split()) "_".join(escaped_title.split())
if exact if exact
else "%" + "%".join(escaped_title.split()) + "%" else "%" + "%".join(escaped_title.split()) + "%"
) )
conditions.append( conditions.append(
f""" sa.or_(
( movies.c.title.like(pattern, escape=escape_char),
{Movie._table}.title LIKE :pattern ESCAPE :escape movies.c.original_title.like(pattern, escape=escape_char),
OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape
) )
"""
) )
if yearcomp: match yearcomp:
op, year = yearcomp case ("<", year):
assert op in "<=>" conditions.append(movies.c.release_year < year)
values["year"] = year case ("=", year):
conditions.append(f"{Movie._table}.release_year{op}:year") conditions.append(movies.c.release_year == year)
case (">", year):
conditions.append(movies.c.release_year > year)
if media_type: if media_type is not None:
values["media_type"] = media_type conditions.append(movies.c.media_type == media_type)
conditions.append(f"{Movie._table}.media_type=:media_type")
if ignore_tv_episodes: if ignore_tv_episodes:
conditions.append(f"{Movie._table}.media_type!='TV Episode'") conditions.append(movies.c.media_type != "TV Episode")
if not include_unrated: if not include_unrated:
conditions.append(f"{Movie._table}.imdb_score NOTNULL") conditions.append(movies.c.imdb_score != None)
query = (
sa.select(movies)
.where(*conditions)
.order_by(
sa.func.length(movies.c.title).asc(),
movies.c.imdb_score.desc(),
movies.c.release_year.desc(),
)
.limit(limit_rows)
.offset(skip_rows)
)
query = f"""
SELECT {','.join(sql_fields(Movie))}
FROM {Movie._table}
WHERE {(' AND '.join(conditions)) if conditions else '1=1'}
ORDER BY
length({Movie._table}.title) ASC,
{Movie._table}.imdb_score DESC,
{Movie._table}.release_year DESC
LIMIT :skip_rows, :limit_rows
"""
async with locked_connection() as conn: async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values)) rows = await conn.fetch_all(query)
movies = [fromplain(Movie, row._mapping, serialized=True) for row in rows] movies_ = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
if not user_ids: if not user_ids:
return ((m, []) for m in movies) return ((m, []) for m in movies_)
ratings = await ratings_for_movies((m.id for m in movies), user_ids) ratings = await ratings_for_movies((m.id for m in movies_), user_ids)
aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies} aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies_}
for rating in ratings: for rating in ratings:
aggreg[rating.movie_id][1].append(rating) aggreg[rating.movie_id][1].append(rating)