make the shared connection internal to the db module

This should make it easier to refactor the code for removing the
databases package.
This commit is contained in:
ducklet 2023-11-26 18:41:32 +01:00
parent 22ea553f48
commit 1f42538481
3 changed files with 27 additions and 23 deletions

View file

@ -17,7 +17,7 @@ def event_loop():
@pytest_asyncio.fixture(scope="session")
async def shared_conn():
c = db.shared_connection()
c = db._shared_connection()
await c.connect()
await db.apply_db_patches(c)

View file

@ -31,7 +31,7 @@ from .types import ULID
log = logging.getLogger(__name__)
T = TypeVar("T")
_shared_connection: Database | None = None
_database: Database | None = None
async def open_connection_pool() -> None:
@ -39,7 +39,7 @@ async def open_connection_pool() -> None:
This function needs to be called before any access to the database can happen.
"""
db = shared_connection()
db = _shared_connection()
await db.connect()
await db.execute(sa.text("PRAGMA journal_mode=WAL"))
@ -53,7 +53,7 @@ async def close_connection_pool() -> None:
This function should be called before the app shuts down to ensure all data
has been flushed to the database.
"""
db = shared_connection()
db = _shared_connection()
# Run automatic ANALYZE prior to closing the db,
# see https://sqlite.com/lang_analyze.html.
@ -205,23 +205,27 @@ async def single_threaded():
@contextlib.asynccontextmanager
async def locked_connection():
async def _locked_connection():
async with single_threaded():
yield shared_connection()
yield _shared_connection()
def shared_connection() -> Database:
global _shared_connection
def _shared_connection() -> Database:
global _database
if _shared_connection is None:
if _database is None:
uri = f"sqlite:///{config.storage_path}"
# uri = f"sqlite+aiosqlite:///{config.storage_path}"
_shared_connection = Database(uri)
_database = Database(uri)
engine = sa.create_engine(uri, future=True)
metadata.create_all(engine, tables=[db_patches])
return _shared_connection
return _database
def transaction():
return _shared_connection().transaction()
async def add(item: Model) -> None:
@ -233,7 +237,7 @@ async def add(item: Model) -> None:
table: sa.Table = item.__table__
values = asplain(item, serialize=True)
stmt = table.insert().values(values)
async with locked_connection() as conn:
async with _locked_connection() as conn:
await conn.execute(stmt)
@ -264,7 +268,7 @@ async def get(
query = query.order_by(
order_col.asc() if order_dir == "asc" else order_col.desc()
)
async with locked_connection() as conn:
async with _locked_connection() as conn:
row = await conn.fetch_one(query)
return fromplain(model, row._mapping, serialized=True) if row else None
@ -284,7 +288,7 @@ async def get_many(
table: sa.Table = model.__table__
query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items()))
async with locked_connection() as conn:
async with _locked_connection() as conn:
rows = await conn.fetch_all(query)
return (fromplain(model, row._mapping, serialized=True) for row in rows)
@ -298,7 +302,7 @@ async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]
query = sa.select(model).where(
*(table.c[k] == v for k, v in field_values.items() if v is not None)
)
async with locked_connection() as conn:
async with _locked_connection() as conn:
rows = await conn.fetch_all(query)
return (fromplain(model, row._mapping, serialized=True) for row in rows)
@ -312,7 +316,7 @@ async def update(item: Model) -> None:
table: sa.Table = item.__table__
values = asplain(item, serialize=True)
stmt = table.update().where(table.c.id == values["id"]).values(values)
async with locked_connection() as conn:
async with _locked_connection() as conn:
await conn.execute(stmt)
@ -320,7 +324,7 @@ async def remove(item: Model) -> None:
table: sa.Table = item.__table__
values = asplain(item, filter_fields={"id"}, serialize=True)
stmt = table.delete().where(table.c.id == values["id"])
async with locked_connection() as conn:
async with _locked_connection() as conn:
await conn.execute(stmt)
@ -471,7 +475,7 @@ async def find_ratings(
)
.limit(limit_rows)
)
async with locked_connection() as conn:
async with _locked_connection() as conn:
rating_rows: AsyncGenerator[Rating, None] = conn.iterate(query) # type: ignore
movie_ids = [r.movie_id async for r in rating_rows]
@ -487,7 +491,7 @@ async def find_ratings(
)
.limit(limit_rows - len(movie_ids))
)
async with locked_connection() as conn:
async with _locked_connection() as conn:
movie_rows: AsyncGenerator[Movie, None] = conn.iterate(query) # type: ignore
movie_ids += [r.id async for r in movie_rows]
@ -523,7 +527,7 @@ async def ratings_for_movie_ids(
.outerjoin_from(movies, ratings, movies.c.id == ratings.c.movie_id)
.where(sa.or_(*conds))
)
async with locked_connection() as conn:
async with _locked_connection() as conn:
rows = await conn.fetch_all(query)
return tuple(dict(r._mapping) for r in rows)
@ -538,7 +542,7 @@ async def ratings_for_movies(
query = sa.select(ratings).where(*conditions)
async with locked_connection() as conn:
async with _locked_connection() as conn:
rows = await conn.fetch_all(query)
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
@ -602,7 +606,7 @@ async def find_movies(
.offset(skip_rows)
)
async with locked_connection() as conn:
async with _locked_connection() as conn:
rows = await conn.fetch_all(query)
movies_ = [fromplain(Movie, row._mapping, serialized=True) for row in rows]

View file

@ -445,7 +445,7 @@ async def remove_user(request):
if not user:
return not_found()
async with db.shared_connection().transaction():
async with db.transaction():
# XXX remove user refs from groups and ratings
await db.remove(user)