migrate db.get_all to pure SQLAlchemy

This commit is contained in:
ducklet 2023-03-19 22:36:33 +01:00
parent 5015815097
commit 1dd7bab4aa
2 changed files with 45 additions and 9 deletions

View file

@ -5,6 +5,42 @@ import pytest
from unwind import db, models, web_models
@pytest.mark.asyncio
async def test_get_all(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True):
m1 = models.Movie(
title="test movie",
release_year=2013,
media_type="Movie",
imdb_id="tt0000000",
genres={"genre-1"},
)
await db.add(m1)
m2 = models.Movie(
title="test movie",
release_year=2013,
media_type="Movie",
imdb_id="tt0000001",
genres={"genre-1"},
)
await db.add(m2)
m3 = models.Movie(
title="test movie",
release_year=2014,
media_type="Movie",
imdb_id="tt0000002",
genres={"genre-1"},
)
await db.add(m3)
assert [] == list(await db.get_all(models.Movie, id="blerp"))
assert [m1] == list(await db.get_all(models.Movie, id=str(m1.id)))
assert [m1, m2] == list(await db.get_all(models.Movie, release_year=2013))
assert [m1, m2, m3] == list(await db.get_all(models.Movie))
@pytest.mark.asyncio
async def test_add_and_get(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True):

View file

@ -6,7 +6,7 @@ import threading
from pathlib import Path
from typing import Any, Iterable, Literal, Type, TypeVar
import sqlalchemy
import sqlalchemy as sa
from databases import Database
from . import config
@ -286,14 +286,14 @@ async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
values = {k: v for k, v in kwds.items() if v is not None}
fields_ = ", ".join(f.name for f in fields(model))
cond = " AND ".join(f"{k}=:{k}" for k in values) or "1=1"
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]:
"""Return all items matching all given field value."""
table: sa.Table = model.__table__
query = sa.select(model).where(
*(table.c[k] == v for k, v in field_values.items() if v is not None)
)
async with locked_connection() as conn:
rows = await conn.fetch_all(query=query, values=values)
rows = await conn.fetch_all(query)
return (fromplain(model, row._mapping, serialized=True) for row in rows)
@ -667,4 +667,4 @@ def bindparams(query: str, values: dict):
return f":{pump_key}"
pump_query = re.sub(r":(\w+)\b", pump, query)
return sqlalchemy.text(pump_query).bindparams(**pump_vals)
return sa.text(pump_query).bindparams(**pump_vals)