diff --git a/tests/test_db.py b/tests/test_db.py index 7c4c96e..cd5f295 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -38,21 +38,25 @@ async def test_get(shared_conn: db.Database): m2 = a_movie(release_year=m1.release_year + 1) await db.add(m2) - assert None == await db.get(models.Movie) - assert None == await db.get(models.Movie, id="blerp") + assert None is await db.get(models.Movie) + assert None is await db.get(models.Movie, id="blerp") assert m1 == await db.get(models.Movie, id=str(m1.id)) assert m2 == await db.get(models.Movie, release_year=m2.release_year) - assert None == await db.get( + assert None is await db.get( models.Movie, id=str(m1.id), release_year=m2.release_year ) assert m2 == await db.get( models.Movie, id=str(m2.id), release_year=m2.release_year ) assert m1 == await db.get( - models.Movie, media_type=m1.media_type, order_by=("release_year", "asc") + models.Movie, + media_type=m1.media_type, + order_by=(models.movies.c.release_year, "asc"), ) assert m2 == await db.get( - models.Movie, media_type=m1.media_type, order_by=("release_year", "desc") + models.Movie, + media_type=m1.media_type, + order_by=(models.movies.c.release_year, "desc"), ) @@ -136,7 +140,7 @@ async def test_remove(shared_conn: db.Database): assert m1 == await db.get(models.Movie, id=str(m1.id)) await db.remove(m1) - assert None == await db.get(models.Movie, id=str(m1.id)) + assert None is await db.get(models.Movie, id=str(m1.id)) @pytest.mark.asyncio diff --git a/unwind/db.py b/unwind/db.py index 63c042a..7a0169d 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -22,6 +22,7 @@ from .models import ( metadata, movies, optional_fields, + progress, ratings, utcnow, ) @@ -110,19 +111,21 @@ async def apply_db_patches(db: Database) -> None: async with db.transaction(): for query in queries: - await db.execute(query) + await db.execute(sa.text(query)) await set_current_patch_level(db, patch_lvl) did_patch = True if did_patch: - await db.execute("vacuum") + await db.execute(sa.text("vacuum")) async def get_import_progress() -> Progress | None: """Return the latest import progress.""" - return await get(Progress, type="import-imdb-movies", order_by=("started", "desc")) + return await get( + Progress, type="import-imdb-movies", order_by=(progress.c.started, "desc") + ) async def stop_import_progress(*, error: BaseException | None = None) -> None: @@ -225,7 +228,7 @@ async def add(item: Model) -> None: # Support late initializing - used for optimization. if getattr(item, "_is_lazy", False): assert hasattr(item, "_lazy_init") - item._lazy_init() + item._lazy_init() # pyright: ignore [reportGeneralTypeIssues] table: sa.Table = item.__table__ values = asplain(item, serialize=True) @@ -240,7 +243,7 @@ ModelType = TypeVar("ModelType", bound=Model) async def get( model: Type[ModelType], *, - order_by: tuple[str, Literal["asc", "desc"]] | None = None, + order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None, **field_values, ) -> ModelType | None: """Load a model instance from the database. @@ -259,9 +262,7 @@ async def get( if order_by: order_col, order_dir = order_by query = query.order_by( - table.c[order_col].asc() - if order_dir == "asc" - else table.c[order_col].desc() + order_col.asc() if order_dir == "asc" else order_col.desc() ) async with locked_connection() as conn: row = await conn.fetch_one(query) @@ -306,7 +307,7 @@ async def update(item: Model) -> None: # Support late initializing - used for optimization. if getattr(item, "_is_lazy", False): assert hasattr(item, "_lazy_init") - item._lazy_init() + item._lazy_init() # pyright: ignore [reportGeneralTypeIssues] table: sa.Table = item.__table__ values = asplain(item, serialize=True) @@ -587,7 +588,7 @@ async def find_movies( conditions.append(movies.c.media_type != "TV Episode") if not include_unrated: - conditions.append(movies.c.imdb_score != None) + conditions.append(movies.c.imdb_score.is_not(None)) query = ( sa.select(movies) diff --git a/unwind/models.py b/unwind/models.py index 77e5b29..13af462 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -275,6 +275,9 @@ class Progress: self._state = state +progress = Progress.__table__ + + @mapper_registry.mapped @dataclass class Movie: diff --git a/unwind/request.py b/unwind/request.py index 4e57564..afd2b86 100644 --- a/unwind/request.py +++ b/unwind/request.py @@ -11,7 +11,7 @@ from hashlib import md5 from pathlib import Path from random import random from time import sleep, time -from typing import Callable, ParamSpec, TypeVar, cast +from typing import Any, Callable, ParamSpec, TypeVar, cast import bs4 import httpx @@ -190,9 +190,11 @@ async def asoup_from_url(url): def _last_modified_from_response(resp: _Response_T) -> float | None: if last_mod := resp.headers.get("last-modified"): try: - return email.utils.parsedate_to_datetime(last_mod).timestamp() - except: + dt = email.utils.parsedate_to_datetime(last_mod) + except ValueError: log.exception("🐛 Received invalid value for Last-Modified: %s", last_mod) + else: + return dt.timestamp() def _last_modified_from_file(path: Path) -> float: @@ -206,8 +208,8 @@ async def adownload( replace_existing: bool | None = None, only_if_newer: bool = False, timeout: float | None = None, - chunk_callback=None, - response_callback=None, + chunk_callback: Callable[[bytes], Any] | None = None, + response_callback: Callable[[_Response_T], Any] | None = None, ) -> bytes | None: """Download a file. @@ -246,7 +248,7 @@ async def adownload( if response_callback is not None: try: response_callback(resp) - except: + except BaseException: log.exception("🐛 Error in response callback.") log.debug( @@ -275,7 +277,7 @@ async def adownload( # Check Last-Modified in case the server ignored If-Modified-Since. # XXX also check Content-Length? if file_exists and only_if_newer and resp_lastmod is not None: - assert file_lastmod + assert file_lastmod # pyright: ignore [reportUnboundVariable] if resp_lastmod <= file_lastmod: log.debug("✋ Local file is newer, skipping download: %a", req.url) @@ -299,7 +301,7 @@ async def adownload( if chunk_callback: try: chunk_callback(chunk) - except: + except BaseException: log.exception("🐛 Error in chunk callback.") finally: os.close(tempfd)