diff --git a/unwind/db.py b/unwind/db.py index 38b0107..16351c7 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -1,7 +1,7 @@ import contextlib import logging from pathlib import Path -from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type, TypeVar +from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine @@ -29,7 +29,6 @@ from .models import ( from .types import ULID log = logging.getLogger(__name__) -T = TypeVar("T") _engine: AsyncEngine | None = None @@ -240,7 +239,7 @@ async def add(conn: Connection, /, item: Model) -> None: # Support late initializing - used for optimization. if getattr(item, "_is_lazy", False): assert hasattr(item, "_lazy_init") - item._lazy_init() # pyright: ignore [reportGeneralTypeIssues] + item._lazy_init() # pyright: ignore[reportAttributeAccessIssue] table: sa.Table = item.__table__ values = asplain(item, serialize=True) @@ -262,17 +261,14 @@ async def fetch_one( return result.first() -ModelType = TypeVar("ModelType", bound=Model) - - -async def get( +async def get[T: Model]( conn: Connection, /, - model: Type[ModelType], + model: Type[T], *, order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None, **field_values, -) -> ModelType | None: +) -> T | None: """Load a model instance from the database. Passing `field_values` allows to filter the item to load. You have to encode the @@ -295,9 +291,9 @@ async def get( return fromplain(model, row._mapping, serialized=True) if row else None -async def get_many( - conn: Connection, /, model: Type[ModelType], **field_sets: set | list -) -> Iterable[ModelType]: +async def get_many[T: Model]( + conn: Connection, /, model: Type[T], **field_sets: set | list +) -> Iterable[T]: """Return the items with any values matching all given field sets. This is similar to `get_all`, but instead of a scalar value a list of values @@ -314,9 +310,9 @@ async def get_many( return (fromplain(model, row._mapping, serialized=True) for row in rows) -async def get_all( - conn: Connection, /, model: Type[ModelType], **field_values -) -> Iterable[ModelType]: +async def get_all[T: Model]( + conn: Connection, /, model: Type[T], **field_values +) -> Iterable[T]: """Filter all items by comparing all given field values. If no filters are given, all items will be returned. @@ -333,7 +329,7 @@ async def update(conn: Connection, /, item: Model) -> None: # Support late initializing - used for optimization. if getattr(item, "_is_lazy", False): assert hasattr(item, "_lazy_init") - item._lazy_init() # pyright: ignore [reportGeneralTypeIssues] + item._lazy_init() # pyright: ignore[reportAttributeAccessIssue] table: sa.Table = item.__table__ values = asplain(item, serialize=True) diff --git a/unwind/imdb_import.py b/unwind/imdb_import.py index 5464df0..28792e2 100644 --- a/unwind/imdb_import.py +++ b/unwind/imdb_import.py @@ -5,7 +5,7 @@ import logging from dataclasses import dataclass, fields from datetime import datetime, timezone from pathlib import Path -from typing import Generator, Literal, Type, TypeVar, overload +from typing import Generator, Literal, Type, overload from . import config, db, request from .db import add_or_update_many_movies @@ -14,8 +14,6 @@ from .models import Movie log = logging.getLogger(__name__) -T = TypeVar("T") - # See # - https://developer.imdb.com/non-commercial-datasets/ # - https://datasets.imdbws.com/ @@ -127,7 +125,7 @@ def read_imdb_tsv( @overload -def read_imdb_tsv( +def read_imdb_tsv[T]( path: Path, row_type: Type[T], *, unpack: Literal[True] = True ) -> Generator[T, None, None]: ... diff --git a/unwind/models.py b/unwind/models.py index 07b81fb..31e18c9 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -12,10 +12,8 @@ from typing import ( Mapping, Protocol, Type, - TypeAlias, TypeAliasType, TypedDict, - TypeVar, Union, get_args, get_origin, @@ -27,12 +25,10 @@ from sqlalchemy.orm import registry from .types import ULID from .utils import json_dump -JSONScalar: TypeAlias = int | float | str | None +type JSONScalar = int | float | str | None type JSON = JSONScalar | list["JSON"] | dict[str, "JSON"] type JSONObject = dict[str, JSON] -T = TypeVar("T") - class Model(Protocol): __table__: ClassVar[Table] @@ -53,6 +49,10 @@ metadata.naming_convention = { def annotations(tp: Type) -> tuple | None: + # Support type aliases and generic aliases. + if isinstance(tp, TypeAliasType) or hasattr(tp, "__value__"): + tp = tp.__value__ + return tp.__metadata__ if hasattr(tp, "__metadata__") else None # type: ignore @@ -110,7 +110,7 @@ def optional_fields(o): yield f -def _id(x: T) -> T: +def _id[T](x: T) -> T: """Return the given argument, aka. the identity function.""" return x @@ -168,7 +168,7 @@ def asplain( ), f"Type mismatch: {f.name!a} ({target!a} != {type(v)!a})" d[f.name] = v elif target in {Literal}: - assert isinstance(v, JSONScalar) + assert isinstance(v, JSONScalar.__value__) d[f.name] = v else: raise ValueError(f"Unsupported value type: {f.name!a}: {type(v)!a}") @@ -176,7 +176,7 @@ def asplain( return d -def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T: +def fromplain[T](cls: Type[T], d: Mapping, *, serialized: bool = False) -> T: """Return an instance of the given model using the given data. If `serialized` is `True`, collection types (lists, dicts, etc.) will be @@ -379,7 +379,7 @@ dataclass containing the ID of the linked data. The contents of the Relation are ignored or discarded when using `asplain`, `fromplain`, and `validate`. """ -Relation: TypeAlias = Annotated[T | None, _RelationSentinel] +type Relation[T] = Annotated[T | None, _RelationSentinel] type Access = Literal[ diff --git a/unwind/request.py b/unwind/request.py index f12936b..46d1e9b 100644 --- a/unwind/request.py +++ b/unwind/request.py @@ -11,7 +11,7 @@ from hashlib import md5 from pathlib import Path from random import random from time import sleep, time -from typing import Any, Callable, ParamSpec, TypeVar, cast, overload +from typing import Any, Callable, cast, overload import bs4 import httpx @@ -24,13 +24,10 @@ if config.debug and config.cachedir: config.cachedir.mkdir(exist_ok=True) -_shared_asession = None - _ASession_T = httpx.AsyncClient -_Response_T = httpx.Response +type _Response_T = httpx.Response -_T = TypeVar("_T") -_P = ParamSpec("_P") +_shared_asession: _ASession_T | None = None @asynccontextmanager @@ -59,17 +56,17 @@ async def asession(): _shared_asession = None -def _throttle( +def _throttle[T, **P]( times: int, per_seconds: float, jitter: Callable[[], float] | None = None -) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: +) -> Callable[[Callable[P, T]], Callable[P, T]]: calls: deque[float] = deque(maxlen=times) if jitter is None: jitter = lambda: 0.0 # noqa: E731 - def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]: + def decorator(func: Callable[P, T]) -> Callable[P, T]: @wraps(func) - def inner(*args: _P.args, **kwds: _P.kwargs): + def inner(*args: P.args, **kwds: P.kwargs): # clean up while calls: if calls[0] + per_seconds > time():