chore: typing

This commit is contained in:
ducklet 2024-05-19 22:49:46 +02:00
parent b0f5ec4cc9
commit dd39849b8d
5 changed files with 82 additions and 57 deletions

View file

@ -28,7 +28,7 @@ from .models import (
ratings, ratings,
utcnow, utcnow,
) )
from .types import ULID from .types import ULID, ImdbMovieId, UserIdStr
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -432,19 +432,16 @@ async def add_or_update_rating(conn: Connection, /, rating: Rating) -> bool:
return False return False
type MovieImdbId = str
async def get_awards( async def get_awards(
conn: Connection, /, imdb_ids: list[MovieImdbId] conn: Connection, /, imdb_ids: list[ImdbMovieId]
) -> dict[MovieImdbId, list[Award]]: ) -> dict[ImdbMovieId, list[Award]]:
query = ( query = (
sa.select(Award, movies.c.imdb_id) sa.select(Award, movies.c.imdb_id)
.join(movies, awards.c.movie_id == movies.c.id) .join(movies, awards.c.movie_id == movies.c.id)
.where(movies.c.imdb_id.in_(imdb_ids)) .where(movies.c.imdb_id.in_(imdb_ids))
) )
rows = await fetch_all(conn, query) rows = await fetch_all(conn, query)
awards_dict: dict[MovieImdbId, list[Award]] = {} awards_dict: dict[ImdbMovieId, list[Award]] = {}
for row in rows: for row in rows:
awards_dict.setdefault(row.imdb_id, []).append( awards_dict.setdefault(row.imdb_id, []).append(
fromplain(Award, row._mapping, serialized=True) fromplain(Award, row._mapping, serialized=True)
@ -467,7 +464,7 @@ async def find_ratings(
include_unrated: bool = False, include_unrated: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None, yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10, limit_rows: int = 10,
user_ids: Iterable[str] = [], user_ids: Iterable[UserIdStr] = [],
) -> Iterable[dict[str, Any]]: ) -> Iterable[dict[str, Any]]:
conditions = [] conditions = []

View file

@ -10,6 +10,7 @@ from typing import (
Container, Container,
Literal, Literal,
Mapping, Mapping,
NewType,
Protocol, Protocol,
Type, Type,
TypeAliasType, TypeAliasType,
@ -22,13 +23,21 @@ from typing import (
from sqlalchemy import Column, ForeignKey, Index, Integer, String, Table from sqlalchemy import Column, ForeignKey, Index, Integer, String, Table
from sqlalchemy.orm import registry from sqlalchemy.orm import registry
from .types import ULID from .types import (
ULID,
AwardId,
GroupId,
ImdbMovieId,
JSONObject,
JSONScalar,
MovieId,
RatingId,
Score100,
UserId,
UserIdStr,
)
from .utils import json_dump from .utils import json_dump
type JSONScalar = int | float | str | None
type JSON = JSONScalar | list["JSON"] | dict[str, "JSON"]
type JSONObject = dict[str, JSON]
class Model(Protocol): class Model(Protocol):
__table__: ClassVar[Table] __table__: ClassVar[Table]
@ -115,6 +124,19 @@ def _id[T](x: T) -> T:
return x return x
def _unpack(type_: Any) -> Any:
"""Return the wrapped type."""
# Handle type aliases.
if isinstance(type_, TypeAliasType):
return _unpack(type_.__value__)
# Handle newtypes.
if isinstance(type_, NewType):
return _unpack(type_.__supertype__)
return type_
def asplain( def asplain(
o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False
) -> dict[str, Any]: ) -> dict[str, Any]:
@ -136,10 +158,7 @@ def asplain(
if filter_fields is not None and f.name not in filter_fields: if filter_fields is not None and f.name not in filter_fields:
continue continue
target: Any = f.type target: Any = _unpack(f.type)
if isinstance(target, TypeAliasType):
# Support type aliases.
target = target.__value__
# XXX this doesn't properly support any kind of nested types # XXX this doesn't properly support any kind of nested types
if (otype := optional_type(f.type)) is not None: if (otype := optional_type(f.type)) is not None:
@ -147,6 +166,8 @@ def asplain(
if (otype := get_origin(target)) is not None: if (otype := get_origin(target)) is not None:
target = otype target = otype
target = _unpack(target)
v = getattr(o, f.name) v = getattr(o, f.name)
if is_optional(f.type) and v is None: if is_optional(f.type) and v is None:
d[f.name] = None d[f.name] = None
@ -188,10 +209,7 @@ def fromplain[T](cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
dd: JSONObject = {} dd: JSONObject = {}
for f in fields(cls): for f in fields(cls):
target: Any = f.type target: Any = _unpack(f.type)
if isinstance(target, TypeAliasType):
# Support type aliases.
target = target.__value__
otype = optional_type(f.type) otype = optional_type(f.type)
is_opt = otype is not None is_opt = otype is not None
@ -200,6 +218,8 @@ def fromplain[T](cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
if (xtype := get_origin(target)) is not None: if (xtype := get_origin(target)) is not None:
target = xtype target = xtype
target = _unpack(target)
v = d[f.name] v = d[f.name]
if is_opt and v is None: if is_opt and v is None:
dd[f.name] = v dd[f.name] = v
@ -225,10 +245,7 @@ def fromplain[T](cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
def validate(o: object) -> None: def validate(o: object) -> None:
for f in fields(o): for f in fields(o):
ftype = f.type ftype = _unpack(f.type)
if isinstance(ftype, TypeAliasType):
# Support type aliases.
ftype = ftype.__value__
v = getattr(o, f.name) v = getattr(o, f.name)
vtype = type(v) vtype = type(v)
@ -243,11 +260,12 @@ def validate(o: object) -> None:
if is_union: if is_union:
# Support unioned types. # Support unioned types.
utypes = get_args(ftype) utypes = get_args(ftype)
utypes = [_unpack(t) for t in utypes]
if vtype in utypes: if vtype in utypes:
continue continue
# Support generic types (set[str], list[int], etc.) # Support generic types (set[str], list[int], etc.)
gtypes = [g for u in utypes if (g := get_origin(u)) is not None] gtypes = [_unpack(g) for u in utypes if (g := get_origin(u)) is not None]
if any(vtype is gtype for gtype in gtypes): if any(vtype is gtype for gtype in gtypes):
continue continue
@ -337,15 +355,15 @@ class Movie:
Column("updated", String, nullable=False), # datetime Column("updated", String, nullable=False), # datetime
) )
id: ULID = field(default_factory=ULID) id: MovieId = field(default_factory=ULID)
title: str = None # canonical title (usually English) title: str = None # canonical title (usually English)
original_title: str | None = ( original_title: str | None = (
None # original title (usually transscribed to latin script) None # original title (usually transscribed to latin script)
) )
release_year: int = None # canonical release date release_year: int = None # canonical release date
media_type: str = None media_type: str = None
imdb_id: str = None imdb_id: ImdbMovieId = None
imdb_score: int | None = None # range: [0,100] imdb_score: Score100 | None = None # range: [0,100]
imdb_votes: int | None = None imdb_votes: int | None = None
runtime: int | None = None # minutes runtime: int | None = None # minutes
genres: set[str] | None = None genres: set[str] | None = None
@ -418,8 +436,8 @@ class User:
Column("groups", String, nullable=False), # JSON array Column("groups", String, nullable=False), # JSON array
) )
id: ULID = field(default_factory=ULID) id: UserId = field(default_factory=ULID)
imdb_id: str = None imdb_id: ImdbMovieId = None
name: str = None # canonical user name name: str = None # canonical user name
secret: str = None secret: str = None
groups: list[UserGroup] = field(default_factory=list) groups: list[UserGroup] = field(default_factory=list)
@ -456,15 +474,15 @@ class Rating:
Column("finished", Integer), # bool Column("finished", Integer), # bool
) )
id: ULID = field(default_factory=ULID) id: RatingId = field(default_factory=ULID)
movie_id: ULID = None movie_id: MovieId = None
movie: Relation[Movie] = None movie: Relation[Movie] = None
user_id: ULID = None user_id: UserId = None
user: Relation[User] = None user: Relation[User] = None
score: int = None # range: [0,100] score: Score100 = None # range: [0,100]
rating_date: datetime = None rating_date: datetime = None
favorite: bool | None = None favorite: bool | None = None
finished: bool | None = None finished: bool | None = None
@ -487,7 +505,7 @@ Index("ratings_index", ratings.c.movie_id, ratings.c.user_id, unique=True)
class GroupUser(TypedDict): class GroupUser(TypedDict):
id: str id: UserIdStr
name: str name: str
@ -502,7 +520,7 @@ class Group:
Column("users", String, nullable=False), # JSON array Column("users", String, nullable=False), # JSON array
) )
id: ULID = field(default_factory=ULID) id: GroupId = field(default_factory=ULID)
name: str = None name: str = None
users: list[GroupUser] = field(default_factory=list) users: list[GroupUser] = field(default_factory=list)
@ -530,9 +548,9 @@ class Award:
Column("updated", String, nullable=False), # datetime Column("updated", String, nullable=False), # datetime
) )
id: ULID = field(default_factory=ULID) id: AwardId = field(default_factory=ULID)
movie_id: ULID = None movie_id: MovieId = None
movie: Relation[Movie] = None movie: Relation[Movie] = None
category: AwardCategory = None category: AwardCategory = None

View file

@ -1,9 +1,13 @@
import re import re
from typing import cast from typing import NewType, cast
import ulid import ulid
from ulid.hints import Buffer from ulid.hints import Buffer
type JSONScalar = int | float | str | None
type JSON = JSONScalar | list["JSON"] | dict[str, "JSON"]
type JSONObject = dict[str, JSON]
class ULID(ulid.ULID): class ULID(ulid.ULID):
"""Extended ULID type. """Extended ULID type.
@ -29,3 +33,14 @@ class ULID(ulid.ULID):
buffer = cast(memoryview, ulid.new().memory) buffer = cast(memoryview, ulid.new().memory)
super().__init__(buffer) super().__init__(buffer)
AwardId = NewType("AwardId", ULID)
GroupId = NewType("GroupId", ULID)
ImdbMovieId = NewType("ImdbMovieId", str)
MovieId = NewType("MovieId", ULID)
MovieIdStr = NewType("MovieIdStr", str)
RatingId = NewType("RatingId", ULID)
Score100 = NewType("Score100", int) # [0, 100]
UserId = NewType("UserId", ULID)
UserIdStr = NewType("UserIdStr", str)

View file

@ -27,8 +27,8 @@ from starlette.routing import Mount, Route
from . import config, db, imdb, imdb_import, web_models from . import config, db, imdb, imdb_import, web_models
from .db import close_connection_pool, find_movies, find_ratings, open_connection_pool from .db import close_connection_pool, find_movies, find_ratings, open_connection_pool
from .middleware.responsetime import ResponseTimeMiddleware from .middleware.responsetime import ResponseTimeMiddleware
from .models import JSON, Access, Group, Movie, User, asplain from .models import Access, Group, Movie, User, asplain
from .types import ULID from .types import JSON, ULID
from .utils import b64decode, b64encode, phc_compare, phc_scrypt from .utils import b64decode, b64encode, phc_compare, phc_scrypt
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -1,23 +1,22 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Container, Iterable from typing import Container, Iterable
from . import imdb, models from . import imdb, models, types
type URL = str type URL = str
type Score100 = int # [0, 100]
@dataclass @dataclass
class Rating: class Rating:
canonical_title: str canonical_title: str
imdb_score: Score100 | None imdb_score: types.Score100 | None
imdb_votes: int | None imdb_votes: int | None
media_type: str media_type: str
movie_imdb_id: str movie_imdb_id: types.ImdbMovieId
original_title: str | None original_title: str | None
release_year: int release_year: int
user_id: str | None user_id: types.UserIdStr | None
user_score: Score100 | None user_score: types.Score100 | None
@classmethod @classmethod
def from_movie(cls, movie: models.Movie, *, rating: models.Rating | None = None): def from_movie(cls, movie: models.Movie, *, rating: models.Rating | None = None):
@ -37,12 +36,12 @@ class Rating:
@dataclass @dataclass
class RatingAggregate: class RatingAggregate:
canonical_title: str canonical_title: str
imdb_score: Score100 | None imdb_score: types.Score100 | None
imdb_votes: int | None imdb_votes: int | None
link: URL link: URL
media_type: str media_type: str
original_title: str | None original_title: str | None
user_scores: list[Score100] user_scores: list[types.Score100]
year: int year: int
awards: list[str] awards: list[str]
@ -61,20 +60,16 @@ class RatingAggregate:
) )
type ImdbMovieId = str
type UserId = str
def aggregate_ratings( def aggregate_ratings(
ratings: Iterable[Rating], ratings: Iterable[Rating],
user_ids: Container[UserId], user_ids: Container[types.UserIdStr],
*, *,
awards_dict: dict[ImdbMovieId, list[models.Award]] | None = None, awards_dict: dict[types.ImdbMovieId, list[models.Award]] | None = None,
) -> Iterable[RatingAggregate]: ) -> Iterable[RatingAggregate]:
if awards_dict is None: if awards_dict is None:
awards_dict = {} awards_dict = {}
aggr: dict[ImdbMovieId, RatingAggregate] = {} aggr: dict[types.ImdbMovieId, RatingAggregate] = {}
for r in ratings: for r in ratings:
awards = awards_dict.get(r.movie_imdb_id, []) awards = awards_dict.get(r.movie_imdb_id, [])