improve strict typing

This commit is contained in:
ducklet 2023-03-28 23:32:24 +02:00
parent 8b5cbdf903
commit 2963a1d3f6
2 changed files with 25 additions and 16 deletions

View file

@ -11,6 +11,7 @@ from sqlalchemy.dialects.sqlite import insert
from . import config
from .models import (
Model,
Movie,
Progress,
Rating,
@ -76,7 +77,7 @@ async def set_current_patch_level(db: Database, current: str) -> None:
db_patches_dir = Path(__file__).parent / "sql"
async def apply_db_patches(db: Database):
async def apply_db_patches(db: Database) -> None:
"""Apply all remaining patches to the database.
Beware that patches will be applied in lexicographical order,
@ -124,7 +125,7 @@ async def get_import_progress() -> Progress | None:
return await get(Progress, type="import-imdb-movies", order_by=("started", "desc"))
async def stop_import_progress(*, error: BaseException | None = None):
async def stop_import_progress(*, error: BaseException | None = None) -> None:
"""Stop the current import.
If an error is given, it will be logged to the progress state.
@ -220,9 +221,10 @@ def shared_connection() -> Database:
return _shared_connection
async def add(item):
async def add(item: Model) -> None:
# Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False):
assert hasattr(item, "_lazy_init")
item._lazy_init()
table: sa.Table = item.__table__
@ -232,7 +234,7 @@ async def add(item):
await conn.execute(stmt)
ModelType = TypeVar("ModelType")
ModelType = TypeVar("ModelType", bound=Model)
async def get(
@ -300,9 +302,10 @@ async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]
return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def update(item):
async def update(item: Model) -> None:
# Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False):
assert hasattr(item, "_lazy_init")
item._lazy_init()
table: sa.Table = item.__table__
@ -312,7 +315,7 @@ async def update(item):
await conn.execute(stmt)
async def remove(item):
async def remove(item: Model) -> None:
table: sa.Table = item.__table__
values = asplain(item, filter_fields={"id"}, serialize=True)
stmt = table.delete().where(table.c.id == values["id"])
@ -320,7 +323,7 @@ async def remove(item):
await conn.execute(stmt)
async def add_or_update_user(user: User):
async def add_or_update_user(user: User) -> None:
db_user = await get(User, imdb_id=user.imdb_id)
if not db_user:
await add(user)
@ -331,7 +334,7 @@ async def add_or_update_user(user: User):
await update(user)
async def add_or_update_many_movies(movies: list[Movie]):
async def add_or_update_many_movies(movies: list[Movie]) -> None:
"""Add or update Movies in the database.
This is an optimized version of `add_or_update_movie` for the purpose
@ -361,7 +364,7 @@ async def add_or_update_many_movies(movies: list[Movie]):
await update(movie)
async def add_or_update_movie(movie: Movie):
async def add_or_update_movie(movie: Movie) -> None:
"""Add or update a Movie in the database.
This is an upsert operation, but it will also update the Movie you pass
@ -419,7 +422,7 @@ async def find_ratings(
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10,
user_ids: Iterable[str] = [],
):
) -> Iterable[dict[str, Any]]:
conditions = []
if title:

View file

@ -11,6 +11,7 @@ from typing import (
Container,
Literal,
Mapping,
Protocol,
Type,
TypedDict,
TypeVar,
@ -29,6 +30,11 @@ JSONObject = dict[str, JSON]
T = TypeVar("T")
class Model(Protocol):
__table__: ClassVar[Table]
mapper_registry = registry()
metadata = mapper_registry.metadata
@ -207,7 +213,7 @@ def utcnow():
@mapper_registry.mapped
@dataclass
class DbPatch:
__table__ = Table(
__table__: ClassVar[Table] = Table(
"db_patches",
metadata,
Column("id", Integer, primary_key=True),
@ -224,7 +230,7 @@ db_patches = DbPatch.__table__
@mapper_registry.mapped
@dataclass
class Progress:
__table__ = Table(
__table__: ClassVar[Table] = Table(
"progress",
metadata,
Column("id", String, primary_key=True), # ULID
@ -274,7 +280,7 @@ class Progress:
@mapper_registry.mapped
@dataclass
class Movie:
__table__ = Table(
__table__: ClassVar[Table] = Table(
"movies",
metadata,
Column("id", String, primary_key=True), # ULID
@ -352,7 +358,7 @@ Relation = Annotated[T | None, _RelationSentinel]
@mapper_registry.mapped
@dataclass
class Rating:
__table__ = Table(
__table__: ClassVar[Table] = Table(
"ratings",
metadata,
Column("id", String, primary_key=True), # ULID
@ -409,7 +415,7 @@ class UserGroup(TypedDict):
@mapper_registry.mapped
@dataclass
class User:
__table__ = Table(
__table__: ClassVar[Table] = Table(
"users",
metadata,
Column("id", String, primary_key=True), # ULID
@ -449,7 +455,7 @@ class GroupUser(TypedDict):
@mapper_registry.mapped
@dataclass
class Group:
__table__ = Table(
__table__: ClassVar[Table] = Table(
"groups",
metadata,
Column("id", String, primary_key=True), # ULID