diff --git a/tests/test_db.py b/tests/test_db.py index d476408..7c4c96e 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -20,6 +20,15 @@ def a_movie(**kwds) -> models.Movie: return models.Movie(**args) +@pytest.mark.asyncio +async def test_current_patch_level(shared_conn: db.Database): + async with shared_conn.transaction(force_rollback=True): + patch_level = "some-patch-level" + assert patch_level != await db.current_patch_level(shared_conn) + await db.set_current_patch_level(shared_conn, patch_level) + assert patch_level == await db.current_patch_level(shared_conn) + + @pytest.mark.asyncio async def test_get(shared_conn: db.Database): async with shared_conn.transaction(force_rollback=True): diff --git a/unwind/db.py b/unwind/db.py index 6c21a9d..4b90085 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -7,6 +7,7 @@ from typing import Any, Iterable, Literal, Type, TypeVar import sqlalchemy as sa from databases import Database +from sqlalchemy.dialects.sqlite import insert from . import config from .models import ( @@ -15,7 +16,9 @@ from .models import ( Rating, User, asplain, + db_patches, fromplain, + metadata, movies, optional_fields, ratings, @@ -50,38 +53,22 @@ async def close_connection_pool() -> None: # Run automatic ANALYZE prior to closing the db, # see https://sqlite.com/lang_analyze.html. - await db.execute("PRAGMA analysis_limit=400") - await db.execute("PRAGMA optimize") + await db.execute(sa.text("PRAGMA analysis_limit=400")) + await db.execute(sa.text("PRAGMA optimize")) await db.disconnect() -async def _create_patch_db(db): - query = """ - CREATE TABLE IF NOT EXISTS db_patches ( - id INTEGER PRIMARY KEY, - current TEXT - ) - """ - await db.execute(query) - - -async def current_patch_level(db) -> str: - await _create_patch_db(db) - - query = "SELECT current FROM db_patches" +async def current_patch_level(db: Database) -> str: + query = sa.select(db_patches.c.current) current = await db.fetch_val(query) return current or "" -async def set_current_patch_level(db, current: str): - await _create_patch_db(db) - - query = """ - INSERT INTO db_patches VALUES (1, :current) - ON CONFLICT DO UPDATE SET current=excluded.current - """ - await db.execute(query, values={"current": current}) +async def set_current_patch_level(db: Database, current: str) -> None: + stmt = insert(db_patches).values(id=1, current=current) + stmt = stmt.on_conflict_do_update(set_={"current": stmt.excluded.current}) + await db.execute(stmt) db_patches_dir = Path(__file__).parent / "sql" @@ -222,8 +209,12 @@ def shared_connection() -> Database: if _shared_connection is None: uri = f"sqlite:///{config.storage_path}" + # uri = f"sqlite+aiosqlite:///{config.storage_path}" _shared_connection = Database(uri) + engine = sa.create_engine(uri, future=True) + metadata.create_all(engine, tables=[db_patches]) + return _shared_connection diff --git a/unwind/models.py b/unwind/models.py index 6d40b35..5628bb0 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -30,6 +30,7 @@ JSONObject = dict[str, JSON] T = TypeVar("T") mapper_registry = registry() +metadata = mapper_registry.metadata def annotations(tp: Type) -> tuple | None: @@ -203,12 +204,29 @@ def utcnow(): return datetime.utcnow().replace(tzinfo=timezone.utc) +@mapper_registry.mapped +@dataclass +class DbPatch: + __table__ = Table( + "db_patches", + metadata, + Column("id", Integer, primary_key=True), + Column("current", String), + ) + + id: int + current: str + + +db_patches = DbPatch.__table__ + + @mapper_registry.mapped @dataclass class Progress: __table__ = Table( "progress", - mapper_registry.metadata, + metadata, Column("id", String, primary_key=True), # ULID Column("type", String, nullable=False), Column("state", String, nullable=False), # JSON {"percent": ..., "error": ...} @@ -258,7 +276,7 @@ class Progress: class Movie: __table__ = Table( "movies", - mapper_registry.metadata, + metadata, Column("id", String, primary_key=True), # ULID Column("title", String, nullable=False), Column("original_title", String), @@ -336,7 +354,7 @@ Relation = Annotated[T | None, _RelationSentinel] class Rating: __table__ = Table( "ratings", - mapper_registry.metadata, + metadata, Column("id", String, primary_key=True), # ULID Column("movie_id", ForeignKey("movies.id"), nullable=False), # ULID Column("user_id", ForeignKey("users.id"), nullable=False), # ULID @@ -393,7 +411,7 @@ class UserGroup(TypedDict): class User: __table__ = Table( "users", - mapper_registry.metadata, + metadata, Column("id", String, primary_key=True), # ULID Column("imdb_id", String, nullable=False, unique=True), Column("name", String, nullable=False), @@ -433,7 +451,7 @@ class GroupUser(TypedDict): class Group: __table__ = Table( "groups", - mapper_registry.metadata, + metadata, Column("id", String, primary_key=True), # ULID Column("name", String, nullable=False), Column("users", String, nullable=False), # JSON array