From b91fcd3f55942e41c46986086df10fdd377dc8a0 Mon Sep 17 00:00:00 2001 From: ducklet Date: Thu, 23 Mar 2023 22:49:17 +0100 Subject: [PATCH] migrate `db.add`, `db.update`, `db.remove` to SQLA --- tests/test_db.py | 25 +++++++++++++++++++++++++ unwind/db.py | 18 +++++++++--------- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/tests/test_db.py b/tests/test_db.py index 04cf8b0..33fff26 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -105,6 +105,31 @@ async def test_add_and_get(shared_conn: db.Database): assert m2 == await db.get(models.Movie, id=str(m2.id)) +@pytest.mark.asyncio +async def test_update(shared_conn: db.Database): + async with shared_conn.transaction(force_rollback=True): + m = a_movie() + await db.add(m) + + assert m == await db.get(models.Movie, id=str(m.id)) + m.title += "something else" + assert m != await db.get(models.Movie, id=str(m.id)) + + await db.update(m) + assert m == await db.get(models.Movie, id=str(m.id)) + + +@pytest.mark.asyncio +async def test_remove(shared_conn: db.Database): + async with shared_conn.transaction(force_rollback=True): + m1 = a_movie() + await db.add(m1) + assert m1 == await db.get(models.Movie, id=str(m1.id)) + + await db.remove(m1) + assert None == await db.get(models.Movie, id=str(m1.id)) + + @pytest.mark.asyncio async def test_find_ratings(shared_conn: db.Database): async with shared_conn.transaction(force_rollback=True): diff --git a/unwind/db.py b/unwind/db.py index 39d383e..1f15bc9 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -232,12 +232,11 @@ async def add(item): if getattr(item, "_is_lazy", False): item._lazy_init() + table: sa.Table = item.__table__ values = asplain(item, serialize=True) - keys = ", ".join(f"{k}" for k in values) - placeholders = ", ".join(f":{k}" for k in values) - query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})" + stmt = table.insert().values(values) async with locked_connection() as conn: - await conn.execute(query=query, values=values) + await conn.execute(stmt) ModelType = TypeVar("ModelType") @@ -313,18 +312,19 @@ async def update(item): if getattr(item, "_is_lazy", False): item._lazy_init() + table: sa.Table = item.__table__ values = asplain(item, serialize=True) - keys = ", ".join(f"{k}=:{k}" for k in values if k != "id") - query = f"UPDATE {item._table} SET {keys} WHERE id=:id" + stmt = table.update().where(table.c.id == values["id"]).values(values) async with locked_connection() as conn: - await conn.execute(query=query, values=values) + await conn.execute(stmt) async def remove(item): + table: sa.Table = item.__table__ values = asplain(item, filter_fields={"id"}, serialize=True) - query = f"DELETE FROM {item._table} WHERE id=:id" + stmt = table.delete().where(table.c.id == values["id"]) async with locked_connection() as conn: - await conn.execute(query=query, values=values) + await conn.execute(stmt) async def add_or_update_user(user: User):