improve typing correctness

This commit is contained in:
ducklet 2023-02-04 01:12:09 +01:00
parent 3320d53eda
commit 418116afac
7 changed files with 45 additions and 32 deletions

View file

@ -136,7 +136,7 @@ async def get_import_progress() -> Progress | None:
return await get(Progress, type="import-imdb-movies", order_by="started DESC") return await get(Progress, type="import-imdb-movies", order_by="started DESC")
async def stop_import_progress(*, error: BaseException = None): async def stop_import_progress(*, error: BaseException | None = None):
"""Stop the current import. """Stop the current import.
If an error is given, it will be logged to the progress state. If an error is given, it will be logged to the progress state.
@ -176,6 +176,8 @@ async def set_import_progress(progress: float) -> Progress:
else: else:
await add(current) await add(current)
return current
_lock = threading.Lock() _lock = threading.Lock()
_prelock = threading.Lock() _prelock = threading.Lock()
@ -243,7 +245,7 @@ ModelType = TypeVar("ModelType")
async def get( async def get(
model: Type[ModelType], *, order_by: str = None, **kwds model: Type[ModelType], *, order_by: str | None = None, **kwds
) -> ModelType | None: ) -> ModelType | None:
"""Load a model instance from the database. """Load a model instance from the database.
@ -406,12 +408,12 @@ def sql_escape(s: str, char="#"):
async def find_ratings( async def find_ratings(
*, *,
title: str = None, title: str | None = None,
media_type: str = None, media_type: str | None = None,
exact: bool = False, exact: bool = False,
ignore_tv_episodes: bool = False, ignore_tv_episodes: bool = False,
include_unrated: bool = False, include_unrated: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] = None, yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10, limit_rows: int = 10,
user_ids: Iterable[str] = [], user_ids: Iterable[str] = [],
): ):
@ -588,11 +590,11 @@ async def ratings_for_movies(
async def find_movies( async def find_movies(
*, *,
title: str = None, title: str | None = None,
media_type: str = None, media_type: str | None = None,
exact: bool = False, exact: bool = False,
ignore_tv_episodes: bool = False, ignore_tv_episodes: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] = None, yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10, limit_rows: int = 10,
skip_rows: int = 0, skip_rows: int = 0,
include_unrated: bool = False, include_unrated: bool = False,

View file

@ -100,7 +100,7 @@ title_types = {
} }
def gz_mtime(path) -> datetime: def gz_mtime(path: Path) -> datetime:
"""Return the timestamp of the compressed file.""" """Return the timestamp of the compressed file."""
g = gzip.GzipFile(path, "rb") g = gzip.GzipFile(path, "rb")
g.peek(1) # start reading the file to fill the timestamp field g.peek(1) # start reading the file to fill the timestamp field
@ -108,7 +108,7 @@ def gz_mtime(path) -> datetime:
return datetime.fromtimestamp(g.mtime).replace(tzinfo=timezone.utc) return datetime.fromtimestamp(g.mtime).replace(tzinfo=timezone.utc)
def count_lines(path) -> int: def count_lines(path: Path) -> int:
i = 0 i = 0
one_mb = 2 ** 20 one_mb = 2 ** 20
@ -124,20 +124,21 @@ def count_lines(path) -> int:
@overload @overload
def read_imdb_tsv( def read_imdb_tsv(
path, row_type, *, unpack: Literal[False] path: Path, row_type, *, unpack: Literal[False]
) -> Generator[list[str], None, None]: ) -> Generator[list[str], None, None]:
... ...
@overload @overload
def read_imdb_tsv( def read_imdb_tsv(
path, row_type: Type[T], *, unpack: Literal[True] = True path: Path, row_type: Type[T], *, unpack: Literal[True] = True
) -> Generator[T, None, None]: ) -> Generator[T, None, None]:
... ...
def read_imdb_tsv(path, row_type, *, unpack=True): def read_imdb_tsv(path: Path, row_type, *, unpack=True):
with gzip.open(path, "rt", newline="") as f: with gzip.open(path, "rt", newline="") as f:
rows = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_NONE) rows = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
# skip header line # skip header line
@ -161,7 +162,7 @@ def read_imdb_tsv(path, row_type, *, unpack=True):
raise raise
def read_ratings(path): def read_ratings(path: Path):
mtime = gz_mtime(path) mtime = gz_mtime(path)
rows = read_imdb_tsv(path, RatingRow) rows = read_imdb_tsv(path, RatingRow)
@ -171,13 +172,13 @@ def read_ratings(path):
yield m yield m
def read_ratings_as_mapping(path): def read_ratings_as_mapping(path: Path):
"""Optimized function to quickly load all ratings.""" """Optimized function to quickly load all ratings."""
rows = read_imdb_tsv(path, RatingRow, unpack=False) rows = read_imdb_tsv(path, RatingRow, unpack=False)
return {r[0]: (round(100 * (float(r[1]) - 1) / 9), int(r[2])) for r in rows} return {r[0]: (round(100 * (float(r[1]) - 1) / 9), int(r[2])) for r in rows}
def read_basics(path): def read_basics(path: Path):
mtime = gz_mtime(path) mtime = gz_mtime(path)
rows = read_imdb_tsv(path, BasicRow) rows = read_imdb_tsv(path, BasicRow)

View file

@ -91,7 +91,7 @@ def _id(x: T) -> T:
def asplain( def asplain(
o: object, *, filter_fields: Container[str] = None, serialize: bool = False o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return the given model instance as `dict` with JSON compatible plain datatypes. """Return the given model instance as `dict` with JSON compatible plain datatypes.

View file

@ -5,7 +5,7 @@ import os
import tempfile import tempfile
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass, field
from functools import wraps from functools import wraps
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
@ -75,10 +75,10 @@ def Session() -> requests.Session:
def throttle( def throttle(
times: int, per_seconds: float, jitter: Callable[[], float] = None times: int, per_seconds: float, jitter: Callable[[], float] | None = None
) -> Callable[[Callable], Callable]: ) -> Callable[[Callable], Callable]:
calls: Deque[float] = deque(maxlen=times) calls: deque[float] = deque(maxlen=times)
if jitter is None: if jitter is None:
jitter = lambda: 0.0 jitter = lambda: 0.0
@ -128,7 +128,7 @@ class CachedResponse:
status_code: int status_code: int
text: str text: str
url: str url: str
headers: dict[str, str] = None headers: dict[str, str] = field(default_factory=dict)
def json(self): def json(self):
return json.loads(self.text) return json.loads(self.text)
@ -215,17 +215,19 @@ def last_modified_from_file(path: Path):
def download( def download(
url: str, url: str,
file_path: Path | str = None, file_path: Path | str | None = None,
*, *,
replace_existing: bool = None, replace_existing: bool | None = None,
only_if_newer: bool = False, only_if_newer: bool = False,
timeout: float = None, timeout: float | None = None,
verify_ssl: bool = True, verify_ssl: bool = True,
chunk_callback=None, chunk_callback=None,
response_callback=None, response_callback=None,
): ) -> bytes | None:
"""Download a file. """Download a file.
If `file_path` is `None` return the remote content, otherwise write the
content to the given file path.
Existing files will not be overwritten unless `replace_existing` is set. Existing files will not be overwritten unless `replace_existing` is set.
Setting `only_if_newer` will check if the remote file is newer than the Setting `only_if_newer` will check if the remote file is newer than the
local file, otherwise the download will be aborted. local file, otherwise the download will be aborted.

View file

@ -17,7 +17,10 @@ def b64padded(s: str) -> str:
def phc_scrypt( def phc_scrypt(
secret: bytes, *, salt: bytes = None, params: dict[Literal["n", "r", "p"], int] = {} secret: bytes,
*,
salt: bytes | None = None,
params: dict[Literal["n", "r", "p"], int] = {},
) -> str: ) -> str:
"""Return the scrypt expanded secret in PHC string format. """Return the scrypt expanded secret in PHC string format.

View file

@ -85,11 +85,14 @@ def truthy(s: str):
return bool(s) and s.lower() in {"1", "yes", "true"} return bool(s) and s.lower() in {"1", "yes", "true"}
def yearcomp(s: str): _Yearcomp = Literal["<", "=", ">"]
def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
if not s: if not s:
return return
comp: Literal["<", "=", ">"] = "=" comp: _Yearcomp = "="
if (prefix := s[0]) in "<=>": if (prefix := s[0]) in "<=>":
comp = prefix # type: ignore comp = prefix # type: ignore
s = s[len(prefix) :] s = s[len(prefix) :]
@ -97,7 +100,9 @@ def yearcomp(s: str):
return comp, int(s) return comp, int(s)
def as_int(x, *, max: int = None, min: int | None = 1, default: int = None): def as_int(
x, *, max: int | None = None, min: int | None = 1, default: int | None = None
) -> int:
try: try:
if not isinstance(x, int): if not isinstance(x, int):
x = int(x) x = int(x)
@ -135,7 +140,7 @@ async def json_from_body(request, keys: list[str]) -> list:
... ...
async def json_from_body(request, keys: list[str] = None): async def json_from_body(request, keys: list[str] | None = None):
if not await request.body(): if not await request.body():
data = {} data = {}
@ -176,7 +181,7 @@ async def auth_user(request) -> User | None:
_routes = [] _routes = []
def route(path: str, *, methods: list[str] = None, **kwds): def route(path: str, *, methods: list[str] | None = None, **kwds):
def decorator(func): def decorator(func):
r = Route(path, func, methods=methods, **kwds) r = Route(path, func, methods=methods, **kwds)
_routes.append(r) _routes.append(r)

View file

@ -20,7 +20,7 @@ class Rating:
user_score: Score100 | None user_score: Score100 | None
@classmethod @classmethod
def from_movie(cls, movie: models.Movie, *, rating: models.Rating = None): def from_movie(cls, movie: models.Movie, *, rating: models.Rating | None = None):
return cls( return cls(
canonical_title=movie.title, canonical_title=movie.title,
imdb_score=movie.imdb_score, imdb_score=movie.imdb_score,