chore: replace TypeVar with native syntax
This commit is contained in:
parent
1ea09c1a45
commit
76a69b6340
4 changed files with 30 additions and 39 deletions
28
unwind/db.py
28
unwind/db.py
|
|
@ -1,7 +1,7 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type, TypeVar
|
from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
|
||||||
|
|
@ -29,7 +29,6 @@ from .models import (
|
||||||
from .types import ULID
|
from .types import ULID
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
_engine: AsyncEngine | None = None
|
_engine: AsyncEngine | None = None
|
||||||
|
|
||||||
|
|
@ -240,7 +239,7 @@ async def add(conn: Connection, /, item: Model) -> None:
|
||||||
# Support late initializing - used for optimization.
|
# Support late initializing - used for optimization.
|
||||||
if getattr(item, "_is_lazy", False):
|
if getattr(item, "_is_lazy", False):
|
||||||
assert hasattr(item, "_lazy_init")
|
assert hasattr(item, "_lazy_init")
|
||||||
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
|
item._lazy_init() # pyright: ignore[reportAttributeAccessIssue]
|
||||||
|
|
||||||
table: sa.Table = item.__table__
|
table: sa.Table = item.__table__
|
||||||
values = asplain(item, serialize=True)
|
values = asplain(item, serialize=True)
|
||||||
|
|
@ -262,17 +261,14 @@ async def fetch_one(
|
||||||
return result.first()
|
return result.first()
|
||||||
|
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=Model)
|
async def get[T: Model](
|
||||||
|
|
||||||
|
|
||||||
async def get(
|
|
||||||
conn: Connection,
|
conn: Connection,
|
||||||
/,
|
/,
|
||||||
model: Type[ModelType],
|
model: Type[T],
|
||||||
*,
|
*,
|
||||||
order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None,
|
order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None,
|
||||||
**field_values,
|
**field_values,
|
||||||
) -> ModelType | None:
|
) -> T | None:
|
||||||
"""Load a model instance from the database.
|
"""Load a model instance from the database.
|
||||||
|
|
||||||
Passing `field_values` allows to filter the item to load. You have to encode the
|
Passing `field_values` allows to filter the item to load. You have to encode the
|
||||||
|
|
@ -295,9 +291,9 @@ async def get(
|
||||||
return fromplain(model, row._mapping, serialized=True) if row else None
|
return fromplain(model, row._mapping, serialized=True) if row else None
|
||||||
|
|
||||||
|
|
||||||
async def get_many(
|
async def get_many[T: Model](
|
||||||
conn: Connection, /, model: Type[ModelType], **field_sets: set | list
|
conn: Connection, /, model: Type[T], **field_sets: set | list
|
||||||
) -> Iterable[ModelType]:
|
) -> Iterable[T]:
|
||||||
"""Return the items with any values matching all given field sets.
|
"""Return the items with any values matching all given field sets.
|
||||||
|
|
||||||
This is similar to `get_all`, but instead of a scalar value a list of values
|
This is similar to `get_all`, but instead of a scalar value a list of values
|
||||||
|
|
@ -314,9 +310,9 @@ async def get_many(
|
||||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
||||||
async def get_all(
|
async def get_all[T: Model](
|
||||||
conn: Connection, /, model: Type[ModelType], **field_values
|
conn: Connection, /, model: Type[T], **field_values
|
||||||
) -> Iterable[ModelType]:
|
) -> Iterable[T]:
|
||||||
"""Filter all items by comparing all given field values.
|
"""Filter all items by comparing all given field values.
|
||||||
|
|
||||||
If no filters are given, all items will be returned.
|
If no filters are given, all items will be returned.
|
||||||
|
|
@ -333,7 +329,7 @@ async def update(conn: Connection, /, item: Model) -> None:
|
||||||
# Support late initializing - used for optimization.
|
# Support late initializing - used for optimization.
|
||||||
if getattr(item, "_is_lazy", False):
|
if getattr(item, "_is_lazy", False):
|
||||||
assert hasattr(item, "_lazy_init")
|
assert hasattr(item, "_lazy_init")
|
||||||
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
|
item._lazy_init() # pyright: ignore[reportAttributeAccessIssue]
|
||||||
|
|
||||||
table: sa.Table = item.__table__
|
table: sa.Table = item.__table__
|
||||||
values = asplain(item, serialize=True)
|
values = asplain(item, serialize=True)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import logging
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator, Literal, Type, TypeVar, overload
|
from typing import Generator, Literal, Type, overload
|
||||||
|
|
||||||
from . import config, db, request
|
from . import config, db, request
|
||||||
from .db import add_or_update_many_movies
|
from .db import add_or_update_many_movies
|
||||||
|
|
@ -14,8 +14,6 @@ from .models import Movie
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
# See
|
# See
|
||||||
# - https://developer.imdb.com/non-commercial-datasets/
|
# - https://developer.imdb.com/non-commercial-datasets/
|
||||||
# - https://datasets.imdbws.com/
|
# - https://datasets.imdbws.com/
|
||||||
|
|
@ -127,7 +125,7 @@ def read_imdb_tsv(
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def read_imdb_tsv(
|
def read_imdb_tsv[T](
|
||||||
path: Path, row_type: Type[T], *, unpack: Literal[True] = True
|
path: Path, row_type: Type[T], *, unpack: Literal[True] = True
|
||||||
) -> Generator[T, None, None]: ...
|
) -> Generator[T, None, None]: ...
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,10 +12,8 @@ from typing import (
|
||||||
Mapping,
|
Mapping,
|
||||||
Protocol,
|
Protocol,
|
||||||
Type,
|
Type,
|
||||||
TypeAlias,
|
|
||||||
TypeAliasType,
|
TypeAliasType,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
TypeVar,
|
|
||||||
Union,
|
Union,
|
||||||
get_args,
|
get_args,
|
||||||
get_origin,
|
get_origin,
|
||||||
|
|
@ -27,12 +25,10 @@ from sqlalchemy.orm import registry
|
||||||
from .types import ULID
|
from .types import ULID
|
||||||
from .utils import json_dump
|
from .utils import json_dump
|
||||||
|
|
||||||
JSONScalar: TypeAlias = int | float | str | None
|
type JSONScalar = int | float | str | None
|
||||||
type JSON = JSONScalar | list["JSON"] | dict[str, "JSON"]
|
type JSON = JSONScalar | list["JSON"] | dict[str, "JSON"]
|
||||||
type JSONObject = dict[str, JSON]
|
type JSONObject = dict[str, JSON]
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class Model(Protocol):
|
class Model(Protocol):
|
||||||
__table__: ClassVar[Table]
|
__table__: ClassVar[Table]
|
||||||
|
|
@ -53,6 +49,10 @@ metadata.naming_convention = {
|
||||||
|
|
||||||
|
|
||||||
def annotations(tp: Type) -> tuple | None:
|
def annotations(tp: Type) -> tuple | None:
|
||||||
|
# Support type aliases and generic aliases.
|
||||||
|
if isinstance(tp, TypeAliasType) or hasattr(tp, "__value__"):
|
||||||
|
tp = tp.__value__
|
||||||
|
|
||||||
return tp.__metadata__ if hasattr(tp, "__metadata__") else None # type: ignore
|
return tp.__metadata__ if hasattr(tp, "__metadata__") else None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -110,7 +110,7 @@ def optional_fields(o):
|
||||||
yield f
|
yield f
|
||||||
|
|
||||||
|
|
||||||
def _id(x: T) -> T:
|
def _id[T](x: T) -> T:
|
||||||
"""Return the given argument, aka. the identity function."""
|
"""Return the given argument, aka. the identity function."""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
@ -168,7 +168,7 @@ def asplain(
|
||||||
), f"Type mismatch: {f.name!a} ({target!a} != {type(v)!a})"
|
), f"Type mismatch: {f.name!a} ({target!a} != {type(v)!a})"
|
||||||
d[f.name] = v
|
d[f.name] = v
|
||||||
elif target in {Literal}:
|
elif target in {Literal}:
|
||||||
assert isinstance(v, JSONScalar)
|
assert isinstance(v, JSONScalar.__value__)
|
||||||
d[f.name] = v
|
d[f.name] = v
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported value type: {f.name!a}: {type(v)!a}")
|
raise ValueError(f"Unsupported value type: {f.name!a}: {type(v)!a}")
|
||||||
|
|
@ -176,7 +176,7 @@ def asplain(
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
|
def fromplain[T](cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
|
||||||
"""Return an instance of the given model using the given data.
|
"""Return an instance of the given model using the given data.
|
||||||
|
|
||||||
If `serialized` is `True`, collection types (lists, dicts, etc.) will be
|
If `serialized` is `True`, collection types (lists, dicts, etc.) will be
|
||||||
|
|
@ -379,7 +379,7 @@ dataclass containing the ID of the linked data.
|
||||||
The contents of the Relation are ignored or discarded when using
|
The contents of the Relation are ignored or discarded when using
|
||||||
`asplain`, `fromplain`, and `validate`.
|
`asplain`, `fromplain`, and `validate`.
|
||||||
"""
|
"""
|
||||||
Relation: TypeAlias = Annotated[T | None, _RelationSentinel]
|
type Relation[T] = Annotated[T | None, _RelationSentinel]
|
||||||
|
|
||||||
|
|
||||||
type Access = Literal[
|
type Access = Literal[
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from hashlib import md5
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from random import random
|
from random import random
|
||||||
from time import sleep, time
|
from time import sleep, time
|
||||||
from typing import Any, Callable, ParamSpec, TypeVar, cast, overload
|
from typing import Any, Callable, cast, overload
|
||||||
|
|
||||||
import bs4
|
import bs4
|
||||||
import httpx
|
import httpx
|
||||||
|
|
@ -24,13 +24,10 @@ if config.debug and config.cachedir:
|
||||||
config.cachedir.mkdir(exist_ok=True)
|
config.cachedir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
_shared_asession = None
|
|
||||||
|
|
||||||
_ASession_T = httpx.AsyncClient
|
_ASession_T = httpx.AsyncClient
|
||||||
_Response_T = httpx.Response
|
type _Response_T = httpx.Response
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_shared_asession: _ASession_T | None = None
|
||||||
_P = ParamSpec("_P")
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
|
|
@ -59,17 +56,17 @@ async def asession():
|
||||||
_shared_asession = None
|
_shared_asession = None
|
||||||
|
|
||||||
|
|
||||||
def _throttle(
|
def _throttle[T, **P](
|
||||||
times: int, per_seconds: float, jitter: Callable[[], float] | None = None
|
times: int, per_seconds: float, jitter: Callable[[], float] | None = None
|
||||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||||
calls: deque[float] = deque(maxlen=times)
|
calls: deque[float] = deque(maxlen=times)
|
||||||
|
|
||||||
if jitter is None:
|
if jitter is None:
|
||||||
jitter = lambda: 0.0 # noqa: E731
|
jitter = lambda: 0.0 # noqa: E731
|
||||||
|
|
||||||
def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def inner(*args: _P.args, **kwds: _P.kwargs):
|
def inner(*args: P.args, **kwds: P.kwargs):
|
||||||
# clean up
|
# clean up
|
||||||
while calls:
|
while calls:
|
||||||
if calls[0] + per_seconds > time():
|
if calls[0] + per_seconds > time():
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue