migrate db.current_patch_level to SQLAlchemy

This commit is contained in:
ducklet 2023-03-28 23:03:35 +02:00
parent e27b57050a
commit 37e8d53b78
3 changed files with 47 additions and 29 deletions

View file

@ -20,6 +20,15 @@ def a_movie(**kwds) -> models.Movie:
return models.Movie(**args) return models.Movie(**args)
@pytest.mark.asyncio
async def test_current_patch_level(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True):
patch_level = "some-patch-level"
assert patch_level != await db.current_patch_level(shared_conn)
await db.set_current_patch_level(shared_conn, patch_level)
assert patch_level == await db.current_patch_level(shared_conn)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get(shared_conn: db.Database): async def test_get(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True): async with shared_conn.transaction(force_rollback=True):

View file

@ -7,6 +7,7 @@ from typing import Any, Iterable, Literal, Type, TypeVar
import sqlalchemy as sa import sqlalchemy as sa
from databases import Database from databases import Database
from sqlalchemy.dialects.sqlite import insert
from . import config from . import config
from .models import ( from .models import (
@ -15,7 +16,9 @@ from .models import (
Rating, Rating,
User, User,
asplain, asplain,
db_patches,
fromplain, fromplain,
metadata,
movies, movies,
optional_fields, optional_fields,
ratings, ratings,
@ -50,38 +53,22 @@ async def close_connection_pool() -> None:
# Run automatic ANALYZE prior to closing the db, # Run automatic ANALYZE prior to closing the db,
# see https://sqlite.com/lang_analyze.html. # see https://sqlite.com/lang_analyze.html.
await db.execute("PRAGMA analysis_limit=400") await db.execute(sa.text("PRAGMA analysis_limit=400"))
await db.execute("PRAGMA optimize") await db.execute(sa.text("PRAGMA optimize"))
await db.disconnect() await db.disconnect()
async def _create_patch_db(db): async def current_patch_level(db: Database) -> str:
query = """ query = sa.select(db_patches.c.current)
CREATE TABLE IF NOT EXISTS db_patches (
id INTEGER PRIMARY KEY,
current TEXT
)
"""
await db.execute(query)
async def current_patch_level(db) -> str:
await _create_patch_db(db)
query = "SELECT current FROM db_patches"
current = await db.fetch_val(query) current = await db.fetch_val(query)
return current or "" return current or ""
async def set_current_patch_level(db, current: str): async def set_current_patch_level(db: Database, current: str) -> None:
await _create_patch_db(db) stmt = insert(db_patches).values(id=1, current=current)
stmt = stmt.on_conflict_do_update(set_={"current": stmt.excluded.current})
query = """ await db.execute(stmt)
INSERT INTO db_patches VALUES (1, :current)
ON CONFLICT DO UPDATE SET current=excluded.current
"""
await db.execute(query, values={"current": current})
db_patches_dir = Path(__file__).parent / "sql" db_patches_dir = Path(__file__).parent / "sql"
@ -222,8 +209,12 @@ def shared_connection() -> Database:
if _shared_connection is None: if _shared_connection is None:
uri = f"sqlite:///{config.storage_path}" uri = f"sqlite:///{config.storage_path}"
# uri = f"sqlite+aiosqlite:///{config.storage_path}"
_shared_connection = Database(uri) _shared_connection = Database(uri)
engine = sa.create_engine(uri, future=True)
metadata.create_all(engine, tables=[db_patches])
return _shared_connection return _shared_connection

View file

@ -30,6 +30,7 @@ JSONObject = dict[str, JSON]
T = TypeVar("T") T = TypeVar("T")
mapper_registry = registry() mapper_registry = registry()
metadata = mapper_registry.metadata
def annotations(tp: Type) -> tuple | None: def annotations(tp: Type) -> tuple | None:
@ -203,12 +204,29 @@ def utcnow():
return datetime.utcnow().replace(tzinfo=timezone.utc) return datetime.utcnow().replace(tzinfo=timezone.utc)
@mapper_registry.mapped
@dataclass
class DbPatch:
__table__ = Table(
"db_patches",
metadata,
Column("id", Integer, primary_key=True),
Column("current", String),
)
id: int
current: str
db_patches = DbPatch.__table__
@mapper_registry.mapped @mapper_registry.mapped
@dataclass @dataclass
class Progress: class Progress:
__table__ = Table( __table__ = Table(
"progress", "progress",
mapper_registry.metadata, metadata,
Column("id", String, primary_key=True), # ULID Column("id", String, primary_key=True), # ULID
Column("type", String, nullable=False), Column("type", String, nullable=False),
Column("state", String, nullable=False), # JSON {"percent": ..., "error": ...} Column("state", String, nullable=False), # JSON {"percent": ..., "error": ...}
@ -258,7 +276,7 @@ class Progress:
class Movie: class Movie:
__table__ = Table( __table__ = Table(
"movies", "movies",
mapper_registry.metadata, metadata,
Column("id", String, primary_key=True), # ULID Column("id", String, primary_key=True), # ULID
Column("title", String, nullable=False), Column("title", String, nullable=False),
Column("original_title", String), Column("original_title", String),
@ -336,7 +354,7 @@ Relation = Annotated[T | None, _RelationSentinel]
class Rating: class Rating:
__table__ = Table( __table__ = Table(
"ratings", "ratings",
mapper_registry.metadata, metadata,
Column("id", String, primary_key=True), # ULID Column("id", String, primary_key=True), # ULID
Column("movie_id", ForeignKey("movies.id"), nullable=False), # ULID Column("movie_id", ForeignKey("movies.id"), nullable=False), # ULID
Column("user_id", ForeignKey("users.id"), nullable=False), # ULID Column("user_id", ForeignKey("users.id"), nullable=False), # ULID
@ -393,7 +411,7 @@ class UserGroup(TypedDict):
class User: class User:
__table__ = Table( __table__ = Table(
"users", "users",
mapper_registry.metadata, metadata,
Column("id", String, primary_key=True), # ULID Column("id", String, primary_key=True), # ULID
Column("imdb_id", String, nullable=False, unique=True), Column("imdb_id", String, nullable=False, unique=True),
Column("name", String, nullable=False), Column("name", String, nullable=False),
@ -433,7 +451,7 @@ class GroupUser(TypedDict):
class Group: class Group:
__table__ = Table( __table__ = Table(
"groups", "groups",
mapper_registry.metadata, metadata,
Column("id", String, primary_key=True), # ULID Column("id", String, primary_key=True), # ULID
Column("name", String, nullable=False), Column("name", String, nullable=False),
Column("users", String, nullable=False), # JSON array Column("users", String, nullable=False), # JSON array