migrate db.ratings_for_movie_ids to SQLAlchemy

This commit is contained in:
ducklet 2023-03-23 23:33:59 +01:00
parent b91fcd3f55
commit d4933bf1a6

View file

@ -505,40 +505,37 @@ async def find_ratings(
async def ratings_for_movie_ids( async def ratings_for_movie_ids(
ids: Iterable[ULID | str] = [], imdb_ids: Iterable[str] = [] ids: Iterable[ULID | str] = [], imdb_ids: Iterable[str] = []
) -> Iterable[dict[str, Any]]: ) -> Iterable[dict[str, Any]]:
conds: list[str] = [] conds = []
vals: dict[str, str] = {}
ratings = Rating.__table__
movies = Movie.__table__
if ids: if ids:
sqlin, sqlin_vals = sql_in(f"{Movie._table}.id", (str(x) for x in ids)) conds.append(movies.c.id.in_([str(x) for x in ids]))
conds.append(sqlin)
vals.update(sqlin_vals)
if imdb_ids: if imdb_ids:
sqlin, sqlin_vals = sql_in(f"{Movie._table}.imdb_id", imdb_ids) conds.append(movies.c.imdb_id.in_(imdb_ids))
conds.append(sqlin)
vals.update(sqlin_vals)
if not conds: if not conds:
return [] return []
query = f""" query = (
SELECT sa.select(
{Rating._table}.score AS user_score, ratings.c.score.label("user_score"),
{Rating._table}.user_id AS user_id, ratings.c.user_id.label("user_id"),
{Movie._table}.imdb_score, movies.c.imdb_score,
{Movie._table}.imdb_votes, movies.c.imdb_votes,
{Movie._table}.imdb_id AS movie_imdb_id, movies.c.imdb_id.label("movie_imdb_id"),
{Movie._table}.media_type AS media_type, movies.c.media_type.label("media_type"),
{Movie._table}.title AS canonical_title, movies.c.title.label("canonical_title"),
{Movie._table}.original_title AS original_title, movies.c.original_title.label("original_title"),
{Movie._table}.release_year AS release_year movies.c.release_year.label("release_year"),
FROM {Movie._table} )
LEFT JOIN {Rating._table} ON {Movie._table}.id={Rating._table}.movie_id .outerjoin_from(movies, ratings, movies.c.id == ratings.c.movie_id)
WHERE {(' OR '.join(conds))} .where(sa.or_(*conds))
""" )
async with locked_connection() as conn: async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, vals)) rows = await conn.fetch_all(query)
return tuple(dict(r._mapping) for r in rows) return tuple(dict(r._mapping) for r in rows)