diff --git a/tests/test_db.py b/tests/test_db.py index 822cdab..37e27c9 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -34,10 +34,37 @@ async def test_get_all(shared_conn: db.Database): assert [] == list(await db.get_all(models.Movie, id="blerp")) assert [m1] == list(await db.get_all(models.Movie, id=str(m1.id))) - assert [m1, m2] == list(await db.get_all(models.Movie, release_year=2013)) + assert [m1, m2] == list( + await db.get_all(models.Movie, release_year=m1.release_year) + ) assert [m1, m2, m3] == list(await db.get_all(models.Movie)) +@pytest.mark.asyncio +async def test_get_many(shared_conn: db.Database): + async with shared_conn.transaction(force_rollback=True): + m1 = a_movie() + await db.add(m1) + + m2 = a_movie(release_year=m1.release_year) + await db.add(m2) + + m3 = a_movie(release_year=m1.release_year + 1) + await db.add(m3) + + assert [] == list(await db.get_many(models.Movie)), "selected nothing" + assert [m1] == list(await db.get_many(models.Movie, id=[str(m1.id)])) + assert [m1] == list(await db.get_many(models.Movie, id={str(m1.id)})) + assert [m1, m2] == list( + await db.get_many(models.Movie, release_year=[m1.release_year]) + ) + assert [m1, m2, m3] == list( + await db.get_many( + models.Movie, release_year=[m1.release_year, m3.release_year] + ) + ) + + @pytest.mark.asyncio async def test_add_and_get(shared_conn: db.Database): async with shared_conn.transaction(force_rollback=True): diff --git a/unwind/db.py b/unwind/db.py index c3e4a0e..ea09873 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -266,28 +266,31 @@ async def get( return fromplain(model, row._mapping, serialized=True) if row else None -async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]: - keys = { - k: [f"{k}_{i}" for i, _ in enumerate(vs, start=1)] for k, vs in kwds.items() - } +async def get_many( + model: Type[ModelType], **field_sets: set | list +) -> Iterable[ModelType]: + """Return the items with any values matching all given field sets. - if not keys: + This is similar to `get_all`, but instead of a scalar value a list of values + must be given. If any of the given values is set for that field on an item, + the item is considered a match. + If no field values are given, no items will be returned. + """ + if not field_sets: return [] - values = {n: v for k, vs in kwds.items() for n, v in zip(keys[k], vs)} - - fields_ = ", ".join(f.name for f in fields(model)) - cond = " AND ".join( - f"{k} IN ({','.join(':'+n for n in ns)})" for k, ns in keys.items() - ) - query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" + table: sa.Table = model.__table__ + query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items())) async with locked_connection() as conn: - rows = await conn.fetch_all(query=query, values=values) + rows = await conn.fetch_all(query) return (fromplain(model, row._mapping, serialized=True) for row in rows) async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]: - """Return all items matching all given field value.""" + """Filter all items by comparing all given field values. + + If no filters are given, all items will be returned. + """ table: sa.Table = model.__table__ query = sa.select(model).where( *(table.c[k] == v for k, v in field_values.items() if v is not None)