diff --git a/unwind/db.py b/unwind/db.py index e416d8b..a8e23d8 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -28,7 +28,7 @@ from .models import ( ratings, utcnow, ) -from .types import ULID +from .types import ULID, ImdbMovieId, UserIdStr log = logging.getLogger(__name__) @@ -432,19 +432,16 @@ async def add_or_update_rating(conn: Connection, /, rating: Rating) -> bool: return False -type MovieImdbId = str - - async def get_awards( - conn: Connection, /, imdb_ids: list[MovieImdbId] -) -> dict[MovieImdbId, list[Award]]: + conn: Connection, /, imdb_ids: list[ImdbMovieId] +) -> dict[ImdbMovieId, list[Award]]: query = ( sa.select(Award, movies.c.imdb_id) .join(movies, awards.c.movie_id == movies.c.id) .where(movies.c.imdb_id.in_(imdb_ids)) ) rows = await fetch_all(conn, query) - awards_dict: dict[MovieImdbId, list[Award]] = {} + awards_dict: dict[ImdbMovieId, list[Award]] = {} for row in rows: awards_dict.setdefault(row.imdb_id, []).append( fromplain(Award, row._mapping, serialized=True) @@ -467,7 +464,7 @@ async def find_ratings( include_unrated: bool = False, yearcomp: tuple[Literal["<", "=", ">"], int] | None = None, limit_rows: int = 10, - user_ids: Iterable[str] = [], + user_ids: Iterable[UserIdStr] = [], ) -> Iterable[dict[str, Any]]: conditions = [] diff --git a/unwind/models.py b/unwind/models.py index 2d59cd0..f952686 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -10,6 +10,7 @@ from typing import ( Container, Literal, Mapping, + NewType, Protocol, Type, TypeAliasType, @@ -22,13 +23,21 @@ from typing import ( from sqlalchemy import Column, ForeignKey, Index, Integer, String, Table from sqlalchemy.orm import registry -from .types import ULID +from .types import ( + ULID, + AwardId, + GroupId, + ImdbMovieId, + JSONObject, + JSONScalar, + MovieId, + RatingId, + Score100, + UserId, + UserIdStr, +) from .utils import json_dump -type JSONScalar = int | float | str | None -type JSON = JSONScalar | list["JSON"] | dict[str, "JSON"] -type JSONObject = dict[str, JSON] - class Model(Protocol): __table__: ClassVar[Table] @@ -115,6 +124,19 @@ def _id[T](x: T) -> T: return x +def _unpack(type_: Any) -> Any: + """Return the wrapped type.""" + # Handle type aliases. + if isinstance(type_, TypeAliasType): + return _unpack(type_.__value__) + + # Handle newtypes. + if isinstance(type_, NewType): + return _unpack(type_.__supertype__) + + return type_ + + def asplain( o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False ) -> dict[str, Any]: @@ -136,10 +158,7 @@ def asplain( if filter_fields is not None and f.name not in filter_fields: continue - target: Any = f.type - if isinstance(target, TypeAliasType): - # Support type aliases. - target = target.__value__ + target: Any = _unpack(f.type) # XXX this doesn't properly support any kind of nested types if (otype := optional_type(f.type)) is not None: @@ -147,6 +166,8 @@ def asplain( if (otype := get_origin(target)) is not None: target = otype + target = _unpack(target) + v = getattr(o, f.name) if is_optional(f.type) and v is None: d[f.name] = None @@ -188,10 +209,7 @@ def fromplain[T](cls: Type[T], d: Mapping, *, serialized: bool = False) -> T: dd: JSONObject = {} for f in fields(cls): - target: Any = f.type - if isinstance(target, TypeAliasType): - # Support type aliases. - target = target.__value__ + target: Any = _unpack(f.type) otype = optional_type(f.type) is_opt = otype is not None @@ -200,6 +218,8 @@ def fromplain[T](cls: Type[T], d: Mapping, *, serialized: bool = False) -> T: if (xtype := get_origin(target)) is not None: target = xtype + target = _unpack(target) + v = d[f.name] if is_opt and v is None: dd[f.name] = v @@ -225,10 +245,7 @@ def fromplain[T](cls: Type[T], d: Mapping, *, serialized: bool = False) -> T: def validate(o: object) -> None: for f in fields(o): - ftype = f.type - if isinstance(ftype, TypeAliasType): - # Support type aliases. - ftype = ftype.__value__ + ftype = _unpack(f.type) v = getattr(o, f.name) vtype = type(v) @@ -243,11 +260,12 @@ def validate(o: object) -> None: if is_union: # Support unioned types. utypes = get_args(ftype) + utypes = [_unpack(t) for t in utypes] if vtype in utypes: continue # Support generic types (set[str], list[int], etc.) - gtypes = [g for u in utypes if (g := get_origin(u)) is not None] + gtypes = [_unpack(g) for u in utypes if (g := get_origin(u)) is not None] if any(vtype is gtype for gtype in gtypes): continue @@ -337,15 +355,15 @@ class Movie: Column("updated", String, nullable=False), # datetime ) - id: ULID = field(default_factory=ULID) + id: MovieId = field(default_factory=ULID) title: str = None # canonical title (usually English) original_title: str | None = ( None # original title (usually transscribed to latin script) ) release_year: int = None # canonical release date media_type: str = None - imdb_id: str = None - imdb_score: int | None = None # range: [0,100] + imdb_id: ImdbMovieId = None + imdb_score: Score100 | None = None # range: [0,100] imdb_votes: int | None = None runtime: int | None = None # minutes genres: set[str] | None = None @@ -418,8 +436,8 @@ class User: Column("groups", String, nullable=False), # JSON array ) - id: ULID = field(default_factory=ULID) - imdb_id: str = None + id: UserId = field(default_factory=ULID) + imdb_id: ImdbMovieId = None name: str = None # canonical user name secret: str = None groups: list[UserGroup] = field(default_factory=list) @@ -456,15 +474,15 @@ class Rating: Column("finished", Integer), # bool ) - id: ULID = field(default_factory=ULID) + id: RatingId = field(default_factory=ULID) - movie_id: ULID = None + movie_id: MovieId = None movie: Relation[Movie] = None - user_id: ULID = None + user_id: UserId = None user: Relation[User] = None - score: int = None # range: [0,100] + score: Score100 = None # range: [0,100] rating_date: datetime = None favorite: bool | None = None finished: bool | None = None @@ -487,7 +505,7 @@ Index("ratings_index", ratings.c.movie_id, ratings.c.user_id, unique=True) class GroupUser(TypedDict): - id: str + id: UserIdStr name: str @@ -502,7 +520,7 @@ class Group: Column("users", String, nullable=False), # JSON array ) - id: ULID = field(default_factory=ULID) + id: GroupId = field(default_factory=ULID) name: str = None users: list[GroupUser] = field(default_factory=list) @@ -530,9 +548,9 @@ class Award: Column("updated", String, nullable=False), # datetime ) - id: ULID = field(default_factory=ULID) + id: AwardId = field(default_factory=ULID) - movie_id: ULID = None + movie_id: MovieId = None movie: Relation[Movie] = None category: AwardCategory = None diff --git a/unwind/types.py b/unwind/types.py index 94c0e00..76ce3e8 100644 --- a/unwind/types.py +++ b/unwind/types.py @@ -1,9 +1,13 @@ import re -from typing import cast +from typing import NewType, cast import ulid from ulid.hints import Buffer +type JSONScalar = int | float | str | None +type JSON = JSONScalar | list["JSON"] | dict[str, "JSON"] +type JSONObject = dict[str, JSON] + class ULID(ulid.ULID): """Extended ULID type. @@ -29,3 +33,14 @@ class ULID(ulid.ULID): buffer = cast(memoryview, ulid.new().memory) super().__init__(buffer) + + +AwardId = NewType("AwardId", ULID) +GroupId = NewType("GroupId", ULID) +ImdbMovieId = NewType("ImdbMovieId", str) +MovieId = NewType("MovieId", ULID) +MovieIdStr = NewType("MovieIdStr", str) +RatingId = NewType("RatingId", ULID) +Score100 = NewType("Score100", int) # [0, 100] +UserId = NewType("UserId", ULID) +UserIdStr = NewType("UserIdStr", str) diff --git a/unwind/web.py b/unwind/web.py index ee024e9..3f62a53 100644 --- a/unwind/web.py +++ b/unwind/web.py @@ -27,8 +27,8 @@ from starlette.routing import Mount, Route from . import config, db, imdb, imdb_import, web_models from .db import close_connection_pool, find_movies, find_ratings, open_connection_pool from .middleware.responsetime import ResponseTimeMiddleware -from .models import JSON, Access, Group, Movie, User, asplain -from .types import ULID +from .models import Access, Group, Movie, User, asplain +from .types import JSON, ULID from .utils import b64decode, b64encode, phc_compare, phc_scrypt log = logging.getLogger(__name__) diff --git a/unwind/web_models.py b/unwind/web_models.py index 0551ba3..42cb4dc 100644 --- a/unwind/web_models.py +++ b/unwind/web_models.py @@ -1,23 +1,22 @@ from dataclasses import dataclass from typing import Container, Iterable -from . import imdb, models +from . import imdb, models, types type URL = str -type Score100 = int # [0, 100] @dataclass class Rating: canonical_title: str - imdb_score: Score100 | None + imdb_score: types.Score100 | None imdb_votes: int | None media_type: str - movie_imdb_id: str + movie_imdb_id: types.ImdbMovieId original_title: str | None release_year: int - user_id: str | None - user_score: Score100 | None + user_id: types.UserIdStr | None + user_score: types.Score100 | None @classmethod def from_movie(cls, movie: models.Movie, *, rating: models.Rating | None = None): @@ -37,12 +36,12 @@ class Rating: @dataclass class RatingAggregate: canonical_title: str - imdb_score: Score100 | None + imdb_score: types.Score100 | None imdb_votes: int | None link: URL media_type: str original_title: str | None - user_scores: list[Score100] + user_scores: list[types.Score100] year: int awards: list[str] @@ -61,20 +60,16 @@ class RatingAggregate: ) -type ImdbMovieId = str -type UserId = str - - def aggregate_ratings( ratings: Iterable[Rating], - user_ids: Container[UserId], + user_ids: Container[types.UserIdStr], *, - awards_dict: dict[ImdbMovieId, list[models.Award]] | None = None, + awards_dict: dict[types.ImdbMovieId, list[models.Award]] | None = None, ) -> Iterable[RatingAggregate]: if awards_dict is None: awards_dict = {} - aggr: dict[ImdbMovieId, RatingAggregate] = {} + aggr: dict[types.ImdbMovieId, RatingAggregate] = {} for r in ratings: awards = awards_dict.get(r.movie_imdb_id, [])