Merge branch 'feat/py311'

This commit is contained in:
ducklet 2023-03-18 01:12:35 +01:00
commit a020d972f8
23 changed files with 783 additions and 748 deletions

4
.git-blame-ignore-revs Normal file
View file

@ -0,0 +1,4 @@
# Apply Black v23.1.0 formatting changes.
8a8bfce89de23d987386a35b659532bbac373788
# Apply auto-formatting to tests.
9ffcc5357150cecde26f5e6f8fccceaf92411efb

View file

@ -1,4 +1,4 @@
FROM docker.io/library/python:3.10-alpine
FROM docker.io/library/python:3.11-alpine
RUN apk update --no-cache \
&& apk upgrade --no-cache \
@ -11,20 +11,18 @@ WORKDIR /var/app
COPY requirements.txt ./
# Required to build greenlet on Alpine, dependency of SQLAlchemy 1.4.
RUN apk add --no-cache \
--virtual .build-deps \
g++ gcc musl-dev \
&& pip install --no-cache-dir --upgrade \
--requirement requirements.txt \
&& apk del .build-deps
RUN pip install --no-cache-dir --upgrade \
--requirement requirements.txt
USER 10000:10001
COPY . ./
ENV UNWIND_DATA="/data"
VOLUME ["/data"]
VOLUME $UNWIND_DATA
ENV UNWIND_PORT=8097
EXPOSE $UNWIND_PORT
ENTRYPOINT ["/var/app/run"]
CMD ["server"]

817
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -6,26 +6,20 @@ authors = ["ducklet <ducklet@noreply.code.dumpr.org>"]
license = "LOL"
[tool.poetry.dependencies]
python = "^3.10"
requests = "^2.25.1"
python = "^3.11"
beautifulsoup4 = "^4.9.3"
html5lib = "^1.1"
starlette = "^0.17.0"
starlette = "^0.26"
ulid-py = "^1.1.0"
databases = {extras = ["sqlite"], version = "^0.6.1"}
toml = "^0.10.2"
uvicorn = "^0.19.0"
[tool.poetry.group.fixes.dependencies]
# `databases` is having issues with new versions of SQLAlchemy 1.4,
# and `greenlet` is also always a pain.
SQLAlchemy = "1.4.25"
greenlet = "1.1.2"
databases = {extras = ["sqlite"], version = "^0.7.0"}
uvicorn = "^0.21"
httpx = "^0.23.3"
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
autoflake = "*"
pytest = "*"
pyright = "*"
black = "*"
@ -37,4 +31,14 @@ requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pyright]
pythonVersion = "3.10"
pythonVersion = "3.11"
[tool.isort]
profile = "black"
[tool.autoflake]
remove-duplicate-keys = true
remove-unused-variables = true
remove-all-unused-imports = true
ignore-init-module-imports = true
ignore-pass-after-docstring = true

View file

@ -4,6 +4,7 @@ cd "$RUN_DIR"
[ -z "${DEBUG:-}" ] || set -x
isort --profile black unwind
black unwind
autoflake --quiet --check --recursive unwind tests
isort unwind tests
black unwind tests
pyright

13
scripts/profile Executable file
View file

@ -0,0 +1,13 @@
#!/bin/sh -eu
cd "$RUN_DIR"
outfile="profile-$(date '+%Y%m%d-%H%M%S').txt"
[ -z "${DEBUG:-}" ] || set -x
echo "# Writing profiler stats to: $outfile"
python -m cProfile -o "$outfile" -m unwind "$@"
echo "# Loading stats file: $outfile"
python -m pstats "$outfile"

View file

@ -1,7 +1,14 @@
#!/bin/sh -eu
: "${UNWIND_PORT:=8097}"
cd "$RUN_DIR"
[ -z "${DEBUG:-}" ] || set -x
exec uvicorn --host 0.0.0.0 --factory unwind:create_app
export UNWIND_PORT
exec uvicorn \
--host 0.0.0.0 \
--port "$UNWIND_PORT" \
--factory unwind:create_app

View file

@ -10,5 +10,6 @@ trap 'rm "$dbfile"' EXIT TERM INT QUIT
[ -z "${DEBUG:-}" ] || set -x
SQLALCHEMY_WARN_20=1 \
UNWIND_STORAGE="$dbfile" \
python -m pytest "$@"

View file

@ -1,6 +1,8 @@
import asyncio
import pytest
import pytest_asyncio
from unwind import db
@ -13,7 +15,7 @@ def event_loop():
loop.close()
@pytest.fixture(scope="session")
@pytest_asyncio.fixture(scope="session")
async def shared_conn():
c = db.shared_connection()
await c.connect()
@ -24,7 +26,7 @@ async def shared_conn():
await c.disconnect()
@pytest.fixture
@pytest_asyncio.fixture
async def conn(shared_conn):
async with shared_conn.transaction(force_rollback=True):
yield shared_conn

View file

@ -1,14 +1,13 @@
from datetime import datetime
import pytest
from unwind import db, models, web_models
pytestmark = pytest.mark.asyncio
async def test_add_and_get(shared_conn):
@pytest.mark.asyncio
async def test_add_and_get(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True):
m1 = models.Movie(
title="test movie",
release_year=2013,
@ -31,9 +30,9 @@ async def test_add_and_get(shared_conn):
assert m2 == await db.get(models.Movie, id=str(m2.id))
async def test_find_ratings(shared_conn):
@pytest.mark.asyncio
async def test_find_ratings(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True):
m1 = models.Movie(
title="test movie",
release_year=2013,
@ -157,4 +156,3 @@ async def test_find_ratings(shared_conn):
rows = await db.find_ratings(title="test", include_unrated=True)
ratings = tuple(web_models.Rating(**r) for r in rows)
assert (web_models.Rating.from_movie(m1),) == ratings

View file

@ -1,14 +1,15 @@
import pytest
from unwind.imdb import imdb_rating_from_score, score_from_imdb_rating
@pytest.mark.parametrize("rating", (x / 10 for x in range(10, 101)))
def test_rating_conversion(rating):
def test_rating_conversion(rating: float):
assert rating == imdb_rating_from_score(score_from_imdb_rating(rating))
@pytest.mark.parametrize("score", range(0, 101))
def test_score_conversion(score):
def test_score_conversion(score: int):
# Because our score covers 101 discrete values and IMDb's rating only 91
# discrete values, the mapping is non-injective, i.e. 10 values can't be
# mapped uniquely.

View file

@ -1,18 +1,14 @@
from starlette.testclient import TestClient
import pytest
from starlette.testclient import TestClient
from unwind import create_app
from unwind import db, models, imdb
# https://pypi.org/project/pytest-asyncio/
pytestmark = pytest.mark.asyncio
from unwind import create_app, db, imdb, models
app = create_app()
async def test_app(shared_conn):
@pytest.mark.asyncio
async def test_app(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True):
# https://www.starlette.io/testclient/
client = TestClient(app)
response = client.get("/api/v1/movies")

View file

@ -6,7 +6,7 @@ from pathlib import Path
from . import config
from .db import close_connection_pool, open_connection_pool
from .imdb import refresh_user_ratings_from_imdb
from .imdb_import import import_from_file
from .imdb_import import download_datasets, import_from_file
log = logging.getLogger(__name__)
@ -15,7 +15,7 @@ async def run_load_user_ratings_from_imdb():
await open_connection_pool()
i = 0
async for rating in refresh_user_ratings_from_imdb():
async for _ in refresh_user_ratings_from_imdb():
i += 1
log.info("✨ Imported %s new ratings.", i)
@ -31,6 +31,10 @@ async def run_import_imdb_dataset(basics_path: Path, ratings_path: Path):
await close_connection_pool()
async def run_download_imdb_dataset(basics_path: Path, ratings_path: Path):
await download_datasets(basics_path=basics_path, ratings_path=ratings_path)
def getargs():
parser = argparse.ArgumentParser()
commands = parser.add_subparsers(required=True)
@ -55,6 +59,25 @@ def getargs():
"--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True
)
parser_download_imdb_dataset = commands.add_parser(
"download-imdb-dataset",
help="Download IMDb datasets.",
description="""
Download IMDb datasets.
""",
)
parser_download_imdb_dataset.add_argument(
dest="mode",
action="store_const",
const="download-imdb-dataset",
)
parser_download_imdb_dataset.add_argument(
"--basics", metavar="basics_file.tsv.gz", type=Path, required=True
)
parser_download_imdb_dataset.add_argument(
"--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True
)
parser_load_user_ratings_from_imdb = commands.add_parser(
"load-user-ratings-from-imdb",
help="Load user ratings from imdb.com.",
@ -94,6 +117,8 @@ def main():
asyncio.run(run_load_user_ratings_from_imdb())
elif args.mode == "import-imdb-dataset":
asyncio.run(run_import_imdb_dataset(args.basics, args.ratings))
elif args.mode == "download-imdb-dataset":
asyncio.run(run_download_imdb_dataset(args.basics, args.ratings))
main()

View file

@ -1,8 +1,7 @@
import os
import tomllib
from pathlib import Path
import toml
datadir = Path(os.getenv("UNWIND_DATA") or "./data")
cachedir = (
Path(cachedir)
@ -14,7 +13,8 @@ loglevel = os.getenv("UNWIND_LOGLEVEL") or ("DEBUG" if debug else "INFO")
storage_path = os.getenv("UNWIND_STORAGE", datadir / "db.sqlite")
config_path = os.getenv("UNWIND_CONFIG", datadir / "config.toml")
_config = toml.load(config_path)
with open(config_path, "rb") as fd:
_config = tomllib.load(fd)
api_base = _config["api"].get("base", "/api/")
api_cors = _config["api"].get("cors", "*")

View file

@ -4,7 +4,7 @@ import logging
import re
import threading
from pathlib import Path
from typing import Any, Iterable, Literal, Optional, Type, TypeVar, Union
from typing import Any, Iterable, Literal, Type, TypeVar
import sqlalchemy
from databases import Database
@ -26,7 +26,7 @@ from .types import ULID
log = logging.getLogger(__name__)
T = TypeVar("T")
_shared_connection: Optional[Database] = None
_shared_connection: Database | None = None
async def open_connection_pool() -> None:
@ -119,7 +119,6 @@ async def apply_db_patches(db: Database):
raise RuntimeError("No statement found.")
async with db.transaction():
for query in queries:
await db.execute(query)
@ -131,12 +130,12 @@ async def apply_db_patches(db: Database):
await db.execute("vacuum")
async def get_import_progress() -> Optional[Progress]:
async def get_import_progress() -> Progress | None:
"""Return the latest import progress."""
return await get(Progress, type="import-imdb-movies", order_by="started DESC")
async def stop_import_progress(*, error: BaseException = None):
async def stop_import_progress(*, error: BaseException | None = None):
"""Stop the current import.
If an error is given, it will be logged to the progress state.
@ -176,6 +175,8 @@ async def set_import_progress(progress: float) -> Progress:
else:
await add(current)
return current
_lock = threading.Lock()
_prelock = threading.Lock()
@ -243,8 +244,8 @@ ModelType = TypeVar("ModelType")
async def get(
model: Type[ModelType], *, order_by: str = None, **kwds
) -> Optional[ModelType]:
model: Type[ModelType], *, order_by: str | None = None, **kwds
) -> ModelType | None:
"""Load a model instance from the database.
Passing `kwds` allows to filter the instance to load. You have to encode the
@ -262,7 +263,7 @@ async def get(
query += f" ORDER BY {order_by}"
async with locked_connection() as conn:
row = await conn.fetch_one(query=query, values=values)
return fromplain(model, row, serialized=True) if row else None
return fromplain(model, row._mapping, serialized=True) if row else None
async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
@ -282,7 +283,7 @@ async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
async with locked_connection() as conn:
rows = await conn.fetch_all(query=query, values=values)
return (fromplain(model, row, serialized=True) for row in rows)
return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
@ -293,7 +294,7 @@ async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
async with locked_connection() as conn:
rows = await conn.fetch_all(query=query, values=values)
return (fromplain(model, row, serialized=True) for row in rows)
return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def update(item):
@ -406,16 +407,16 @@ def sql_escape(s: str, char="#"):
async def find_ratings(
*,
title: str = None,
media_type: str = None,
title: str | None = None,
media_type: str | None = None,
exact: bool = False,
ignore_tv_episodes: bool = False,
include_unrated: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] = None,
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10,
user_ids: Iterable[str] = [],
):
values: dict[str, Union[int, str]] = {
values: dict[str, int | str] = {
"limit_rows": limit_rows,
}
@ -466,7 +467,7 @@ async def find_ratings(
"""
async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values))
movie_ids = tuple(r["movie_id"] for r in rows)
movie_ids = tuple(r._mapping["movie_id"] for r in rows)
if include_unrated and len(movie_ids) < limit_rows:
sqlin, sqlin_vals = sql_in("id", movie_ids, not_=True)
@ -485,7 +486,7 @@ async def find_ratings(
{**values, **sqlin_vals, "limit_rows": limit_rows - len(movie_ids)},
)
)
movie_ids += tuple(r["movie_id"] for r in rows)
movie_ids += tuple(r._mapping["movie_id"] for r in rows)
return await ratings_for_movie_ids(ids=movie_ids)
@ -527,29 +528,13 @@ async def ratings_for_movie_ids(
async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, vals))
return tuple(dict(r) for r in rows)
return tuple(dict(r._mapping) for r in rows)
def sql_fields(tp: Type):
return (f"{tp._table}.{f.name}" for f in fields(tp))
def sql_fieldmap(tp: Type):
"""-> {alias: (table, field_name)}"""
return {f"{tp._table}_{f.name}": (tp._table, f.name) for f in fields(tp)}
def mux(*tps: Type):
return ", ".join(
f"{t}.{n} AS {k}" for tp in tps for k, (t, n) in sql_fieldmap(tp).items()
)
def demux(tp: Type[ModelType], row) -> ModelType:
d = {n: row[k] for k, (_, n) in sql_fieldmap(tp).items()}
return fromplain(tp, d, serialized=True)
def sql_in(column: str, values: Iterable[T], not_=False) -> tuple[str, dict[str, T]]:
c = column.replace(".", "___")
value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)}
@ -583,22 +568,22 @@ async def ratings_for_movies(
async with locked_connection() as conn:
rows = await conn.fetch_all(query, values)
return (fromplain(Rating, row, serialized=True) for row in rows)
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
async def find_movies(
*,
title: str = None,
media_type: str = None,
title: str | None = None,
media_type: str | None = None,
exact: bool = False,
ignore_tv_episodes: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] = None,
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10,
skip_rows: int = 0,
include_unrated: bool = False,
user_ids: list[ULID] = [],
) -> Iterable[tuple[Movie, list[Rating]]]:
values: dict[str, Union[int, str]] = {
values: dict[str, int | str] = {
"limit_rows": limit_rows,
"skip_rows": skip_rows,
}
@ -650,7 +635,7 @@ async def find_movies(
async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values))
movies = [fromplain(Movie, row, serialized=True) for row in rows]
movies = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
if not user_ids:
return ((m, []) for m in movies)

View file

@ -2,12 +2,11 @@ import logging
import re
from collections import namedtuple
from datetime import datetime
from typing import Optional, Tuple
from urllib.parse import urljoin
from . import db
from .models import Movie, Rating, User
from .request import cache_path, session, soup_from_url
from .request import asession, asoup_from_url, cache_path
log = logging.getLogger(__name__)
@ -35,13 +34,11 @@ log = logging.getLogger(__name__)
# p.text-muted.text-small span[name=nv] [data-value]
async def refresh_user_ratings_from_imdb(stop_on_dupe=True):
with session() as s:
async def refresh_user_ratings_from_imdb(stop_on_dupe: bool = True):
async with asession() as s:
s.headers["Accept-Language"] = "en-US, en;q=0.5"
for user in await db.get_all(User):
log.info("⚡️ Loading data for %s ...", user.name)
try:
@ -98,7 +95,6 @@ find_movie_id = re.compile(r"/title/(?P<id>tt\d+)/").search
def movie_and_rating_from_item(item) -> tuple[Movie, Rating]:
genres = (genre := item.find("span", "genre")) and genre.string or ""
movie = Movie(
title=item.h3.a.string.strip(),
@ -153,10 +149,10 @@ def movie_and_rating_from_item(item) -> tuple[Movie, Rating]:
ForgedRequest = namedtuple("ForgedRequest", "url headers")
async def parse_page(url) -> Tuple[list[Rating], Optional[str]]:
async def parse_page(url: str) -> tuple[list[Rating], str | None]:
ratings = []
soup = soup_from_url(url)
soup = await asoup_from_url(url)
meta = soup.find("meta", property="pageId")
headline = soup.h1
@ -170,7 +166,6 @@ async def parse_page(url) -> Tuple[list[Rating], Optional[str]]:
items = soup.find_all("div", "lister-item-content")
for i, item in enumerate(items):
try:
movie, rating = movie_and_rating_from_item(item)
except Exception as err:
@ -196,11 +191,10 @@ async def parse_page(url) -> Tuple[list[Rating], Optional[str]]:
return (ratings, next_url if url != next_url else None)
async def load_ratings(user_id):
async def load_ratings(user_id: str):
next_url = user_ratings_url(user_id)
while next_url:
ratings, next_url = await parse_page(next_url)
for i, rating in enumerate(ratings):

View file

@ -1,10 +1,11 @@
import asyncio
import csv
import gzip
import logging
from dataclasses import dataclass, fields
from datetime import datetime, timezone
from pathlib import Path
from typing import Generator, Literal, Optional, Type, TypeVar, overload
from typing import Generator, Literal, Type, TypeVar, overload
from . import config, db, request
from .db import add_or_update_many_movies
@ -27,10 +28,10 @@ class BasicRow:
primaryTitle: str
originalTitle: str
isAdult: bool
startYear: Optional[int]
endYear: Optional[int]
runtimeMinutes: Optional[int]
genres: Optional[set[str]]
startYear: int | None
endYear: int | None
runtimeMinutes: int | None
genres: set[str] | None
@classmethod
def from_row(cls, row):
@ -100,7 +101,7 @@ title_types = {
}
def gz_mtime(path) -> datetime:
def gz_mtime(path: Path) -> datetime:
"""Return the timestamp of the compressed file."""
g = gzip.GzipFile(path, "rb")
g.peek(1) # start reading the file to fill the timestamp field
@ -108,14 +109,13 @@ def gz_mtime(path) -> datetime:
return datetime.fromtimestamp(g.mtime).replace(tzinfo=timezone.utc)
def count_lines(path) -> int:
def count_lines(path: Path) -> int:
i = 0
one_mb = 2 ** 20
one_mb = 2**20
buf_size = 8 * one_mb # 8 MiB seems to give a good read/process performance.
with gzip.open(path, "rt") as f:
while buf := f.read(buf_size):
i += buf.count("\n")
@ -124,19 +124,19 @@ def count_lines(path) -> int:
@overload
def read_imdb_tsv(
path, row_type, *, unpack: Literal[False]
path: Path, row_type, *, unpack: Literal[False]
) -> Generator[list[str], None, None]:
...
@overload
def read_imdb_tsv(
path, row_type: Type[T], *, unpack: Literal[True] = True
path: Path, row_type: Type[T], *, unpack: Literal[True] = True
) -> Generator[T, None, None]:
...
def read_imdb_tsv(path, row_type, *, unpack=True):
def read_imdb_tsv(path: Path, row_type, *, unpack=True):
with gzip.open(path, "rt", newline="") as f:
rows = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
@ -161,7 +161,7 @@ def read_imdb_tsv(path, row_type, *, unpack=True):
raise
def read_ratings(path):
def read_ratings(path: Path):
mtime = gz_mtime(path)
rows = read_imdb_tsv(path, RatingRow)
@ -171,19 +171,20 @@ def read_ratings(path):
yield m
def read_ratings_as_mapping(path):
def read_ratings_as_mapping(path: Path):
"""Optimized function to quickly load all ratings."""
rows = read_imdb_tsv(path, RatingRow, unpack=False)
return {r[0]: (round(100 * (float(r[1]) - 1) / 9), int(r[2])) for r in rows}
def read_basics(path):
def read_basics(path: Path) -> Generator[Movie | None, None, None]:
mtime = gz_mtime(path)
rows = read_imdb_tsv(path, BasicRow)
for row in rows:
if row.startYear is None:
log.debug("Skipping movie, missing year: %s", row)
yield None
continue
m = row.as_movie()
@ -197,20 +198,24 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path):
log.info("💾 Importing movies ...")
total = count_lines(basics_path)
assert total != 0
log.debug("Found %i movies.", total)
if total == 0:
raise RuntimeError(f"No movies found.")
perc_next_report = 0.0
perc_step = 0.1
chunk = []
for i, m in enumerate(read_basics(basics_path)):
perc = 100 * i / total
if perc >= perc_next_report:
await db.set_import_progress(perc)
log.info("⏳ Imported %s%%", round(perc, 1))
perc_next_report += perc_step
if m is None:
continue
if m.media_type not in {
"Movie",
"Short",
@ -235,10 +240,27 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path):
await add_or_update_many_movies(chunk)
chunk = []
log.info("👍 Imported 100%")
await db.set_import_progress(100)
async def load_from_web(*, force: bool = False):
async def download_datasets(*, basics_path: Path, ratings_path: Path) -> None:
"""Download IMDb movie database dumps.
See https://www.imdb.com/interfaces/ and https://datasets.imdbws.com/ for
more information on the IMDb database dumps.
"""
basics_url = "https://datasets.imdbws.com/title.basics.tsv.gz"
ratings_url = "https://datasets.imdbws.com/title.ratings.tsv.gz"
async with request.asession():
await asyncio.gather(
request.adownload(ratings_url, to_path=ratings_path, only_if_newer=True),
request.adownload(basics_url, to_path=basics_path, only_if_newer=True),
)
async def load_from_web(*, force: bool = False) -> None:
"""Refresh the full IMDb movie database.
The latest dumps are first downloaded and then imported into the database.
@ -251,17 +273,13 @@ async def load_from_web(*, force: bool = False):
await db.set_import_progress(0)
try:
basics_url = "https://datasets.imdbws.com/title.basics.tsv.gz"
ratings_url = "https://datasets.imdbws.com/title.ratings.tsv.gz"
ratings_file = config.datadir / "imdb/title.ratings.tsv.gz"
basics_file = config.datadir / "imdb/title.basics.tsv.gz"
ratings_mtime = ratings_file.stat().st_mtime if ratings_file.exists() else None
bastics_mtime = basics_file.stat().st_mtime if basics_file.exists() else None
with request.session():
request.download(ratings_url, ratings_file, only_if_newer=True)
request.download(basics_url, basics_file, only_if_newer=True)
await download_datasets(basics_path=basics_file, ratings_path=ratings_file)
is_changed = (
ratings_mtime != ratings_file.stat().st_mtime

View file

@ -3,13 +3,14 @@ from dataclasses import dataclass, field
from dataclasses import fields as _fields
from datetime import datetime, timezone
from functools import partial
from types import UnionType
from typing import (
Annotated,
Any,
ClassVar,
Container,
Literal,
Optional,
Mapping,
Type,
TypeVar,
Union,
@ -19,13 +20,13 @@ from typing import (
from .types import ULID
JSON = Union[int, float, str, None, list["JSON"], dict[str, "JSON"]]
JSON = int | float | str | None | list["JSON"] | dict[str, "JSON"]
JSONObject = dict[str, JSON]
T = TypeVar("T")
def annotations(tp: Type) -> Optional[tuple]:
def annotations(tp: Type) -> tuple | None:
return tp.__metadata__ if hasattr(tp, "__metadata__") else None
@ -42,7 +43,6 @@ def fields(class_or_instance):
# XXX this might be a little slow (not sure), if so, memoize
for f in _fields(class_or_instance):
if f.name == "_is_lazy":
continue
@ -54,21 +54,21 @@ def fields(class_or_instance):
def is_optional(tp: Type) -> bool:
"""Return wether the given type is optional."""
if get_origin(tp) is not Union:
if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
return False
args = get_args(tp)
return len(args) == 2 and type(None) in args
def optional_type(tp: Type) -> Optional[Type]:
def optional_type(tp: Type) -> Type | None:
"""Return the wrapped type from an optional type.
For example this will return `int` for `Optional[int]`.
Since they're equivalent this also works for other optioning notations, like
`Union[int, None]` and `int | None`.
"""
if get_origin(tp) is not Union:
if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
return None
args = get_args(tp)
@ -92,7 +92,7 @@ def _id(x: T) -> T:
def asplain(
o: object, *, filter_fields: Container[str] = None, serialize: bool = False
o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False
) -> dict[str, Any]:
"""Return the given model instance as `dict` with JSON compatible plain datatypes.
@ -109,7 +109,6 @@ def asplain(
d: JSONObject = {}
for f in fields(o):
if filter_fields is not None and f.name not in filter_fields:
continue
@ -146,7 +145,7 @@ def asplain(
return d
def fromplain(cls: Type[T], d: dict[str, Any], *, serialized: bool = False) -> T:
def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
"""Return an instance of the given model using the given data.
If `serialized` is `True`, collection types (lists, dicts, etc.) will be
@ -157,7 +156,6 @@ def fromplain(cls: Type[T], d: dict[str, Any], *, serialized: bool = False) -> T
dd: JSONObject = {}
for f in fields(cls):
target = f.type
otype = optional_type(f.type)
is_opt = otype is not None
@ -188,7 +186,8 @@ def validate(o: object) -> None:
vtype = type(getattr(o, f.name))
if vtype is not f.type:
if get_origin(f.type) is vtype or (
get_origin(f.type) is Union and vtype in get_args(f.type)
(isinstance(f.type, UnionType) or get_origin(f.type) is Union)
and vtype in get_args(f.type)
):
continue
raise ValueError(f"Invalid value type: {f.name}: {vtype}")
@ -206,7 +205,7 @@ class Progress:
type: str = None
state: str = None
started: datetime = field(default_factory=utcnow)
stopped: Optional[str] = None
stopped: str | None = None
@property
def _state(self) -> dict:
@ -243,15 +242,15 @@ class Movie:
id: ULID = field(default_factory=ULID)
title: str = None # canonical title (usually English)
original_title: Optional[
str
] = None # original title (usually transscribed to latin script)
original_title: str | None = (
None # original title (usually transscribed to latin script)
)
release_year: int = None # canonical release date
media_type: str = None
imdb_id: str = None
imdb_score: Optional[int] = None # range: [0,100]
imdb_votes: Optional[int] = None
runtime: Optional[int] = None # minutes
imdb_score: int | None = None # range: [0,100]
imdb_votes: int | None = None
runtime: int | None = None # minutes
genres: set[str] = None
created: datetime = field(default_factory=utcnow)
updated: datetime = field(default_factory=utcnow)
@ -292,7 +291,7 @@ dataclass containing the ID of the linked data.
The contents of the Relation are ignored or discarded when using
`asplain`, `fromplain`, and `validate`.
"""
Relation = Annotated[Optional[T], _RelationSentinel]
Relation = Annotated[T | None, _RelationSentinel]
@dataclass
@ -309,8 +308,8 @@ class Rating:
score: int = None # range: [0,100]
rating_date: datetime = None
favorite: Optional[bool] = None
finished: Optional[bool] = None
favorite: bool | None = None
finished: bool | None = None
def __eq__(self, other):
"""Return wether two Ratings are equal.
@ -342,11 +341,11 @@ class User:
secret: str = None
groups: list[dict[str, str]] = field(default_factory=list)
def has_access(self, group_id: Union[ULID, str], access: Access = "r"):
def has_access(self, group_id: ULID | str, access: Access = "r"):
group_id = group_id if isinstance(group_id, str) else str(group_id)
return any(g["id"] == group_id and access == g["access"] for g in self.groups)
def set_access(self, group_id: Union[ULID, str], access: Access):
def set_access(self, group_id: ULID | str, access: Access):
group_id = group_id if isinstance(group_id, str) else str(group_id)
for g in self.groups:
if g["id"] == group_id:

View file

@ -4,19 +4,17 @@ import logging
import os
import tempfile
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from functools import wraps
from hashlib import md5
from pathlib import Path
from random import random
from time import sleep, time
from typing import Callable, Optional, Union
from typing import Callable, ParamSpec, TypeVar, cast
import bs4
import requests
from requests.status_codes import codes
from urllib3.util.retry import Retry
import httpx
from . import config
@ -26,28 +24,17 @@ if config.debug and config.cachedir:
config.cachedir.mkdir(exist_ok=True)
def set_retries(s: requests.Session, n: int, backoff_factor: float = 0.2):
retry = (
Retry(
total=n,
connect=n,
read=n,
status=n,
status_forcelist=Retry.RETRY_AFTER_STATUS_CODES,
backoff_factor=backoff_factor,
)
if n
else Retry(0, read=False)
)
for a in s.adapters.values():
a.max_retries = retry
_shared_asession = None
_ASession_T = httpx.AsyncClient
_Response_T = httpx.Response
_T = TypeVar("_T")
_P = ParamSpec("_P")
_shared_session = None
@contextmanager
def session():
@asynccontextmanager
async def asession():
"""Return the shared request session.
The session is shared by all request functions and provides cookie
@ -55,38 +42,34 @@ def session():
Opening the session before making a request allows you to set headers
or change the retry behavior.
"""
global _shared_session
global _shared_asession
if _shared_session:
yield _shared_session
if _shared_asession:
yield _shared_asession
return
_shared_session = Session()
_shared_asession = _ASession_T()
_shared_asession.headers[
"user-agent"
] = "Mozilla/5.0 Gecko/20100101 unwind/20230203"
try:
yield _shared_session
async with _shared_asession:
yield _shared_asession
finally:
_shared_session = None
_shared_asession = None
def Session() -> requests.Session:
s = requests.Session()
s.headers["User-Agent"] = "Mozilla/5.0 Gecko/20100101 unwind/20210506"
return s
def throttle(
times: int, per_seconds: float, jitter: Callable[[], float] = None
) -> Callable[[Callable], Callable]:
calls: Deque[float] = deque(maxlen=times)
def _throttle(
times: int, per_seconds: float, jitter: Callable[[], float] | None = None
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
calls: deque[float] = deque(maxlen=times)
if jitter is None:
jitter = lambda: 0.0
def decorator(func: Callable) -> Callable:
def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
@wraps(func)
def inner(*args, **kwds):
def inner(*args: _P.args, **kwds: _P.kwargs):
# clean up
while calls:
if calls[0] + per_seconds > time():
@ -118,23 +101,19 @@ def throttle(
return decorator
class CachedStr(str):
is_cached = True
@dataclass
class CachedResponse:
class _CachedResponse:
is_cached = True
status_code: int
text: str
url: str
headers: dict[str, str] = None
headers: dict[str, str] = field(default_factory=dict)
def json(self):
return json.loads(self.text)
class RedirectError(RuntimeError):
class _RedirectError(RuntimeError):
def __init__(self, from_url: str, to_url: str, is_cached=False):
self.from_url = from_url
self.to_url = to_url
@ -142,44 +121,51 @@ class RedirectError(RuntimeError):
super().__init__(f"Redirected: {from_url} -> {to_url}")
def cache_path(req) -> Optional[Path]:
def cache_path(req) -> Path | None:
if not config.cachedir:
return
sig = repr(req.url) # + repr(sorted(req.headers.items()))
return config.cachedir / md5(sig.encode()).hexdigest()
@throttle(1, 1, random)
def http_get(s: requests.Session, url: str, *args, **kwds) -> requests.Response:
req = s.prepare_request(requests.Request("GET", url, *args, **kwds))
@_throttle(1, 1, random)
async def _ahttp_get(s: _ASession_T, url: str, *args, **kwds) -> _Response_T:
req = s.build_request(method="GET", url=url, *args, **kwds)
cachefile = cache_path(req) if config.debug else None
if cachefile:
if cachefile.exists():
log.debug(
f"💾 loading {req.url} ({req.headers!a}) from cache {cachefile} ..."
"💾 loading %s (%a) from cache %s ...", req.url, req.headers, cachefile
)
with cachefile.open() as fp:
resp = CachedResponse(**json.load(fp))
resp = _CachedResponse(**json.load(fp))
if 300 <= resp.status_code <= 399:
raise RedirectError(
raise _RedirectError(
from_url=resp.url, to_url=resp.headers["location"], is_cached=True
)
return resp
return cast(_Response_T, resp)
log.debug(f"⚡️ loading {req.url} ({req.headers!a}) ...")
resp = s.send(req, allow_redirects=False, stream=True)
log.debug("⚡️ loading %s (%a) ...", req.url, req.headers)
resp = await s.send(req, follow_redirects=False, stream=True)
resp.raise_for_status()
await resp.aread() # Download the response stream to allow `resp.text` access.
if cachefile:
log.debug(
"💾 writing response to cache: %s (%a) -> %s",
req.url,
req.headers,
cachefile,
)
with cachefile.open("w") as fp:
json.dump(
{
"status_code": resp.status_code,
"text": resp.text,
"url": resp.url,
"url": str(resp.url),
"headers": dict(resp.headers),
},
fp,
@ -187,45 +173,46 @@ def http_get(s: requests.Session, url: str, *args, **kwds) -> requests.Response:
if resp.is_redirect:
# Redirects could mean trouble, we need to stay on top of that!
raise RedirectError(from_url=resp.url, to_url=resp.headers["location"])
raise _RedirectError(from_url=str(resp.url), to_url=resp.headers["location"])
return resp
def soup_from_url(url):
async def asoup_from_url(url):
"""Return a BeautifulSoup instance from the contents for the given URL."""
with session() as s:
r = http_get(s, url)
async with asession() as s:
r = await _ahttp_get(s, url)
soup = bs4.BeautifulSoup(r.text, "html5lib")
return soup
def last_modified_from_response(resp):
if last_mod := resp.headers.get("Last-Modified"):
def _last_modified_from_response(resp: _Response_T) -> float | None:
if last_mod := resp.headers.get("last-modified"):
try:
return email.utils.parsedate_to_datetime(last_mod).timestamp()
except:
log.exception("🐛 Received invalid value for Last-Modified: %s", last_mod)
def last_modified_from_file(path: Path):
def _last_modified_from_file(path: Path) -> float:
return path.stat().st_mtime
def download(
async def adownload(
url: str,
file_path: Union[Path, str] = None,
*,
replace_existing: bool = None,
to_path: Path | str | None = None,
replace_existing: bool | None = None,
only_if_newer: bool = False,
timeout: float = None,
verify_ssl: bool = True,
timeout: float | None = None,
chunk_callback=None,
response_callback=None,
):
) -> bytes | None:
"""Download a file.
If `to_path` is `None` return the remote content, otherwise write the
content to the given file path.
Existing files will not be overwritten unless `replace_existing` is set.
Setting `only_if_newer` will check if the remote file is newer than the
local file, otherwise the download will be aborted.
@ -234,50 +221,56 @@ def download(
replace_existing = only_if_newer
file_exists = None
if file_path is not None:
file_path = Path(file_path)
if to_path is not None:
to_path = Path(to_path)
file_exists = file_path.exists() and file_path.stat().st_size
file_exists = to_path.exists() and to_path.stat().st_size
if file_exists and not replace_existing:
raise FileExistsError(23, "Would replace existing file", str(file_path))
with session() as s:
raise FileExistsError(23, "Would replace existing file", str(to_path))
async with asession() as s:
headers = {}
if file_exists and only_if_newer:
assert file_path
file_lastmod = last_modified_from_file(file_path)
headers["If-Modified-Since"] = email.utils.formatdate(
assert to_path
file_lastmod = _last_modified_from_file(to_path)
headers["if-modified-since"] = email.utils.formatdate(
file_lastmod, usegmt=True
)
req = s.prepare_request(requests.Request("GET", url, headers=headers))
req = s.build_request(method="GET", url=url, headers=headers, timeout=timeout)
log.debug("⚡️ loading %s (%s) ...", req.url, req.headers)
resp = s.send(
req, allow_redirects=True, stream=True, timeout=timeout, verify=verify_ssl
)
log.debug("⚡️ Loading %s (%a) ...", req.url, dict(req.headers))
resp = await s.send(req, follow_redirects=True, stream=True)
try:
if response_callback is not None:
try:
response_callback(resp)
except:
log.exception("🐛 Error in response callback.")
log.debug("☕️ Response status: %s; headers: %s", resp.status_code, resp.headers)
log.debug(
"☕️ %s -> status: %s; headers: %a",
req.url,
resp.status_code,
dict(resp.headers),
)
if resp.status_code == httpx.codes.NOT_MODIFIED:
log.debug(
"✋ Remote file has not changed, skipping download: %s -> %a",
req.url,
to_path,
)
return
resp.raise_for_status()
if resp.status_code == codes.not_modified:
log.debug("✋ Remote file has not changed, skipping download.")
return
if file_path is None:
if to_path is None:
await resp.aread() # Download the response stream to allow `resp.content` access.
return resp.content
assert replace_existing is True
resp_lastmod = last_modified_from_response(resp)
resp_lastmod = _last_modified_from_response(resp)
# Check Last-Modified in case the server ignored If-Modified-Since.
# XXX also check Content-Length?
@ -285,24 +278,23 @@ def download(
assert file_lastmod
if resp_lastmod <= file_lastmod:
log.debug("✋ Local file is newer, skipping download.")
resp.close()
log.debug("✋ Local file is newer, skipping download: %a", req.url)
return
# Create intermediate directories if necessary.
download_dir = file_path.parent
download_dir = to_path.parent
download_dir.mkdir(parents=True, exist_ok=True)
# Write content to temp file.
tempdir = download_dir
tempfd, tempfile_path = tempfile.mkstemp(
dir=tempdir, prefix=f".download-{file_path.name}."
dir=tempdir, prefix=f".download-{to_path.name}."
)
one_mb = 2 ** 20
one_mb = 2**20
chunk_size = 8 * one_mb
try:
log.debug("💾 Writing to temp file %s ...", tempfile_path)
for chunk in resp.iter_content(chunk_size=chunk_size, decode_unicode=False):
async for chunk in resp.aiter_bytes(chunk_size):
os.write(tempfd, chunk)
if chunk_callback:
try:
@ -313,10 +305,19 @@ def download(
os.close(tempfd)
# Move downloaded file to destination.
if file_exists:
log.debug("💾 Replacing existing file: %s", file_path)
Path(tempfile_path).replace(file_path)
if to_path.exists():
log.debug("💾 Replacing existing file: %s", to_path)
else:
log.debug("💾 Move to destination: %s", to_path)
if replace_existing:
Path(tempfile_path).replace(to_path)
else:
Path(tempfile_path).rename(to_path)
# Fix file attributes.
if resp_lastmod is not None:
os.utime(file_path, (resp_lastmod, resp_lastmod))
log.debug("💾 Adjusting file timestamp: %s (%s)", to_path, resp_lastmod)
os.utime(to_path, (resp_lastmod, resp_lastmod))
finally:
await resp.aclose()

View file

@ -1,5 +1,5 @@
import re
from typing import Union, cast
from typing import cast
import ulid
from ulid.hints import Buffer
@ -16,7 +16,7 @@ class ULID(ulid.ULID):
_pattern = re.compile(r"^[0-9A-HJKMNP-TV-Z]{26}$")
def __init__(self, buffer: Union[Buffer, ulid.ULID, str, None] = None):
def __init__(self, buffer: Buffer | ulid.ULID | str | None = None):
if isinstance(buffer, str):
if not self._pattern.search(buffer):
raise ValueError("Invalid ULID.")

View file

@ -17,7 +17,10 @@ def b64padded(s: str) -> str:
def phc_scrypt(
secret: bytes, *, salt: bytes = None, params: dict[Literal["n", "r", "p"], int] = {}
secret: bytes,
*,
salt: bytes | None = None,
params: dict[Literal["n", "r", "p"], int] = {},
) -> str:
"""Return the scrypt expanded secret in PHC string format.
@ -30,7 +33,7 @@ def phc_scrypt(
if salt is None:
salt = secrets.token_bytes(16)
n = params.get("n", 2 ** 14) # CPU/Memory cost factor
n = params.get("n", 2**14) # CPU/Memory cost factor
r = params.get("r", 8) # block size
p = params.get("p", 1) # parallelization factor
# maxmem = 2 * 128 * n * r * p

View file

@ -1,8 +1,9 @@
import asyncio
import contextlib
import logging
import secrets
from json.decoder import JSONDecodeError
from typing import Literal, Optional, overload
from typing import Literal, overload
from starlette.applications import Starlette
from starlette.authentication import (
@ -85,11 +86,14 @@ def truthy(s: str):
return bool(s) and s.lower() in {"1", "yes", "true"}
def yearcomp(s: str):
_Yearcomp = Literal["<", "=", ">"]
def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
if not s:
return
comp: Literal["<", "=", ">"] = "="
comp: _Yearcomp = "="
if (prefix := s[0]) in "<=>":
comp = prefix # type: ignore
s = s[len(prefix) :]
@ -97,7 +101,9 @@ def yearcomp(s: str):
return comp, int(s)
def as_int(x, *, max: int = None, min: Optional[int] = 1, default: int = None):
def as_int(
x, *, max: int | None = None, min: int | None = 1, default: int | None = None
) -> int:
try:
if not isinstance(x, int):
x = int(x)
@ -135,7 +141,7 @@ async def json_from_body(request, keys: list[str]) -> list:
...
async def json_from_body(request, keys: list[str] = None):
async def json_from_body(request, keys: list[str] | None = None):
if not await request.body():
data = {}
@ -158,7 +164,7 @@ def is_admin(request):
return "admin" in request.auth.scopes
async def auth_user(request) -> Optional[User]:
async def auth_user(request) -> User | None:
if not isinstance(request.user, AuthedUser):
return
@ -176,7 +182,7 @@ async def auth_user(request) -> Optional[User]:
_routes = []
def route(path: str, *, methods: list[str] = None, **kwds):
def route(path: str, *, methods: list[str] | None = None, **kwds):
def decorator(func):
r = Route(path, func, methods=methods, **kwds)
_routes.append(r)
@ -190,7 +196,6 @@ route.registered = _routes
@route("/groups/{group_id}/ratings")
async def get_ratings_for_group(request):
group_id = as_ulid(request.path_params["group_id"])
group = await db.get(Group, id=str(group_id))
@ -251,7 +256,6 @@ def not_implemented():
@route("/movies")
@requires(["authenticated"])
async def list_movies(request):
params = request.query_params
user = await auth_user(request)
@ -319,7 +323,6 @@ async def list_movies(request):
@route("/movies", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_movie(request):
not_implemented()
@ -361,7 +364,6 @@ _import_lock = asyncio.Lock()
@route("/movies/_reload_imdb", methods=["POST"])
@requires(["authenticated", "admin"])
async def load_imdb_movies(request):
params = request.query_params
force = truthy(params.get("force"))
@ -384,7 +386,6 @@ async def load_imdb_movies(request):
@route("/users")
@requires(["authenticated", "admin"])
async def list_users(request):
users = await db.get_all(User)
return JSONResponse([asplain(u) for u in users])
@ -393,7 +394,6 @@ async def list_users(request):
@route("/users", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_user(request):
name, imdb_id = await json_from_body(request, ["name", "imdb_id"])
# XXX restrict name
@ -415,7 +415,6 @@ async def add_user(request):
@route("/users/{user_id}")
@requires(["authenticated"])
async def show_user(request):
user_id = as_ulid(request.path_params["user_id"])
if is_admin(request):
@ -444,7 +443,6 @@ async def show_user(request):
@route("/users/{user_id}", methods=["DELETE"])
@requires(["authenticated", "admin"])
async def remove_user(request):
user_id = as_ulid(request.path_params["user_id"])
user = await db.get(User, id=str(user_id))
@ -462,7 +460,6 @@ async def remove_user(request):
@route("/users/{user_id}", methods=["PATCH"])
@requires(["authenticated"])
async def modify_user(request):
user_id = as_ulid(request.path_params["user_id"])
if is_admin(request):
@ -510,7 +507,6 @@ async def modify_user(request):
@route("/users/{user_id}/groups", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_group_to_user(request):
user_id = as_ulid(request.path_params["user_id"])
user = await db.get(User, id=str(user_id))
@ -535,21 +531,18 @@ async def add_group_to_user(request):
@route("/users/{user_id}/ratings")
@requires(["private"])
async def ratings_for_user(request):
not_implemented()
@route("/users/{user_id}/ratings", methods=["PUT"])
@requires("authenticated")
async def set_rating_for_user(request):
not_implemented()
@route("/users/_reload_ratings", methods=["POST"])
@requires(["authenticated", "admin"])
async def load_imdb_user_ratings(request):
ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()]
return JSONResponse({"new_ratings": [asplain(r) for r in ratings]})
@ -558,7 +551,6 @@ async def load_imdb_user_ratings(request):
@route("/groups")
@requires(["authenticated", "admin"])
async def list_groups(request):
groups = await db.get_all(Group)
return JSONResponse([asplain(g) for g in groups])
@ -567,7 +559,6 @@ async def list_groups(request):
@route("/groups", methods=["POST"])
@requires(["authenticated", "admin"])
async def add_group(request):
(name,) = await json_from_body(request, ["name"])
# XXX restrict name
@ -581,7 +572,6 @@ async def add_group(request):
@route("/groups/{group_id}/users", methods=["POST"])
@requires(["authenticated"])
async def add_user_to_group(request):
group_id = as_ulid(request.path_params["group_id"])
group = await db.get(Group, id=str(group_id))
@ -623,6 +613,13 @@ def auth_error(request, err):
return unauthorized(str(err))
@contextlib.asynccontextmanager
async def lifespan(app: Starlette):
await open_connection_pool()
yield
await close_connection_pool()
def create_app():
if config.loglevel == "DEBUG":
logging.basicConfig(
@ -633,8 +630,7 @@ def create_app():
log.debug(f"Log level: {config.loglevel}")
return Starlette(
on_startup=[open_connection_pool],
on_shutdown=[close_connection_pool],
lifespan=lifespan,
routes=[
Mount(f"{config.api_base}v1", routes=route.registered),
],

View file

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Container, Iterable, Optional
from typing import Container, Iterable
from . import imdb, models
@ -10,17 +10,17 @@ Score100 = int # [0, 100]
@dataclass
class Rating:
canonical_title: str
imdb_score: Optional[Score100]
imdb_votes: Optional[int]
imdb_score: Score100 | None
imdb_votes: int | None
media_type: str
movie_imdb_id: str
original_title: Optional[str]
original_title: str | None
release_year: int
user_id: Optional[str]
user_score: Optional[Score100]
user_id: str | None
user_score: Score100 | None
@classmethod
def from_movie(cls, movie: models.Movie, *, rating: models.Rating = None):
def from_movie(cls, movie: models.Movie, *, rating: models.Rating | None = None):
return cls(
canonical_title=movie.title,
imdb_score=movie.imdb_score,
@ -37,11 +37,11 @@ class Rating:
@dataclass
class RatingAggregate:
canonical_title: str
imdb_score: Optional[Score100]
imdb_votes: Optional[int]
imdb_score: Score100 | None
imdb_votes: int | None
link: URL
media_type: str
original_title: Optional[str]
original_title: str | None
user_scores: list[Score100]
year: int