make the shared connection internal to the db module
This should make it easier to refactor the code for removing the databases package.
This commit is contained in:
parent
22ea553f48
commit
1f42538481
3 changed files with 27 additions and 23 deletions
|
|
@ -17,7 +17,7 @@ def event_loop():
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def shared_conn():
|
async def shared_conn():
|
||||||
c = db.shared_connection()
|
c = db._shared_connection()
|
||||||
await c.connect()
|
await c.connect()
|
||||||
|
|
||||||
await db.apply_db_patches(c)
|
await db.apply_db_patches(c)
|
||||||
|
|
|
||||||
46
unwind/db.py
46
unwind/db.py
|
|
@ -31,7 +31,7 @@ from .types import ULID
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
_shared_connection: Database | None = None
|
_database: Database | None = None
|
||||||
|
|
||||||
|
|
||||||
async def open_connection_pool() -> None:
|
async def open_connection_pool() -> None:
|
||||||
|
|
@ -39,7 +39,7 @@ async def open_connection_pool() -> None:
|
||||||
|
|
||||||
This function needs to be called before any access to the database can happen.
|
This function needs to be called before any access to the database can happen.
|
||||||
"""
|
"""
|
||||||
db = shared_connection()
|
db = _shared_connection()
|
||||||
await db.connect()
|
await db.connect()
|
||||||
|
|
||||||
await db.execute(sa.text("PRAGMA journal_mode=WAL"))
|
await db.execute(sa.text("PRAGMA journal_mode=WAL"))
|
||||||
|
|
@ -53,7 +53,7 @@ async def close_connection_pool() -> None:
|
||||||
This function should be called before the app shuts down to ensure all data
|
This function should be called before the app shuts down to ensure all data
|
||||||
has been flushed to the database.
|
has been flushed to the database.
|
||||||
"""
|
"""
|
||||||
db = shared_connection()
|
db = _shared_connection()
|
||||||
|
|
||||||
# Run automatic ANALYZE prior to closing the db,
|
# Run automatic ANALYZE prior to closing the db,
|
||||||
# see https://sqlite.com/lang_analyze.html.
|
# see https://sqlite.com/lang_analyze.html.
|
||||||
|
|
@ -205,23 +205,27 @@ async def single_threaded():
|
||||||
|
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def locked_connection():
|
async def _locked_connection():
|
||||||
async with single_threaded():
|
async with single_threaded():
|
||||||
yield shared_connection()
|
yield _shared_connection()
|
||||||
|
|
||||||
|
|
||||||
def shared_connection() -> Database:
|
def _shared_connection() -> Database:
|
||||||
global _shared_connection
|
global _database
|
||||||
|
|
||||||
if _shared_connection is None:
|
if _database is None:
|
||||||
uri = f"sqlite:///{config.storage_path}"
|
uri = f"sqlite:///{config.storage_path}"
|
||||||
# uri = f"sqlite+aiosqlite:///{config.storage_path}"
|
# uri = f"sqlite+aiosqlite:///{config.storage_path}"
|
||||||
_shared_connection = Database(uri)
|
_database = Database(uri)
|
||||||
|
|
||||||
engine = sa.create_engine(uri, future=True)
|
engine = sa.create_engine(uri, future=True)
|
||||||
metadata.create_all(engine, tables=[db_patches])
|
metadata.create_all(engine, tables=[db_patches])
|
||||||
|
|
||||||
return _shared_connection
|
return _database
|
||||||
|
|
||||||
|
|
||||||
|
def transaction():
|
||||||
|
return _shared_connection().transaction()
|
||||||
|
|
||||||
|
|
||||||
async def add(item: Model) -> None:
|
async def add(item: Model) -> None:
|
||||||
|
|
@ -233,7 +237,7 @@ async def add(item: Model) -> None:
|
||||||
table: sa.Table = item.__table__
|
table: sa.Table = item.__table__
|
||||||
values = asplain(item, serialize=True)
|
values = asplain(item, serialize=True)
|
||||||
stmt = table.insert().values(values)
|
stmt = table.insert().values(values)
|
||||||
async with locked_connection() as conn:
|
async with _locked_connection() as conn:
|
||||||
await conn.execute(stmt)
|
await conn.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -264,7 +268,7 @@ async def get(
|
||||||
query = query.order_by(
|
query = query.order_by(
|
||||||
order_col.asc() if order_dir == "asc" else order_col.desc()
|
order_col.asc() if order_dir == "asc" else order_col.desc()
|
||||||
)
|
)
|
||||||
async with locked_connection() as conn:
|
async with _locked_connection() as conn:
|
||||||
row = await conn.fetch_one(query)
|
row = await conn.fetch_one(query)
|
||||||
return fromplain(model, row._mapping, serialized=True) if row else None
|
return fromplain(model, row._mapping, serialized=True) if row else None
|
||||||
|
|
||||||
|
|
@ -284,7 +288,7 @@ async def get_many(
|
||||||
|
|
||||||
table: sa.Table = model.__table__
|
table: sa.Table = model.__table__
|
||||||
query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items()))
|
query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items()))
|
||||||
async with locked_connection() as conn:
|
async with _locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(query)
|
rows = await conn.fetch_all(query)
|
||||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
@ -298,7 +302,7 @@ async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]
|
||||||
query = sa.select(model).where(
|
query = sa.select(model).where(
|
||||||
*(table.c[k] == v for k, v in field_values.items() if v is not None)
|
*(table.c[k] == v for k, v in field_values.items() if v is not None)
|
||||||
)
|
)
|
||||||
async with locked_connection() as conn:
|
async with _locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(query)
|
rows = await conn.fetch_all(query)
|
||||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
@ -312,7 +316,7 @@ async def update(item: Model) -> None:
|
||||||
table: sa.Table = item.__table__
|
table: sa.Table = item.__table__
|
||||||
values = asplain(item, serialize=True)
|
values = asplain(item, serialize=True)
|
||||||
stmt = table.update().where(table.c.id == values["id"]).values(values)
|
stmt = table.update().where(table.c.id == values["id"]).values(values)
|
||||||
async with locked_connection() as conn:
|
async with _locked_connection() as conn:
|
||||||
await conn.execute(stmt)
|
await conn.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -320,7 +324,7 @@ async def remove(item: Model) -> None:
|
||||||
table: sa.Table = item.__table__
|
table: sa.Table = item.__table__
|
||||||
values = asplain(item, filter_fields={"id"}, serialize=True)
|
values = asplain(item, filter_fields={"id"}, serialize=True)
|
||||||
stmt = table.delete().where(table.c.id == values["id"])
|
stmt = table.delete().where(table.c.id == values["id"])
|
||||||
async with locked_connection() as conn:
|
async with _locked_connection() as conn:
|
||||||
await conn.execute(stmt)
|
await conn.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -471,7 +475,7 @@ async def find_ratings(
|
||||||
)
|
)
|
||||||
.limit(limit_rows)
|
.limit(limit_rows)
|
||||||
)
|
)
|
||||||
async with locked_connection() as conn:
|
async with _locked_connection() as conn:
|
||||||
rating_rows: AsyncGenerator[Rating, None] = conn.iterate(query) # type: ignore
|
rating_rows: AsyncGenerator[Rating, None] = conn.iterate(query) # type: ignore
|
||||||
movie_ids = [r.movie_id async for r in rating_rows]
|
movie_ids = [r.movie_id async for r in rating_rows]
|
||||||
|
|
||||||
|
|
@ -487,7 +491,7 @@ async def find_ratings(
|
||||||
)
|
)
|
||||||
.limit(limit_rows - len(movie_ids))
|
.limit(limit_rows - len(movie_ids))
|
||||||
)
|
)
|
||||||
async with locked_connection() as conn:
|
async with _locked_connection() as conn:
|
||||||
movie_rows: AsyncGenerator[Movie, None] = conn.iterate(query) # type: ignore
|
movie_rows: AsyncGenerator[Movie, None] = conn.iterate(query) # type: ignore
|
||||||
movie_ids += [r.id async for r in movie_rows]
|
movie_ids += [r.id async for r in movie_rows]
|
||||||
|
|
||||||
|
|
@ -523,7 +527,7 @@ async def ratings_for_movie_ids(
|
||||||
.outerjoin_from(movies, ratings, movies.c.id == ratings.c.movie_id)
|
.outerjoin_from(movies, ratings, movies.c.id == ratings.c.movie_id)
|
||||||
.where(sa.or_(*conds))
|
.where(sa.or_(*conds))
|
||||||
)
|
)
|
||||||
async with locked_connection() as conn:
|
async with _locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(query)
|
rows = await conn.fetch_all(query)
|
||||||
return tuple(dict(r._mapping) for r in rows)
|
return tuple(dict(r._mapping) for r in rows)
|
||||||
|
|
||||||
|
|
@ -538,7 +542,7 @@ async def ratings_for_movies(
|
||||||
|
|
||||||
query = sa.select(ratings).where(*conditions)
|
query = sa.select(ratings).where(*conditions)
|
||||||
|
|
||||||
async with locked_connection() as conn:
|
async with _locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(query)
|
rows = await conn.fetch_all(query)
|
||||||
|
|
||||||
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
|
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
@ -602,7 +606,7 @@ async def find_movies(
|
||||||
.offset(skip_rows)
|
.offset(skip_rows)
|
||||||
)
|
)
|
||||||
|
|
||||||
async with locked_connection() as conn:
|
async with _locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(query)
|
rows = await conn.fetch_all(query)
|
||||||
|
|
||||||
movies_ = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
|
movies_ = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
|
||||||
|
|
|
||||||
|
|
@ -445,7 +445,7 @@ async def remove_user(request):
|
||||||
if not user:
|
if not user:
|
||||||
return not_found()
|
return not_found()
|
||||||
|
|
||||||
async with db.shared_connection().transaction():
|
async with db.transaction():
|
||||||
# XXX remove user refs from groups and ratings
|
# XXX remove user refs from groups and ratings
|
||||||
|
|
||||||
await db.remove(user)
|
await db.remove(user)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue