From 6f6354cface73988eadb032e4446b58233efc300 Mon Sep 17 00:00:00 2001 From: ducklet Date: Mon, 20 Mar 2023 21:37:50 +0100 Subject: [PATCH] migrate `db.get` to SQLAlchemy --- tests/test_db.py | 27 +++++++++++++++++++++++++++ unwind/db.py | 28 ++++++++++++++++++---------- 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/tests/test_db.py b/tests/test_db.py index 37e27c9..04cf8b0 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -20,6 +20,33 @@ def a_movie(**kwds) -> models.Movie: return models.Movie(**args) +@pytest.mark.asyncio +async def test_get(shared_conn: db.Database): + async with shared_conn.transaction(force_rollback=True): + m1 = a_movie() + await db.add(m1) + + m2 = a_movie(release_year=m1.release_year + 1) + await db.add(m2) + + assert None == await db.get(models.Movie) + assert None == await db.get(models.Movie, id="blerp") + assert m1 == await db.get(models.Movie, id=str(m1.id)) + assert m2 == await db.get(models.Movie, release_year=m2.release_year) + assert None == await db.get( + models.Movie, id=str(m1.id), release_year=m2.release_year + ) + assert m2 == await db.get( + models.Movie, id=str(m2.id), release_year=m2.release_year + ) + assert m1 == await db.get( + models.Movie, media_type=m1.media_type, order_by=("release_year", "asc") + ) + assert m2 == await db.get( + models.Movie, media_type=m1.media_type, order_by=("release_year", "desc") + ) + + @pytest.mark.asyncio async def test_get_all(shared_conn: db.Database): async with shared_conn.transaction(force_rollback=True): diff --git a/unwind/db.py b/unwind/db.py index ea09873..39d383e 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -132,7 +132,7 @@ async def apply_db_patches(db: Database): async def get_import_progress() -> Progress | None: """Return the latest import progress.""" - return await get(Progress, type="import-imdb-movies", order_by="started DESC") + return await get(Progress, type="import-imdb-movies", order_by=("started", "desc")) async def stop_import_progress(*, error: BaseException | None = None): @@ -244,25 +244,33 @@ ModelType = TypeVar("ModelType") async def get( - model: Type[ModelType], *, order_by: str | None = None, **kwds + model: Type[ModelType], + *, + order_by: tuple[str, Literal["asc", "desc"]] | None = None, + **field_values, ) -> ModelType | None: """Load a model instance from the database. - Passing `kwds` allows to filter the instance to load. You have to encode the + Passing `field_values` allows to filter the item to load. You have to encode the values as the appropriate data type for the database prior to passing them to this function. """ - values = {k: v for k, v in kwds.items() if v is not None} - if not values: + if not field_values: return - fields_ = ", ".join(f.name for f in fields(model)) - cond = " AND ".join(f"{k}=:{k}" for k in values) - query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" + table: sa.Table = model.__table__ + query = sa.select(model).where( + *(table.c[k] == v for k, v in field_values.items() if v is not None) + ) if order_by: - query += f" ORDER BY {order_by}" + order_col, order_dir = order_by + query = query.order_by( + table.c[order_col].asc() + if order_dir == "asc" + else table.c[order_col].desc() + ) async with locked_connection() as conn: - row = await conn.fetch_one(query=query, values=values) + row = await conn.fetch_one(query) return fromplain(model, row._mapping, serialized=True) if row else None