From 418116afac22246876cc9fde9f6358fd22140a43 Mon Sep 17 00:00:00 2001 From: ducklet Date: Sat, 4 Feb 2023 01:12:09 +0100 Subject: [PATCH] improve typing correctness --- unwind/db.py | 18 ++++++++++-------- unwind/imdb_import.py | 17 +++++++++-------- unwind/models.py | 2 +- unwind/request.py | 18 ++++++++++-------- unwind/utils.py | 5 ++++- unwind/web.py | 15 ++++++++++----- unwind/web_models.py | 2 +- 7 files changed, 45 insertions(+), 32 deletions(-) diff --git a/unwind/db.py b/unwind/db.py index bff8c20..e32d4c0 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -136,7 +136,7 @@ async def get_import_progress() -> Progress | None: return await get(Progress, type="import-imdb-movies", order_by="started DESC") -async def stop_import_progress(*, error: BaseException = None): +async def stop_import_progress(*, error: BaseException | None = None): """Stop the current import. If an error is given, it will be logged to the progress state. @@ -176,6 +176,8 @@ async def set_import_progress(progress: float) -> Progress: else: await add(current) + return current + _lock = threading.Lock() _prelock = threading.Lock() @@ -243,7 +245,7 @@ ModelType = TypeVar("ModelType") async def get( - model: Type[ModelType], *, order_by: str = None, **kwds + model: Type[ModelType], *, order_by: str | None = None, **kwds ) -> ModelType | None: """Load a model instance from the database. @@ -406,12 +408,12 @@ def sql_escape(s: str, char="#"): async def find_ratings( *, - title: str = None, - media_type: str = None, + title: str | None = None, + media_type: str | None = None, exact: bool = False, ignore_tv_episodes: bool = False, include_unrated: bool = False, - yearcomp: tuple[Literal["<", "=", ">"], int] = None, + yearcomp: tuple[Literal["<", "=", ">"], int] | None = None, limit_rows: int = 10, user_ids: Iterable[str] = [], ): @@ -588,11 +590,11 @@ async def ratings_for_movies( async def find_movies( *, - title: str = None, - media_type: str = None, + title: str | None = None, + media_type: str | None = None, exact: bool = False, ignore_tv_episodes: bool = False, - yearcomp: tuple[Literal["<", "=", ">"], int] = None, + yearcomp: tuple[Literal["<", "=", ">"], int] | None = None, limit_rows: int = 10, skip_rows: int = 0, include_unrated: bool = False, diff --git a/unwind/imdb_import.py b/unwind/imdb_import.py index 61a892c..3557993 100644 --- a/unwind/imdb_import.py +++ b/unwind/imdb_import.py @@ -100,7 +100,7 @@ title_types = { } -def gz_mtime(path) -> datetime: +def gz_mtime(path: Path) -> datetime: """Return the timestamp of the compressed file.""" g = gzip.GzipFile(path, "rb") g.peek(1) # start reading the file to fill the timestamp field @@ -108,7 +108,7 @@ def gz_mtime(path) -> datetime: return datetime.fromtimestamp(g.mtime).replace(tzinfo=timezone.utc) -def count_lines(path) -> int: +def count_lines(path: Path) -> int: i = 0 one_mb = 2 ** 20 @@ -124,20 +124,21 @@ def count_lines(path) -> int: @overload def read_imdb_tsv( - path, row_type, *, unpack: Literal[False] + path: Path, row_type, *, unpack: Literal[False] ) -> Generator[list[str], None, None]: ... @overload def read_imdb_tsv( - path, row_type: Type[T], *, unpack: Literal[True] = True + path: Path, row_type: Type[T], *, unpack: Literal[True] = True ) -> Generator[T, None, None]: ... -def read_imdb_tsv(path, row_type, *, unpack=True): +def read_imdb_tsv(path: Path, row_type, *, unpack=True): with gzip.open(path, "rt", newline="") as f: + rows = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_NONE) # skip header line @@ -161,7 +162,7 @@ def read_imdb_tsv(path, row_type, *, unpack=True): raise -def read_ratings(path): +def read_ratings(path: Path): mtime = gz_mtime(path) rows = read_imdb_tsv(path, RatingRow) @@ -171,13 +172,13 @@ def read_ratings(path): yield m -def read_ratings_as_mapping(path): +def read_ratings_as_mapping(path: Path): """Optimized function to quickly load all ratings.""" rows = read_imdb_tsv(path, RatingRow, unpack=False) return {r[0]: (round(100 * (float(r[1]) - 1) / 9), int(r[2])) for r in rows} -def read_basics(path): +def read_basics(path: Path): mtime = gz_mtime(path) rows = read_imdb_tsv(path, BasicRow) diff --git a/unwind/models.py b/unwind/models.py index 674337d..70ffe26 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -91,7 +91,7 @@ def _id(x: T) -> T: def asplain( - o: object, *, filter_fields: Container[str] = None, serialize: bool = False + o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False ) -> dict[str, Any]: """Return the given model instance as `dict` with JSON compatible plain datatypes. diff --git a/unwind/request.py b/unwind/request.py index 0b6e07c..3b78872 100644 --- a/unwind/request.py +++ b/unwind/request.py @@ -5,7 +5,7 @@ import os import tempfile from collections import deque from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import wraps from hashlib import md5 from pathlib import Path @@ -75,10 +75,10 @@ def Session() -> requests.Session: def throttle( - times: int, per_seconds: float, jitter: Callable[[], float] = None + times: int, per_seconds: float, jitter: Callable[[], float] | None = None ) -> Callable[[Callable], Callable]: - calls: Deque[float] = deque(maxlen=times) + calls: deque[float] = deque(maxlen=times) if jitter is None: jitter = lambda: 0.0 @@ -128,7 +128,7 @@ class CachedResponse: status_code: int text: str url: str - headers: dict[str, str] = None + headers: dict[str, str] = field(default_factory=dict) def json(self): return json.loads(self.text) @@ -215,17 +215,19 @@ def last_modified_from_file(path: Path): def download( url: str, - file_path: Path | str = None, + file_path: Path | str | None = None, *, - replace_existing: bool = None, + replace_existing: bool | None = None, only_if_newer: bool = False, - timeout: float = None, + timeout: float | None = None, verify_ssl: bool = True, chunk_callback=None, response_callback=None, -): +) -> bytes | None: """Download a file. + If `file_path` is `None` return the remote content, otherwise write the + content to the given file path. Existing files will not be overwritten unless `replace_existing` is set. Setting `only_if_newer` will check if the remote file is newer than the local file, otherwise the download will be aborted. diff --git a/unwind/utils.py b/unwind/utils.py index 012d1fb..efe9f17 100644 --- a/unwind/utils.py +++ b/unwind/utils.py @@ -17,7 +17,10 @@ def b64padded(s: str) -> str: def phc_scrypt( - secret: bytes, *, salt: bytes = None, params: dict[Literal["n", "r", "p"], int] = {} + secret: bytes, + *, + salt: bytes | None = None, + params: dict[Literal["n", "r", "p"], int] = {}, ) -> str: """Return the scrypt expanded secret in PHC string format. diff --git a/unwind/web.py b/unwind/web.py index b8705a1..dbe39bc 100644 --- a/unwind/web.py +++ b/unwind/web.py @@ -85,11 +85,14 @@ def truthy(s: str): return bool(s) and s.lower() in {"1", "yes", "true"} -def yearcomp(s: str): +_Yearcomp = Literal["<", "=", ">"] + + +def yearcomp(s: str) -> tuple[_Yearcomp, int] | None: if not s: return - comp: Literal["<", "=", ">"] = "=" + comp: _Yearcomp = "=" if (prefix := s[0]) in "<=>": comp = prefix # type: ignore s = s[len(prefix) :] @@ -97,7 +100,9 @@ def yearcomp(s: str): return comp, int(s) -def as_int(x, *, max: int = None, min: int | None = 1, default: int = None): +def as_int( + x, *, max: int | None = None, min: int | None = 1, default: int | None = None +) -> int: try: if not isinstance(x, int): x = int(x) @@ -135,7 +140,7 @@ async def json_from_body(request, keys: list[str]) -> list: ... -async def json_from_body(request, keys: list[str] = None): +async def json_from_body(request, keys: list[str] | None = None): if not await request.body(): data = {} @@ -176,7 +181,7 @@ async def auth_user(request) -> User | None: _routes = [] -def route(path: str, *, methods: list[str] = None, **kwds): +def route(path: str, *, methods: list[str] | None = None, **kwds): def decorator(func): r = Route(path, func, methods=methods, **kwds) _routes.append(r) diff --git a/unwind/web_models.py b/unwind/web_models.py index e514c5f..6e83e1d 100644 --- a/unwind/web_models.py +++ b/unwind/web_models.py @@ -20,7 +20,7 @@ class Rating: user_score: Score100 | None @classmethod - def from_movie(cls, movie: models.Movie, *, rating: models.Rating = None): + def from_movie(cls, movie: models.Movie, *, rating: models.Rating | None = None): return cls( canonical_title=movie.title, imdb_score=movie.imdb_score,