improve strict typing

This commit is contained in:
ducklet 2023-07-22 19:37:01 +02:00
parent 86c3030e31
commit 25f31db756
4 changed files with 34 additions and 24 deletions

View file

@ -38,21 +38,25 @@ async def test_get(shared_conn: db.Database):
m2 = a_movie(release_year=m1.release_year + 1)
await db.add(m2)
assert None == await db.get(models.Movie)
assert None == await db.get(models.Movie, id="blerp")
assert None is await db.get(models.Movie)
assert None is await db.get(models.Movie, id="blerp")
assert m1 == await db.get(models.Movie, id=str(m1.id))
assert m2 == await db.get(models.Movie, release_year=m2.release_year)
assert None == await db.get(
assert None is await db.get(
models.Movie, id=str(m1.id), release_year=m2.release_year
)
assert m2 == await db.get(
models.Movie, id=str(m2.id), release_year=m2.release_year
)
assert m1 == await db.get(
models.Movie, media_type=m1.media_type, order_by=("release_year", "asc")
models.Movie,
media_type=m1.media_type,
order_by=(models.movies.c.release_year, "asc"),
)
assert m2 == await db.get(
models.Movie, media_type=m1.media_type, order_by=("release_year", "desc")
models.Movie,
media_type=m1.media_type,
order_by=(models.movies.c.release_year, "desc"),
)
@ -136,7 +140,7 @@ async def test_remove(shared_conn: db.Database):
assert m1 == await db.get(models.Movie, id=str(m1.id))
await db.remove(m1)
assert None == await db.get(models.Movie, id=str(m1.id))
assert None is await db.get(models.Movie, id=str(m1.id))
@pytest.mark.asyncio

View file

@ -22,6 +22,7 @@ from .models import (
metadata,
movies,
optional_fields,
progress,
ratings,
utcnow,
)
@ -110,19 +111,21 @@ async def apply_db_patches(db: Database) -> None:
async with db.transaction():
for query in queries:
await db.execute(query)
await db.execute(sa.text(query))
await set_current_patch_level(db, patch_lvl)
did_patch = True
if did_patch:
await db.execute("vacuum")
await db.execute(sa.text("vacuum"))
async def get_import_progress() -> Progress | None:
"""Return the latest import progress."""
return await get(Progress, type="import-imdb-movies", order_by=("started", "desc"))
return await get(
Progress, type="import-imdb-movies", order_by=(progress.c.started, "desc")
)
async def stop_import_progress(*, error: BaseException | None = None) -> None:
@ -225,7 +228,7 @@ async def add(item: Model) -> None:
# Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False):
assert hasattr(item, "_lazy_init")
item._lazy_init()
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
table: sa.Table = item.__table__
values = asplain(item, serialize=True)
@ -240,7 +243,7 @@ ModelType = TypeVar("ModelType", bound=Model)
async def get(
model: Type[ModelType],
*,
order_by: tuple[str, Literal["asc", "desc"]] | None = None,
order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None,
**field_values,
) -> ModelType | None:
"""Load a model instance from the database.
@ -259,9 +262,7 @@ async def get(
if order_by:
order_col, order_dir = order_by
query = query.order_by(
table.c[order_col].asc()
if order_dir == "asc"
else table.c[order_col].desc()
order_col.asc() if order_dir == "asc" else order_col.desc()
)
async with locked_connection() as conn:
row = await conn.fetch_one(query)
@ -306,7 +307,7 @@ async def update(item: Model) -> None:
# Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False):
assert hasattr(item, "_lazy_init")
item._lazy_init()
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
table: sa.Table = item.__table__
values = asplain(item, serialize=True)
@ -587,7 +588,7 @@ async def find_movies(
conditions.append(movies.c.media_type != "TV Episode")
if not include_unrated:
conditions.append(movies.c.imdb_score != None)
conditions.append(movies.c.imdb_score.is_not(None))
query = (
sa.select(movies)

View file

@ -275,6 +275,9 @@ class Progress:
self._state = state
progress = Progress.__table__
@mapper_registry.mapped
@dataclass
class Movie:

View file

@ -11,7 +11,7 @@ from hashlib import md5
from pathlib import Path
from random import random
from time import sleep, time
from typing import Callable, ParamSpec, TypeVar, cast
from typing import Any, Callable, ParamSpec, TypeVar, cast
import bs4
import httpx
@ -190,9 +190,11 @@ async def asoup_from_url(url):
def _last_modified_from_response(resp: _Response_T) -> float | None:
if last_mod := resp.headers.get("last-modified"):
try:
return email.utils.parsedate_to_datetime(last_mod).timestamp()
except:
dt = email.utils.parsedate_to_datetime(last_mod)
except ValueError:
log.exception("🐛 Received invalid value for Last-Modified: %s", last_mod)
else:
return dt.timestamp()
def _last_modified_from_file(path: Path) -> float:
@ -206,8 +208,8 @@ async def adownload(
replace_existing: bool | None = None,
only_if_newer: bool = False,
timeout: float | None = None,
chunk_callback=None,
response_callback=None,
chunk_callback: Callable[[bytes], Any] | None = None,
response_callback: Callable[[_Response_T], Any] | None = None,
) -> bytes | None:
"""Download a file.
@ -246,7 +248,7 @@ async def adownload(
if response_callback is not None:
try:
response_callback(resp)
except:
except BaseException:
log.exception("🐛 Error in response callback.")
log.debug(
@ -275,7 +277,7 @@ async def adownload(
# Check Last-Modified in case the server ignored If-Modified-Since.
# XXX also check Content-Length?
if file_exists and only_if_newer and resp_lastmod is not None:
assert file_lastmod
assert file_lastmod # pyright: ignore [reportUnboundVariable]
if resp_lastmod <= file_lastmod:
log.debug("✋ Local file is newer, skipping download: %a", req.url)
@ -299,7 +301,7 @@ async def adownload(
if chunk_callback:
try:
chunk_callback(chunk)
except:
except BaseException:
log.exception("🐛 Error in chunk callback.")
finally:
os.close(tempfd)