migrate db.get to SQLAlchemy

This commit is contained in:
ducklet 2023-03-20 21:37:50 +01:00
parent af9c166124
commit 6f6354cfac
2 changed files with 45 additions and 10 deletions

View file

@ -20,6 +20,33 @@ def a_movie(**kwds) -> models.Movie:
return models.Movie(**args) return models.Movie(**args)
@pytest.mark.asyncio
async def test_get(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True):
m1 = a_movie()
await db.add(m1)
m2 = a_movie(release_year=m1.release_year + 1)
await db.add(m2)
assert None == await db.get(models.Movie)
assert None == await db.get(models.Movie, id="blerp")
assert m1 == await db.get(models.Movie, id=str(m1.id))
assert m2 == await db.get(models.Movie, release_year=m2.release_year)
assert None == await db.get(
models.Movie, id=str(m1.id), release_year=m2.release_year
)
assert m2 == await db.get(
models.Movie, id=str(m2.id), release_year=m2.release_year
)
assert m1 == await db.get(
models.Movie, media_type=m1.media_type, order_by=("release_year", "asc")
)
assert m2 == await db.get(
models.Movie, media_type=m1.media_type, order_by=("release_year", "desc")
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_all(shared_conn: db.Database): async def test_get_all(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True): async with shared_conn.transaction(force_rollback=True):

View file

@ -132,7 +132,7 @@ async def apply_db_patches(db: Database):
async def get_import_progress() -> Progress | None: async def get_import_progress() -> Progress | None:
"""Return the latest import progress.""" """Return the latest import progress."""
return await get(Progress, type="import-imdb-movies", order_by="started DESC") return await get(Progress, type="import-imdb-movies", order_by=("started", "desc"))
async def stop_import_progress(*, error: BaseException | None = None): async def stop_import_progress(*, error: BaseException | None = None):
@ -244,25 +244,33 @@ ModelType = TypeVar("ModelType")
async def get( async def get(
model: Type[ModelType], *, order_by: str | None = None, **kwds model: Type[ModelType],
*,
order_by: tuple[str, Literal["asc", "desc"]] | None = None,
**field_values,
) -> ModelType | None: ) -> ModelType | None:
"""Load a model instance from the database. """Load a model instance from the database.
Passing `kwds` allows to filter the instance to load. You have to encode the Passing `field_values` allows to filter the item to load. You have to encode the
values as the appropriate data type for the database prior to passing them values as the appropriate data type for the database prior to passing them
to this function. to this function.
""" """
values = {k: v for k, v in kwds.items() if v is not None} if not field_values:
if not values:
return return
fields_ = ", ".join(f.name for f in fields(model)) table: sa.Table = model.__table__
cond = " AND ".join(f"{k}=:{k}" for k in values) query = sa.select(model).where(
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" *(table.c[k] == v for k, v in field_values.items() if v is not None)
)
if order_by: if order_by:
query += f" ORDER BY {order_by}" order_col, order_dir = order_by
query = query.order_by(
table.c[order_col].asc()
if order_dir == "asc"
else table.c[order_col].desc()
)
async with locked_connection() as conn: async with locked_connection() as conn:
row = await conn.fetch_one(query=query, values=values) row = await conn.fetch_one(query)
return fromplain(model, row._mapping, serialized=True) if row else None return fromplain(model, row._mapping, serialized=True) if row else None