improve strict typing
This commit is contained in:
parent
8b5cbdf903
commit
2963a1d3f6
2 changed files with 25 additions and 16 deletions
23
unwind/db.py
23
unwind/db.py
|
|
@ -11,6 +11,7 @@ from sqlalchemy.dialects.sqlite import insert
|
||||||
|
|
||||||
from . import config
|
from . import config
|
||||||
from .models import (
|
from .models import (
|
||||||
|
Model,
|
||||||
Movie,
|
Movie,
|
||||||
Progress,
|
Progress,
|
||||||
Rating,
|
Rating,
|
||||||
|
|
@ -76,7 +77,7 @@ async def set_current_patch_level(db: Database, current: str) -> None:
|
||||||
db_patches_dir = Path(__file__).parent / "sql"
|
db_patches_dir = Path(__file__).parent / "sql"
|
||||||
|
|
||||||
|
|
||||||
async def apply_db_patches(db: Database):
|
async def apply_db_patches(db: Database) -> None:
|
||||||
"""Apply all remaining patches to the database.
|
"""Apply all remaining patches to the database.
|
||||||
|
|
||||||
Beware that patches will be applied in lexicographical order,
|
Beware that patches will be applied in lexicographical order,
|
||||||
|
|
@ -124,7 +125,7 @@ async def get_import_progress() -> Progress | None:
|
||||||
return await get(Progress, type="import-imdb-movies", order_by=("started", "desc"))
|
return await get(Progress, type="import-imdb-movies", order_by=("started", "desc"))
|
||||||
|
|
||||||
|
|
||||||
async def stop_import_progress(*, error: BaseException | None = None):
|
async def stop_import_progress(*, error: BaseException | None = None) -> None:
|
||||||
"""Stop the current import.
|
"""Stop the current import.
|
||||||
|
|
||||||
If an error is given, it will be logged to the progress state.
|
If an error is given, it will be logged to the progress state.
|
||||||
|
|
@ -220,9 +221,10 @@ def shared_connection() -> Database:
|
||||||
return _shared_connection
|
return _shared_connection
|
||||||
|
|
||||||
|
|
||||||
async def add(item):
|
async def add(item: Model) -> None:
|
||||||
# Support late initializing - used for optimization.
|
# Support late initializing - used for optimization.
|
||||||
if getattr(item, "_is_lazy", False):
|
if getattr(item, "_is_lazy", False):
|
||||||
|
assert hasattr(item, "_lazy_init")
|
||||||
item._lazy_init()
|
item._lazy_init()
|
||||||
|
|
||||||
table: sa.Table = item.__table__
|
table: sa.Table = item.__table__
|
||||||
|
|
@ -232,7 +234,7 @@ async def add(item):
|
||||||
await conn.execute(stmt)
|
await conn.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType")
|
ModelType = TypeVar("ModelType", bound=Model)
|
||||||
|
|
||||||
|
|
||||||
async def get(
|
async def get(
|
||||||
|
|
@ -300,9 +302,10 @@ async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]
|
||||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
||||||
async def update(item):
|
async def update(item: Model) -> None:
|
||||||
# Support late initializing - used for optimization.
|
# Support late initializing - used for optimization.
|
||||||
if getattr(item, "_is_lazy", False):
|
if getattr(item, "_is_lazy", False):
|
||||||
|
assert hasattr(item, "_lazy_init")
|
||||||
item._lazy_init()
|
item._lazy_init()
|
||||||
|
|
||||||
table: sa.Table = item.__table__
|
table: sa.Table = item.__table__
|
||||||
|
|
@ -312,7 +315,7 @@ async def update(item):
|
||||||
await conn.execute(stmt)
|
await conn.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
async def remove(item):
|
async def remove(item: Model) -> None:
|
||||||
table: sa.Table = item.__table__
|
table: sa.Table = item.__table__
|
||||||
values = asplain(item, filter_fields={"id"}, serialize=True)
|
values = asplain(item, filter_fields={"id"}, serialize=True)
|
||||||
stmt = table.delete().where(table.c.id == values["id"])
|
stmt = table.delete().where(table.c.id == values["id"])
|
||||||
|
|
@ -320,7 +323,7 @@ async def remove(item):
|
||||||
await conn.execute(stmt)
|
await conn.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
async def add_or_update_user(user: User):
|
async def add_or_update_user(user: User) -> None:
|
||||||
db_user = await get(User, imdb_id=user.imdb_id)
|
db_user = await get(User, imdb_id=user.imdb_id)
|
||||||
if not db_user:
|
if not db_user:
|
||||||
await add(user)
|
await add(user)
|
||||||
|
|
@ -331,7 +334,7 @@ async def add_or_update_user(user: User):
|
||||||
await update(user)
|
await update(user)
|
||||||
|
|
||||||
|
|
||||||
async def add_or_update_many_movies(movies: list[Movie]):
|
async def add_or_update_many_movies(movies: list[Movie]) -> None:
|
||||||
"""Add or update Movies in the database.
|
"""Add or update Movies in the database.
|
||||||
|
|
||||||
This is an optimized version of `add_or_update_movie` for the purpose
|
This is an optimized version of `add_or_update_movie` for the purpose
|
||||||
|
|
@ -361,7 +364,7 @@ async def add_or_update_many_movies(movies: list[Movie]):
|
||||||
await update(movie)
|
await update(movie)
|
||||||
|
|
||||||
|
|
||||||
async def add_or_update_movie(movie: Movie):
|
async def add_or_update_movie(movie: Movie) -> None:
|
||||||
"""Add or update a Movie in the database.
|
"""Add or update a Movie in the database.
|
||||||
|
|
||||||
This is an upsert operation, but it will also update the Movie you pass
|
This is an upsert operation, but it will also update the Movie you pass
|
||||||
|
|
@ -419,7 +422,7 @@ async def find_ratings(
|
||||||
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
|
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
|
||||||
limit_rows: int = 10,
|
limit_rows: int = 10,
|
||||||
user_ids: Iterable[str] = [],
|
user_ids: Iterable[str] = [],
|
||||||
):
|
) -> Iterable[dict[str, Any]]:
|
||||||
conditions = []
|
conditions = []
|
||||||
|
|
||||||
if title:
|
if title:
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ from typing import (
|
||||||
Container,
|
Container,
|
||||||
Literal,
|
Literal,
|
||||||
Mapping,
|
Mapping,
|
||||||
|
Protocol,
|
||||||
Type,
|
Type,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
|
@ -29,6 +30,11 @@ JSONObject = dict[str, JSON]
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class Model(Protocol):
|
||||||
|
__table__: ClassVar[Table]
|
||||||
|
|
||||||
|
|
||||||
mapper_registry = registry()
|
mapper_registry = registry()
|
||||||
metadata = mapper_registry.metadata
|
metadata = mapper_registry.metadata
|
||||||
|
|
||||||
|
|
@ -207,7 +213,7 @@ def utcnow():
|
||||||
@mapper_registry.mapped
|
@mapper_registry.mapped
|
||||||
@dataclass
|
@dataclass
|
||||||
class DbPatch:
|
class DbPatch:
|
||||||
__table__ = Table(
|
__table__: ClassVar[Table] = Table(
|
||||||
"db_patches",
|
"db_patches",
|
||||||
metadata,
|
metadata,
|
||||||
Column("id", Integer, primary_key=True),
|
Column("id", Integer, primary_key=True),
|
||||||
|
|
@ -224,7 +230,7 @@ db_patches = DbPatch.__table__
|
||||||
@mapper_registry.mapped
|
@mapper_registry.mapped
|
||||||
@dataclass
|
@dataclass
|
||||||
class Progress:
|
class Progress:
|
||||||
__table__ = Table(
|
__table__: ClassVar[Table] = Table(
|
||||||
"progress",
|
"progress",
|
||||||
metadata,
|
metadata,
|
||||||
Column("id", String, primary_key=True), # ULID
|
Column("id", String, primary_key=True), # ULID
|
||||||
|
|
@ -274,7 +280,7 @@ class Progress:
|
||||||
@mapper_registry.mapped
|
@mapper_registry.mapped
|
||||||
@dataclass
|
@dataclass
|
||||||
class Movie:
|
class Movie:
|
||||||
__table__ = Table(
|
__table__: ClassVar[Table] = Table(
|
||||||
"movies",
|
"movies",
|
||||||
metadata,
|
metadata,
|
||||||
Column("id", String, primary_key=True), # ULID
|
Column("id", String, primary_key=True), # ULID
|
||||||
|
|
@ -352,7 +358,7 @@ Relation = Annotated[T | None, _RelationSentinel]
|
||||||
@mapper_registry.mapped
|
@mapper_registry.mapped
|
||||||
@dataclass
|
@dataclass
|
||||||
class Rating:
|
class Rating:
|
||||||
__table__ = Table(
|
__table__: ClassVar[Table] = Table(
|
||||||
"ratings",
|
"ratings",
|
||||||
metadata,
|
metadata,
|
||||||
Column("id", String, primary_key=True), # ULID
|
Column("id", String, primary_key=True), # ULID
|
||||||
|
|
@ -409,7 +415,7 @@ class UserGroup(TypedDict):
|
||||||
@mapper_registry.mapped
|
@mapper_registry.mapped
|
||||||
@dataclass
|
@dataclass
|
||||||
class User:
|
class User:
|
||||||
__table__ = Table(
|
__table__: ClassVar[Table] = Table(
|
||||||
"users",
|
"users",
|
||||||
metadata,
|
metadata,
|
||||||
Column("id", String, primary_key=True), # ULID
|
Column("id", String, primary_key=True), # ULID
|
||||||
|
|
@ -449,7 +455,7 @@ class GroupUser(TypedDict):
|
||||||
@mapper_registry.mapped
|
@mapper_registry.mapped
|
||||||
@dataclass
|
@dataclass
|
||||||
class Group:
|
class Group:
|
||||||
__table__ = Table(
|
__table__: ClassVar[Table] = Table(
|
||||||
"groups",
|
"groups",
|
||||||
metadata,
|
metadata,
|
||||||
Column("id", String, primary_key=True), # ULID
|
Column("id", String, primary_key=True), # ULID
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue