chore: replace TypeVar with native syntax
This commit is contained in:
parent
1ea09c1a45
commit
76a69b6340
4 changed files with 30 additions and 39 deletions
28
unwind/db.py
28
unwind/db.py
|
|
@ -1,7 +1,7 @@
|
|||
import contextlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type, TypeVar
|
||||
from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
|
||||
|
|
@ -29,7 +29,6 @@ from .models import (
|
|||
from .types import ULID
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
|
||||
_engine: AsyncEngine | None = None
|
||||
|
||||
|
|
@ -240,7 +239,7 @@ async def add(conn: Connection, /, item: Model) -> None:
|
|||
# Support late initializing - used for optimization.
|
||||
if getattr(item, "_is_lazy", False):
|
||||
assert hasattr(item, "_lazy_init")
|
||||
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
|
||||
item._lazy_init() # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
table: sa.Table = item.__table__
|
||||
values = asplain(item, serialize=True)
|
||||
|
|
@ -262,17 +261,14 @@ async def fetch_one(
|
|||
return result.first()
|
||||
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Model)
|
||||
|
||||
|
||||
async def get(
|
||||
async def get[T: Model](
|
||||
conn: Connection,
|
||||
/,
|
||||
model: Type[ModelType],
|
||||
model: Type[T],
|
||||
*,
|
||||
order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None,
|
||||
**field_values,
|
||||
) -> ModelType | None:
|
||||
) -> T | None:
|
||||
"""Load a model instance from the database.
|
||||
|
||||
Passing `field_values` allows to filter the item to load. You have to encode the
|
||||
|
|
@ -295,9 +291,9 @@ async def get(
|
|||
return fromplain(model, row._mapping, serialized=True) if row else None
|
||||
|
||||
|
||||
async def get_many(
|
||||
conn: Connection, /, model: Type[ModelType], **field_sets: set | list
|
||||
) -> Iterable[ModelType]:
|
||||
async def get_many[T: Model](
|
||||
conn: Connection, /, model: Type[T], **field_sets: set | list
|
||||
) -> Iterable[T]:
|
||||
"""Return the items with any values matching all given field sets.
|
||||
|
||||
This is similar to `get_all`, but instead of a scalar value a list of values
|
||||
|
|
@ -314,9 +310,9 @@ async def get_many(
|
|||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||
|
||||
|
||||
async def get_all(
|
||||
conn: Connection, /, model: Type[ModelType], **field_values
|
||||
) -> Iterable[ModelType]:
|
||||
async def get_all[T: Model](
|
||||
conn: Connection, /, model: Type[T], **field_values
|
||||
) -> Iterable[T]:
|
||||
"""Filter all items by comparing all given field values.
|
||||
|
||||
If no filters are given, all items will be returned.
|
||||
|
|
@ -333,7 +329,7 @@ async def update(conn: Connection, /, item: Model) -> None:
|
|||
# Support late initializing - used for optimization.
|
||||
if getattr(item, "_is_lazy", False):
|
||||
assert hasattr(item, "_lazy_init")
|
||||
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
|
||||
item._lazy_init() # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
table: sa.Table = item.__table__
|
||||
values = asplain(item, serialize=True)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import logging
|
|||
from dataclasses import dataclass, fields
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Generator, Literal, Type, TypeVar, overload
|
||||
from typing import Generator, Literal, Type, overload
|
||||
|
||||
from . import config, db, request
|
||||
from .db import add_or_update_many_movies
|
||||
|
|
@ -14,8 +14,6 @@ from .models import Movie
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# See
|
||||
# - https://developer.imdb.com/non-commercial-datasets/
|
||||
# - https://datasets.imdbws.com/
|
||||
|
|
@ -127,7 +125,7 @@ def read_imdb_tsv(
|
|||
|
||||
|
||||
@overload
|
||||
def read_imdb_tsv(
|
||||
def read_imdb_tsv[T](
|
||||
path: Path, row_type: Type[T], *, unpack: Literal[True] = True
|
||||
) -> Generator[T, None, None]: ...
|
||||
|
||||
|
|
|
|||
|
|
@ -12,10 +12,8 @@ from typing import (
|
|||
Mapping,
|
||||
Protocol,
|
||||
Type,
|
||||
TypeAlias,
|
||||
TypeAliasType,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
|
|
@ -27,12 +25,10 @@ from sqlalchemy.orm import registry
|
|||
from .types import ULID
|
||||
from .utils import json_dump
|
||||
|
||||
JSONScalar: TypeAlias = int | float | str | None
|
||||
type JSONScalar = int | float | str | None
|
||||
type JSON = JSONScalar | list["JSON"] | dict[str, "JSON"]
|
||||
type JSONObject = dict[str, JSON]
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Model(Protocol):
|
||||
__table__: ClassVar[Table]
|
||||
|
|
@ -53,6 +49,10 @@ metadata.naming_convention = {
|
|||
|
||||
|
||||
def annotations(tp: Type) -> tuple | None:
|
||||
# Support type aliases and generic aliases.
|
||||
if isinstance(tp, TypeAliasType) or hasattr(tp, "__value__"):
|
||||
tp = tp.__value__
|
||||
|
||||
return tp.__metadata__ if hasattr(tp, "__metadata__") else None # type: ignore
|
||||
|
||||
|
||||
|
|
@ -110,7 +110,7 @@ def optional_fields(o):
|
|||
yield f
|
||||
|
||||
|
||||
def _id(x: T) -> T:
|
||||
def _id[T](x: T) -> T:
|
||||
"""Return the given argument, aka. the identity function."""
|
||||
return x
|
||||
|
||||
|
|
@ -168,7 +168,7 @@ def asplain(
|
|||
), f"Type mismatch: {f.name!a} ({target!a} != {type(v)!a})"
|
||||
d[f.name] = v
|
||||
elif target in {Literal}:
|
||||
assert isinstance(v, JSONScalar)
|
||||
assert isinstance(v, JSONScalar.__value__)
|
||||
d[f.name] = v
|
||||
else:
|
||||
raise ValueError(f"Unsupported value type: {f.name!a}: {type(v)!a}")
|
||||
|
|
@ -176,7 +176,7 @@ def asplain(
|
|||
return d
|
||||
|
||||
|
||||
def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
|
||||
def fromplain[T](cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
|
||||
"""Return an instance of the given model using the given data.
|
||||
|
||||
If `serialized` is `True`, collection types (lists, dicts, etc.) will be
|
||||
|
|
@ -379,7 +379,7 @@ dataclass containing the ID of the linked data.
|
|||
The contents of the Relation are ignored or discarded when using
|
||||
`asplain`, `fromplain`, and `validate`.
|
||||
"""
|
||||
Relation: TypeAlias = Annotated[T | None, _RelationSentinel]
|
||||
type Relation[T] = Annotated[T | None, _RelationSentinel]
|
||||
|
||||
|
||||
type Access = Literal[
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from hashlib import md5
|
|||
from pathlib import Path
|
||||
from random import random
|
||||
from time import sleep, time
|
||||
from typing import Any, Callable, ParamSpec, TypeVar, cast, overload
|
||||
from typing import Any, Callable, cast, overload
|
||||
|
||||
import bs4
|
||||
import httpx
|
||||
|
|
@ -24,13 +24,10 @@ if config.debug and config.cachedir:
|
|||
config.cachedir.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
_shared_asession = None
|
||||
|
||||
_ASession_T = httpx.AsyncClient
|
||||
_Response_T = httpx.Response
|
||||
type _Response_T = httpx.Response
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
_shared_asession: _ASession_T | None = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
@ -59,17 +56,17 @@ async def asession():
|
|||
_shared_asession = None
|
||||
|
||||
|
||||
def _throttle(
|
||||
def _throttle[T, **P](
|
||||
times: int, per_seconds: float, jitter: Callable[[], float] | None = None
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||
calls: deque[float] = deque(maxlen=times)
|
||||
|
||||
if jitter is None:
|
||||
jitter = lambda: 0.0 # noqa: E731
|
||||
|
||||
def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
||||
@wraps(func)
|
||||
def inner(*args: _P.args, **kwds: _P.kwargs):
|
||||
def inner(*args: P.args, **kwds: P.kwargs):
|
||||
# clean up
|
||||
while calls:
|
||||
if calls[0] + per_seconds > time():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue