migrate db.current_patch_level to SQLAlchemy
This commit is contained in:
parent
e27b57050a
commit
37e8d53b78
3 changed files with 47 additions and 29 deletions
|
|
@ -20,6 +20,15 @@ def a_movie(**kwds) -> models.Movie:
|
|||
return models.Movie(**args)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_current_patch_level(shared_conn: db.Database):
|
||||
async with shared_conn.transaction(force_rollback=True):
|
||||
patch_level = "some-patch-level"
|
||||
assert patch_level != await db.current_patch_level(shared_conn)
|
||||
await db.set_current_patch_level(shared_conn, patch_level)
|
||||
assert patch_level == await db.current_patch_level(shared_conn)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get(shared_conn: db.Database):
|
||||
async with shared_conn.transaction(force_rollback=True):
|
||||
|
|
|
|||
39
unwind/db.py
39
unwind/db.py
|
|
@ -7,6 +7,7 @@ from typing import Any, Iterable, Literal, Type, TypeVar
|
|||
|
||||
import sqlalchemy as sa
|
||||
from databases import Database
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
|
||||
from . import config
|
||||
from .models import (
|
||||
|
|
@ -15,7 +16,9 @@ from .models import (
|
|||
Rating,
|
||||
User,
|
||||
asplain,
|
||||
db_patches,
|
||||
fromplain,
|
||||
metadata,
|
||||
movies,
|
||||
optional_fields,
|
||||
ratings,
|
||||
|
|
@ -50,38 +53,22 @@ async def close_connection_pool() -> None:
|
|||
|
||||
# Run automatic ANALYZE prior to closing the db,
|
||||
# see https://sqlite.com/lang_analyze.html.
|
||||
await db.execute("PRAGMA analysis_limit=400")
|
||||
await db.execute("PRAGMA optimize")
|
||||
await db.execute(sa.text("PRAGMA analysis_limit=400"))
|
||||
await db.execute(sa.text("PRAGMA optimize"))
|
||||
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
async def _create_patch_db(db):
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS db_patches (
|
||||
id INTEGER PRIMARY KEY,
|
||||
current TEXT
|
||||
)
|
||||
"""
|
||||
await db.execute(query)
|
||||
|
||||
|
||||
async def current_patch_level(db) -> str:
|
||||
await _create_patch_db(db)
|
||||
|
||||
query = "SELECT current FROM db_patches"
|
||||
async def current_patch_level(db: Database) -> str:
|
||||
query = sa.select(db_patches.c.current)
|
||||
current = await db.fetch_val(query)
|
||||
return current or ""
|
||||
|
||||
|
||||
async def set_current_patch_level(db, current: str):
|
||||
await _create_patch_db(db)
|
||||
|
||||
query = """
|
||||
INSERT INTO db_patches VALUES (1, :current)
|
||||
ON CONFLICT DO UPDATE SET current=excluded.current
|
||||
"""
|
||||
await db.execute(query, values={"current": current})
|
||||
async def set_current_patch_level(db: Database, current: str) -> None:
|
||||
stmt = insert(db_patches).values(id=1, current=current)
|
||||
stmt = stmt.on_conflict_do_update(set_={"current": stmt.excluded.current})
|
||||
await db.execute(stmt)
|
||||
|
||||
|
||||
db_patches_dir = Path(__file__).parent / "sql"
|
||||
|
|
@ -222,8 +209,12 @@ def shared_connection() -> Database:
|
|||
|
||||
if _shared_connection is None:
|
||||
uri = f"sqlite:///{config.storage_path}"
|
||||
# uri = f"sqlite+aiosqlite:///{config.storage_path}"
|
||||
_shared_connection = Database(uri)
|
||||
|
||||
engine = sa.create_engine(uri, future=True)
|
||||
metadata.create_all(engine, tables=[db_patches])
|
||||
|
||||
return _shared_connection
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ JSONObject = dict[str, JSON]
|
|||
T = TypeVar("T")
|
||||
|
||||
mapper_registry = registry()
|
||||
metadata = mapper_registry.metadata
|
||||
|
||||
|
||||
def annotations(tp: Type) -> tuple | None:
|
||||
|
|
@ -203,12 +204,29 @@ def utcnow():
|
|||
return datetime.utcnow().replace(tzinfo=timezone.utc)
|
||||
|
||||
|
||||
@mapper_registry.mapped
|
||||
@dataclass
|
||||
class DbPatch:
|
||||
__table__ = Table(
|
||||
"db_patches",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("current", String),
|
||||
)
|
||||
|
||||
id: int
|
||||
current: str
|
||||
|
||||
|
||||
db_patches = DbPatch.__table__
|
||||
|
||||
|
||||
@mapper_registry.mapped
|
||||
@dataclass
|
||||
class Progress:
|
||||
__table__ = Table(
|
||||
"progress",
|
||||
mapper_registry.metadata,
|
||||
metadata,
|
||||
Column("id", String, primary_key=True), # ULID
|
||||
Column("type", String, nullable=False),
|
||||
Column("state", String, nullable=False), # JSON {"percent": ..., "error": ...}
|
||||
|
|
@ -258,7 +276,7 @@ class Progress:
|
|||
class Movie:
|
||||
__table__ = Table(
|
||||
"movies",
|
||||
mapper_registry.metadata,
|
||||
metadata,
|
||||
Column("id", String, primary_key=True), # ULID
|
||||
Column("title", String, nullable=False),
|
||||
Column("original_title", String),
|
||||
|
|
@ -336,7 +354,7 @@ Relation = Annotated[T | None, _RelationSentinel]
|
|||
class Rating:
|
||||
__table__ = Table(
|
||||
"ratings",
|
||||
mapper_registry.metadata,
|
||||
metadata,
|
||||
Column("id", String, primary_key=True), # ULID
|
||||
Column("movie_id", ForeignKey("movies.id"), nullable=False), # ULID
|
||||
Column("user_id", ForeignKey("users.id"), nullable=False), # ULID
|
||||
|
|
@ -393,7 +411,7 @@ class UserGroup(TypedDict):
|
|||
class User:
|
||||
__table__ = Table(
|
||||
"users",
|
||||
mapper_registry.metadata,
|
||||
metadata,
|
||||
Column("id", String, primary_key=True), # ULID
|
||||
Column("imdb_id", String, nullable=False, unique=True),
|
||||
Column("name", String, nullable=False),
|
||||
|
|
@ -433,7 +451,7 @@ class GroupUser(TypedDict):
|
|||
class Group:
|
||||
__table__ = Table(
|
||||
"groups",
|
||||
mapper_registry.metadata,
|
||||
metadata,
|
||||
Column("id", String, primary_key=True), # ULID
|
||||
Column("name", String, nullable=False),
|
||||
Column("users", String, nullable=False), # JSON array
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue