From 14f2395fa658b328dd529cbc4945ef444f6c2feb Mon Sep 17 00:00:00 2001 From: ducklet Date: Tue, 3 Aug 2021 16:39:36 +0200 Subject: [PATCH] improve some type annotations --- unwind/imdb_import.py | 18 ++++++++++++++++-- unwind/models.py | 5 +++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/unwind/imdb_import.py b/unwind/imdb_import.py index 1d514f3..e29a981 100644 --- a/unwind/imdb_import.py +++ b/unwind/imdb_import.py @@ -4,7 +4,7 @@ import logging from dataclasses import dataclass, fields from datetime import datetime, timezone from pathlib import Path -from typing import Optional, cast +from typing import Generator, Literal, Optional, Type, TypeVar, overload from . import config, db, request from .db import add_or_update_many_movies @@ -13,6 +13,7 @@ from .models import Movie log = logging.getLogger(__name__) +T = TypeVar("T") # See # - https://www.imdb.com/interfaces/ @@ -120,6 +121,20 @@ def count_lines(path) -> int: return i +@overload +def read_imdb_tsv( + path, row_type, *, unpack: Literal[False] +) -> Generator[list[str], None, None]: + ... + + +@overload +def read_imdb_tsv( + path, row_type: Type[T], *, unpack: Literal[True] = True +) -> Generator[T, None, None]: + ... + + def read_imdb_tsv(path, row_type, *, unpack=True): with gzip.open(path, "rt", newline="") as f: rows = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_NONE) @@ -158,7 +173,6 @@ def read_ratings(path): def read_ratings_as_mapping(path): """Optimized function to quickly load all ratings.""" rows = read_imdb_tsv(path, RatingRow, unpack=False) - rows = cast(list[list[str]], rows) return {r[0]: (round(100 * (float(r[1]) - 1) / 9), int(r[2])) for r in rows} diff --git a/unwind/models.py b/unwind/models.py index 184696b..ec7f1fb 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -16,6 +16,8 @@ from typing import ( from .types import ULID +T = TypeVar("T") + def annotations(tp: Type) -> Optional[tuple]: return tp.__metadata__ if hasattr(tp, "__metadata__") else None @@ -102,7 +104,7 @@ def asplain(o, *, fields_: set = None) -> dict[str, Any]: return d -def fromplain(cls, d: dict[str, Any]): +def fromplain(cls: Type[T], d: dict[str, Any]) -> T: dd = {} for f in fields(cls): @@ -232,7 +234,6 @@ class Movie: self._is_lazy = False -T = TypeVar("T") _RelationSentinel = object() """Mark a model field as containing external data.