chore: replace TypeVar with native syntax

This commit is contained in:
ducklet 2024-05-19 02:57:13 +02:00
parent 1ea09c1a45
commit 76a69b6340
4 changed files with 30 additions and 39 deletions

View file

@ -1,7 +1,7 @@
import contextlib
import logging
from pathlib import Path
from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type, TypeVar
from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
@ -29,7 +29,6 @@ from .models import (
from .types import ULID
log = logging.getLogger(__name__)
T = TypeVar("T")
_engine: AsyncEngine | None = None
@ -240,7 +239,7 @@ async def add(conn: Connection, /, item: Model) -> None:
# Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False):
assert hasattr(item, "_lazy_init")
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
item._lazy_init() # pyright: ignore[reportAttributeAccessIssue]
table: sa.Table = item.__table__
values = asplain(item, serialize=True)
@ -262,17 +261,14 @@ async def fetch_one(
return result.first()
ModelType = TypeVar("ModelType", bound=Model)
async def get(
async def get[T: Model](
conn: Connection,
/,
model: Type[ModelType],
model: Type[T],
*,
order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None,
**field_values,
) -> ModelType | None:
) -> T | None:
"""Load a model instance from the database.
Passing `field_values` allows to filter the item to load. You have to encode the
@ -295,9 +291,9 @@ async def get(
return fromplain(model, row._mapping, serialized=True) if row else None
async def get_many(
conn: Connection, /, model: Type[ModelType], **field_sets: set | list
) -> Iterable[ModelType]:
async def get_many[T: Model](
conn: Connection, /, model: Type[T], **field_sets: set | list
) -> Iterable[T]:
"""Return the items with any values matching all given field sets.
This is similar to `get_all`, but instead of a scalar value a list of values
@ -314,9 +310,9 @@ async def get_many(
return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def get_all(
conn: Connection, /, model: Type[ModelType], **field_values
) -> Iterable[ModelType]:
async def get_all[T: Model](
conn: Connection, /, model: Type[T], **field_values
) -> Iterable[T]:
"""Filter all items by comparing all given field values.
If no filters are given, all items will be returned.
@ -333,7 +329,7 @@ async def update(conn: Connection, /, item: Model) -> None:
# Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False):
assert hasattr(item, "_lazy_init")
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
item._lazy_init() # pyright: ignore[reportAttributeAccessIssue]
table: sa.Table = item.__table__
values = asplain(item, serialize=True)

View file

@ -5,7 +5,7 @@ import logging
from dataclasses import dataclass, fields
from datetime import datetime, timezone
from pathlib import Path
from typing import Generator, Literal, Type, TypeVar, overload
from typing import Generator, Literal, Type, overload
from . import config, db, request
from .db import add_or_update_many_movies
@ -14,8 +14,6 @@ from .models import Movie
log = logging.getLogger(__name__)
T = TypeVar("T")
# See
# - https://developer.imdb.com/non-commercial-datasets/
# - https://datasets.imdbws.com/
@ -127,7 +125,7 @@ def read_imdb_tsv(
@overload
def read_imdb_tsv(
def read_imdb_tsv[T](
path: Path, row_type: Type[T], *, unpack: Literal[True] = True
) -> Generator[T, None, None]: ...

View file

@ -12,10 +12,8 @@ from typing import (
Mapping,
Protocol,
Type,
TypeAlias,
TypeAliasType,
TypedDict,
TypeVar,
Union,
get_args,
get_origin,
@ -27,12 +25,10 @@ from sqlalchemy.orm import registry
from .types import ULID
from .utils import json_dump
JSONScalar: TypeAlias = int | float | str | None
type JSONScalar = int | float | str | None
type JSON = JSONScalar | list["JSON"] | dict[str, "JSON"]
type JSONObject = dict[str, JSON]
T = TypeVar("T")
class Model(Protocol):
__table__: ClassVar[Table]
@ -53,6 +49,10 @@ metadata.naming_convention = {
def annotations(tp: Type) -> tuple | None:
# Support type aliases and generic aliases.
if isinstance(tp, TypeAliasType) or hasattr(tp, "__value__"):
tp = tp.__value__
return tp.__metadata__ if hasattr(tp, "__metadata__") else None # type: ignore
@ -110,7 +110,7 @@ def optional_fields(o):
yield f
def _id(x: T) -> T:
def _id[T](x: T) -> T:
"""Return the given argument, aka. the identity function."""
return x
@ -168,7 +168,7 @@ def asplain(
), f"Type mismatch: {f.name!a} ({target!a} != {type(v)!a})"
d[f.name] = v
elif target in {Literal}:
assert isinstance(v, JSONScalar)
assert isinstance(v, JSONScalar.__value__)
d[f.name] = v
else:
raise ValueError(f"Unsupported value type: {f.name!a}: {type(v)!a}")
@ -176,7 +176,7 @@ def asplain(
return d
def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
def fromplain[T](cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
"""Return an instance of the given model using the given data.
If `serialized` is `True`, collection types (lists, dicts, etc.) will be
@ -379,7 +379,7 @@ dataclass containing the ID of the linked data.
The contents of the Relation are ignored or discarded when using
`asplain`, `fromplain`, and `validate`.
"""
Relation: TypeAlias = Annotated[T | None, _RelationSentinel]
type Relation[T] = Annotated[T | None, _RelationSentinel]
type Access = Literal[

View file

@ -11,7 +11,7 @@ from hashlib import md5
from pathlib import Path
from random import random
from time import sleep, time
from typing import Any, Callable, ParamSpec, TypeVar, cast, overload
from typing import Any, Callable, cast, overload
import bs4
import httpx
@ -24,13 +24,10 @@ if config.debug and config.cachedir:
config.cachedir.mkdir(exist_ok=True)
_shared_asession = None
_ASession_T = httpx.AsyncClient
_Response_T = httpx.Response
type _Response_T = httpx.Response
_T = TypeVar("_T")
_P = ParamSpec("_P")
_shared_asession: _ASession_T | None = None
@asynccontextmanager
@ -59,17 +56,17 @@ async def asession():
_shared_asession = None
def _throttle(
def _throttle[T, **P](
times: int, per_seconds: float, jitter: Callable[[], float] | None = None
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
calls: deque[float] = deque(maxlen=times)
if jitter is None:
jitter = lambda: 0.0 # noqa: E731
def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
def decorator(func: Callable[P, T]) -> Callable[P, T]:
@wraps(func)
def inner(*args: _P.args, **kwds: _P.kwargs):
def inner(*args: P.args, **kwds: P.kwargs):
# clean up
while calls:
if calls[0] + per_seconds > time():