diff --git a/tests/conftest.py b/tests/conftest.py index e57d3e1..470bc4d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ def event_loop(): @pytest_asyncio.fixture(scope="session") async def shared_conn(): - c = db.shared_connection() + c = db._shared_connection() await c.connect() await db.apply_db_patches(c) diff --git a/unwind/db.py b/unwind/db.py index 278c0c0..51e66d5 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -31,7 +31,7 @@ from .types import ULID log = logging.getLogger(__name__) T = TypeVar("T") -_shared_connection: Database | None = None +_database: Database | None = None async def open_connection_pool() -> None: @@ -39,7 +39,7 @@ async def open_connection_pool() -> None: This function needs to be called before any access to the database can happen. """ - db = shared_connection() + db = _shared_connection() await db.connect() await db.execute(sa.text("PRAGMA journal_mode=WAL")) @@ -53,7 +53,7 @@ async def close_connection_pool() -> None: This function should be called before the app shuts down to ensure all data has been flushed to the database. """ - db = shared_connection() + db = _shared_connection() # Run automatic ANALYZE prior to closing the db, # see https://sqlite.com/lang_analyze.html. @@ -205,23 +205,27 @@ async def single_threaded(): @contextlib.asynccontextmanager -async def locked_connection(): +async def _locked_connection(): async with single_threaded(): - yield shared_connection() + yield _shared_connection() -def shared_connection() -> Database: - global _shared_connection +def _shared_connection() -> Database: + global _database - if _shared_connection is None: + if _database is None: uri = f"sqlite:///{config.storage_path}" # uri = f"sqlite+aiosqlite:///{config.storage_path}" - _shared_connection = Database(uri) + _database = Database(uri) engine = sa.create_engine(uri, future=True) metadata.create_all(engine, tables=[db_patches]) - return _shared_connection + return _database + + +def transaction(): + return _shared_connection().transaction() async def add(item: Model) -> None: @@ -233,7 +237,7 @@ async def add(item: Model) -> None: table: sa.Table = item.__table__ values = asplain(item, serialize=True) stmt = table.insert().values(values) - async with locked_connection() as conn: + async with _locked_connection() as conn: await conn.execute(stmt) @@ -264,7 +268,7 @@ async def get( query = query.order_by( order_col.asc() if order_dir == "asc" else order_col.desc() ) - async with locked_connection() as conn: + async with _locked_connection() as conn: row = await conn.fetch_one(query) return fromplain(model, row._mapping, serialized=True) if row else None @@ -284,7 +288,7 @@ async def get_many( table: sa.Table = model.__table__ query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items())) - async with locked_connection() as conn: + async with _locked_connection() as conn: rows = await conn.fetch_all(query) return (fromplain(model, row._mapping, serialized=True) for row in rows) @@ -298,7 +302,7 @@ async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType] query = sa.select(model).where( *(table.c[k] == v for k, v in field_values.items() if v is not None) ) - async with locked_connection() as conn: + async with _locked_connection() as conn: rows = await conn.fetch_all(query) return (fromplain(model, row._mapping, serialized=True) for row in rows) @@ -312,7 +316,7 @@ async def update(item: Model) -> None: table: sa.Table = item.__table__ values = asplain(item, serialize=True) stmt = table.update().where(table.c.id == values["id"]).values(values) - async with locked_connection() as conn: + async with _locked_connection() as conn: await conn.execute(stmt) @@ -320,7 +324,7 @@ async def remove(item: Model) -> None: table: sa.Table = item.__table__ values = asplain(item, filter_fields={"id"}, serialize=True) stmt = table.delete().where(table.c.id == values["id"]) - async with locked_connection() as conn: + async with _locked_connection() as conn: await conn.execute(stmt) @@ -471,7 +475,7 @@ async def find_ratings( ) .limit(limit_rows) ) - async with locked_connection() as conn: + async with _locked_connection() as conn: rating_rows: AsyncGenerator[Rating, None] = conn.iterate(query) # type: ignore movie_ids = [r.movie_id async for r in rating_rows] @@ -487,7 +491,7 @@ async def find_ratings( ) .limit(limit_rows - len(movie_ids)) ) - async with locked_connection() as conn: + async with _locked_connection() as conn: movie_rows: AsyncGenerator[Movie, None] = conn.iterate(query) # type: ignore movie_ids += [r.id async for r in movie_rows] @@ -523,7 +527,7 @@ async def ratings_for_movie_ids( .outerjoin_from(movies, ratings, movies.c.id == ratings.c.movie_id) .where(sa.or_(*conds)) ) - async with locked_connection() as conn: + async with _locked_connection() as conn: rows = await conn.fetch_all(query) return tuple(dict(r._mapping) for r in rows) @@ -538,7 +542,7 @@ async def ratings_for_movies( query = sa.select(ratings).where(*conditions) - async with locked_connection() as conn: + async with _locked_connection() as conn: rows = await conn.fetch_all(query) return (fromplain(Rating, row._mapping, serialized=True) for row in rows) @@ -602,7 +606,7 @@ async def find_movies( .offset(skip_rows) ) - async with locked_connection() as conn: + async with _locked_connection() as conn: rows = await conn.fetch_all(query) movies_ = [fromplain(Movie, row._mapping, serialized=True) for row in rows] diff --git a/unwind/web.py b/unwind/web.py index 3ebbcdc..bddd54d 100644 --- a/unwind/web.py +++ b/unwind/web.py @@ -445,7 +445,7 @@ async def remove_user(request): if not user: return not_found() - async with db.shared_connection().transaction(): + async with db.transaction(): # XXX remove user refs from groups and ratings await db.remove(user)