improve strict typing
This commit is contained in:
parent
86c3030e31
commit
25f31db756
4 changed files with 34 additions and 24 deletions
|
|
@ -38,21 +38,25 @@ async def test_get(shared_conn: db.Database):
|
||||||
m2 = a_movie(release_year=m1.release_year + 1)
|
m2 = a_movie(release_year=m1.release_year + 1)
|
||||||
await db.add(m2)
|
await db.add(m2)
|
||||||
|
|
||||||
assert None == await db.get(models.Movie)
|
assert None is await db.get(models.Movie)
|
||||||
assert None == await db.get(models.Movie, id="blerp")
|
assert None is await db.get(models.Movie, id="blerp")
|
||||||
assert m1 == await db.get(models.Movie, id=str(m1.id))
|
assert m1 == await db.get(models.Movie, id=str(m1.id))
|
||||||
assert m2 == await db.get(models.Movie, release_year=m2.release_year)
|
assert m2 == await db.get(models.Movie, release_year=m2.release_year)
|
||||||
assert None == await db.get(
|
assert None is await db.get(
|
||||||
models.Movie, id=str(m1.id), release_year=m2.release_year
|
models.Movie, id=str(m1.id), release_year=m2.release_year
|
||||||
)
|
)
|
||||||
assert m2 == await db.get(
|
assert m2 == await db.get(
|
||||||
models.Movie, id=str(m2.id), release_year=m2.release_year
|
models.Movie, id=str(m2.id), release_year=m2.release_year
|
||||||
)
|
)
|
||||||
assert m1 == await db.get(
|
assert m1 == await db.get(
|
||||||
models.Movie, media_type=m1.media_type, order_by=("release_year", "asc")
|
models.Movie,
|
||||||
|
media_type=m1.media_type,
|
||||||
|
order_by=(models.movies.c.release_year, "asc"),
|
||||||
)
|
)
|
||||||
assert m2 == await db.get(
|
assert m2 == await db.get(
|
||||||
models.Movie, media_type=m1.media_type, order_by=("release_year", "desc")
|
models.Movie,
|
||||||
|
media_type=m1.media_type,
|
||||||
|
order_by=(models.movies.c.release_year, "desc"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -136,7 +140,7 @@ async def test_remove(shared_conn: db.Database):
|
||||||
assert m1 == await db.get(models.Movie, id=str(m1.id))
|
assert m1 == await db.get(models.Movie, id=str(m1.id))
|
||||||
|
|
||||||
await db.remove(m1)
|
await db.remove(m1)
|
||||||
assert None == await db.get(models.Movie, id=str(m1.id))
|
assert None is await db.get(models.Movie, id=str(m1.id))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
||||||
21
unwind/db.py
21
unwind/db.py
|
|
@ -22,6 +22,7 @@ from .models import (
|
||||||
metadata,
|
metadata,
|
||||||
movies,
|
movies,
|
||||||
optional_fields,
|
optional_fields,
|
||||||
|
progress,
|
||||||
ratings,
|
ratings,
|
||||||
utcnow,
|
utcnow,
|
||||||
)
|
)
|
||||||
|
|
@ -110,19 +111,21 @@ async def apply_db_patches(db: Database) -> None:
|
||||||
|
|
||||||
async with db.transaction():
|
async with db.transaction():
|
||||||
for query in queries:
|
for query in queries:
|
||||||
await db.execute(query)
|
await db.execute(sa.text(query))
|
||||||
|
|
||||||
await set_current_patch_level(db, patch_lvl)
|
await set_current_patch_level(db, patch_lvl)
|
||||||
|
|
||||||
did_patch = True
|
did_patch = True
|
||||||
|
|
||||||
if did_patch:
|
if did_patch:
|
||||||
await db.execute("vacuum")
|
await db.execute(sa.text("vacuum"))
|
||||||
|
|
||||||
|
|
||||||
async def get_import_progress() -> Progress | None:
|
async def get_import_progress() -> Progress | None:
|
||||||
"""Return the latest import progress."""
|
"""Return the latest import progress."""
|
||||||
return await get(Progress, type="import-imdb-movies", order_by=("started", "desc"))
|
return await get(
|
||||||
|
Progress, type="import-imdb-movies", order_by=(progress.c.started, "desc")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def stop_import_progress(*, error: BaseException | None = None) -> None:
|
async def stop_import_progress(*, error: BaseException | None = None) -> None:
|
||||||
|
|
@ -225,7 +228,7 @@ async def add(item: Model) -> None:
|
||||||
# Support late initializing - used for optimization.
|
# Support late initializing - used for optimization.
|
||||||
if getattr(item, "_is_lazy", False):
|
if getattr(item, "_is_lazy", False):
|
||||||
assert hasattr(item, "_lazy_init")
|
assert hasattr(item, "_lazy_init")
|
||||||
item._lazy_init()
|
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
|
||||||
|
|
||||||
table: sa.Table = item.__table__
|
table: sa.Table = item.__table__
|
||||||
values = asplain(item, serialize=True)
|
values = asplain(item, serialize=True)
|
||||||
|
|
@ -240,7 +243,7 @@ ModelType = TypeVar("ModelType", bound=Model)
|
||||||
async def get(
|
async def get(
|
||||||
model: Type[ModelType],
|
model: Type[ModelType],
|
||||||
*,
|
*,
|
||||||
order_by: tuple[str, Literal["asc", "desc"]] | None = None,
|
order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None,
|
||||||
**field_values,
|
**field_values,
|
||||||
) -> ModelType | None:
|
) -> ModelType | None:
|
||||||
"""Load a model instance from the database.
|
"""Load a model instance from the database.
|
||||||
|
|
@ -259,9 +262,7 @@ async def get(
|
||||||
if order_by:
|
if order_by:
|
||||||
order_col, order_dir = order_by
|
order_col, order_dir = order_by
|
||||||
query = query.order_by(
|
query = query.order_by(
|
||||||
table.c[order_col].asc()
|
order_col.asc() if order_dir == "asc" else order_col.desc()
|
||||||
if order_dir == "asc"
|
|
||||||
else table.c[order_col].desc()
|
|
||||||
)
|
)
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
row = await conn.fetch_one(query)
|
row = await conn.fetch_one(query)
|
||||||
|
|
@ -306,7 +307,7 @@ async def update(item: Model) -> None:
|
||||||
# Support late initializing - used for optimization.
|
# Support late initializing - used for optimization.
|
||||||
if getattr(item, "_is_lazy", False):
|
if getattr(item, "_is_lazy", False):
|
||||||
assert hasattr(item, "_lazy_init")
|
assert hasattr(item, "_lazy_init")
|
||||||
item._lazy_init()
|
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
|
||||||
|
|
||||||
table: sa.Table = item.__table__
|
table: sa.Table = item.__table__
|
||||||
values = asplain(item, serialize=True)
|
values = asplain(item, serialize=True)
|
||||||
|
|
@ -587,7 +588,7 @@ async def find_movies(
|
||||||
conditions.append(movies.c.media_type != "TV Episode")
|
conditions.append(movies.c.media_type != "TV Episode")
|
||||||
|
|
||||||
if not include_unrated:
|
if not include_unrated:
|
||||||
conditions.append(movies.c.imdb_score != None)
|
conditions.append(movies.c.imdb_score.is_not(None))
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
sa.select(movies)
|
sa.select(movies)
|
||||||
|
|
|
||||||
|
|
@ -275,6 +275,9 @@ class Progress:
|
||||||
self._state = state
|
self._state = state
|
||||||
|
|
||||||
|
|
||||||
|
progress = Progress.__table__
|
||||||
|
|
||||||
|
|
||||||
@mapper_registry.mapped
|
@mapper_registry.mapped
|
||||||
@dataclass
|
@dataclass
|
||||||
class Movie:
|
class Movie:
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from hashlib import md5
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from random import random
|
from random import random
|
||||||
from time import sleep, time
|
from time import sleep, time
|
||||||
from typing import Callable, ParamSpec, TypeVar, cast
|
from typing import Any, Callable, ParamSpec, TypeVar, cast
|
||||||
|
|
||||||
import bs4
|
import bs4
|
||||||
import httpx
|
import httpx
|
||||||
|
|
@ -190,9 +190,11 @@ async def asoup_from_url(url):
|
||||||
def _last_modified_from_response(resp: _Response_T) -> float | None:
|
def _last_modified_from_response(resp: _Response_T) -> float | None:
|
||||||
if last_mod := resp.headers.get("last-modified"):
|
if last_mod := resp.headers.get("last-modified"):
|
||||||
try:
|
try:
|
||||||
return email.utils.parsedate_to_datetime(last_mod).timestamp()
|
dt = email.utils.parsedate_to_datetime(last_mod)
|
||||||
except:
|
except ValueError:
|
||||||
log.exception("🐛 Received invalid value for Last-Modified: %s", last_mod)
|
log.exception("🐛 Received invalid value for Last-Modified: %s", last_mod)
|
||||||
|
else:
|
||||||
|
return dt.timestamp()
|
||||||
|
|
||||||
|
|
||||||
def _last_modified_from_file(path: Path) -> float:
|
def _last_modified_from_file(path: Path) -> float:
|
||||||
|
|
@ -206,8 +208,8 @@ async def adownload(
|
||||||
replace_existing: bool | None = None,
|
replace_existing: bool | None = None,
|
||||||
only_if_newer: bool = False,
|
only_if_newer: bool = False,
|
||||||
timeout: float | None = None,
|
timeout: float | None = None,
|
||||||
chunk_callback=None,
|
chunk_callback: Callable[[bytes], Any] | None = None,
|
||||||
response_callback=None,
|
response_callback: Callable[[_Response_T], Any] | None = None,
|
||||||
) -> bytes | None:
|
) -> bytes | None:
|
||||||
"""Download a file.
|
"""Download a file.
|
||||||
|
|
||||||
|
|
@ -246,7 +248,7 @@ async def adownload(
|
||||||
if response_callback is not None:
|
if response_callback is not None:
|
||||||
try:
|
try:
|
||||||
response_callback(resp)
|
response_callback(resp)
|
||||||
except:
|
except BaseException:
|
||||||
log.exception("🐛 Error in response callback.")
|
log.exception("🐛 Error in response callback.")
|
||||||
|
|
||||||
log.debug(
|
log.debug(
|
||||||
|
|
@ -275,7 +277,7 @@ async def adownload(
|
||||||
# Check Last-Modified in case the server ignored If-Modified-Since.
|
# Check Last-Modified in case the server ignored If-Modified-Since.
|
||||||
# XXX also check Content-Length?
|
# XXX also check Content-Length?
|
||||||
if file_exists and only_if_newer and resp_lastmod is not None:
|
if file_exists and only_if_newer and resp_lastmod is not None:
|
||||||
assert file_lastmod
|
assert file_lastmod # pyright: ignore [reportUnboundVariable]
|
||||||
|
|
||||||
if resp_lastmod <= file_lastmod:
|
if resp_lastmod <= file_lastmod:
|
||||||
log.debug("✋ Local file is newer, skipping download: %a", req.url)
|
log.debug("✋ Local file is newer, skipping download: %a", req.url)
|
||||||
|
|
@ -299,7 +301,7 @@ async def adownload(
|
||||||
if chunk_callback:
|
if chunk_callback:
|
||||||
try:
|
try:
|
||||||
chunk_callback(chunk)
|
chunk_callback(chunk)
|
||||||
except:
|
except BaseException:
|
||||||
log.exception("🐛 Error in chunk callback.")
|
log.exception("🐛 Error in chunk callback.")
|
||||||
finally:
|
finally:
|
||||||
os.close(tempfd)
|
os.close(tempfd)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue