make the shared connection internal to the db module
This should make it easier to refactor the code for removing the databases package.
This commit is contained in:
parent
22ea553f48
commit
1f42538481
3 changed files with 27 additions and 23 deletions
|
|
@ -17,7 +17,7 @@ def event_loop():
|
|||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def shared_conn():
|
||||
c = db.shared_connection()
|
||||
c = db._shared_connection()
|
||||
await c.connect()
|
||||
|
||||
await db.apply_db_patches(c)
|
||||
|
|
|
|||
46
unwind/db.py
46
unwind/db.py
|
|
@ -31,7 +31,7 @@ from .types import ULID
|
|||
log = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
|
||||
_shared_connection: Database | None = None
|
||||
_database: Database | None = None
|
||||
|
||||
|
||||
async def open_connection_pool() -> None:
|
||||
|
|
@ -39,7 +39,7 @@ async def open_connection_pool() -> None:
|
|||
|
||||
This function needs to be called before any access to the database can happen.
|
||||
"""
|
||||
db = shared_connection()
|
||||
db = _shared_connection()
|
||||
await db.connect()
|
||||
|
||||
await db.execute(sa.text("PRAGMA journal_mode=WAL"))
|
||||
|
|
@ -53,7 +53,7 @@ async def close_connection_pool() -> None:
|
|||
This function should be called before the app shuts down to ensure all data
|
||||
has been flushed to the database.
|
||||
"""
|
||||
db = shared_connection()
|
||||
db = _shared_connection()
|
||||
|
||||
# Run automatic ANALYZE prior to closing the db,
|
||||
# see https://sqlite.com/lang_analyze.html.
|
||||
|
|
@ -205,23 +205,27 @@ async def single_threaded():
|
|||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def locked_connection():
|
||||
async def _locked_connection():
|
||||
async with single_threaded():
|
||||
yield shared_connection()
|
||||
yield _shared_connection()
|
||||
|
||||
|
||||
def shared_connection() -> Database:
|
||||
global _shared_connection
|
||||
def _shared_connection() -> Database:
|
||||
global _database
|
||||
|
||||
if _shared_connection is None:
|
||||
if _database is None:
|
||||
uri = f"sqlite:///{config.storage_path}"
|
||||
# uri = f"sqlite+aiosqlite:///{config.storage_path}"
|
||||
_shared_connection = Database(uri)
|
||||
_database = Database(uri)
|
||||
|
||||
engine = sa.create_engine(uri, future=True)
|
||||
metadata.create_all(engine, tables=[db_patches])
|
||||
|
||||
return _shared_connection
|
||||
return _database
|
||||
|
||||
|
||||
def transaction():
|
||||
return _shared_connection().transaction()
|
||||
|
||||
|
||||
async def add(item: Model) -> None:
|
||||
|
|
@ -233,7 +237,7 @@ async def add(item: Model) -> None:
|
|||
table: sa.Table = item.__table__
|
||||
values = asplain(item, serialize=True)
|
||||
stmt = table.insert().values(values)
|
||||
async with locked_connection() as conn:
|
||||
async with _locked_connection() as conn:
|
||||
await conn.execute(stmt)
|
||||
|
||||
|
||||
|
|
@ -264,7 +268,7 @@ async def get(
|
|||
query = query.order_by(
|
||||
order_col.asc() if order_dir == "asc" else order_col.desc()
|
||||
)
|
||||
async with locked_connection() as conn:
|
||||
async with _locked_connection() as conn:
|
||||
row = await conn.fetch_one(query)
|
||||
return fromplain(model, row._mapping, serialized=True) if row else None
|
||||
|
||||
|
|
@ -284,7 +288,7 @@ async def get_many(
|
|||
|
||||
table: sa.Table = model.__table__
|
||||
query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items()))
|
||||
async with locked_connection() as conn:
|
||||
async with _locked_connection() as conn:
|
||||
rows = await conn.fetch_all(query)
|
||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||
|
||||
|
|
@ -298,7 +302,7 @@ async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]
|
|||
query = sa.select(model).where(
|
||||
*(table.c[k] == v for k, v in field_values.items() if v is not None)
|
||||
)
|
||||
async with locked_connection() as conn:
|
||||
async with _locked_connection() as conn:
|
||||
rows = await conn.fetch_all(query)
|
||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||
|
||||
|
|
@ -312,7 +316,7 @@ async def update(item: Model) -> None:
|
|||
table: sa.Table = item.__table__
|
||||
values = asplain(item, serialize=True)
|
||||
stmt = table.update().where(table.c.id == values["id"]).values(values)
|
||||
async with locked_connection() as conn:
|
||||
async with _locked_connection() as conn:
|
||||
await conn.execute(stmt)
|
||||
|
||||
|
||||
|
|
@ -320,7 +324,7 @@ async def remove(item: Model) -> None:
|
|||
table: sa.Table = item.__table__
|
||||
values = asplain(item, filter_fields={"id"}, serialize=True)
|
||||
stmt = table.delete().where(table.c.id == values["id"])
|
||||
async with locked_connection() as conn:
|
||||
async with _locked_connection() as conn:
|
||||
await conn.execute(stmt)
|
||||
|
||||
|
||||
|
|
@ -471,7 +475,7 @@ async def find_ratings(
|
|||
)
|
||||
.limit(limit_rows)
|
||||
)
|
||||
async with locked_connection() as conn:
|
||||
async with _locked_connection() as conn:
|
||||
rating_rows: AsyncGenerator[Rating, None] = conn.iterate(query) # type: ignore
|
||||
movie_ids = [r.movie_id async for r in rating_rows]
|
||||
|
||||
|
|
@ -487,7 +491,7 @@ async def find_ratings(
|
|||
)
|
||||
.limit(limit_rows - len(movie_ids))
|
||||
)
|
||||
async with locked_connection() as conn:
|
||||
async with _locked_connection() as conn:
|
||||
movie_rows: AsyncGenerator[Movie, None] = conn.iterate(query) # type: ignore
|
||||
movie_ids += [r.id async for r in movie_rows]
|
||||
|
||||
|
|
@ -523,7 +527,7 @@ async def ratings_for_movie_ids(
|
|||
.outerjoin_from(movies, ratings, movies.c.id == ratings.c.movie_id)
|
||||
.where(sa.or_(*conds))
|
||||
)
|
||||
async with locked_connection() as conn:
|
||||
async with _locked_connection() as conn:
|
||||
rows = await conn.fetch_all(query)
|
||||
return tuple(dict(r._mapping) for r in rows)
|
||||
|
||||
|
|
@ -538,7 +542,7 @@ async def ratings_for_movies(
|
|||
|
||||
query = sa.select(ratings).where(*conditions)
|
||||
|
||||
async with locked_connection() as conn:
|
||||
async with _locked_connection() as conn:
|
||||
rows = await conn.fetch_all(query)
|
||||
|
||||
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
|
||||
|
|
@ -602,7 +606,7 @@ async def find_movies(
|
|||
.offset(skip_rows)
|
||||
)
|
||||
|
||||
async with locked_connection() as conn:
|
||||
async with _locked_connection() as conn:
|
||||
rows = await conn.fetch_all(query)
|
||||
|
||||
movies_ = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
|
||||
|
|
|
|||
|
|
@ -445,7 +445,7 @@ async def remove_user(request):
|
|||
if not user:
|
||||
return not_found()
|
||||
|
||||
async with db.shared_connection().transaction():
|
||||
async with db.transaction():
|
||||
# XXX remove user refs from groups and ratings
|
||||
|
||||
await db.remove(user)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue