migrate db.get_many to SQLAlchemy
This commit is contained in:
parent
a444909b1f
commit
af9c166124
2 changed files with 45 additions and 15 deletions
|
|
@ -34,10 +34,37 @@ async def test_get_all(shared_conn: db.Database):
|
||||||
|
|
||||||
assert [] == list(await db.get_all(models.Movie, id="blerp"))
|
assert [] == list(await db.get_all(models.Movie, id="blerp"))
|
||||||
assert [m1] == list(await db.get_all(models.Movie, id=str(m1.id)))
|
assert [m1] == list(await db.get_all(models.Movie, id=str(m1.id)))
|
||||||
assert [m1, m2] == list(await db.get_all(models.Movie, release_year=2013))
|
assert [m1, m2] == list(
|
||||||
|
await db.get_all(models.Movie, release_year=m1.release_year)
|
||||||
|
)
|
||||||
assert [m1, m2, m3] == list(await db.get_all(models.Movie))
|
assert [m1, m2, m3] == list(await db.get_all(models.Movie))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_many(shared_conn: db.Database):
|
||||||
|
async with shared_conn.transaction(force_rollback=True):
|
||||||
|
m1 = a_movie()
|
||||||
|
await db.add(m1)
|
||||||
|
|
||||||
|
m2 = a_movie(release_year=m1.release_year)
|
||||||
|
await db.add(m2)
|
||||||
|
|
||||||
|
m3 = a_movie(release_year=m1.release_year + 1)
|
||||||
|
await db.add(m3)
|
||||||
|
|
||||||
|
assert [] == list(await db.get_many(models.Movie)), "selected nothing"
|
||||||
|
assert [m1] == list(await db.get_many(models.Movie, id=[str(m1.id)]))
|
||||||
|
assert [m1] == list(await db.get_many(models.Movie, id={str(m1.id)}))
|
||||||
|
assert [m1, m2] == list(
|
||||||
|
await db.get_many(models.Movie, release_year=[m1.release_year])
|
||||||
|
)
|
||||||
|
assert [m1, m2, m3] == list(
|
||||||
|
await db.get_many(
|
||||||
|
models.Movie, release_year=[m1.release_year, m3.release_year]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_and_get(shared_conn: db.Database):
|
async def test_add_and_get(shared_conn: db.Database):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
async with shared_conn.transaction(force_rollback=True):
|
||||||
|
|
|
||||||
31
unwind/db.py
31
unwind/db.py
|
|
@ -266,28 +266,31 @@ async def get(
|
||||||
return fromplain(model, row._mapping, serialized=True) if row else None
|
return fromplain(model, row._mapping, serialized=True) if row else None
|
||||||
|
|
||||||
|
|
||||||
async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
async def get_many(
|
||||||
keys = {
|
model: Type[ModelType], **field_sets: set | list
|
||||||
k: [f"{k}_{i}" for i, _ in enumerate(vs, start=1)] for k, vs in kwds.items()
|
) -> Iterable[ModelType]:
|
||||||
}
|
"""Return the items with any values matching all given field sets.
|
||||||
|
|
||||||
if not keys:
|
This is similar to `get_all`, but instead of a scalar value a list of values
|
||||||
|
must be given. If any of the given values is set for that field on an item,
|
||||||
|
the item is considered a match.
|
||||||
|
If no field values are given, no items will be returned.
|
||||||
|
"""
|
||||||
|
if not field_sets:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
values = {n: v for k, vs in kwds.items() for n, v in zip(keys[k], vs)}
|
table: sa.Table = model.__table__
|
||||||
|
query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items()))
|
||||||
fields_ = ", ".join(f.name for f in fields(model))
|
|
||||||
cond = " AND ".join(
|
|
||||||
f"{k} IN ({','.join(':'+n for n in ns)})" for k, ns in keys.items()
|
|
||||||
)
|
|
||||||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(query=query, values=values)
|
rows = await conn.fetch_all(query)
|
||||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
||||||
async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]:
|
async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]:
|
||||||
"""Return all items matching all given field value."""
|
"""Filter all items by comparing all given field values.
|
||||||
|
|
||||||
|
If no filters are given, all items will be returned.
|
||||||
|
"""
|
||||||
table: sa.Table = model.__table__
|
table: sa.Table = model.__table__
|
||||||
query = sa.select(model).where(
|
query = sa.select(model).where(
|
||||||
*(table.c[k] == v for k, v in field_values.items() if v is not None)
|
*(table.c[k] == v for k, v in field_values.items() if v is not None)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue