improve typing defs

This commit is contained in:
ducklet 2023-02-10 21:00:11 +01:00
parent 3a2988b8c7
commit 398736c4ad
3 changed files with 43 additions and 31 deletions

View file

@ -4,12 +4,13 @@ from datetime import datetime
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from random import randint from random import randint
from typing import Iterable, overload from typing import Generator, Iterable, Literal, Sequence, overload
from sqlalchemy import ( from sqlalchemy import (
Column, Column,
DateTime, DateTime,
Enum, Enum,
Executable,
Integer, Integer,
MetaData, MetaData,
String, String,
@ -70,7 +71,7 @@ def check_integrity(conn: Connection) -> None:
log.info("Database file integrity: %s", state) log.info("Database file integrity: %s", state)
def check_parent_ids(conn: Connection): def check_parent_ids(conn: Connection) -> None:
log.info("Checking parent file associations ... press Ctrl-C to skip!") log.info("Checking parent file associations ... press Ctrl-C to skip!")
try: try:
@ -79,7 +80,7 @@ def check_parent_ids(conn: Connection):
log.warning("Aborted parent ID rebuild.") log.warning("Aborted parent ID rebuild.")
def optimize(conn: Connection, *, vacuum: bool = False): def optimize(conn: Connection, *, vacuum: bool = False) -> None:
log.info("Optimizing database ...") log.info("Optimizing database ...")
conn.execute(text("PRAGMA analysis_limit=400")) conn.execute(text("PRAGMA analysis_limit=400"))
@ -93,7 +94,7 @@ def optimize(conn: Connection, *, vacuum: bool = False):
log.warning("Aborted DB cleanup.") log.warning("Aborted DB cleanup.")
def autoconf(conn: Connection): def autoconf(conn: Connection) -> None:
log.info("Configuring database ...") log.info("Configuring database ...")
conn.execute(text("PRAGMA journal_mode=WAL")) conn.execute(text("PRAGMA journal_mode=WAL"))
@ -104,13 +105,13 @@ class Db:
engine: "Engine | None" = None engine: "Engine | None" = None
is_dirty: bool = False is_dirty: bool = False
def __init__(self, path: Path, *, create_if_missing: bool = True): def __init__(self, path: Path, *, create_if_missing: bool = True) -> None:
self.open(path, create_if_missing=create_if_missing) self.open(path, create_if_missing=create_if_missing)
def __del__(self): def __del__(self) -> None:
self.close() self.close()
def open(self, path: Path, *, create_if_missing: bool = True): def open(self, path: Path, *, create_if_missing: bool = True) -> None:
log.info("Using database: %a", str(path)) log.info("Using database: %a", str(path))
if not create_if_missing and not path.exists(): if not create_if_missing and not path.exists():
@ -128,7 +129,7 @@ class Db:
autoconf(conn) autoconf(conn)
check_integrity(conn) check_integrity(conn)
def close(self): def close(self) -> None:
if self.engine is None: if self.engine is None:
return return
@ -146,7 +147,7 @@ class Db:
@contextmanager @contextmanager
def transaction( def transaction(
self, *, rollback_on_error: bool = False, force_rollback: bool = False self, *, rollback_on_error: bool = False, force_rollback: bool = False
): ) -> Generator[Connection, None, None]:
if self.engine is None: if self.engine is None:
raise RuntimeError("DB was closed.") raise RuntimeError("DB was closed.")
@ -157,7 +158,7 @@ class Db:
) )
err = None err = None
with connect() as conn: with connect() as conn: # type: ignore[attr-defined] # Mypy 1.0.0 doesn't understand the `connect` object is always a context manager
try: try:
yield conn yield conn
except BaseException as e: except BaseException as e:
@ -174,7 +175,7 @@ class Db:
self.is_dirty = True self.is_dirty = True
def init(path: Path = Path(":memory:")): def init(path: Path = Path(":memory:")) -> None:
global engine global engine
log.info("Using database: %a", str(path)) log.info("Using database: %a", str(path))
@ -192,7 +193,7 @@ def init(path: Path = Path(":memory:")):
check_integrity(conn) check_integrity(conn)
def close(): def close() -> None:
global engine global engine
chance = 10 # Set the chance for long running actions to happen to 1 in X. chance = 10 # Set the chance for long running actions to happen to 1 in X.
@ -210,7 +211,7 @@ def iter_all(conn: Connection) -> Iterable[Row]:
return conn.execute(select(metadex)) return conn.execute(select(metadex))
def get_file(conn: Connection, *, location: str, hostname: str): def get_file(conn: Connection, *, location: str, hostname: str) -> "Row | None":
stmt = select(metadex).where( stmt = select(metadex).where(
and_( and_(
metadex.c.location == location, metadex.c.location == location,
@ -220,7 +221,7 @@ def get_file(conn: Connection, *, location: str, hostname: str):
return conn.execute(stmt).one_or_none() return conn.execute(stmt).one_or_none()
def get_files(conn: Connection, *, parent_id: int): def get_files(conn: Connection, *, parent_id: int) -> Sequence[Row]:
stmt = select(metadex).where( stmt = select(metadex).where(
metadex.c.parent_id == parent_id, metadex.c.parent_id == parent_id,
) )
@ -294,7 +295,9 @@ def all_hostnames(conn: Connection) -> Iterable[str]:
return conn.execute(stmt).scalars().all() return conn.execute(stmt).scalars().all()
def _fake_entry(path: Path, *, hostname=None, now, parent_id) -> dict: def _fake_entry(
path: Path, *, hostname: "str | None" = None, now: datetime, parent_id: "int | None"
) -> dict:
return dict( return dict(
parent_id=parent_id, parent_id=parent_id,
added=now, added=now,
@ -307,10 +310,10 @@ def _fake_entry(path: Path, *, hostname=None, now, parent_id) -> dict:
) )
def _add_parents(conn: Connection, *, location: str, hostname: str): def _add_parents(conn: Connection, *, location: str, hostname: str) -> "int | None":
p_id: "int | None" = None p_id: "int | None" = None
for p in reversed(Path(location).parents): for p in reversed(Path(location).parents):
log.warning("Forging parent: %a:%a", hostname, str(p)) d: "dict | Row"
d = _fake_entry(p, hostname=hostname, now=datetime.now(), parent_id=p_id) d = _fake_entry(p, hostname=hostname, now=datetime.now(), parent_id=p_id)
d = get_or_add(conn, d) d = get_or_add(conn, d)
if isinstance(d, dict): if isinstance(d, dict):
@ -369,11 +372,14 @@ def get_or_add(conn: Connection, new_data: dict) -> "Row | dict":
return new_data return new_data
def upsert_if_changed(conn: Connection, new_data: dict): def upsert_if_changed(
conn: Connection, new_data: dict
) -> Literal["added", "changed", "unchanged"]:
row = get_or_add(conn, new_data) row = get_or_add(conn, new_data)
is_from_db = isinstance(row, Row) is_from_db = isinstance(row, Row)
if not is_from_db: if not is_from_db:
return "added" return "added"
assert isinstance(row, Row) # Required for Mypy 1.0.0.
is_changed = ( is_changed = (
new_data["stat_bytes"] != row.stat_bytes new_data["stat_bytes"] != row.stat_bytes
@ -410,7 +416,9 @@ def upsert_if_changed(conn: Connection, new_data: dict):
return "changed" return "changed"
def remove_all(conn: Connection, location: str, *, hostname=None) -> int: def remove_all(
conn: Connection, location: str, *, hostname: "str | None" = None
) -> int:
"""Remove the entry with the given path and all its descendants.""" """Remove the entry with the given path and all its descendants."""
# We're using text comparison here to catch removed descendants even if # We're using text comparison here to catch removed descendants even if
@ -421,6 +429,7 @@ def remove_all(conn: Connection, location: str, *, hostname=None) -> int:
# or change our parent-id-mechanism to support skipping intermediates, both of # or change our parent-id-mechanism to support skipping intermediates, both of
# which might be valid decisions for sake of optimization. For now we choose # which might be valid decisions for sake of optimization. For now we choose
# simple correctness. Let's see how bad the performance can get. # simple correctness. Let's see how bad the performance can get.
stmt: Executable
if hostname is None: if hostname is None:
hostname = config.hostname hostname = config.hostname
@ -443,11 +452,11 @@ def remove_all(conn: Connection, location: str, *, hostname=None) -> int:
@contextmanager @contextmanager
def transaction(rollback_on_error: bool = False): def transaction(rollback_on_error: bool = False) -> Generator[Connection, None, None]:
connect = engine.connect if config.dryrun else engine.begin connect = engine.connect if config.dryrun else engine.begin
err = None err = None
with connect() as conn: with connect() as conn: # type: ignore[attr-defined] # Mypy 1.0.0 doesn't understand the `connect` object is always a context manager
try: try:
yield conn yield conn
except BaseException as e: except BaseException as e:
@ -461,7 +470,9 @@ def transaction(rollback_on_error: bool = False):
raise err raise err
def files_in_dir(conn: Connection, location: str, *, hostname=None) -> Iterable[str]: def files_in_dir(
conn: Connection, location: str, *, hostname: "str | None" = None
) -> Iterable[str]:
"""Return all file names for the given dir.""" """Return all file names for the given dir."""
if hostname is None: if hostname is None:
hostname = config.hostname hostname = config.hostname
@ -524,7 +535,8 @@ def _parent_id(
return val return val
def reassign_parent_ids(conn: Connection): def reassign_parent_ids(conn: Connection) -> None:
stmt: Executable
stmt = select( stmt = select(
metadex.c.id, metadex.c.parent_id, metadex.c.location, metadex.c.hostname metadex.c.id, metadex.c.parent_id, metadex.c.location, metadex.c.hostname
) )

View file

@ -208,14 +208,14 @@ def _scan_remove_missing(
dirs.extendleft(subdirs) dirs.extendleft(subdirs)
for name in expected: for name in expected:
f = str(cwd / name) ff = str(cwd / name)
if is_ignored(f): if is_ignored(ff):
log.info("Ignoring file (for removal): %a", f) log.info("Ignoring file (for removal): %a", ff)
continue continue
log.info("File removed: %a", f) log.info("File removed: %a", ff)
context.removed += db.remove_all(conn, f) context.removed += db.remove_all(conn, ff)
db.recalculate_dir_sizes(conn) db.recalculate_dir_sizes(conn)
@ -291,7 +291,7 @@ def _parse_pathspec_mapping(map_pathspecs: "list[str]") -> _PathspecMapping:
return maps return maps
def _apply_mapping(maps: dict, d: dict) -> None: def _apply_mapping(maps: _PathspecMapping, d: dict) -> None:
hostname = d["hostname"] hostname = d["hostname"]
location = ( location = (
d["location"] d["location"]

View file

@ -31,7 +31,7 @@ class File:
stat_type: StatType stat_type: StatType
@classmethod @classmethod
def from_direntry(cls, entry: DirEntry) -> Self: def from_direntry(cls, entry: DirEntry[str]) -> Self:
now = datetime.now() now = datetime.now()
pstat = entry.stat(follow_symlinks=False) pstat = entry.stat(follow_symlinks=False)
return cls( return cls(
@ -61,7 +61,7 @@ class File:
) )
@staticmethod @staticmethod
def dict_from_entry(entry: "DirEntry | Path") -> dict: def dict_from_entry(entry: "DirEntry[str] | Path") -> dict:
"""Return the File's data structure as dict. """Return the File's data structure as dict.
This can be useful to skip calling `asdict`, which can be quite slow. This can be useful to skip calling `asdict`, which can be quite slow.