migrate db.get_all to pure SQLAlchemy
This commit is contained in:
parent
5015815097
commit
1dd7bab4aa
2 changed files with 45 additions and 9 deletions
|
|
@ -5,6 +5,42 @@ import pytest
|
||||||
from unwind import db, models, web_models
|
from unwind import db, models, web_models
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_all(shared_conn: db.Database):
|
||||||
|
async with shared_conn.transaction(force_rollback=True):
|
||||||
|
m1 = models.Movie(
|
||||||
|
title="test movie",
|
||||||
|
release_year=2013,
|
||||||
|
media_type="Movie",
|
||||||
|
imdb_id="tt0000000",
|
||||||
|
genres={"genre-1"},
|
||||||
|
)
|
||||||
|
await db.add(m1)
|
||||||
|
|
||||||
|
m2 = models.Movie(
|
||||||
|
title="test movie",
|
||||||
|
release_year=2013,
|
||||||
|
media_type="Movie",
|
||||||
|
imdb_id="tt0000001",
|
||||||
|
genres={"genre-1"},
|
||||||
|
)
|
||||||
|
await db.add(m2)
|
||||||
|
|
||||||
|
m3 = models.Movie(
|
||||||
|
title="test movie",
|
||||||
|
release_year=2014,
|
||||||
|
media_type="Movie",
|
||||||
|
imdb_id="tt0000002",
|
||||||
|
genres={"genre-1"},
|
||||||
|
)
|
||||||
|
await db.add(m3)
|
||||||
|
|
||||||
|
assert [] == list(await db.get_all(models.Movie, id="blerp"))
|
||||||
|
assert [m1] == list(await db.get_all(models.Movie, id=str(m1.id)))
|
||||||
|
assert [m1, m2] == list(await db.get_all(models.Movie, release_year=2013))
|
||||||
|
assert [m1, m2, m3] == list(await db.get_all(models.Movie))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_and_get(shared_conn: db.Database):
|
async def test_add_and_get(shared_conn: db.Database):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
async with shared_conn.transaction(force_rollback=True):
|
||||||
|
|
|
||||||
18
unwind/db.py
18
unwind/db.py
|
|
@ -6,7 +6,7 @@ import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Iterable, Literal, Type, TypeVar
|
from typing import Any, Iterable, Literal, Type, TypeVar
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy as sa
|
||||||
from databases import Database
|
from databases import Database
|
||||||
|
|
||||||
from . import config
|
from . import config
|
||||||
|
|
@ -286,14 +286,14 @@ async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
||||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
||||||
async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]:
|
||||||
values = {k: v for k, v in kwds.items() if v is not None}
|
"""Return all items matching all given field value."""
|
||||||
|
table: sa.Table = model.__table__
|
||||||
fields_ = ", ".join(f.name for f in fields(model))
|
query = sa.select(model).where(
|
||||||
cond = " AND ".join(f"{k}=:{k}" for k in values) or "1=1"
|
*(table.c[k] == v for k, v in field_values.items() if v is not None)
|
||||||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
)
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(query=query, values=values)
|
rows = await conn.fetch_all(query)
|
||||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -667,4 +667,4 @@ def bindparams(query: str, values: dict):
|
||||||
return f":{pump_key}"
|
return f":{pump_key}"
|
||||||
|
|
||||||
pump_query = re.sub(r":(\w+)\b", pump, query)
|
pump_query = re.sub(r":(\w+)\b", pump, query)
|
||||||
return sqlalchemy.text(pump_query).bindparams(**pump_vals)
|
return sa.text(pump_query).bindparams(**pump_vals)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue