From 1dd7bab4aad8c5a3485e478ec53529ce418559d8 Mon Sep 17 00:00:00 2001 From: ducklet Date: Sun, 19 Mar 2023 22:36:33 +0100 Subject: [PATCH] migrate `db.get_all` to pure SQLAlchemy --- tests/test_db.py | 36 ++++++++++++++++++++++++++++++++++++ unwind/db.py | 18 +++++++++--------- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/tests/test_db.py b/tests/test_db.py index ac8e64b..4670ffe 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -5,6 +5,42 @@ import pytest from unwind import db, models, web_models +@pytest.mark.asyncio +async def test_get_all(shared_conn: db.Database): + async with shared_conn.transaction(force_rollback=True): + m1 = models.Movie( + title="test movie", + release_year=2013, + media_type="Movie", + imdb_id="tt0000000", + genres={"genre-1"}, + ) + await db.add(m1) + + m2 = models.Movie( + title="test movie", + release_year=2013, + media_type="Movie", + imdb_id="tt0000001", + genres={"genre-1"}, + ) + await db.add(m2) + + m3 = models.Movie( + title="test movie", + release_year=2014, + media_type="Movie", + imdb_id="tt0000002", + genres={"genre-1"}, + ) + await db.add(m3) + + assert [] == list(await db.get_all(models.Movie, id="blerp")) + assert [m1] == list(await db.get_all(models.Movie, id=str(m1.id))) + assert [m1, m2] == list(await db.get_all(models.Movie, release_year=2013)) + assert [m1, m2, m3] == list(await db.get_all(models.Movie)) + + @pytest.mark.asyncio async def test_add_and_get(shared_conn: db.Database): async with shared_conn.transaction(force_rollback=True): diff --git a/unwind/db.py b/unwind/db.py index c07b3a9..c3e4a0e 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -6,7 +6,7 @@ import threading from pathlib import Path from typing import Any, Iterable, Literal, Type, TypeVar -import sqlalchemy +import sqlalchemy as sa from databases import Database from . import config @@ -286,14 +286,14 @@ async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]: return (fromplain(model, row._mapping, serialized=True) for row in rows) -async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]: - values = {k: v for k, v in kwds.items() if v is not None} - - fields_ = ", ".join(f.name for f in fields(model)) - cond = " AND ".join(f"{k}=:{k}" for k in values) or "1=1" - query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" +async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]: + """Return all items matching all given field value.""" + table: sa.Table = model.__table__ + query = sa.select(model).where( + *(table.c[k] == v for k, v in field_values.items() if v is not None) + ) async with locked_connection() as conn: - rows = await conn.fetch_all(query=query, values=values) + rows = await conn.fetch_all(query) return (fromplain(model, row._mapping, serialized=True) for row in rows) @@ -667,4 +667,4 @@ def bindparams(query: str, values: dict): return f":{pump_key}" pump_query = re.sub(r":(\w+)\b", pump, query) - return sqlalchemy.text(pump_query).bindparams(**pump_vals) + return sa.text(pump_query).bindparams(**pump_vals)