improve strict typing

This commit is contained in:
ducklet 2023-07-22 19:37:01 +02:00
parent 86c3030e31
commit 25f31db756
4 changed files with 34 additions and 24 deletions

View file

@ -38,21 +38,25 @@ async def test_get(shared_conn: db.Database):
m2 = a_movie(release_year=m1.release_year + 1) m2 = a_movie(release_year=m1.release_year + 1)
await db.add(m2) await db.add(m2)
assert None == await db.get(models.Movie) assert None is await db.get(models.Movie)
assert None == await db.get(models.Movie, id="blerp") assert None is await db.get(models.Movie, id="blerp")
assert m1 == await db.get(models.Movie, id=str(m1.id)) assert m1 == await db.get(models.Movie, id=str(m1.id))
assert m2 == await db.get(models.Movie, release_year=m2.release_year) assert m2 == await db.get(models.Movie, release_year=m2.release_year)
assert None == await db.get( assert None is await db.get(
models.Movie, id=str(m1.id), release_year=m2.release_year models.Movie, id=str(m1.id), release_year=m2.release_year
) )
assert m2 == await db.get( assert m2 == await db.get(
models.Movie, id=str(m2.id), release_year=m2.release_year models.Movie, id=str(m2.id), release_year=m2.release_year
) )
assert m1 == await db.get( assert m1 == await db.get(
models.Movie, media_type=m1.media_type, order_by=("release_year", "asc") models.Movie,
media_type=m1.media_type,
order_by=(models.movies.c.release_year, "asc"),
) )
assert m2 == await db.get( assert m2 == await db.get(
models.Movie, media_type=m1.media_type, order_by=("release_year", "desc") models.Movie,
media_type=m1.media_type,
order_by=(models.movies.c.release_year, "desc"),
) )
@ -136,7 +140,7 @@ async def test_remove(shared_conn: db.Database):
assert m1 == await db.get(models.Movie, id=str(m1.id)) assert m1 == await db.get(models.Movie, id=str(m1.id))
await db.remove(m1) await db.remove(m1)
assert None == await db.get(models.Movie, id=str(m1.id)) assert None is await db.get(models.Movie, id=str(m1.id))
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -22,6 +22,7 @@ from .models import (
metadata, metadata,
movies, movies,
optional_fields, optional_fields,
progress,
ratings, ratings,
utcnow, utcnow,
) )
@ -110,19 +111,21 @@ async def apply_db_patches(db: Database) -> None:
async with db.transaction(): async with db.transaction():
for query in queries: for query in queries:
await db.execute(query) await db.execute(sa.text(query))
await set_current_patch_level(db, patch_lvl) await set_current_patch_level(db, patch_lvl)
did_patch = True did_patch = True
if did_patch: if did_patch:
await db.execute("vacuum") await db.execute(sa.text("vacuum"))
async def get_import_progress() -> Progress | None: async def get_import_progress() -> Progress | None:
"""Return the latest import progress.""" """Return the latest import progress."""
return await get(Progress, type="import-imdb-movies", order_by=("started", "desc")) return await get(
Progress, type="import-imdb-movies", order_by=(progress.c.started, "desc")
)
async def stop_import_progress(*, error: BaseException | None = None) -> None: async def stop_import_progress(*, error: BaseException | None = None) -> None:
@ -225,7 +228,7 @@ async def add(item: Model) -> None:
# Support late initializing - used for optimization. # Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False): if getattr(item, "_is_lazy", False):
assert hasattr(item, "_lazy_init") assert hasattr(item, "_lazy_init")
item._lazy_init() item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
table: sa.Table = item.__table__ table: sa.Table = item.__table__
values = asplain(item, serialize=True) values = asplain(item, serialize=True)
@ -240,7 +243,7 @@ ModelType = TypeVar("ModelType", bound=Model)
async def get( async def get(
model: Type[ModelType], model: Type[ModelType],
*, *,
order_by: tuple[str, Literal["asc", "desc"]] | None = None, order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None,
**field_values, **field_values,
) -> ModelType | None: ) -> ModelType | None:
"""Load a model instance from the database. """Load a model instance from the database.
@ -259,9 +262,7 @@ async def get(
if order_by: if order_by:
order_col, order_dir = order_by order_col, order_dir = order_by
query = query.order_by( query = query.order_by(
table.c[order_col].asc() order_col.asc() if order_dir == "asc" else order_col.desc()
if order_dir == "asc"
else table.c[order_col].desc()
) )
async with locked_connection() as conn: async with locked_connection() as conn:
row = await conn.fetch_one(query) row = await conn.fetch_one(query)
@ -306,7 +307,7 @@ async def update(item: Model) -> None:
# Support late initializing - used for optimization. # Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False): if getattr(item, "_is_lazy", False):
assert hasattr(item, "_lazy_init") assert hasattr(item, "_lazy_init")
item._lazy_init() item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
table: sa.Table = item.__table__ table: sa.Table = item.__table__
values = asplain(item, serialize=True) values = asplain(item, serialize=True)
@ -587,7 +588,7 @@ async def find_movies(
conditions.append(movies.c.media_type != "TV Episode") conditions.append(movies.c.media_type != "TV Episode")
if not include_unrated: if not include_unrated:
conditions.append(movies.c.imdb_score != None) conditions.append(movies.c.imdb_score.is_not(None))
query = ( query = (
sa.select(movies) sa.select(movies)

View file

@ -275,6 +275,9 @@ class Progress:
self._state = state self._state = state
progress = Progress.__table__
@mapper_registry.mapped @mapper_registry.mapped
@dataclass @dataclass
class Movie: class Movie:

View file

@ -11,7 +11,7 @@ from hashlib import md5
from pathlib import Path from pathlib import Path
from random import random from random import random
from time import sleep, time from time import sleep, time
from typing import Callable, ParamSpec, TypeVar, cast from typing import Any, Callable, ParamSpec, TypeVar, cast
import bs4 import bs4
import httpx import httpx
@ -190,9 +190,11 @@ async def asoup_from_url(url):
def _last_modified_from_response(resp: _Response_T) -> float | None: def _last_modified_from_response(resp: _Response_T) -> float | None:
if last_mod := resp.headers.get("last-modified"): if last_mod := resp.headers.get("last-modified"):
try: try:
return email.utils.parsedate_to_datetime(last_mod).timestamp() dt = email.utils.parsedate_to_datetime(last_mod)
except: except ValueError:
log.exception("🐛 Received invalid value for Last-Modified: %s", last_mod) log.exception("🐛 Received invalid value for Last-Modified: %s", last_mod)
else:
return dt.timestamp()
def _last_modified_from_file(path: Path) -> float: def _last_modified_from_file(path: Path) -> float:
@ -206,8 +208,8 @@ async def adownload(
replace_existing: bool | None = None, replace_existing: bool | None = None,
only_if_newer: bool = False, only_if_newer: bool = False,
timeout: float | None = None, timeout: float | None = None,
chunk_callback=None, chunk_callback: Callable[[bytes], Any] | None = None,
response_callback=None, response_callback: Callable[[_Response_T], Any] | None = None,
) -> bytes | None: ) -> bytes | None:
"""Download a file. """Download a file.
@ -246,7 +248,7 @@ async def adownload(
if response_callback is not None: if response_callback is not None:
try: try:
response_callback(resp) response_callback(resp)
except: except BaseException:
log.exception("🐛 Error in response callback.") log.exception("🐛 Error in response callback.")
log.debug( log.debug(
@ -275,7 +277,7 @@ async def adownload(
# Check Last-Modified in case the server ignored If-Modified-Since. # Check Last-Modified in case the server ignored If-Modified-Since.
# XXX also check Content-Length? # XXX also check Content-Length?
if file_exists and only_if_newer and resp_lastmod is not None: if file_exists and only_if_newer and resp_lastmod is not None:
assert file_lastmod assert file_lastmod # pyright: ignore [reportUnboundVariable]
if resp_lastmod <= file_lastmod: if resp_lastmod <= file_lastmod:
log.debug("✋ Local file is newer, skipping download: %a", req.url) log.debug("✋ Local file is newer, skipping download: %a", req.url)
@ -299,7 +301,7 @@ async def adownload(
if chunk_callback: if chunk_callback:
try: try:
chunk_callback(chunk) chunk_callback(chunk)
except: except BaseException:
log.exception("🐛 Error in chunk callback.") log.exception("🐛 Error in chunk callback.")
finally: finally:
os.close(tempfd) os.close(tempfd)