chore: replace TypeVar with native syntax

This commit is contained in:
ducklet 2024-05-19 02:57:13 +02:00
parent 1ea09c1a45
commit 76a69b6340
4 changed files with 30 additions and 39 deletions

View file

@ -1,7 +1,7 @@
import contextlib import contextlib
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type, TypeVar from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
@ -29,7 +29,6 @@ from .models import (
from .types import ULID from .types import ULID
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
T = TypeVar("T")
_engine: AsyncEngine | None = None _engine: AsyncEngine | None = None
@ -240,7 +239,7 @@ async def add(conn: Connection, /, item: Model) -> None:
# Support late initializing - used for optimization. # Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False): if getattr(item, "_is_lazy", False):
assert hasattr(item, "_lazy_init") assert hasattr(item, "_lazy_init")
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues] item._lazy_init() # pyright: ignore[reportAttributeAccessIssue]
table: sa.Table = item.__table__ table: sa.Table = item.__table__
values = asplain(item, serialize=True) values = asplain(item, serialize=True)
@ -262,17 +261,14 @@ async def fetch_one(
return result.first() return result.first()
ModelType = TypeVar("ModelType", bound=Model) async def get[T: Model](
async def get(
conn: Connection, conn: Connection,
/, /,
model: Type[ModelType], model: Type[T],
*, *,
order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None, order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None,
**field_values, **field_values,
) -> ModelType | None: ) -> T | None:
"""Load a model instance from the database. """Load a model instance from the database.
Passing `field_values` allows to filter the item to load. You have to encode the Passing `field_values` allows to filter the item to load. You have to encode the
@ -295,9 +291,9 @@ async def get(
return fromplain(model, row._mapping, serialized=True) if row else None return fromplain(model, row._mapping, serialized=True) if row else None
async def get_many( async def get_many[T: Model](
conn: Connection, /, model: Type[ModelType], **field_sets: set | list conn: Connection, /, model: Type[T], **field_sets: set | list
) -> Iterable[ModelType]: ) -> Iterable[T]:
"""Return the items with any values matching all given field sets. """Return the items with any values matching all given field sets.
This is similar to `get_all`, but instead of a scalar value a list of values This is similar to `get_all`, but instead of a scalar value a list of values
@ -314,9 +310,9 @@ async def get_many(
return (fromplain(model, row._mapping, serialized=True) for row in rows) return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def get_all( async def get_all[T: Model](
conn: Connection, /, model: Type[ModelType], **field_values conn: Connection, /, model: Type[T], **field_values
) -> Iterable[ModelType]: ) -> Iterable[T]:
"""Filter all items by comparing all given field values. """Filter all items by comparing all given field values.
If no filters are given, all items will be returned. If no filters are given, all items will be returned.
@ -333,7 +329,7 @@ async def update(conn: Connection, /, item: Model) -> None:
# Support late initializing - used for optimization. # Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False): if getattr(item, "_is_lazy", False):
assert hasattr(item, "_lazy_init") assert hasattr(item, "_lazy_init")
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues] item._lazy_init() # pyright: ignore[reportAttributeAccessIssue]
table: sa.Table = item.__table__ table: sa.Table = item.__table__
values = asplain(item, serialize=True) values = asplain(item, serialize=True)

View file

@ -5,7 +5,7 @@ import logging
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Generator, Literal, Type, TypeVar, overload from typing import Generator, Literal, Type, overload
from . import config, db, request from . import config, db, request
from .db import add_or_update_many_movies from .db import add_or_update_many_movies
@ -14,8 +14,6 @@ from .models import Movie
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
T = TypeVar("T")
# See # See
# - https://developer.imdb.com/non-commercial-datasets/ # - https://developer.imdb.com/non-commercial-datasets/
# - https://datasets.imdbws.com/ # - https://datasets.imdbws.com/
@ -127,7 +125,7 @@ def read_imdb_tsv(
@overload @overload
def read_imdb_tsv( def read_imdb_tsv[T](
path: Path, row_type: Type[T], *, unpack: Literal[True] = True path: Path, row_type: Type[T], *, unpack: Literal[True] = True
) -> Generator[T, None, None]: ... ) -> Generator[T, None, None]: ...

View file

@ -12,10 +12,8 @@ from typing import (
Mapping, Mapping,
Protocol, Protocol,
Type, Type,
TypeAlias,
TypeAliasType, TypeAliasType,
TypedDict, TypedDict,
TypeVar,
Union, Union,
get_args, get_args,
get_origin, get_origin,
@ -27,12 +25,10 @@ from sqlalchemy.orm import registry
from .types import ULID from .types import ULID
from .utils import json_dump from .utils import json_dump
JSONScalar: TypeAlias = int | float | str | None type JSONScalar = int | float | str | None
type JSON = JSONScalar | list["JSON"] | dict[str, "JSON"] type JSON = JSONScalar | list["JSON"] | dict[str, "JSON"]
type JSONObject = dict[str, JSON] type JSONObject = dict[str, JSON]
T = TypeVar("T")
class Model(Protocol): class Model(Protocol):
__table__: ClassVar[Table] __table__: ClassVar[Table]
@ -53,6 +49,10 @@ metadata.naming_convention = {
def annotations(tp: Type) -> tuple | None: def annotations(tp: Type) -> tuple | None:
# Support type aliases and generic aliases.
if isinstance(tp, TypeAliasType) or hasattr(tp, "__value__"):
tp = tp.__value__
return tp.__metadata__ if hasattr(tp, "__metadata__") else None # type: ignore return tp.__metadata__ if hasattr(tp, "__metadata__") else None # type: ignore
@ -110,7 +110,7 @@ def optional_fields(o):
yield f yield f
def _id(x: T) -> T: def _id[T](x: T) -> T:
"""Return the given argument, aka. the identity function.""" """Return the given argument, aka. the identity function."""
return x return x
@ -168,7 +168,7 @@ def asplain(
), f"Type mismatch: {f.name!a} ({target!a} != {type(v)!a})" ), f"Type mismatch: {f.name!a} ({target!a} != {type(v)!a})"
d[f.name] = v d[f.name] = v
elif target in {Literal}: elif target in {Literal}:
assert isinstance(v, JSONScalar) assert isinstance(v, JSONScalar.__value__)
d[f.name] = v d[f.name] = v
else: else:
raise ValueError(f"Unsupported value type: {f.name!a}: {type(v)!a}") raise ValueError(f"Unsupported value type: {f.name!a}: {type(v)!a}")
@ -176,7 +176,7 @@ def asplain(
return d return d
def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T: def fromplain[T](cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
"""Return an instance of the given model using the given data. """Return an instance of the given model using the given data.
If `serialized` is `True`, collection types (lists, dicts, etc.) will be If `serialized` is `True`, collection types (lists, dicts, etc.) will be
@ -379,7 +379,7 @@ dataclass containing the ID of the linked data.
The contents of the Relation are ignored or discarded when using The contents of the Relation are ignored or discarded when using
`asplain`, `fromplain`, and `validate`. `asplain`, `fromplain`, and `validate`.
""" """
Relation: TypeAlias = Annotated[T | None, _RelationSentinel] type Relation[T] = Annotated[T | None, _RelationSentinel]
type Access = Literal[ type Access = Literal[

View file

@ -11,7 +11,7 @@ from hashlib import md5
from pathlib import Path from pathlib import Path
from random import random from random import random
from time import sleep, time from time import sleep, time
from typing import Any, Callable, ParamSpec, TypeVar, cast, overload from typing import Any, Callable, cast, overload
import bs4 import bs4
import httpx import httpx
@ -24,13 +24,10 @@ if config.debug and config.cachedir:
config.cachedir.mkdir(exist_ok=True) config.cachedir.mkdir(exist_ok=True)
_shared_asession = None
_ASession_T = httpx.AsyncClient _ASession_T = httpx.AsyncClient
_Response_T = httpx.Response type _Response_T = httpx.Response
_T = TypeVar("_T") _shared_asession: _ASession_T | None = None
_P = ParamSpec("_P")
@asynccontextmanager @asynccontextmanager
@ -59,17 +56,17 @@ async def asession():
_shared_asession = None _shared_asession = None
def _throttle( def _throttle[T, **P](
times: int, per_seconds: float, jitter: Callable[[], float] | None = None times: int, per_seconds: float, jitter: Callable[[], float] | None = None
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ) -> Callable[[Callable[P, T]], Callable[P, T]]:
calls: deque[float] = deque(maxlen=times) calls: deque[float] = deque(maxlen=times)
if jitter is None: if jitter is None:
jitter = lambda: 0.0 # noqa: E731 jitter = lambda: 0.0 # noqa: E731
def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]: def decorator(func: Callable[P, T]) -> Callable[P, T]:
@wraps(func) @wraps(func)
def inner(*args: _P.args, **kwds: _P.kwargs): def inner(*args: P.args, **kwds: P.kwargs):
# clean up # clean up
while calls: while calls:
if calls[0] + per_seconds > time(): if calls[0] + per_seconds > time():