improve typing correctness

This commit is contained in:
ducklet 2023-02-04 01:12:09 +01:00
parent 3320d53eda
commit 418116afac
7 changed files with 45 additions and 32 deletions

View file

@ -136,7 +136,7 @@ async def get_import_progress() -> Progress | None:
return await get(Progress, type="import-imdb-movies", order_by="started DESC")
async def stop_import_progress(*, error: BaseException = None):
async def stop_import_progress(*, error: BaseException | None = None):
"""Stop the current import.
If an error is given, it will be logged to the progress state.
@ -176,6 +176,8 @@ async def set_import_progress(progress: float) -> Progress:
else:
await add(current)
return current
_lock = threading.Lock()
_prelock = threading.Lock()
@ -243,7 +245,7 @@ ModelType = TypeVar("ModelType")
async def get(
model: Type[ModelType], *, order_by: str = None, **kwds
model: Type[ModelType], *, order_by: str | None = None, **kwds
) -> ModelType | None:
"""Load a model instance from the database.
@ -406,12 +408,12 @@ def sql_escape(s: str, char="#"):
async def find_ratings(
*,
title: str = None,
media_type: str = None,
title: str | None = None,
media_type: str | None = None,
exact: bool = False,
ignore_tv_episodes: bool = False,
include_unrated: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] = None,
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10,
user_ids: Iterable[str] = [],
):
@ -588,11 +590,11 @@ async def ratings_for_movies(
async def find_movies(
*,
title: str = None,
media_type: str = None,
title: str | None = None,
media_type: str | None = None,
exact: bool = False,
ignore_tv_episodes: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] = None,
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10,
skip_rows: int = 0,
include_unrated: bool = False,

View file

@ -100,7 +100,7 @@ title_types = {
}
def gz_mtime(path) -> datetime:
def gz_mtime(path: Path) -> datetime:
"""Return the timestamp of the compressed file."""
g = gzip.GzipFile(path, "rb")
g.peek(1) # start reading the file to fill the timestamp field
@ -108,7 +108,7 @@ def gz_mtime(path) -> datetime:
return datetime.fromtimestamp(g.mtime).replace(tzinfo=timezone.utc)
def count_lines(path) -> int:
def count_lines(path: Path) -> int:
i = 0
one_mb = 2 ** 20
@ -124,20 +124,21 @@ def count_lines(path) -> int:
@overload
def read_imdb_tsv(
path, row_type, *, unpack: Literal[False]
path: Path, row_type, *, unpack: Literal[False]
) -> Generator[list[str], None, None]:
...
@overload
def read_imdb_tsv(
path, row_type: Type[T], *, unpack: Literal[True] = True
path: Path, row_type: Type[T], *, unpack: Literal[True] = True
) -> Generator[T, None, None]:
...
def read_imdb_tsv(path, row_type, *, unpack=True):
def read_imdb_tsv(path: Path, row_type, *, unpack=True):
with gzip.open(path, "rt", newline="") as f:
rows = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
# skip header line
@ -161,7 +162,7 @@ def read_imdb_tsv(path, row_type, *, unpack=True):
raise
def read_ratings(path):
def read_ratings(path: Path):
mtime = gz_mtime(path)
rows = read_imdb_tsv(path, RatingRow)
@ -171,13 +172,13 @@ def read_ratings(path):
yield m
def read_ratings_as_mapping(path):
def read_ratings_as_mapping(path: Path):
"""Optimized function to quickly load all ratings."""
rows = read_imdb_tsv(path, RatingRow, unpack=False)
return {r[0]: (round(100 * (float(r[1]) - 1) / 9), int(r[2])) for r in rows}
def read_basics(path):
def read_basics(path: Path):
mtime = gz_mtime(path)
rows = read_imdb_tsv(path, BasicRow)

View file

@ -91,7 +91,7 @@ def _id(x: T) -> T:
def asplain(
o: object, *, filter_fields: Container[str] = None, serialize: bool = False
o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False
) -> dict[str, Any]:
"""Return the given model instance as `dict` with JSON compatible plain datatypes.

View file

@ -5,7 +5,7 @@ import os
import tempfile
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import wraps
from hashlib import md5
from pathlib import Path
@ -75,10 +75,10 @@ def Session() -> requests.Session:
def throttle(
times: int, per_seconds: float, jitter: Callable[[], float] = None
times: int, per_seconds: float, jitter: Callable[[], float] | None = None
) -> Callable[[Callable], Callable]:
calls: Deque[float] = deque(maxlen=times)
calls: deque[float] = deque(maxlen=times)
if jitter is None:
jitter = lambda: 0.0
@ -128,7 +128,7 @@ class CachedResponse:
status_code: int
text: str
url: str
headers: dict[str, str] = None
headers: dict[str, str] = field(default_factory=dict)
def json(self):
return json.loads(self.text)
@ -215,17 +215,19 @@ def last_modified_from_file(path: Path):
def download(
url: str,
file_path: Path | str = None,
file_path: Path | str | None = None,
*,
replace_existing: bool = None,
replace_existing: bool | None = None,
only_if_newer: bool = False,
timeout: float = None,
timeout: float | None = None,
verify_ssl: bool = True,
chunk_callback=None,
response_callback=None,
):
) -> bytes | None:
"""Download a file.
If `file_path` is `None` return the remote content, otherwise write the
content to the given file path.
Existing files will not be overwritten unless `replace_existing` is set.
Setting `only_if_newer` will check if the remote file is newer than the
local file, otherwise the download will be aborted.

View file

@ -17,7 +17,10 @@ def b64padded(s: str) -> str:
def phc_scrypt(
secret: bytes, *, salt: bytes = None, params: dict[Literal["n", "r", "p"], int] = {}
secret: bytes,
*,
salt: bytes | None = None,
params: dict[Literal["n", "r", "p"], int] = {},
) -> str:
"""Return the scrypt expanded secret in PHC string format.

View file

@ -85,11 +85,14 @@ def truthy(s: str):
return bool(s) and s.lower() in {"1", "yes", "true"}
def yearcomp(s: str):
_Yearcomp = Literal["<", "=", ">"]
def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
if not s:
return
comp: Literal["<", "=", ">"] = "="
comp: _Yearcomp = "="
if (prefix := s[0]) in "<=>":
comp = prefix # type: ignore
s = s[len(prefix) :]
@ -97,7 +100,9 @@ def yearcomp(s: str):
return comp, int(s)
def as_int(x, *, max: int = None, min: int | None = 1, default: int = None):
def as_int(
x, *, max: int | None = None, min: int | None = 1, default: int | None = None
) -> int:
try:
if not isinstance(x, int):
x = int(x)
@ -135,7 +140,7 @@ async def json_from_body(request, keys: list[str]) -> list:
...
async def json_from_body(request, keys: list[str] = None):
async def json_from_body(request, keys: list[str] | None = None):
if not await request.body():
data = {}
@ -176,7 +181,7 @@ async def auth_user(request) -> User | None:
_routes = []
def route(path: str, *, methods: list[str] = None, **kwds):
def route(path: str, *, methods: list[str] | None = None, **kwds):
def decorator(func):
r = Route(path, func, methods=methods, **kwds)
_routes.append(r)

View file

@ -20,7 +20,7 @@ class Rating:
user_score: Score100 | None
@classmethod
def from_movie(cls, movie: models.Movie, *, rating: models.Rating = None):
def from_movie(cls, movie: models.Movie, *, rating: models.Rating | None = None):
return cls(
canonical_title=movie.title,
imdb_score=movie.imdb_score,