make the shared connection internal to the db module

This should make it easier to refactor the code for removing the
databases package.
This commit is contained in:
ducklet 2023-11-26 18:41:32 +01:00
parent 22ea553f48
commit 1f42538481
3 changed files with 27 additions and 23 deletions

View file

@ -17,7 +17,7 @@ def event_loop():
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def shared_conn(): async def shared_conn():
c = db.shared_connection() c = db._shared_connection()
await c.connect() await c.connect()
await db.apply_db_patches(c) await db.apply_db_patches(c)

View file

@ -31,7 +31,7 @@ from .types import ULID
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
T = TypeVar("T") T = TypeVar("T")
_shared_connection: Database | None = None _database: Database | None = None
async def open_connection_pool() -> None: async def open_connection_pool() -> None:
@ -39,7 +39,7 @@ async def open_connection_pool() -> None:
This function needs to be called before any access to the database can happen. This function needs to be called before any access to the database can happen.
""" """
db = shared_connection() db = _shared_connection()
await db.connect() await db.connect()
await db.execute(sa.text("PRAGMA journal_mode=WAL")) await db.execute(sa.text("PRAGMA journal_mode=WAL"))
@ -53,7 +53,7 @@ async def close_connection_pool() -> None:
This function should be called before the app shuts down to ensure all data This function should be called before the app shuts down to ensure all data
has been flushed to the database. has been flushed to the database.
""" """
db = shared_connection() db = _shared_connection()
# Run automatic ANALYZE prior to closing the db, # Run automatic ANALYZE prior to closing the db,
# see https://sqlite.com/lang_analyze.html. # see https://sqlite.com/lang_analyze.html.
@ -205,23 +205,27 @@ async def single_threaded():
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def locked_connection(): async def _locked_connection():
async with single_threaded(): async with single_threaded():
yield shared_connection() yield _shared_connection()
def shared_connection() -> Database: def _shared_connection() -> Database:
global _shared_connection global _database
if _shared_connection is None: if _database is None:
uri = f"sqlite:///{config.storage_path}" uri = f"sqlite:///{config.storage_path}"
# uri = f"sqlite+aiosqlite:///{config.storage_path}" # uri = f"sqlite+aiosqlite:///{config.storage_path}"
_shared_connection = Database(uri) _database = Database(uri)
engine = sa.create_engine(uri, future=True) engine = sa.create_engine(uri, future=True)
metadata.create_all(engine, tables=[db_patches]) metadata.create_all(engine, tables=[db_patches])
return _shared_connection return _database
def transaction():
return _shared_connection().transaction()
async def add(item: Model) -> None: async def add(item: Model) -> None:
@ -233,7 +237,7 @@ async def add(item: Model) -> None:
table: sa.Table = item.__table__ table: sa.Table = item.__table__
values = asplain(item, serialize=True) values = asplain(item, serialize=True)
stmt = table.insert().values(values) stmt = table.insert().values(values)
async with locked_connection() as conn: async with _locked_connection() as conn:
await conn.execute(stmt) await conn.execute(stmt)
@ -264,7 +268,7 @@ async def get(
query = query.order_by( query = query.order_by(
order_col.asc() if order_dir == "asc" else order_col.desc() order_col.asc() if order_dir == "asc" else order_col.desc()
) )
async with locked_connection() as conn: async with _locked_connection() as conn:
row = await conn.fetch_one(query) row = await conn.fetch_one(query)
return fromplain(model, row._mapping, serialized=True) if row else None return fromplain(model, row._mapping, serialized=True) if row else None
@ -284,7 +288,7 @@ async def get_many(
table: sa.Table = model.__table__ table: sa.Table = model.__table__
query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items())) query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items()))
async with locked_connection() as conn: async with _locked_connection() as conn:
rows = await conn.fetch_all(query) rows = await conn.fetch_all(query)
return (fromplain(model, row._mapping, serialized=True) for row in rows) return (fromplain(model, row._mapping, serialized=True) for row in rows)
@ -298,7 +302,7 @@ async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]
query = sa.select(model).where( query = sa.select(model).where(
*(table.c[k] == v for k, v in field_values.items() if v is not None) *(table.c[k] == v for k, v in field_values.items() if v is not None)
) )
async with locked_connection() as conn: async with _locked_connection() as conn:
rows = await conn.fetch_all(query) rows = await conn.fetch_all(query)
return (fromplain(model, row._mapping, serialized=True) for row in rows) return (fromplain(model, row._mapping, serialized=True) for row in rows)
@ -312,7 +316,7 @@ async def update(item: Model) -> None:
table: sa.Table = item.__table__ table: sa.Table = item.__table__
values = asplain(item, serialize=True) values = asplain(item, serialize=True)
stmt = table.update().where(table.c.id == values["id"]).values(values) stmt = table.update().where(table.c.id == values["id"]).values(values)
async with locked_connection() as conn: async with _locked_connection() as conn:
await conn.execute(stmt) await conn.execute(stmt)
@ -320,7 +324,7 @@ async def remove(item: Model) -> None:
table: sa.Table = item.__table__ table: sa.Table = item.__table__
values = asplain(item, filter_fields={"id"}, serialize=True) values = asplain(item, filter_fields={"id"}, serialize=True)
stmt = table.delete().where(table.c.id == values["id"]) stmt = table.delete().where(table.c.id == values["id"])
async with locked_connection() as conn: async with _locked_connection() as conn:
await conn.execute(stmt) await conn.execute(stmt)
@ -471,7 +475,7 @@ async def find_ratings(
) )
.limit(limit_rows) .limit(limit_rows)
) )
async with locked_connection() as conn: async with _locked_connection() as conn:
rating_rows: AsyncGenerator[Rating, None] = conn.iterate(query) # type: ignore rating_rows: AsyncGenerator[Rating, None] = conn.iterate(query) # type: ignore
movie_ids = [r.movie_id async for r in rating_rows] movie_ids = [r.movie_id async for r in rating_rows]
@ -487,7 +491,7 @@ async def find_ratings(
) )
.limit(limit_rows - len(movie_ids)) .limit(limit_rows - len(movie_ids))
) )
async with locked_connection() as conn: async with _locked_connection() as conn:
movie_rows: AsyncGenerator[Movie, None] = conn.iterate(query) # type: ignore movie_rows: AsyncGenerator[Movie, None] = conn.iterate(query) # type: ignore
movie_ids += [r.id async for r in movie_rows] movie_ids += [r.id async for r in movie_rows]
@ -523,7 +527,7 @@ async def ratings_for_movie_ids(
.outerjoin_from(movies, ratings, movies.c.id == ratings.c.movie_id) .outerjoin_from(movies, ratings, movies.c.id == ratings.c.movie_id)
.where(sa.or_(*conds)) .where(sa.or_(*conds))
) )
async with locked_connection() as conn: async with _locked_connection() as conn:
rows = await conn.fetch_all(query) rows = await conn.fetch_all(query)
return tuple(dict(r._mapping) for r in rows) return tuple(dict(r._mapping) for r in rows)
@ -538,7 +542,7 @@ async def ratings_for_movies(
query = sa.select(ratings).where(*conditions) query = sa.select(ratings).where(*conditions)
async with locked_connection() as conn: async with _locked_connection() as conn:
rows = await conn.fetch_all(query) rows = await conn.fetch_all(query)
return (fromplain(Rating, row._mapping, serialized=True) for row in rows) return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
@ -602,7 +606,7 @@ async def find_movies(
.offset(skip_rows) .offset(skip_rows)
) )
async with locked_connection() as conn: async with _locked_connection() as conn:
rows = await conn.fetch_all(query) rows = await conn.fetch_all(query)
movies_ = [fromplain(Movie, row._mapping, serialized=True) for row in rows] movies_ = [fromplain(Movie, row._mapping, serialized=True) for row in rows]

View file

@ -445,7 +445,7 @@ async def remove_user(request):
if not user: if not user:
return not_found() return not_found()
async with db.shared_connection().transaction(): async with db.transaction():
# XXX remove user refs from groups and ratings # XXX remove user refs from groups and ratings
await db.remove(user) await db.remove(user)