improve typing defs

This commit is contained in:
ducklet 2023-02-10 21:00:11 +01:00
parent 3a2988b8c7
commit 398736c4ad
3 changed files with 43 additions and 31 deletions

View file

@ -4,12 +4,13 @@ from datetime import datetime
from functools import lru_cache
from pathlib import Path
from random import randint
from typing import Iterable, overload
from typing import Generator, Iterable, Literal, Sequence, overload
from sqlalchemy import (
Column,
DateTime,
Enum,
Executable,
Integer,
MetaData,
String,
@ -70,7 +71,7 @@ def check_integrity(conn: Connection) -> None:
log.info("Database file integrity: %s", state)
def check_parent_ids(conn: Connection):
def check_parent_ids(conn: Connection) -> None:
log.info("Checking parent file associations ... press Ctrl-C to skip!")
try:
@ -79,7 +80,7 @@ def check_parent_ids(conn: Connection):
log.warning("Aborted parent ID rebuild.")
def optimize(conn: Connection, *, vacuum: bool = False):
def optimize(conn: Connection, *, vacuum: bool = False) -> None:
log.info("Optimizing database ...")
conn.execute(text("PRAGMA analysis_limit=400"))
@ -93,7 +94,7 @@ def optimize(conn: Connection, *, vacuum: bool = False):
log.warning("Aborted DB cleanup.")
def autoconf(conn: Connection):
def autoconf(conn: Connection) -> None:
log.info("Configuring database ...")
conn.execute(text("PRAGMA journal_mode=WAL"))
@ -104,13 +105,13 @@ class Db:
engine: "Engine | None" = None
is_dirty: bool = False
def __init__(self, path: Path, *, create_if_missing: bool = True):
def __init__(self, path: Path, *, create_if_missing: bool = True) -> None:
self.open(path, create_if_missing=create_if_missing)
def __del__(self):
def __del__(self) -> None:
self.close()
def open(self, path: Path, *, create_if_missing: bool = True):
def open(self, path: Path, *, create_if_missing: bool = True) -> None:
log.info("Using database: %a", str(path))
if not create_if_missing and not path.exists():
@ -128,7 +129,7 @@ class Db:
autoconf(conn)
check_integrity(conn)
def close(self):
def close(self) -> None:
if self.engine is None:
return
@ -146,7 +147,7 @@ class Db:
@contextmanager
def transaction(
self, *, rollback_on_error: bool = False, force_rollback: bool = False
):
) -> Generator[Connection, None, None]:
if self.engine is None:
raise RuntimeError("DB was closed.")
@ -157,7 +158,7 @@ class Db:
)
err = None
with connect() as conn:
with connect() as conn: # type: ignore[attr-defined] # Mypy 1.0.0 doesn't understand the `connect` object is always a context manager
try:
yield conn
except BaseException as e:
@ -174,7 +175,7 @@ class Db:
self.is_dirty = True
def init(path: Path = Path(":memory:")):
def init(path: Path = Path(":memory:")) -> None:
global engine
log.info("Using database: %a", str(path))
@ -192,7 +193,7 @@ def init(path: Path = Path(":memory:")):
check_integrity(conn)
def close():
def close() -> None:
global engine
chance = 10 # Set the chance for long running actions to happen to 1 in X.
@ -210,7 +211,7 @@ def iter_all(conn: Connection) -> Iterable[Row]:
return conn.execute(select(metadex))
def get_file(conn: Connection, *, location: str, hostname: str):
def get_file(conn: Connection, *, location: str, hostname: str) -> "Row | None":
stmt = select(metadex).where(
and_(
metadex.c.location == location,
@ -220,7 +221,7 @@ def get_file(conn: Connection, *, location: str, hostname: str):
return conn.execute(stmt).one_or_none()
def get_files(conn: Connection, *, parent_id: int):
def get_files(conn: Connection, *, parent_id: int) -> Sequence[Row]:
stmt = select(metadex).where(
metadex.c.parent_id == parent_id,
)
@ -294,7 +295,9 @@ def all_hostnames(conn: Connection) -> Iterable[str]:
return conn.execute(stmt).scalars().all()
def _fake_entry(path: Path, *, hostname=None, now, parent_id) -> dict:
def _fake_entry(
path: Path, *, hostname: "str | None" = None, now: datetime, parent_id: "int | None"
) -> dict:
return dict(
parent_id=parent_id,
added=now,
@ -307,10 +310,10 @@ def _fake_entry(path: Path, *, hostname=None, now, parent_id) -> dict:
)
def _add_parents(conn: Connection, *, location: str, hostname: str):
def _add_parents(conn: Connection, *, location: str, hostname: str) -> "int | None":
p_id: "int | None" = None
for p in reversed(Path(location).parents):
log.warning("Forging parent: %a:%a", hostname, str(p))
d: "dict | Row"
d = _fake_entry(p, hostname=hostname, now=datetime.now(), parent_id=p_id)
d = get_or_add(conn, d)
if isinstance(d, dict):
@ -369,11 +372,14 @@ def get_or_add(conn: Connection, new_data: dict) -> "Row | dict":
return new_data
def upsert_if_changed(conn: Connection, new_data: dict):
def upsert_if_changed(
conn: Connection, new_data: dict
) -> Literal["added", "changed", "unchanged"]:
row = get_or_add(conn, new_data)
is_from_db = isinstance(row, Row)
if not is_from_db:
return "added"
assert isinstance(row, Row) # Required for Mypy 1.0.0.
is_changed = (
new_data["stat_bytes"] != row.stat_bytes
@ -410,7 +416,9 @@ def upsert_if_changed(conn: Connection, new_data: dict):
return "changed"
def remove_all(conn: Connection, location: str, *, hostname=None) -> int:
def remove_all(
conn: Connection, location: str, *, hostname: "str | None" = None
) -> int:
"""Remove the entry with the given path and all its descendants."""
# We're using text comparison here to catch removed descendants even if
@ -421,6 +429,7 @@ def remove_all(conn: Connection, location: str, *, hostname=None) -> int:
# or change our parent-id-mechanism to support skipping intermediates, both of
# which might be valid decisions for sake of optimization. For now we choose
# simple correctness. Let's see how bad the performance can get.
stmt: Executable
if hostname is None:
hostname = config.hostname
@ -443,11 +452,11 @@ def remove_all(conn: Connection, location: str, *, hostname=None) -> int:
@contextmanager
def transaction(rollback_on_error: bool = False):
def transaction(rollback_on_error: bool = False) -> Generator[Connection, None, None]:
connect = engine.connect if config.dryrun else engine.begin
err = None
with connect() as conn:
with connect() as conn: # type: ignore[attr-defined] # Mypy 1.0.0 doesn't understand the `connect` object is always a context manager
try:
yield conn
except BaseException as e:
@ -461,7 +470,9 @@ def transaction(rollback_on_error: bool = False):
raise err
def files_in_dir(conn: Connection, location: str, *, hostname=None) -> Iterable[str]:
def files_in_dir(
conn: Connection, location: str, *, hostname: "str | None" = None
) -> Iterable[str]:
"""Return all file names for the given dir."""
if hostname is None:
hostname = config.hostname
@ -524,7 +535,8 @@ def _parent_id(
return val
def reassign_parent_ids(conn: Connection):
def reassign_parent_ids(conn: Connection) -> None:
stmt: Executable
stmt = select(
metadex.c.id, metadex.c.parent_id, metadex.c.location, metadex.c.hostname
)

View file

@ -208,14 +208,14 @@ def _scan_remove_missing(
dirs.extendleft(subdirs)
for name in expected:
f = str(cwd / name)
if is_ignored(f):
log.info("Ignoring file (for removal): %a", f)
ff = str(cwd / name)
if is_ignored(ff):
log.info("Ignoring file (for removal): %a", ff)
continue
log.info("File removed: %a", f)
log.info("File removed: %a", ff)
context.removed += db.remove_all(conn, f)
context.removed += db.remove_all(conn, ff)
db.recalculate_dir_sizes(conn)
@ -291,7 +291,7 @@ def _parse_pathspec_mapping(map_pathspecs: "list[str]") -> _PathspecMapping:
return maps
def _apply_mapping(maps: dict, d: dict) -> None:
def _apply_mapping(maps: _PathspecMapping, d: dict) -> None:
hostname = d["hostname"]
location = (
d["location"]

View file

@ -31,7 +31,7 @@ class File:
stat_type: StatType
@classmethod
def from_direntry(cls, entry: DirEntry) -> Self:
def from_direntry(cls, entry: DirEntry[str]) -> Self:
now = datetime.now()
pstat = entry.stat(follow_symlinks=False)
return cls(
@ -61,7 +61,7 @@ class File:
)
@staticmethod
def dict_from_entry(entry: "DirEntry | Path") -> dict:
def dict_from_entry(entry: "DirEntry[str] | Path") -> dict:
"""Return the File's data structure as dict.
This can be useful to skip calling `asdict`, which can be quite slow.