migrate db.get_many to SQLAlchemy

This commit is contained in:
ducklet 2023-03-19 23:14:59 +01:00
parent a444909b1f
commit af9c166124
2 changed files with 45 additions and 15 deletions

View file

@ -34,10 +34,37 @@ async def test_get_all(shared_conn: db.Database):
assert [] == list(await db.get_all(models.Movie, id="blerp")) assert [] == list(await db.get_all(models.Movie, id="blerp"))
assert [m1] == list(await db.get_all(models.Movie, id=str(m1.id))) assert [m1] == list(await db.get_all(models.Movie, id=str(m1.id)))
assert [m1, m2] == list(await db.get_all(models.Movie, release_year=2013)) assert [m1, m2] == list(
await db.get_all(models.Movie, release_year=m1.release_year)
)
assert [m1, m2, m3] == list(await db.get_all(models.Movie)) assert [m1, m2, m3] == list(await db.get_all(models.Movie))
@pytest.mark.asyncio
async def test_get_many(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True):
m1 = a_movie()
await db.add(m1)
m2 = a_movie(release_year=m1.release_year)
await db.add(m2)
m3 = a_movie(release_year=m1.release_year + 1)
await db.add(m3)
assert [] == list(await db.get_many(models.Movie)), "selected nothing"
assert [m1] == list(await db.get_many(models.Movie, id=[str(m1.id)]))
assert [m1] == list(await db.get_many(models.Movie, id={str(m1.id)}))
assert [m1, m2] == list(
await db.get_many(models.Movie, release_year=[m1.release_year])
)
assert [m1, m2, m3] == list(
await db.get_many(
models.Movie, release_year=[m1.release_year, m3.release_year]
)
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_and_get(shared_conn: db.Database): async def test_add_and_get(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True): async with shared_conn.transaction(force_rollback=True):

View file

@ -266,28 +266,31 @@ async def get(
return fromplain(model, row._mapping, serialized=True) if row else None return fromplain(model, row._mapping, serialized=True) if row else None
async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]: async def get_many(
keys = { model: Type[ModelType], **field_sets: set | list
k: [f"{k}_{i}" for i, _ in enumerate(vs, start=1)] for k, vs in kwds.items() ) -> Iterable[ModelType]:
} """Return the items with any values matching all given field sets.
if not keys: This is similar to `get_all`, but instead of a scalar value a list of values
must be given. If any of the given values is set for that field on an item,
the item is considered a match.
If no field values are given, no items will be returned.
"""
if not field_sets:
return [] return []
values = {n: v for k, vs in kwds.items() for n, v in zip(keys[k], vs)} table: sa.Table = model.__table__
query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items()))
fields_ = ", ".join(f.name for f in fields(model))
cond = " AND ".join(
f"{k} IN ({','.join(':'+n for n in ns)})" for k, ns in keys.items()
)
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
async with locked_connection() as conn: async with locked_connection() as conn:
rows = await conn.fetch_all(query=query, values=values) rows = await conn.fetch_all(query)
return (fromplain(model, row._mapping, serialized=True) for row in rows) return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]: async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]:
"""Return all items matching all given field value.""" """Filter all items by comparing all given field values.
If no filters are given, all items will be returned.
"""
table: sa.Table = model.__table__ table: sa.Table = model.__table__
query = sa.select(model).where( query = sa.select(model).where(
*(table.c[k] == v for k, v in field_values.items() if v is not None) *(table.c[k] == v for k, v in field_values.items() if v is not None)