migrate db.get to SQLAlchemy
This commit is contained in:
parent
af9c166124
commit
6f6354cfac
2 changed files with 45 additions and 10 deletions
|
|
@ -20,6 +20,33 @@ def a_movie(**kwds) -> models.Movie:
|
||||||
return models.Movie(**args)
|
return models.Movie(**args)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get(shared_conn: db.Database):
|
||||||
|
async with shared_conn.transaction(force_rollback=True):
|
||||||
|
m1 = a_movie()
|
||||||
|
await db.add(m1)
|
||||||
|
|
||||||
|
m2 = a_movie(release_year=m1.release_year + 1)
|
||||||
|
await db.add(m2)
|
||||||
|
|
||||||
|
assert None == await db.get(models.Movie)
|
||||||
|
assert None == await db.get(models.Movie, id="blerp")
|
||||||
|
assert m1 == await db.get(models.Movie, id=str(m1.id))
|
||||||
|
assert m2 == await db.get(models.Movie, release_year=m2.release_year)
|
||||||
|
assert None == await db.get(
|
||||||
|
models.Movie, id=str(m1.id), release_year=m2.release_year
|
||||||
|
)
|
||||||
|
assert m2 == await db.get(
|
||||||
|
models.Movie, id=str(m2.id), release_year=m2.release_year
|
||||||
|
)
|
||||||
|
assert m1 == await db.get(
|
||||||
|
models.Movie, media_type=m1.media_type, order_by=("release_year", "asc")
|
||||||
|
)
|
||||||
|
assert m2 == await db.get(
|
||||||
|
models.Movie, media_type=m1.media_type, order_by=("release_year", "desc")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_all(shared_conn: db.Database):
|
async def test_get_all(shared_conn: db.Database):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
async with shared_conn.transaction(force_rollback=True):
|
||||||
|
|
|
||||||
28
unwind/db.py
28
unwind/db.py
|
|
@ -132,7 +132,7 @@ async def apply_db_patches(db: Database):
|
||||||
|
|
||||||
async def get_import_progress() -> Progress | None:
|
async def get_import_progress() -> Progress | None:
|
||||||
"""Return the latest import progress."""
|
"""Return the latest import progress."""
|
||||||
return await get(Progress, type="import-imdb-movies", order_by="started DESC")
|
return await get(Progress, type="import-imdb-movies", order_by=("started", "desc"))
|
||||||
|
|
||||||
|
|
||||||
async def stop_import_progress(*, error: BaseException | None = None):
|
async def stop_import_progress(*, error: BaseException | None = None):
|
||||||
|
|
@ -244,25 +244,33 @@ ModelType = TypeVar("ModelType")
|
||||||
|
|
||||||
|
|
||||||
async def get(
|
async def get(
|
||||||
model: Type[ModelType], *, order_by: str | None = None, **kwds
|
model: Type[ModelType],
|
||||||
|
*,
|
||||||
|
order_by: tuple[str, Literal["asc", "desc"]] | None = None,
|
||||||
|
**field_values,
|
||||||
) -> ModelType | None:
|
) -> ModelType | None:
|
||||||
"""Load a model instance from the database.
|
"""Load a model instance from the database.
|
||||||
|
|
||||||
Passing `kwds` allows to filter the instance to load. You have to encode the
|
Passing `field_values` allows to filter the item to load. You have to encode the
|
||||||
values as the appropriate data type for the database prior to passing them
|
values as the appropriate data type for the database prior to passing them
|
||||||
to this function.
|
to this function.
|
||||||
"""
|
"""
|
||||||
values = {k: v for k, v in kwds.items() if v is not None}
|
if not field_values:
|
||||||
if not values:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
fields_ = ", ".join(f.name for f in fields(model))
|
table: sa.Table = model.__table__
|
||||||
cond = " AND ".join(f"{k}=:{k}" for k in values)
|
query = sa.select(model).where(
|
||||||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
*(table.c[k] == v for k, v in field_values.items() if v is not None)
|
||||||
|
)
|
||||||
if order_by:
|
if order_by:
|
||||||
query += f" ORDER BY {order_by}"
|
order_col, order_dir = order_by
|
||||||
|
query = query.order_by(
|
||||||
|
table.c[order_col].asc()
|
||||||
|
if order_dir == "asc"
|
||||||
|
else table.c[order_col].desc()
|
||||||
|
)
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
row = await conn.fetch_one(query=query, values=values)
|
row = await conn.fetch_one(query)
|
||||||
return fromplain(model, row._mapping, serialized=True) if row else None
|
return fromplain(model, row._mapping, serialized=True) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue