migrate db.ratings_for_movies to SQLAlchemy
This commit is contained in:
parent
1a3528e096
commit
1fd7e730b3
2 changed files with 70 additions and 24 deletions
|
|
@ -250,3 +250,69 @@ async def test_find_ratings(shared_conn: db.Database):
|
||||||
rows = await db.find_ratings(title="test", include_unrated=True)
|
rows = await db.find_ratings(title="test", include_unrated=True)
|
||||||
ratings = tuple(web_models.Rating(**r) for r in rows)
|
ratings = tuple(web_models.Rating(**r) for r in rows)
|
||||||
assert (web_models.Rating.from_movie(m1),) == ratings
|
assert (web_models.Rating.from_movie(m1),) == ratings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ratings_for_movies(shared_conn: db.Database):
|
||||||
|
async with shared_conn.transaction(force_rollback=True):
|
||||||
|
m1 = a_movie()
|
||||||
|
await db.add(m1)
|
||||||
|
|
||||||
|
m2 = a_movie()
|
||||||
|
await db.add(m2)
|
||||||
|
|
||||||
|
u1 = models.User(
|
||||||
|
imdb_id="u00001",
|
||||||
|
name="User1",
|
||||||
|
secret="secret1",
|
||||||
|
)
|
||||||
|
await db.add(u1)
|
||||||
|
|
||||||
|
u2 = models.User(
|
||||||
|
imdb_id="u00002",
|
||||||
|
name="User2",
|
||||||
|
secret="secret2",
|
||||||
|
)
|
||||||
|
await db.add(u2)
|
||||||
|
|
||||||
|
r1 = models.Rating(
|
||||||
|
movie_id=m2.id,
|
||||||
|
movie=m2,
|
||||||
|
user_id=u1.id,
|
||||||
|
user=u1,
|
||||||
|
score=66,
|
||||||
|
rating_date=datetime.now(),
|
||||||
|
)
|
||||||
|
await db.add(r1)
|
||||||
|
|
||||||
|
# ---
|
||||||
|
|
||||||
|
movie_ids = [m1.id]
|
||||||
|
user_ids = []
|
||||||
|
assert tuple() == tuple(
|
||||||
|
await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
movie_ids = [m2.id]
|
||||||
|
user_ids = []
|
||||||
|
assert (r1,) == tuple(
|
||||||
|
await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
movie_ids = [m2.id]
|
||||||
|
user_ids = [u2.id]
|
||||||
|
assert tuple() == tuple(
|
||||||
|
await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
movie_ids = [m2.id]
|
||||||
|
user_ids = [u1.id]
|
||||||
|
assert (r1,) == tuple(
|
||||||
|
await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
movie_ids = [m1.id, m2.id]
|
||||||
|
user_ids = [u1.id, u2.id]
|
||||||
|
assert (r1,) == tuple(
|
||||||
|
await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids)
|
||||||
|
)
|
||||||
|
|
|
||||||
28
unwind/db.py
28
unwind/db.py
|
|
@ -537,38 +537,18 @@ def sql_fields(tp: Type):
|
||||||
return (f"{tp._table}.{f.name}" for f in fields(tp))
|
return (f"{tp._table}.{f.name}" for f in fields(tp))
|
||||||
|
|
||||||
|
|
||||||
def sql_in(column: str, values: Iterable[T], not_=False) -> tuple[str, dict[str, T]]:
|
|
||||||
c = column.replace(".", "___")
|
|
||||||
value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)}
|
|
||||||
placeholders = ",".join(":" + k for k in value_map)
|
|
||||||
if not_:
|
|
||||||
return f"{column} NOT IN ({placeholders})", value_map
|
|
||||||
return f"{column} IN ({placeholders})", value_map
|
|
||||||
|
|
||||||
|
|
||||||
async def ratings_for_movies(
|
async def ratings_for_movies(
|
||||||
movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = []
|
movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = []
|
||||||
) -> Iterable[Rating]:
|
) -> Iterable[Rating]:
|
||||||
values: dict[str, str] = {}
|
conditions = [ratings.c.movie_id.in_(str(x) for x in movie_ids)]
|
||||||
conditions: list[str] = []
|
|
||||||
|
|
||||||
q, vm = sql_in("movie_id", [str(m) for m in movie_ids])
|
|
||||||
conditions.append(q)
|
|
||||||
values.update(vm)
|
|
||||||
|
|
||||||
if user_ids:
|
if user_ids:
|
||||||
q, vm = sql_in("user_id", [str(m) for m in user_ids])
|
conditions.append(ratings.c.user_id.in_(str(x) for x in user_ids))
|
||||||
conditions.append(q)
|
|
||||||
values.update(vm)
|
|
||||||
|
|
||||||
query = f"""
|
query = sa.select(ratings).where(*conditions)
|
||||||
SELECT {','.join(sql_fields(Rating))}
|
|
||||||
FROM {Rating._table}
|
|
||||||
WHERE {' AND '.join(f'({c})' for c in conditions) if conditions else '1=1'}
|
|
||||||
"""
|
|
||||||
|
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(query, values)
|
rows = await conn.fetch_all(query)
|
||||||
|
|
||||||
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
|
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue