migrate db.find_ratings to SQLAlchemy

This commit is contained in:
ducklet 2023-03-28 21:50:14 +02:00
parent d4933bf1a6
commit 1a3528e096
3 changed files with 54 additions and 55 deletions

View file

@ -1,4 +1,5 @@
from datetime import datetime
import pytest
from starlette.testclient import TestClient

View file

@ -18,7 +18,9 @@ from .models import (
asplain,
fields,
fromplain,
movies,
optional_fields,
ratings,
utcnow,
)
from .types import ULID
@ -427,77 +429,72 @@ async def find_ratings(
limit_rows: int = 10,
user_ids: Iterable[str] = [],
):
values: dict[str, int | str] = {
"limit_rows": limit_rows,
}
conditions = []
if title:
values["escape"] = "#"
escaped_title = sql_escape(title, char=values["escape"])
values["pattern"] = (
escape_char = "#"
escaped_title = sql_escape(title, char=escape_char)
pattern = (
"_".join(escaped_title.split())
if exact
else "%" + "%".join(escaped_title.split()) + "%"
)
conditions.append(
f"""
(
{Movie._table}.title LIKE :pattern ESCAPE :escape
OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape
sa.or_(
movies.c.title.like(pattern, escape=escape_char),
movies.c.original_title.like(pattern, escape=escape_char),
)
"""
)
if yearcomp:
op, year = yearcomp
assert op in "<=>"
values["year"] = year
conditions.append(f"{Movie._table}.release_year{op}:year")
match yearcomp:
case ("<", year):
conditions.append(movies.c.release_year < year)
case ("=", year):
conditions.append(movies.c.release_year == year)
case (">", year):
conditions.append(movies.c.release_year > year)
if media_type:
values["media_type"] = media_type
conditions.append(f"{Movie._table}.media_type=:media_type")
if media_type is not None:
conditions.append(movies.c.media_type == media_type)
if ignore_tv_episodes:
conditions.append(f"{Movie._table}.media_type!='TV Episode'")
conditions.append(movies.c.media_type != "TV Episode")
user_condition = "1=1"
user_condition = []
if user_ids:
uvs = {f"user_id_{i}": v for i, v in enumerate(user_ids, start=1)}
values.update(uvs)
user_condition = f"{Rating._table}.user_id IN ({','.join(':'+n for n in uvs)})"
user_condition.append(ratings.c.user_id.in_(user_ids))
query = f"""
SELECT DISTINCT {Rating._table}.movie_id
FROM {Rating._table}
LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id
WHERE {user_condition}{(' AND ' + ' AND '.join(conditions)) if conditions else ''}
ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC, {Movie._table}.imdb_score DESC
LIMIT :limit_rows
"""
query = (
sa.select(ratings.c.movie_id)
.distinct()
.outerjoin_from(ratings, movies, movies.c.id == ratings.c.movie_id)
.where(*conditions, *user_condition)
.order_by(
sa.func.length(movies.c.title).asc(),
ratings.c.rating_date.desc(),
movies.c.imdb_score.desc(),
)
.limit(limit_rows)
)
async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values))
movie_ids = tuple(r._mapping["movie_id"] for r in rows)
rows = conn.iterate(query)
movie_ids = [r.movie_id async for r in rows]
if include_unrated and len(movie_ids) < limit_rows:
sqlin, sqlin_vals = sql_in("id", movie_ids, not_=True)
query = f"""
SELECT DISTINCT id AS movie_id
FROM {Movie._table}
WHERE {sqlin}
{('AND ' + ' AND '.join(conditions)) if conditions else ''}
ORDER BY length(title) ASC, imdb_score DESC, release_year DESC
LIMIT :limit_rows
"""
query = (
sa.select(movies.c.id.label("movie_id"))
.distinct()
.where(movies.c.id.not_in(movie_ids), *conditions)
.order_by(
sa.func.length(movies.c.title).asc(),
movies.c.imdb_score.desc(),
movies.c.release_year.desc(),
)
.limit(limit_rows - len(movie_ids))
)
async with locked_connection() as conn:
rows = await conn.fetch_all(
bindparams(
query,
{**values, **sqlin_vals, "limit_rows": limit_rows - len(movie_ids)},
)
)
movie_ids += tuple(r._mapping["movie_id"] for r in rows)
rows = conn.iterate(query)
movie_ids += [r.movie_id async for r in rows]
return await ratings_for_movie_ids(ids=movie_ids)
@ -507,9 +504,6 @@ async def ratings_for_movie_ids(
) -> Iterable[dict[str, Any]]:
conds = []
ratings = Rating.__table__
movies = Movie.__table__
if ids:
conds.append(movies.c.id.in_([str(x) for x in ids]))

View file

@ -12,8 +12,8 @@ from typing import (
Literal,
Mapping,
Type,
TypeVar,
TypedDict,
TypeVar,
Union,
get_args,
get_origin,
@ -318,6 +318,8 @@ class Movie:
self._is_lazy = False
movies = Movie.__table__
_RelationSentinel = object()
"""Mark a model field as containing external data.
@ -372,6 +374,8 @@ class Rating:
)
ratings = Rating.__table__
Access = Literal[
"r", # read
"i", # index