diff --git a/tests/test_web.py b/tests/test_web.py index 364edd5..5a4c3c5 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -1,4 +1,5 @@ from datetime import datetime + import pytest from starlette.testclient import TestClient diff --git a/unwind/db.py b/unwind/db.py index b91d3c1..34abd81 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -18,7 +18,9 @@ from .models import ( asplain, fields, fromplain, + movies, optional_fields, + ratings, utcnow, ) from .types import ULID @@ -427,77 +429,72 @@ async def find_ratings( limit_rows: int = 10, user_ids: Iterable[str] = [], ): - values: dict[str, int | str] = { - "limit_rows": limit_rows, - } - conditions = [] + if title: - values["escape"] = "#" - escaped_title = sql_escape(title, char=values["escape"]) - values["pattern"] = ( + escape_char = "#" + escaped_title = sql_escape(title, char=escape_char) + pattern = ( "_".join(escaped_title.split()) if exact else "%" + "%".join(escaped_title.split()) + "%" ) conditions.append( - f""" - ( - {Movie._table}.title LIKE :pattern ESCAPE :escape - OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape + sa.or_( + movies.c.title.like(pattern, escape=escape_char), + movies.c.original_title.like(pattern, escape=escape_char), ) - """ ) - if yearcomp: - op, year = yearcomp - assert op in "<=>" - values["year"] = year - conditions.append(f"{Movie._table}.release_year{op}:year") + match yearcomp: + case ("<", year): + conditions.append(movies.c.release_year < year) + case ("=", year): + conditions.append(movies.c.release_year == year) + case (">", year): + conditions.append(movies.c.release_year > year) - if media_type: - values["media_type"] = media_type - conditions.append(f"{Movie._table}.media_type=:media_type") + if media_type is not None: + conditions.append(movies.c.media_type == media_type) if ignore_tv_episodes: - conditions.append(f"{Movie._table}.media_type!='TV Episode'") + conditions.append(movies.c.media_type != "TV Episode") - user_condition = "1=1" + user_condition = [] if user_ids: - uvs = {f"user_id_{i}": v for i, v in enumerate(user_ids, start=1)} - values.update(uvs) - user_condition = f"{Rating._table}.user_id IN ({','.join(':'+n for n in uvs)})" + user_condition.append(ratings.c.user_id.in_(user_ids)) - query = f""" - SELECT DISTINCT {Rating._table}.movie_id - FROM {Rating._table} - LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id - WHERE {user_condition}{(' AND ' + ' AND '.join(conditions)) if conditions else ''} - ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC, {Movie._table}.imdb_score DESC - LIMIT :limit_rows - """ + query = ( + sa.select(ratings.c.movie_id) + .distinct() + .outerjoin_from(ratings, movies, movies.c.id == ratings.c.movie_id) + .where(*conditions, *user_condition) + .order_by( + sa.func.length(movies.c.title).asc(), + ratings.c.rating_date.desc(), + movies.c.imdb_score.desc(), + ) + .limit(limit_rows) + ) async with locked_connection() as conn: - rows = await conn.fetch_all(bindparams(query, values)) - movie_ids = tuple(r._mapping["movie_id"] for r in rows) + rows = conn.iterate(query) + movie_ids = [r.movie_id async for r in rows] if include_unrated and len(movie_ids) < limit_rows: - sqlin, sqlin_vals = sql_in("id", movie_ids, not_=True) - query = f""" - SELECT DISTINCT id AS movie_id - FROM {Movie._table} - WHERE {sqlin} - {('AND ' + ' AND '.join(conditions)) if conditions else ''} - ORDER BY length(title) ASC, imdb_score DESC, release_year DESC - LIMIT :limit_rows - """ - async with locked_connection() as conn: - rows = await conn.fetch_all( - bindparams( - query, - {**values, **sqlin_vals, "limit_rows": limit_rows - len(movie_ids)}, - ) + query = ( + sa.select(movies.c.id.label("movie_id")) + .distinct() + .where(movies.c.id.not_in(movie_ids), *conditions) + .order_by( + sa.func.length(movies.c.title).asc(), + movies.c.imdb_score.desc(), + movies.c.release_year.desc(), ) - movie_ids += tuple(r._mapping["movie_id"] for r in rows) + .limit(limit_rows - len(movie_ids)) + ) + async with locked_connection() as conn: + rows = conn.iterate(query) + movie_ids += [r.movie_id async for r in rows] return await ratings_for_movie_ids(ids=movie_ids) @@ -507,9 +504,6 @@ async def ratings_for_movie_ids( ) -> Iterable[dict[str, Any]]: conds = [] - ratings = Rating.__table__ - movies = Movie.__table__ - if ids: conds.append(movies.c.id.in_([str(x) for x in ids])) diff --git a/unwind/models.py b/unwind/models.py index 964ac01..6d40b35 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -12,8 +12,8 @@ from typing import ( Literal, Mapping, Type, - TypeVar, TypedDict, + TypeVar, Union, get_args, get_origin, @@ -318,6 +318,8 @@ class Movie: self._is_lazy = False +movies = Movie.__table__ + _RelationSentinel = object() """Mark a model field as containing external data. @@ -372,6 +374,8 @@ class Rating: ) +ratings = Rating.__table__ + Access = Literal[ "r", # read "i", # index