chore: typing

This commit is contained in:
ducklet 2024-05-19 22:49:46 +02:00
parent b0f5ec4cc9
commit dd39849b8d
5 changed files with 82 additions and 57 deletions

View file

@ -28,7 +28,7 @@ from .models import (
ratings,
utcnow,
)
from .types import ULID
from .types import ULID, ImdbMovieId, UserIdStr
log = logging.getLogger(__name__)
@ -432,19 +432,16 @@ async def add_or_update_rating(conn: Connection, /, rating: Rating) -> bool:
return False
type MovieImdbId = str
async def get_awards(
conn: Connection, /, imdb_ids: list[MovieImdbId]
) -> dict[MovieImdbId, list[Award]]:
conn: Connection, /, imdb_ids: list[ImdbMovieId]
) -> dict[ImdbMovieId, list[Award]]:
query = (
sa.select(Award, movies.c.imdb_id)
.join(movies, awards.c.movie_id == movies.c.id)
.where(movies.c.imdb_id.in_(imdb_ids))
)
rows = await fetch_all(conn, query)
awards_dict: dict[MovieImdbId, list[Award]] = {}
awards_dict: dict[ImdbMovieId, list[Award]] = {}
for row in rows:
awards_dict.setdefault(row.imdb_id, []).append(
fromplain(Award, row._mapping, serialized=True)
@ -467,7 +464,7 @@ async def find_ratings(
include_unrated: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10,
user_ids: Iterable[str] = [],
user_ids: Iterable[UserIdStr] = [],
) -> Iterable[dict[str, Any]]:
conditions = []

View file

@ -10,6 +10,7 @@ from typing import (
Container,
Literal,
Mapping,
NewType,
Protocol,
Type,
TypeAliasType,
@ -22,13 +23,21 @@ from typing import (
from sqlalchemy import Column, ForeignKey, Index, Integer, String, Table
from sqlalchemy.orm import registry
from .types import ULID
from .types import (
ULID,
AwardId,
GroupId,
ImdbMovieId,
JSONObject,
JSONScalar,
MovieId,
RatingId,
Score100,
UserId,
UserIdStr,
)
from .utils import json_dump
type JSONScalar = int | float | str | None
type JSON = JSONScalar | list["JSON"] | dict[str, "JSON"]
type JSONObject = dict[str, JSON]
class Model(Protocol):
__table__: ClassVar[Table]
@ -115,6 +124,19 @@ def _id[T](x: T) -> T:
return x
def _unpack(type_: Any) -> Any:
"""Return the wrapped type."""
# Handle type aliases.
if isinstance(type_, TypeAliasType):
return _unpack(type_.__value__)
# Handle newtypes.
if isinstance(type_, NewType):
return _unpack(type_.__supertype__)
return type_
def asplain(
o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False
) -> dict[str, Any]:
@ -136,10 +158,7 @@ def asplain(
if filter_fields is not None and f.name not in filter_fields:
continue
target: Any = f.type
if isinstance(target, TypeAliasType):
# Support type aliases.
target = target.__value__
target: Any = _unpack(f.type)
# XXX this doesn't properly support any kind of nested types
if (otype := optional_type(f.type)) is not None:
@ -147,6 +166,8 @@ def asplain(
if (otype := get_origin(target)) is not None:
target = otype
target = _unpack(target)
v = getattr(o, f.name)
if is_optional(f.type) and v is None:
d[f.name] = None
@ -188,10 +209,7 @@ def fromplain[T](cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
dd: JSONObject = {}
for f in fields(cls):
target: Any = f.type
if isinstance(target, TypeAliasType):
# Support type aliases.
target = target.__value__
target: Any = _unpack(f.type)
otype = optional_type(f.type)
is_opt = otype is not None
@ -200,6 +218,8 @@ def fromplain[T](cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
if (xtype := get_origin(target)) is not None:
target = xtype
target = _unpack(target)
v = d[f.name]
if is_opt and v is None:
dd[f.name] = v
@ -225,10 +245,7 @@ def fromplain[T](cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
def validate(o: object) -> None:
for f in fields(o):
ftype = f.type
if isinstance(ftype, TypeAliasType):
# Support type aliases.
ftype = ftype.__value__
ftype = _unpack(f.type)
v = getattr(o, f.name)
vtype = type(v)
@ -243,11 +260,12 @@ def validate(o: object) -> None:
if is_union:
# Support unioned types.
utypes = get_args(ftype)
utypes = [_unpack(t) for t in utypes]
if vtype in utypes:
continue
# Support generic types (set[str], list[int], etc.)
gtypes = [g for u in utypes if (g := get_origin(u)) is not None]
gtypes = [_unpack(g) for u in utypes if (g := get_origin(u)) is not None]
if any(vtype is gtype for gtype in gtypes):
continue
@ -337,15 +355,15 @@ class Movie:
Column("updated", String, nullable=False), # datetime
)
id: ULID = field(default_factory=ULID)
id: MovieId = field(default_factory=ULID)
title: str = None # canonical title (usually English)
original_title: str | None = (
None # original title (usually transscribed to latin script)
)
release_year: int = None # canonical release date
media_type: str = None
imdb_id: str = None
imdb_score: int | None = None # range: [0,100]
imdb_id: ImdbMovieId = None
imdb_score: Score100 | None = None # range: [0,100]
imdb_votes: int | None = None
runtime: int | None = None # minutes
genres: set[str] | None = None
@ -418,8 +436,8 @@ class User:
Column("groups", String, nullable=False), # JSON array
)
id: ULID = field(default_factory=ULID)
imdb_id: str = None
id: UserId = field(default_factory=ULID)
imdb_id: ImdbMovieId = None
name: str = None # canonical user name
secret: str = None
groups: list[UserGroup] = field(default_factory=list)
@ -456,15 +474,15 @@ class Rating:
Column("finished", Integer), # bool
)
id: ULID = field(default_factory=ULID)
id: RatingId = field(default_factory=ULID)
movie_id: ULID = None
movie_id: MovieId = None
movie: Relation[Movie] = None
user_id: ULID = None
user_id: UserId = None
user: Relation[User] = None
score: int = None # range: [0,100]
score: Score100 = None # range: [0,100]
rating_date: datetime = None
favorite: bool | None = None
finished: bool | None = None
@ -487,7 +505,7 @@ Index("ratings_index", ratings.c.movie_id, ratings.c.user_id, unique=True)
class GroupUser(TypedDict):
id: str
id: UserIdStr
name: str
@ -502,7 +520,7 @@ class Group:
Column("users", String, nullable=False), # JSON array
)
id: ULID = field(default_factory=ULID)
id: GroupId = field(default_factory=ULID)
name: str = None
users: list[GroupUser] = field(default_factory=list)
@ -530,9 +548,9 @@ class Award:
Column("updated", String, nullable=False), # datetime
)
id: ULID = field(default_factory=ULID)
id: AwardId = field(default_factory=ULID)
movie_id: ULID = None
movie_id: MovieId = None
movie: Relation[Movie] = None
category: AwardCategory = None

View file

@ -1,9 +1,13 @@
import re
from typing import cast
from typing import NewType, cast
import ulid
from ulid.hints import Buffer
type JSONScalar = int | float | str | None
type JSON = JSONScalar | list["JSON"] | dict[str, "JSON"]
type JSONObject = dict[str, JSON]
class ULID(ulid.ULID):
"""Extended ULID type.
@ -29,3 +33,14 @@ class ULID(ulid.ULID):
buffer = cast(memoryview, ulid.new().memory)
super().__init__(buffer)
AwardId = NewType("AwardId", ULID)
GroupId = NewType("GroupId", ULID)
ImdbMovieId = NewType("ImdbMovieId", str)
MovieId = NewType("MovieId", ULID)
MovieIdStr = NewType("MovieIdStr", str)
RatingId = NewType("RatingId", ULID)
Score100 = NewType("Score100", int) # [0, 100]
UserId = NewType("UserId", ULID)
UserIdStr = NewType("UserIdStr", str)

View file

@ -27,8 +27,8 @@ from starlette.routing import Mount, Route
from . import config, db, imdb, imdb_import, web_models
from .db import close_connection_pool, find_movies, find_ratings, open_connection_pool
from .middleware.responsetime import ResponseTimeMiddleware
from .models import JSON, Access, Group, Movie, User, asplain
from .types import ULID
from .models import Access, Group, Movie, User, asplain
from .types import JSON, ULID
from .utils import b64decode, b64encode, phc_compare, phc_scrypt
log = logging.getLogger(__name__)

View file

@ -1,23 +1,22 @@
from dataclasses import dataclass
from typing import Container, Iterable
from . import imdb, models
from . import imdb, models, types
type URL = str
type Score100 = int # [0, 100]
@dataclass
class Rating:
canonical_title: str
imdb_score: Score100 | None
imdb_score: types.Score100 | None
imdb_votes: int | None
media_type: str
movie_imdb_id: str
movie_imdb_id: types.ImdbMovieId
original_title: str | None
release_year: int
user_id: str | None
user_score: Score100 | None
user_id: types.UserIdStr | None
user_score: types.Score100 | None
@classmethod
def from_movie(cls, movie: models.Movie, *, rating: models.Rating | None = None):
@ -37,12 +36,12 @@ class Rating:
@dataclass
class RatingAggregate:
canonical_title: str
imdb_score: Score100 | None
imdb_score: types.Score100 | None
imdb_votes: int | None
link: URL
media_type: str
original_title: str | None
user_scores: list[Score100]
user_scores: list[types.Score100]
year: int
awards: list[str]
@ -61,20 +60,16 @@ class RatingAggregate:
)
type ImdbMovieId = str
type UserId = str
def aggregate_ratings(
ratings: Iterable[Rating],
user_ids: Container[UserId],
user_ids: Container[types.UserIdStr],
*,
awards_dict: dict[ImdbMovieId, list[models.Award]] | None = None,
awards_dict: dict[types.ImdbMovieId, list[models.Award]] | None = None,
) -> Iterable[RatingAggregate]:
if awards_dict is None:
awards_dict = {}
aggr: dict[ImdbMovieId, RatingAggregate] = {}
aggr: dict[types.ImdbMovieId, RatingAggregate] = {}
for r in ratings:
awards = awards_dict.get(r.movie_imdb_id, [])