diff --git a/tests/test_db.py b/tests/test_db.py index 92d75ae..d476408 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -316,3 +316,86 @@ async def test_ratings_for_movies(shared_conn: db.Database): assert (r1,) == tuple( await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids) ) + + +@pytest.mark.asyncio +async def test_find_movies(shared_conn: db.Database): + async with shared_conn.transaction(force_rollback=True): + m1 = a_movie(title="movie one") + await db.add(m1) + + m2 = a_movie(title="movie two", imdb_score=33, release_year=m1.release_year + 1) + await db.add(m2) + + u1 = models.User( + imdb_id="u00001", + name="User1", + secret="secret1", + ) + await db.add(u1) + + u2 = models.User( + imdb_id="u00002", + name="User2", + secret="secret2", + ) + await db.add(u2) + + r1 = models.Rating( + movie_id=m2.id, + movie=m2, + user_id=u1.id, + user=u1, + score=66, + rating_date=datetime.now(), + ) + await db.add(r1) + + # --- + + assert () == tuple(await db.find_movies(title=m1.title, include_unrated=False)) + assert ((m1, []),) == tuple( + await db.find_movies(title=m1.title, include_unrated=True) + ) + + assert ((m1, []),) == tuple( + await db.find_movies(title="mo on", exact=False, include_unrated=True) + ) + assert ((m1, []),) == tuple( + await db.find_movies(title="movie one", exact=True, include_unrated=True) + ) + assert () == tuple( + await db.find_movies(title="mo on", exact=True, include_unrated=True) + ) + + assert ((m2, []),) == tuple( + await db.find_movies(title="movie", exact=False, include_unrated=False) + ) + assert ((m2, []), (m1, [])) == tuple( + await db.find_movies(title="movie", exact=False, include_unrated=True) + ) + + assert ((m1, []),) == tuple( + await db.find_movies(include_unrated=True, yearcomp=("=", m1.release_year)) + ) + assert ((m2, []),) == tuple( + await db.find_movies(include_unrated=True, yearcomp=("=", m2.release_year)) + ) + assert ((m1, []),) == tuple( + await db.find_movies(include_unrated=True, yearcomp=("<", m2.release_year)) + ) + assert ((m2, []),) == tuple( + await db.find_movies(include_unrated=True, yearcomp=(">", m1.release_year)) + ) + + assert ((m2, []), (m1, [])) == tuple(await db.find_movies(include_unrated=True)) + assert ((m2, []),) == tuple( + await db.find_movies(include_unrated=True, limit_rows=1) + ) + assert ((m1, []),) == tuple( + await db.find_movies(include_unrated=True, skip_rows=1) + ) + + assert ((m2, [r1]), (m1, [])) == tuple( + await db.find_movies(include_unrated=True, user_ids=[u1.id, u2.id]) + ) diff --git a/unwind/db.py b/unwind/db.py index 9e9f959..a610b78 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -565,66 +565,63 @@ async def find_movies( include_unrated: bool = False, user_ids: list[ULID] = [], ) -> Iterable[tuple[Movie, list[Rating]]]: - values: dict[str, int | str] = { - "limit_rows": limit_rows, - "skip_rows": skip_rows, - } - conditions = [] + if title: - values["escape"] = "#" - escaped_title = sql_escape(title, char=values["escape"]) - values["pattern"] = ( + escape_char = "#" + escaped_title = sql_escape(title, char=escape_char) + pattern = ( "_".join(escaped_title.split()) if exact else "%" + "%".join(escaped_title.split()) + "%" ) conditions.append( - f""" - ( - {Movie._table}.title LIKE :pattern ESCAPE :escape - OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape + sa.or_( + movies.c.title.like(pattern, escape=escape_char), + movies.c.original_title.like(pattern, escape=escape_char), ) - """ ) - if yearcomp: - op, year = yearcomp - assert op in "<=>" - values["year"] = year - conditions.append(f"{Movie._table}.release_year{op}:year") + match yearcomp: + case ("<", year): + conditions.append(movies.c.release_year < year) + case ("=", year): + conditions.append(movies.c.release_year == year) + case (">", year): + conditions.append(movies.c.release_year > year) - if media_type: - values["media_type"] = media_type - conditions.append(f"{Movie._table}.media_type=:media_type") + if media_type is not None: + conditions.append(movies.c.media_type == media_type) if ignore_tv_episodes: - conditions.append(f"{Movie._table}.media_type!='TV Episode'") + conditions.append(movies.c.media_type != "TV Episode") if not include_unrated: - conditions.append(f"{Movie._table}.imdb_score NOTNULL") + conditions.append(movies.c.imdb_score != None) + + query = ( + sa.select(movies) + .where(*conditions) + .order_by( + sa.func.length(movies.c.title).asc(), + movies.c.imdb_score.desc(), + movies.c.release_year.desc(), + ) + .limit(limit_rows) + .offset(skip_rows) + ) - query = f""" - SELECT {','.join(sql_fields(Movie))} - FROM {Movie._table} - WHERE {(' AND '.join(conditions)) if conditions else '1=1'} - ORDER BY - length({Movie._table}.title) ASC, - {Movie._table}.imdb_score DESC, - {Movie._table}.release_year DESC - LIMIT :skip_rows, :limit_rows - """ async with locked_connection() as conn: - rows = await conn.fetch_all(bindparams(query, values)) + rows = await conn.fetch_all(query) - movies = [fromplain(Movie, row._mapping, serialized=True) for row in rows] + movies_ = [fromplain(Movie, row._mapping, serialized=True) for row in rows] if not user_ids: - return ((m, []) for m in movies) + return ((m, []) for m in movies_) - ratings = await ratings_for_movies((m.id for m in movies), user_ids) + ratings = await ratings_for_movies((m.id for m in movies_), user_ids) - aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies} + aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies_} for rating in ratings: aggreg[rating.movie_id][1].append(rating)