migrate db.get_all to pure SQLAlchemy
This commit is contained in:
parent
5015815097
commit
1dd7bab4aa
2 changed files with 45 additions and 9 deletions
|
|
@ -5,6 +5,42 @@ import pytest
|
|||
from unwind import db, models, web_models
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all(shared_conn: db.Database):
|
||||
async with shared_conn.transaction(force_rollback=True):
|
||||
m1 = models.Movie(
|
||||
title="test movie",
|
||||
release_year=2013,
|
||||
media_type="Movie",
|
||||
imdb_id="tt0000000",
|
||||
genres={"genre-1"},
|
||||
)
|
||||
await db.add(m1)
|
||||
|
||||
m2 = models.Movie(
|
||||
title="test movie",
|
||||
release_year=2013,
|
||||
media_type="Movie",
|
||||
imdb_id="tt0000001",
|
||||
genres={"genre-1"},
|
||||
)
|
||||
await db.add(m2)
|
||||
|
||||
m3 = models.Movie(
|
||||
title="test movie",
|
||||
release_year=2014,
|
||||
media_type="Movie",
|
||||
imdb_id="tt0000002",
|
||||
genres={"genre-1"},
|
||||
)
|
||||
await db.add(m3)
|
||||
|
||||
assert [] == list(await db.get_all(models.Movie, id="blerp"))
|
||||
assert [m1] == list(await db.get_all(models.Movie, id=str(m1.id)))
|
||||
assert [m1, m2] == list(await db.get_all(models.Movie, release_year=2013))
|
||||
assert [m1, m2, m3] == list(await db.get_all(models.Movie))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_and_get(shared_conn: db.Database):
|
||||
async with shared_conn.transaction(force_rollback=True):
|
||||
|
|
|
|||
18
unwind/db.py
18
unwind/db.py
|
|
@ -6,7 +6,7 @@ import threading
|
|||
from pathlib import Path
|
||||
from typing import Any, Iterable, Literal, Type, TypeVar
|
||||
|
||||
import sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
from databases import Database
|
||||
|
||||
from . import config
|
||||
|
|
@ -286,14 +286,14 @@ async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
|||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||
|
||||
|
||||
async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
||||
values = {k: v for k, v in kwds.items() if v is not None}
|
||||
|
||||
fields_ = ", ".join(f.name for f in fields(model))
|
||||
cond = " AND ".join(f"{k}=:{k}" for k in values) or "1=1"
|
||||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
||||
async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]:
|
||||
"""Return all items matching all given field value."""
|
||||
table: sa.Table = model.__table__
|
||||
query = sa.select(model).where(
|
||||
*(table.c[k] == v for k, v in field_values.items() if v is not None)
|
||||
)
|
||||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(query=query, values=values)
|
||||
rows = await conn.fetch_all(query)
|
||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||
|
||||
|
||||
|
|
@ -667,4 +667,4 @@ def bindparams(query: str, values: dict):
|
|||
return f":{pump_key}"
|
||||
|
||||
pump_query = re.sub(r":(\w+)\b", pump, query)
|
||||
return sqlalchemy.text(pump_query).bindparams(**pump_vals)
|
||||
return sa.text(pump_query).bindparams(**pump_vals)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue