migrate db.get_all to pure SQLAlchemy

This commit is contained in:
ducklet 2023-03-19 22:36:33 +01:00
parent 5015815097
commit 1dd7bab4aa
2 changed files with 45 additions and 9 deletions

View file

@ -5,6 +5,42 @@ import pytest
from unwind import db, models, web_models from unwind import db, models, web_models
@pytest.mark.asyncio
async def test_get_all(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True):
m1 = models.Movie(
title="test movie",
release_year=2013,
media_type="Movie",
imdb_id="tt0000000",
genres={"genre-1"},
)
await db.add(m1)
m2 = models.Movie(
title="test movie",
release_year=2013,
media_type="Movie",
imdb_id="tt0000001",
genres={"genre-1"},
)
await db.add(m2)
m3 = models.Movie(
title="test movie",
release_year=2014,
media_type="Movie",
imdb_id="tt0000002",
genres={"genre-1"},
)
await db.add(m3)
assert [] == list(await db.get_all(models.Movie, id="blerp"))
assert [m1] == list(await db.get_all(models.Movie, id=str(m1.id)))
assert [m1, m2] == list(await db.get_all(models.Movie, release_year=2013))
assert [m1, m2, m3] == list(await db.get_all(models.Movie))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_and_get(shared_conn: db.Database): async def test_add_and_get(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True): async with shared_conn.transaction(force_rollback=True):

View file

@ -6,7 +6,7 @@ import threading
from pathlib import Path from pathlib import Path
from typing import Any, Iterable, Literal, Type, TypeVar from typing import Any, Iterable, Literal, Type, TypeVar
import sqlalchemy import sqlalchemy as sa
from databases import Database from databases import Database
from . import config from . import config
@ -286,14 +286,14 @@ async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
return (fromplain(model, row._mapping, serialized=True) for row in rows) return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]: async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]:
values = {k: v for k, v in kwds.items() if v is not None} """Return all items matching all given field value."""
table: sa.Table = model.__table__
fields_ = ", ".join(f.name for f in fields(model)) query = sa.select(model).where(
cond = " AND ".join(f"{k}=:{k}" for k in values) or "1=1" *(table.c[k] == v for k, v in field_values.items() if v is not None)
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" )
async with locked_connection() as conn: async with locked_connection() as conn:
rows = await conn.fetch_all(query=query, values=values) rows = await conn.fetch_all(query)
return (fromplain(model, row._mapping, serialized=True) for row in rows) return (fromplain(model, row._mapping, serialized=True) for row in rows)
@ -667,4 +667,4 @@ def bindparams(query: str, values: dict):
return f":{pump_key}" return f":{pump_key}"
pump_query = re.sub(r":(\w+)\b", pump, query) pump_query = re.sub(r":(\w+)\b", pump, query)
return sqlalchemy.text(pump_query).bindparams(**pump_vals) return sa.text(pump_query).bindparams(**pump_vals)