From 2963a1d3f6ef916d8926e389a8adf4424bda6fb4 Mon Sep 17 00:00:00 2001 From: ducklet Date: Tue, 28 Mar 2023 23:32:24 +0200 Subject: [PATCH] improve strict typing --- unwind/db.py | 23 +++++++++++++---------- unwind/models.py | 18 ++++++++++++------ 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/unwind/db.py b/unwind/db.py index fb0cabe..63c042a 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -11,6 +11,7 @@ from sqlalchemy.dialects.sqlite import insert from . import config from .models import ( + Model, Movie, Progress, Rating, @@ -76,7 +77,7 @@ async def set_current_patch_level(db: Database, current: str) -> None: db_patches_dir = Path(__file__).parent / "sql" -async def apply_db_patches(db: Database): +async def apply_db_patches(db: Database) -> None: """Apply all remaining patches to the database. Beware that patches will be applied in lexicographical order, @@ -124,7 +125,7 @@ async def get_import_progress() -> Progress | None: return await get(Progress, type="import-imdb-movies", order_by=("started", "desc")) -async def stop_import_progress(*, error: BaseException | None = None): +async def stop_import_progress(*, error: BaseException | None = None) -> None: """Stop the current import. If an error is given, it will be logged to the progress state. @@ -220,9 +221,10 @@ def shared_connection() -> Database: return _shared_connection -async def add(item): +async def add(item: Model) -> None: # Support late initializing - used for optimization. if getattr(item, "_is_lazy", False): + assert hasattr(item, "_lazy_init") item._lazy_init() table: sa.Table = item.__table__ @@ -232,7 +234,7 @@ async def add(item): await conn.execute(stmt) -ModelType = TypeVar("ModelType") +ModelType = TypeVar("ModelType", bound=Model) async def get( @@ -300,9 +302,10 @@ async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType] return (fromplain(model, row._mapping, serialized=True) for row in rows) -async def update(item): +async def update(item: Model) -> None: # Support late initializing - used for optimization. if getattr(item, "_is_lazy", False): + assert hasattr(item, "_lazy_init") item._lazy_init() table: sa.Table = item.__table__ @@ -312,7 +315,7 @@ async def update(item): await conn.execute(stmt) -async def remove(item): +async def remove(item: Model) -> None: table: sa.Table = item.__table__ values = asplain(item, filter_fields={"id"}, serialize=True) stmt = table.delete().where(table.c.id == values["id"]) @@ -320,7 +323,7 @@ async def remove(item): await conn.execute(stmt) -async def add_or_update_user(user: User): +async def add_or_update_user(user: User) -> None: db_user = await get(User, imdb_id=user.imdb_id) if not db_user: await add(user) @@ -331,7 +334,7 @@ async def add_or_update_user(user: User): await update(user) -async def add_or_update_many_movies(movies: list[Movie]): +async def add_or_update_many_movies(movies: list[Movie]) -> None: """Add or update Movies in the database. This is an optimized version of `add_or_update_movie` for the purpose @@ -361,7 +364,7 @@ async def add_or_update_many_movies(movies: list[Movie]): await update(movie) -async def add_or_update_movie(movie: Movie): +async def add_or_update_movie(movie: Movie) -> None: """Add or update a Movie in the database. This is an upsert operation, but it will also update the Movie you pass @@ -419,7 +422,7 @@ async def find_ratings( yearcomp: tuple[Literal["<", "=", ">"], int] | None = None, limit_rows: int = 10, user_ids: Iterable[str] = [], -): +) -> Iterable[dict[str, Any]]: conditions = [] if title: diff --git a/unwind/models.py b/unwind/models.py index 5628bb0..93b51f9 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -11,6 +11,7 @@ from typing import ( Container, Literal, Mapping, + Protocol, Type, TypedDict, TypeVar, @@ -29,6 +30,11 @@ JSONObject = dict[str, JSON] T = TypeVar("T") + +class Model(Protocol): + __table__: ClassVar[Table] + + mapper_registry = registry() metadata = mapper_registry.metadata @@ -207,7 +213,7 @@ def utcnow(): @mapper_registry.mapped @dataclass class DbPatch: - __table__ = Table( + __table__: ClassVar[Table] = Table( "db_patches", metadata, Column("id", Integer, primary_key=True), @@ -224,7 +230,7 @@ db_patches = DbPatch.__table__ @mapper_registry.mapped @dataclass class Progress: - __table__ = Table( + __table__: ClassVar[Table] = Table( "progress", metadata, Column("id", String, primary_key=True), # ULID @@ -274,7 +280,7 @@ class Progress: @mapper_registry.mapped @dataclass class Movie: - __table__ = Table( + __table__: ClassVar[Table] = Table( "movies", metadata, Column("id", String, primary_key=True), # ULID @@ -352,7 +358,7 @@ Relation = Annotated[T | None, _RelationSentinel] @mapper_registry.mapped @dataclass class Rating: - __table__ = Table( + __table__: ClassVar[Table] = Table( "ratings", metadata, Column("id", String, primary_key=True), # ULID @@ -409,7 +415,7 @@ class UserGroup(TypedDict): @mapper_registry.mapped @dataclass class User: - __table__ = Table( + __table__: ClassVar[Table] = Table( "users", metadata, Column("id", String, primary_key=True), # ULID @@ -449,7 +455,7 @@ class GroupUser(TypedDict): @mapper_registry.mapped @dataclass class Group: - __table__ = Table( + __table__: ClassVar[Table] = Table( "groups", metadata, Column("id", String, primary_key=True), # ULID