From 398736c4ad6cf583a05458782b4aa41e5bbe117d Mon Sep 17 00:00:00 2001 From: ducklet Date: Fri, 10 Feb 2023 21:00:11 +0100 Subject: [PATCH] improve typing defs --- metadex/db.py | 58 ++++++++++++++++++++++++++++------------------ metadex/metadex.py | 12 +++++----- metadex/models.py | 4 ++-- 3 files changed, 43 insertions(+), 31 deletions(-) diff --git a/metadex/db.py b/metadex/db.py index 136a3de..e3bd146 100644 --- a/metadex/db.py +++ b/metadex/db.py @@ -4,12 +4,13 @@ from datetime import datetime from functools import lru_cache from pathlib import Path from random import randint -from typing import Iterable, overload +from typing import Generator, Iterable, Literal, Sequence, overload from sqlalchemy import ( Column, DateTime, Enum, + Executable, Integer, MetaData, String, @@ -70,7 +71,7 @@ def check_integrity(conn: Connection) -> None: log.info("Database file integrity: %s", state) -def check_parent_ids(conn: Connection): +def check_parent_ids(conn: Connection) -> None: log.info("Checking parent file associations ... press Ctrl-C to skip!") try: @@ -79,7 +80,7 @@ def check_parent_ids(conn: Connection): log.warning("Aborted parent ID rebuild.") -def optimize(conn: Connection, *, vacuum: bool = False): +def optimize(conn: Connection, *, vacuum: bool = False) -> None: log.info("Optimizing database ...") conn.execute(text("PRAGMA analysis_limit=400")) @@ -93,7 +94,7 @@ def optimize(conn: Connection, *, vacuum: bool = False): log.warning("Aborted DB cleanup.") -def autoconf(conn: Connection): +def autoconf(conn: Connection) -> None: log.info("Configuring database ...") conn.execute(text("PRAGMA journal_mode=WAL")) @@ -104,13 +105,13 @@ class Db: engine: "Engine | None" = None is_dirty: bool = False - def __init__(self, path: Path, *, create_if_missing: bool = True): + def __init__(self, path: Path, *, create_if_missing: bool = True) -> None: self.open(path, create_if_missing=create_if_missing) - def __del__(self): + def __del__(self) -> None: self.close() - def open(self, path: Path, *, create_if_missing: bool = True): + def open(self, path: Path, *, create_if_missing: bool = True) -> None: log.info("Using database: %a", str(path)) if not create_if_missing and not path.exists(): @@ -128,7 +129,7 @@ class Db: autoconf(conn) check_integrity(conn) - def close(self): + def close(self) -> None: if self.engine is None: return @@ -146,7 +147,7 @@ class Db: @contextmanager def transaction( self, *, rollback_on_error: bool = False, force_rollback: bool = False - ): + ) -> Generator[Connection, None, None]: if self.engine is None: raise RuntimeError("DB was closed.") @@ -157,7 +158,7 @@ class Db: ) err = None - with connect() as conn: + with connect() as conn: # type: ignore[attr-defined] # Mypy 1.0.0 doesn't understand the `connect` object is always a context manager try: yield conn except BaseException as e: @@ -174,7 +175,7 @@ class Db: self.is_dirty = True -def init(path: Path = Path(":memory:")): +def init(path: Path = Path(":memory:")) -> None: global engine log.info("Using database: %a", str(path)) @@ -192,7 +193,7 @@ def init(path: Path = Path(":memory:")): check_integrity(conn) -def close(): +def close() -> None: global engine chance = 10 # Set the chance for long running actions to happen to 1 in X. @@ -210,7 +211,7 @@ def iter_all(conn: Connection) -> Iterable[Row]: return conn.execute(select(metadex)) -def get_file(conn: Connection, *, location: str, hostname: str): +def get_file(conn: Connection, *, location: str, hostname: str) -> "Row | None": stmt = select(metadex).where( and_( metadex.c.location == location, @@ -220,7 +221,7 @@ def get_file(conn: Connection, *, location: str, hostname: str): return conn.execute(stmt).one_or_none() -def get_files(conn: Connection, *, parent_id: int): +def get_files(conn: Connection, *, parent_id: int) -> Sequence[Row]: stmt = select(metadex).where( metadex.c.parent_id == parent_id, ) @@ -294,7 +295,9 @@ def all_hostnames(conn: Connection) -> Iterable[str]: return conn.execute(stmt).scalars().all() -def _fake_entry(path: Path, *, hostname=None, now, parent_id) -> dict: +def _fake_entry( + path: Path, *, hostname: "str | None" = None, now: datetime, parent_id: "int | None" +) -> dict: return dict( parent_id=parent_id, added=now, @@ -307,10 +310,10 @@ def _fake_entry(path: Path, *, hostname=None, now, parent_id) -> dict: ) -def _add_parents(conn: Connection, *, location: str, hostname: str): +def _add_parents(conn: Connection, *, location: str, hostname: str) -> "int | None": p_id: "int | None" = None for p in reversed(Path(location).parents): - log.warning("Forging parent: %a:%a", hostname, str(p)) + d: "dict | Row" d = _fake_entry(p, hostname=hostname, now=datetime.now(), parent_id=p_id) d = get_or_add(conn, d) if isinstance(d, dict): @@ -369,11 +372,14 @@ def get_or_add(conn: Connection, new_data: dict) -> "Row | dict": return new_data -def upsert_if_changed(conn: Connection, new_data: dict): +def upsert_if_changed( + conn: Connection, new_data: dict +) -> Literal["added", "changed", "unchanged"]: row = get_or_add(conn, new_data) is_from_db = isinstance(row, Row) if not is_from_db: return "added" + assert isinstance(row, Row) # Required for Mypy 1.0.0. is_changed = ( new_data["stat_bytes"] != row.stat_bytes @@ -410,7 +416,9 @@ def upsert_if_changed(conn: Connection, new_data: dict): return "changed" -def remove_all(conn: Connection, location: str, *, hostname=None) -> int: +def remove_all( + conn: Connection, location: str, *, hostname: "str | None" = None +) -> int: """Remove the entry with the given path and all its descendants.""" # We're using text comparison here to catch removed descendants even if @@ -421,6 +429,7 @@ def remove_all(conn: Connection, location: str, *, hostname=None) -> int: # or change our parent-id-mechanism to support skipping intermediates, both of # which might be valid decisions for sake of optimization. For now we choose # simple correctness. Let's see how bad the performance can get. + stmt: Executable if hostname is None: hostname = config.hostname @@ -443,11 +452,11 @@ def remove_all(conn: Connection, location: str, *, hostname=None) -> int: @contextmanager -def transaction(rollback_on_error: bool = False): +def transaction(rollback_on_error: bool = False) -> Generator[Connection, None, None]: connect = engine.connect if config.dryrun else engine.begin err = None - with connect() as conn: + with connect() as conn: # type: ignore[attr-defined] # Mypy 1.0.0 doesn't understand the `connect` object is always a context manager try: yield conn except BaseException as e: @@ -461,7 +470,9 @@ def transaction(rollback_on_error: bool = False): raise err -def files_in_dir(conn: Connection, location: str, *, hostname=None) -> Iterable[str]: +def files_in_dir( + conn: Connection, location: str, *, hostname: "str | None" = None +) -> Iterable[str]: """Return all file names for the given dir.""" if hostname is None: hostname = config.hostname @@ -524,7 +535,8 @@ def _parent_id( return val -def reassign_parent_ids(conn: Connection): +def reassign_parent_ids(conn: Connection) -> None: + stmt: Executable stmt = select( metadex.c.id, metadex.c.parent_id, metadex.c.location, metadex.c.hostname ) diff --git a/metadex/metadex.py b/metadex/metadex.py index 6fac9b4..fbde390 100644 --- a/metadex/metadex.py +++ b/metadex/metadex.py @@ -208,14 +208,14 @@ def _scan_remove_missing( dirs.extendleft(subdirs) for name in expected: - f = str(cwd / name) - if is_ignored(f): - log.info("Ignoring file (for removal): %a", f) + ff = str(cwd / name) + if is_ignored(ff): + log.info("Ignoring file (for removal): %a", ff) continue - log.info("File removed: %a", f) + log.info("File removed: %a", ff) - context.removed += db.remove_all(conn, f) + context.removed += db.remove_all(conn, ff) db.recalculate_dir_sizes(conn) @@ -291,7 +291,7 @@ def _parse_pathspec_mapping(map_pathspecs: "list[str]") -> _PathspecMapping: return maps -def _apply_mapping(maps: dict, d: dict) -> None: +def _apply_mapping(maps: _PathspecMapping, d: dict) -> None: hostname = d["hostname"] location = ( d["location"] diff --git a/metadex/models.py b/metadex/models.py index 4857edf..8c6b85d 100644 --- a/metadex/models.py +++ b/metadex/models.py @@ -31,7 +31,7 @@ class File: stat_type: StatType @classmethod - def from_direntry(cls, entry: DirEntry) -> Self: + def from_direntry(cls, entry: DirEntry[str]) -> Self: now = datetime.now() pstat = entry.stat(follow_symlinks=False) return cls( @@ -61,7 +61,7 @@ class File: ) @staticmethod - def dict_from_entry(entry: "DirEntry | Path") -> dict: + def dict_from_entry(entry: "DirEntry[str] | Path") -> dict: """Return the File's data structure as dict. This can be useful to skip calling `asdict`, which can be quite slow.