From 1fd7e730b349f48f07c1961644e83363f9ee0dfe Mon Sep 17 00:00:00 2001 From: ducklet Date: Tue, 28 Mar 2023 00:23:37 +0200 Subject: [PATCH] migrate `db.ratings_for_movies` to SQLAlchemy --- tests/test_db.py | 66 ++++++++++++++++++++++++++++++++++++++++++++++++ unwind/db.py | 28 +++----------------- 2 files changed, 70 insertions(+), 24 deletions(-) diff --git a/tests/test_db.py b/tests/test_db.py index 33fff26..92d75ae 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -250,3 +250,69 @@ async def test_find_ratings(shared_conn: db.Database): rows = await db.find_ratings(title="test", include_unrated=True) ratings = tuple(web_models.Rating(**r) for r in rows) assert (web_models.Rating.from_movie(m1),) == ratings + + +@pytest.mark.asyncio +async def test_ratings_for_movies(shared_conn: db.Database): + async with shared_conn.transaction(force_rollback=True): + m1 = a_movie() + await db.add(m1) + + m2 = a_movie() + await db.add(m2) + + u1 = models.User( + imdb_id="u00001", + name="User1", + secret="secret1", + ) + await db.add(u1) + + u2 = models.User( + imdb_id="u00002", + name="User2", + secret="secret2", + ) + await db.add(u2) + + r1 = models.Rating( + movie_id=m2.id, + movie=m2, + user_id=u1.id, + user=u1, + score=66, + rating_date=datetime.now(), + ) + await db.add(r1) + + # --- + + movie_ids = [m1.id] + user_ids = [] + assert tuple() == tuple( + await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids) + ) + + movie_ids = [m2.id] + user_ids = [] + assert (r1,) == tuple( + await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids) + ) + + movie_ids = [m2.id] + user_ids = [u2.id] + assert tuple() == tuple( + await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids) + ) + + movie_ids = [m2.id] + user_ids = [u1.id] + assert (r1,) == tuple( + await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids) + ) + + movie_ids = [m1.id, m2.id] + user_ids = [u1.id, u2.id] + assert (r1,) == tuple( + await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids) + ) diff --git a/unwind/db.py b/unwind/db.py index 34abd81..9e9f959 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -537,38 +537,18 @@ def sql_fields(tp: Type): return (f"{tp._table}.{f.name}" for f in fields(tp)) -def sql_in(column: str, values: Iterable[T], not_=False) -> tuple[str, dict[str, T]]: - c = column.replace(".", "___") - value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)} - placeholders = ",".join(":" + k for k in value_map) - if not_: - return f"{column} NOT IN ({placeholders})", value_map - return f"{column} IN ({placeholders})", value_map - - async def ratings_for_movies( movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = [] ) -> Iterable[Rating]: - values: dict[str, str] = {} - conditions: list[str] = [] - - q, vm = sql_in("movie_id", [str(m) for m in movie_ids]) - conditions.append(q) - values.update(vm) + conditions = [ratings.c.movie_id.in_(str(x) for x in movie_ids)] if user_ids: - q, vm = sql_in("user_id", [str(m) for m in user_ids]) - conditions.append(q) - values.update(vm) + conditions.append(ratings.c.user_id.in_(str(x) for x in user_ids)) - query = f""" - SELECT {','.join(sql_fields(Rating))} - FROM {Rating._table} - WHERE {' AND '.join(f'({c})' for c in conditions) if conditions else '1=1'} - """ + query = sa.select(ratings).where(*conditions) async with locked_connection() as conn: - rows = await conn.fetch_all(query, values) + rows = await conn.fetch_all(query) return (fromplain(Rating, row._mapping, serialized=True) for row in rows)