improve strict typing

This commit is contained in:
ducklet 2023-03-28 23:32:24 +02:00
parent 8b5cbdf903
commit 2963a1d3f6
2 changed files with 25 additions and 16 deletions

View file

@ -11,6 +11,7 @@ from sqlalchemy.dialects.sqlite import insert
from . import config from . import config
from .models import ( from .models import (
Model,
Movie, Movie,
Progress, Progress,
Rating, Rating,
@ -76,7 +77,7 @@ async def set_current_patch_level(db: Database, current: str) -> None:
db_patches_dir = Path(__file__).parent / "sql" db_patches_dir = Path(__file__).parent / "sql"
async def apply_db_patches(db: Database): async def apply_db_patches(db: Database) -> None:
"""Apply all remaining patches to the database. """Apply all remaining patches to the database.
Beware that patches will be applied in lexicographical order, Beware that patches will be applied in lexicographical order,
@ -124,7 +125,7 @@ async def get_import_progress() -> Progress | None:
return await get(Progress, type="import-imdb-movies", order_by=("started", "desc")) return await get(Progress, type="import-imdb-movies", order_by=("started", "desc"))
async def stop_import_progress(*, error: BaseException | None = None): async def stop_import_progress(*, error: BaseException | None = None) -> None:
"""Stop the current import. """Stop the current import.
If an error is given, it will be logged to the progress state. If an error is given, it will be logged to the progress state.
@ -220,9 +221,10 @@ def shared_connection() -> Database:
return _shared_connection return _shared_connection
async def add(item): async def add(item: Model) -> None:
# Support late initializing - used for optimization. # Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False): if getattr(item, "_is_lazy", False):
assert hasattr(item, "_lazy_init")
item._lazy_init() item._lazy_init()
table: sa.Table = item.__table__ table: sa.Table = item.__table__
@ -232,7 +234,7 @@ async def add(item):
await conn.execute(stmt) await conn.execute(stmt)
ModelType = TypeVar("ModelType") ModelType = TypeVar("ModelType", bound=Model)
async def get( async def get(
@ -300,9 +302,10 @@ async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]
return (fromplain(model, row._mapping, serialized=True) for row in rows) return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def update(item): async def update(item: Model) -> None:
# Support late initializing - used for optimization. # Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False): if getattr(item, "_is_lazy", False):
assert hasattr(item, "_lazy_init")
item._lazy_init() item._lazy_init()
table: sa.Table = item.__table__ table: sa.Table = item.__table__
@ -312,7 +315,7 @@ async def update(item):
await conn.execute(stmt) await conn.execute(stmt)
async def remove(item): async def remove(item: Model) -> None:
table: sa.Table = item.__table__ table: sa.Table = item.__table__
values = asplain(item, filter_fields={"id"}, serialize=True) values = asplain(item, filter_fields={"id"}, serialize=True)
stmt = table.delete().where(table.c.id == values["id"]) stmt = table.delete().where(table.c.id == values["id"])
@ -320,7 +323,7 @@ async def remove(item):
await conn.execute(stmt) await conn.execute(stmt)
async def add_or_update_user(user: User): async def add_or_update_user(user: User) -> None:
db_user = await get(User, imdb_id=user.imdb_id) db_user = await get(User, imdb_id=user.imdb_id)
if not db_user: if not db_user:
await add(user) await add(user)
@ -331,7 +334,7 @@ async def add_or_update_user(user: User):
await update(user) await update(user)
async def add_or_update_many_movies(movies: list[Movie]): async def add_or_update_many_movies(movies: list[Movie]) -> None:
"""Add or update Movies in the database. """Add or update Movies in the database.
This is an optimized version of `add_or_update_movie` for the purpose This is an optimized version of `add_or_update_movie` for the purpose
@ -361,7 +364,7 @@ async def add_or_update_many_movies(movies: list[Movie]):
await update(movie) await update(movie)
async def add_or_update_movie(movie: Movie): async def add_or_update_movie(movie: Movie) -> None:
"""Add or update a Movie in the database. """Add or update a Movie in the database.
This is an upsert operation, but it will also update the Movie you pass This is an upsert operation, but it will also update the Movie you pass
@ -419,7 +422,7 @@ async def find_ratings(
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None, yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10, limit_rows: int = 10,
user_ids: Iterable[str] = [], user_ids: Iterable[str] = [],
): ) -> Iterable[dict[str, Any]]:
conditions = [] conditions = []
if title: if title:

View file

@ -11,6 +11,7 @@ from typing import (
Container, Container,
Literal, Literal,
Mapping, Mapping,
Protocol,
Type, Type,
TypedDict, TypedDict,
TypeVar, TypeVar,
@ -29,6 +30,11 @@ JSONObject = dict[str, JSON]
T = TypeVar("T") T = TypeVar("T")
class Model(Protocol):
__table__: ClassVar[Table]
mapper_registry = registry() mapper_registry = registry()
metadata = mapper_registry.metadata metadata = mapper_registry.metadata
@ -207,7 +213,7 @@ def utcnow():
@mapper_registry.mapped @mapper_registry.mapped
@dataclass @dataclass
class DbPatch: class DbPatch:
__table__ = Table( __table__: ClassVar[Table] = Table(
"db_patches", "db_patches",
metadata, metadata,
Column("id", Integer, primary_key=True), Column("id", Integer, primary_key=True),
@ -224,7 +230,7 @@ db_patches = DbPatch.__table__
@mapper_registry.mapped @mapper_registry.mapped
@dataclass @dataclass
class Progress: class Progress:
__table__ = Table( __table__: ClassVar[Table] = Table(
"progress", "progress",
metadata, metadata,
Column("id", String, primary_key=True), # ULID Column("id", String, primary_key=True), # ULID
@ -274,7 +280,7 @@ class Progress:
@mapper_registry.mapped @mapper_registry.mapped
@dataclass @dataclass
class Movie: class Movie:
__table__ = Table( __table__: ClassVar[Table] = Table(
"movies", "movies",
metadata, metadata,
Column("id", String, primary_key=True), # ULID Column("id", String, primary_key=True), # ULID
@ -352,7 +358,7 @@ Relation = Annotated[T | None, _RelationSentinel]
@mapper_registry.mapped @mapper_registry.mapped
@dataclass @dataclass
class Rating: class Rating:
__table__ = Table( __table__: ClassVar[Table] = Table(
"ratings", "ratings",
metadata, metadata,
Column("id", String, primary_key=True), # ULID Column("id", String, primary_key=True), # ULID
@ -409,7 +415,7 @@ class UserGroup(TypedDict):
@mapper_registry.mapped @mapper_registry.mapped
@dataclass @dataclass
class User: class User:
__table__ = Table( __table__: ClassVar[Table] = Table(
"users", "users",
metadata, metadata,
Column("id", String, primary_key=True), # ULID Column("id", String, primary_key=True), # ULID
@ -449,7 +455,7 @@ class GroupUser(TypedDict):
@mapper_registry.mapped @mapper_registry.mapped
@dataclass @dataclass
class Group: class Group:
__table__ = Table( __table__: ClassVar[Table] = Table(
"groups", "groups",
metadata, metadata,
Column("id", String, primary_key=True), # ULID Column("id", String, primary_key=True), # ULID