Merge branch 'feat/py311'

This commit is contained in:
ducklet 2023-03-18 01:12:35 +01:00
commit a020d972f8
23 changed files with 783 additions and 748 deletions

4
.git-blame-ignore-revs Normal file
View file

@ -0,0 +1,4 @@
# Apply Black v23.1.0 formatting changes.
8a8bfce89de23d987386a35b659532bbac373788
# Apply auto-formatting to tests.
9ffcc5357150cecde26f5e6f8fccceaf92411efb

View file

@ -1,4 +1,4 @@
FROM docker.io/library/python:3.10-alpine FROM docker.io/library/python:3.11-alpine
RUN apk update --no-cache \ RUN apk update --no-cache \
&& apk upgrade --no-cache \ && apk upgrade --no-cache \
@ -11,20 +11,18 @@ WORKDIR /var/app
COPY requirements.txt ./ COPY requirements.txt ./
# Required to build greenlet on Alpine, dependency of SQLAlchemy 1.4. RUN pip install --no-cache-dir --upgrade \
RUN apk add --no-cache \ --requirement requirements.txt
--virtual .build-deps \
g++ gcc musl-dev \
&& pip install --no-cache-dir --upgrade \
--requirement requirements.txt \
&& apk del .build-deps
USER 10000:10001 USER 10000:10001
COPY . ./ COPY . ./
ENV UNWIND_DATA="/data" ENV UNWIND_DATA="/data"
VOLUME ["/data"] VOLUME $UNWIND_DATA
ENV UNWIND_PORT=8097
EXPOSE $UNWIND_PORT
ENTRYPOINT ["/var/app/run"] ENTRYPOINT ["/var/app/run"]
CMD ["server"] CMD ["server"]

817
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -6,26 +6,20 @@ authors = ["ducklet <ducklet@noreply.code.dumpr.org>"]
license = "LOL" license = "LOL"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.10" python = "^3.11"
requests = "^2.25.1"
beautifulsoup4 = "^4.9.3" beautifulsoup4 = "^4.9.3"
html5lib = "^1.1" html5lib = "^1.1"
starlette = "^0.17.0" starlette = "^0.26"
ulid-py = "^1.1.0" ulid-py = "^1.1.0"
databases = {extras = ["sqlite"], version = "^0.6.1"} databases = {extras = ["sqlite"], version = "^0.7.0"}
toml = "^0.10.2" uvicorn = "^0.21"
uvicorn = "^0.19.0" httpx = "^0.23.3"
[tool.poetry.group.fixes.dependencies]
# `databases` is having issues with new versions of SQLAlchemy 1.4,
# and `greenlet` is also always a pain.
SQLAlchemy = "1.4.25"
greenlet = "1.1.2"
[tool.poetry.group.dev] [tool.poetry.group.dev]
optional = true optional = true
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
autoflake = "*"
pytest = "*" pytest = "*"
pyright = "*" pyright = "*"
black = "*" black = "*"
@ -37,4 +31,14 @@ requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.pyright] [tool.pyright]
pythonVersion = "3.10" pythonVersion = "3.11"
[tool.isort]
profile = "black"
[tool.autoflake]
remove-duplicate-keys = true
remove-unused-variables = true
remove-all-unused-imports = true
ignore-init-module-imports = true
ignore-pass-after-docstring = true

View file

@ -4,6 +4,7 @@ cd "$RUN_DIR"
[ -z "${DEBUG:-}" ] || set -x [ -z "${DEBUG:-}" ] || set -x
isort --profile black unwind autoflake --quiet --check --recursive unwind tests
black unwind isort unwind tests
black unwind tests
pyright pyright

13
scripts/profile Executable file
View file

@ -0,0 +1,13 @@
#!/bin/sh -eu
cd "$RUN_DIR"
outfile="profile-$(date '+%Y%m%d-%H%M%S').txt"
[ -z "${DEBUG:-}" ] || set -x
echo "# Writing profiler stats to: $outfile"
python -m cProfile -o "$outfile" -m unwind "$@"
echo "# Loading stats file: $outfile"
python -m pstats "$outfile"

View file

@ -1,7 +1,14 @@
#!/bin/sh -eu #!/bin/sh -eu
: "${UNWIND_PORT:=8097}"
cd "$RUN_DIR" cd "$RUN_DIR"
[ -z "${DEBUG:-}" ] || set -x [ -z "${DEBUG:-}" ] || set -x
exec uvicorn --host 0.0.0.0 --factory unwind:create_app export UNWIND_PORT
exec uvicorn \
--host 0.0.0.0 \
--port "$UNWIND_PORT" \
--factory unwind:create_app

View file

@ -10,5 +10,6 @@ trap 'rm "$dbfile"' EXIT TERM INT QUIT
[ -z "${DEBUG:-}" ] || set -x [ -z "${DEBUG:-}" ] || set -x
SQLALCHEMY_WARN_20=1 \
UNWIND_STORAGE="$dbfile" \ UNWIND_STORAGE="$dbfile" \
python -m pytest "$@" python -m pytest "$@"

View file

@ -1,6 +1,8 @@
import asyncio import asyncio
import pytest import pytest
import pytest_asyncio
from unwind import db from unwind import db
@ -13,7 +15,7 @@ def event_loop():
loop.close() loop.close()
@pytest.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def shared_conn(): async def shared_conn():
c = db.shared_connection() c = db.shared_connection()
await c.connect() await c.connect()
@ -24,7 +26,7 @@ async def shared_conn():
await c.disconnect() await c.disconnect()
@pytest.fixture @pytest_asyncio.fixture
async def conn(shared_conn): async def conn(shared_conn):
async with shared_conn.transaction(force_rollback=True): async with shared_conn.transaction(force_rollback=True):
yield shared_conn yield shared_conn

View file

@ -1,14 +1,13 @@
from datetime import datetime from datetime import datetime
import pytest import pytest
from unwind import db, models, web_models from unwind import db, models, web_models
pytestmark = pytest.mark.asyncio
@pytest.mark.asyncio
async def test_add_and_get(shared_conn): async def test_add_and_get(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True): async with shared_conn.transaction(force_rollback=True):
m1 = models.Movie( m1 = models.Movie(
title="test movie", title="test movie",
release_year=2013, release_year=2013,
@ -31,9 +30,9 @@ async def test_add_and_get(shared_conn):
assert m2 == await db.get(models.Movie, id=str(m2.id)) assert m2 == await db.get(models.Movie, id=str(m2.id))
async def test_find_ratings(shared_conn): @pytest.mark.asyncio
async def test_find_ratings(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True): async with shared_conn.transaction(force_rollback=True):
m1 = models.Movie( m1 = models.Movie(
title="test movie", title="test movie",
release_year=2013, release_year=2013,
@ -157,4 +156,3 @@ async def test_find_ratings(shared_conn):
rows = await db.find_ratings(title="test", include_unrated=True) rows = await db.find_ratings(title="test", include_unrated=True)
ratings = tuple(web_models.Rating(**r) for r in rows) ratings = tuple(web_models.Rating(**r) for r in rows)
assert (web_models.Rating.from_movie(m1),) == ratings assert (web_models.Rating.from_movie(m1),) == ratings

View file

@ -1,14 +1,15 @@
import pytest import pytest
from unwind.imdb import imdb_rating_from_score, score_from_imdb_rating from unwind.imdb import imdb_rating_from_score, score_from_imdb_rating
@pytest.mark.parametrize("rating", (x / 10 for x in range(10, 101))) @pytest.mark.parametrize("rating", (x / 10 for x in range(10, 101)))
def test_rating_conversion(rating): def test_rating_conversion(rating: float):
assert rating == imdb_rating_from_score(score_from_imdb_rating(rating)) assert rating == imdb_rating_from_score(score_from_imdb_rating(rating))
@pytest.mark.parametrize("score", range(0, 101)) @pytest.mark.parametrize("score", range(0, 101))
def test_score_conversion(score): def test_score_conversion(score: int):
# Because our score covers 101 discrete values and IMDb's rating only 91 # Because our score covers 101 discrete values and IMDb's rating only 91
# discrete values, the mapping is non-injective, i.e. 10 values can't be # discrete values, the mapping is non-injective, i.e. 10 values can't be
# mapped uniquely. # mapped uniquely.

View file

@ -1,18 +1,14 @@
from starlette.testclient import TestClient
import pytest import pytest
from starlette.testclient import TestClient
from unwind import create_app from unwind import create_app, db, imdb, models
from unwind import db, models, imdb
# https://pypi.org/project/pytest-asyncio/
pytestmark = pytest.mark.asyncio
app = create_app() app = create_app()
async def test_app(shared_conn): @pytest.mark.asyncio
async def test_app(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True): async with shared_conn.transaction(force_rollback=True):
# https://www.starlette.io/testclient/ # https://www.starlette.io/testclient/
client = TestClient(app) client = TestClient(app)
response = client.get("/api/v1/movies") response = client.get("/api/v1/movies")

View file

@ -6,7 +6,7 @@ from pathlib import Path
from . import config from . import config
from .db import close_connection_pool, open_connection_pool from .db import close_connection_pool, open_connection_pool
from .imdb import refresh_user_ratings_from_imdb from .imdb import refresh_user_ratings_from_imdb
from .imdb_import import import_from_file from .imdb_import import download_datasets, import_from_file
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -15,7 +15,7 @@ async def run_load_user_ratings_from_imdb():
await open_connection_pool() await open_connection_pool()
i = 0 i = 0
async for rating in refresh_user_ratings_from_imdb(): async for _ in refresh_user_ratings_from_imdb():
i += 1 i += 1
log.info("✨ Imported %s new ratings.", i) log.info("✨ Imported %s new ratings.", i)
@ -31,6 +31,10 @@ async def run_import_imdb_dataset(basics_path: Path, ratings_path: Path):
await close_connection_pool() await close_connection_pool()
async def run_download_imdb_dataset(basics_path: Path, ratings_path: Path):
await download_datasets(basics_path=basics_path, ratings_path=ratings_path)
def getargs(): def getargs():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
commands = parser.add_subparsers(required=True) commands = parser.add_subparsers(required=True)
@ -55,6 +59,25 @@ def getargs():
"--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True "--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True
) )
parser_download_imdb_dataset = commands.add_parser(
"download-imdb-dataset",
help="Download IMDb datasets.",
description="""
Download IMDb datasets.
""",
)
parser_download_imdb_dataset.add_argument(
dest="mode",
action="store_const",
const="download-imdb-dataset",
)
parser_download_imdb_dataset.add_argument(
"--basics", metavar="basics_file.tsv.gz", type=Path, required=True
)
parser_download_imdb_dataset.add_argument(
"--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True
)
parser_load_user_ratings_from_imdb = commands.add_parser( parser_load_user_ratings_from_imdb = commands.add_parser(
"load-user-ratings-from-imdb", "load-user-ratings-from-imdb",
help="Load user ratings from imdb.com.", help="Load user ratings from imdb.com.",
@ -94,6 +117,8 @@ def main():
asyncio.run(run_load_user_ratings_from_imdb()) asyncio.run(run_load_user_ratings_from_imdb())
elif args.mode == "import-imdb-dataset": elif args.mode == "import-imdb-dataset":
asyncio.run(run_import_imdb_dataset(args.basics, args.ratings)) asyncio.run(run_import_imdb_dataset(args.basics, args.ratings))
elif args.mode == "download-imdb-dataset":
asyncio.run(run_download_imdb_dataset(args.basics, args.ratings))
main() main()

View file

@ -1,8 +1,7 @@
import os import os
import tomllib
from pathlib import Path from pathlib import Path
import toml
datadir = Path(os.getenv("UNWIND_DATA") or "./data") datadir = Path(os.getenv("UNWIND_DATA") or "./data")
cachedir = ( cachedir = (
Path(cachedir) Path(cachedir)
@ -14,7 +13,8 @@ loglevel = os.getenv("UNWIND_LOGLEVEL") or ("DEBUG" if debug else "INFO")
storage_path = os.getenv("UNWIND_STORAGE", datadir / "db.sqlite") storage_path = os.getenv("UNWIND_STORAGE", datadir / "db.sqlite")
config_path = os.getenv("UNWIND_CONFIG", datadir / "config.toml") config_path = os.getenv("UNWIND_CONFIG", datadir / "config.toml")
_config = toml.load(config_path) with open(config_path, "rb") as fd:
_config = tomllib.load(fd)
api_base = _config["api"].get("base", "/api/") api_base = _config["api"].get("base", "/api/")
api_cors = _config["api"].get("cors", "*") api_cors = _config["api"].get("cors", "*")

View file

@ -4,7 +4,7 @@ import logging
import re import re
import threading import threading
from pathlib import Path from pathlib import Path
from typing import Any, Iterable, Literal, Optional, Type, TypeVar, Union from typing import Any, Iterable, Literal, Type, TypeVar
import sqlalchemy import sqlalchemy
from databases import Database from databases import Database
@ -26,7 +26,7 @@ from .types import ULID
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
T = TypeVar("T") T = TypeVar("T")
_shared_connection: Optional[Database] = None _shared_connection: Database | None = None
async def open_connection_pool() -> None: async def open_connection_pool() -> None:
@ -119,7 +119,6 @@ async def apply_db_patches(db: Database):
raise RuntimeError("No statement found.") raise RuntimeError("No statement found.")
async with db.transaction(): async with db.transaction():
for query in queries: for query in queries:
await db.execute(query) await db.execute(query)
@ -131,12 +130,12 @@ async def apply_db_patches(db: Database):
await db.execute("vacuum") await db.execute("vacuum")
async def get_import_progress() -> Optional[Progress]: async def get_import_progress() -> Progress | None:
"""Return the latest import progress.""" """Return the latest import progress."""
return await get(Progress, type="import-imdb-movies", order_by="started DESC") return await get(Progress, type="import-imdb-movies", order_by="started DESC")
async def stop_import_progress(*, error: BaseException = None): async def stop_import_progress(*, error: BaseException | None = None):
"""Stop the current import. """Stop the current import.
If an error is given, it will be logged to the progress state. If an error is given, it will be logged to the progress state.
@ -176,6 +175,8 @@ async def set_import_progress(progress: float) -> Progress:
else: else:
await add(current) await add(current)
return current
_lock = threading.Lock() _lock = threading.Lock()
_prelock = threading.Lock() _prelock = threading.Lock()
@ -243,8 +244,8 @@ ModelType = TypeVar("ModelType")
async def get( async def get(
model: Type[ModelType], *, order_by: str = None, **kwds model: Type[ModelType], *, order_by: str | None = None, **kwds
) -> Optional[ModelType]: ) -> ModelType | None:
"""Load a model instance from the database. """Load a model instance from the database.
Passing `kwds` allows to filter the instance to load. You have to encode the Passing `kwds` allows to filter the instance to load. You have to encode the
@ -262,7 +263,7 @@ async def get(
query += f" ORDER BY {order_by}" query += f" ORDER BY {order_by}"
async with locked_connection() as conn: async with locked_connection() as conn:
row = await conn.fetch_one(query=query, values=values) row = await conn.fetch_one(query=query, values=values)
return fromplain(model, row, serialized=True) if row else None return fromplain(model, row._mapping, serialized=True) if row else None
async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]: async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
@ -282,7 +283,7 @@ async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
async with locked_connection() as conn: async with locked_connection() as conn:
rows = await conn.fetch_all(query=query, values=values) rows = await conn.fetch_all(query=query, values=values)
return (fromplain(model, row, serialized=True) for row in rows) return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]: async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
@ -293,7 +294,7 @@ async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
async with locked_connection() as conn: async with locked_connection() as conn:
rows = await conn.fetch_all(query=query, values=values) rows = await conn.fetch_all(query=query, values=values)
return (fromplain(model, row, serialized=True) for row in rows) return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def update(item): async def update(item):
@ -406,16 +407,16 @@ def sql_escape(s: str, char="#"):
async def find_ratings( async def find_ratings(
*, *,
title: str = None, title: str | None = None,
media_type: str = None, media_type: str | None = None,
exact: bool = False, exact: bool = False,
ignore_tv_episodes: bool = False, ignore_tv_episodes: bool = False,
include_unrated: bool = False, include_unrated: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] = None, yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10, limit_rows: int = 10,
user_ids: Iterable[str] = [], user_ids: Iterable[str] = [],
): ):
values: dict[str, Union[int, str]] = { values: dict[str, int | str] = {
"limit_rows": limit_rows, "limit_rows": limit_rows,
} }
@ -466,7 +467,7 @@ async def find_ratings(
""" """
async with locked_connection() as conn: async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values)) rows = await conn.fetch_all(bindparams(query, values))
movie_ids = tuple(r["movie_id"] for r in rows) movie_ids = tuple(r._mapping["movie_id"] for r in rows)
if include_unrated and len(movie_ids) < limit_rows: if include_unrated and len(movie_ids) < limit_rows:
sqlin, sqlin_vals = sql_in("id", movie_ids, not_=True) sqlin, sqlin_vals = sql_in("id", movie_ids, not_=True)
@ -485,7 +486,7 @@ async def find_ratings(
{**values, **sqlin_vals, "limit_rows": limit_rows - len(movie_ids)}, {**values, **sqlin_vals, "limit_rows": limit_rows - len(movie_ids)},
) )
) )
movie_ids += tuple(r["movie_id"] for r in rows) movie_ids += tuple(r._mapping["movie_id"] for r in rows)
return await ratings_for_movie_ids(ids=movie_ids) return await ratings_for_movie_ids(ids=movie_ids)
@ -527,29 +528,13 @@ async def ratings_for_movie_ids(
async with locked_connection() as conn: async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, vals)) rows = await conn.fetch_all(bindparams(query, vals))
return tuple(dict(r) for r in rows) return tuple(dict(r._mapping) for r in rows)
def sql_fields(tp: Type): def sql_fields(tp: Type):
return (f"{tp._table}.{f.name}" for f in fields(tp)) return (f"{tp._table}.{f.name}" for f in fields(tp))
def sql_fieldmap(tp: Type):
"""-> {alias: (table, field_name)}"""
return {f"{tp._table}_{f.name}": (tp._table, f.name) for f in fields(tp)}
def mux(*tps: Type):
return ", ".join(
f"{t}.{n} AS {k}" for tp in tps for k, (t, n) in sql_fieldmap(tp).items()
)
def demux(tp: Type[ModelType], row) -> ModelType:
d = {n: row[k] for k, (_, n) in sql_fieldmap(tp).items()}
return fromplain(tp, d, serialized=True)
def sql_in(column: str, values: Iterable[T], not_=False) -> tuple[str, dict[str, T]]: def sql_in(column: str, values: Iterable[T], not_=False) -> tuple[str, dict[str, T]]:
c = column.replace(".", "___") c = column.replace(".", "___")
value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)} value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)}
@ -583,22 +568,22 @@ async def ratings_for_movies(
async with locked_connection() as conn: async with locked_connection() as conn:
rows = await conn.fetch_all(query, values) rows = await conn.fetch_all(query, values)
return (fromplain(Rating, row, serialized=True) for row in rows) return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
async def find_movies( async def find_movies(
*, *,
title: str = None, title: str | None = None,
media_type: str = None, media_type: str | None = None,
exact: bool = False, exact: bool = False,
ignore_tv_episodes: bool = False, ignore_tv_episodes: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] = None, yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10, limit_rows: int = 10,
skip_rows: int = 0, skip_rows: int = 0,
include_unrated: bool = False, include_unrated: bool = False,
user_ids: list[ULID] = [], user_ids: list[ULID] = [],
) -> Iterable[tuple[Movie, list[Rating]]]: ) -> Iterable[tuple[Movie, list[Rating]]]:
values: dict[str, Union[int, str]] = { values: dict[str, int | str] = {
"limit_rows": limit_rows, "limit_rows": limit_rows,
"skip_rows": skip_rows, "skip_rows": skip_rows,
} }
@ -650,7 +635,7 @@ async def find_movies(
async with locked_connection() as conn: async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values)) rows = await conn.fetch_all(bindparams(query, values))
movies = [fromplain(Movie, row, serialized=True) for row in rows] movies = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
if not user_ids: if not user_ids:
return ((m, []) for m in movies) return ((m, []) for m in movies)

View file

@ -2,12 +2,11 @@ import logging
import re import re
from collections import namedtuple from collections import namedtuple
from datetime import datetime from datetime import datetime
from typing import Optional, Tuple
from urllib.parse import urljoin from urllib.parse import urljoin
from . import db from . import db
from .models import Movie, Rating, User from .models import Movie, Rating, User
from .request import cache_path, session, soup_from_url from .request import asession, asoup_from_url, cache_path
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -35,13 +34,11 @@ log = logging.getLogger(__name__)
# p.text-muted.text-small span[name=nv] [data-value] # p.text-muted.text-small span[name=nv] [data-value]
async def refresh_user_ratings_from_imdb(stop_on_dupe=True): async def refresh_user_ratings_from_imdb(stop_on_dupe: bool = True):
async with asession() as s:
with session() as s:
s.headers["Accept-Language"] = "en-US, en;q=0.5" s.headers["Accept-Language"] = "en-US, en;q=0.5"
for user in await db.get_all(User): for user in await db.get_all(User):
log.info("⚡️ Loading data for %s ...", user.name) log.info("⚡️ Loading data for %s ...", user.name)
try: try:
@ -98,7 +95,6 @@ find_movie_id = re.compile(r"/title/(?P<id>tt\d+)/").search
def movie_and_rating_from_item(item) -> tuple[Movie, Rating]: def movie_and_rating_from_item(item) -> tuple[Movie, Rating]:
genres = (genre := item.find("span", "genre")) and genre.string or "" genres = (genre := item.find("span", "genre")) and genre.string or ""
movie = Movie( movie = Movie(
title=item.h3.a.string.strip(), title=item.h3.a.string.strip(),
@ -153,10 +149,10 @@ def movie_and_rating_from_item(item) -> tuple[Movie, Rating]:
ForgedRequest = namedtuple("ForgedRequest", "url headers") ForgedRequest = namedtuple("ForgedRequest", "url headers")
async def parse_page(url) -> Tuple[list[Rating], Optional[str]]: async def parse_page(url: str) -> tuple[list[Rating], str | None]:
ratings = [] ratings = []
soup = soup_from_url(url) soup = await asoup_from_url(url)
meta = soup.find("meta", property="pageId") meta = soup.find("meta", property="pageId")
headline = soup.h1 headline = soup.h1
@ -170,7 +166,6 @@ async def parse_page(url) -> Tuple[list[Rating], Optional[str]]:
items = soup.find_all("div", "lister-item-content") items = soup.find_all("div", "lister-item-content")
for i, item in enumerate(items): for i, item in enumerate(items):
try: try:
movie, rating = movie_and_rating_from_item(item) movie, rating = movie_and_rating_from_item(item)
except Exception as err: except Exception as err:
@ -196,11 +191,10 @@ async def parse_page(url) -> Tuple[list[Rating], Optional[str]]:
return (ratings, next_url if url != next_url else None) return (ratings, next_url if url != next_url else None)
async def load_ratings(user_id): async def load_ratings(user_id: str):
next_url = user_ratings_url(user_id) next_url = user_ratings_url(user_id)
while next_url: while next_url:
ratings, next_url = await parse_page(next_url) ratings, next_url = await parse_page(next_url)
for i, rating in enumerate(ratings): for i, rating in enumerate(ratings):

View file

@ -1,10 +1,11 @@
import asyncio
import csv import csv
import gzip import gzip
import logging import logging
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Generator, Literal, Optional, Type, TypeVar, overload from typing import Generator, Literal, Type, TypeVar, overload
from . import config, db, request from . import config, db, request
from .db import add_or_update_many_movies from .db import add_or_update_many_movies
@ -27,10 +28,10 @@ class BasicRow:
primaryTitle: str primaryTitle: str
originalTitle: str originalTitle: str
isAdult: bool isAdult: bool
startYear: Optional[int] startYear: int | None
endYear: Optional[int] endYear: int | None
runtimeMinutes: Optional[int] runtimeMinutes: int | None
genres: Optional[set[str]] genres: set[str] | None
@classmethod @classmethod
def from_row(cls, row): def from_row(cls, row):
@ -100,7 +101,7 @@ title_types = {
} }
def gz_mtime(path) -> datetime: def gz_mtime(path: Path) -> datetime:
"""Return the timestamp of the compressed file.""" """Return the timestamp of the compressed file."""
g = gzip.GzipFile(path, "rb") g = gzip.GzipFile(path, "rb")
g.peek(1) # start reading the file to fill the timestamp field g.peek(1) # start reading the file to fill the timestamp field
@ -108,14 +109,13 @@ def gz_mtime(path) -> datetime:
return datetime.fromtimestamp(g.mtime).replace(tzinfo=timezone.utc) return datetime.fromtimestamp(g.mtime).replace(tzinfo=timezone.utc)
def count_lines(path) -> int: def count_lines(path: Path) -> int:
i = 0 i = 0
one_mb = 2 ** 20 one_mb = 2**20
buf_size = 8 * one_mb # 8 MiB seems to give a good read/process performance. buf_size = 8 * one_mb # 8 MiB seems to give a good read/process performance.
with gzip.open(path, "rt") as f: with gzip.open(path, "rt") as f:
while buf := f.read(buf_size): while buf := f.read(buf_size):
i += buf.count("\n") i += buf.count("\n")
@ -124,19 +124,19 @@ def count_lines(path) -> int:
@overload @overload
def read_imdb_tsv( def read_imdb_tsv(
path, row_type, *, unpack: Literal[False] path: Path, row_type, *, unpack: Literal[False]
) -> Generator[list[str], None, None]: ) -> Generator[list[str], None, None]:
... ...
@overload @overload
def read_imdb_tsv( def read_imdb_tsv(
path, row_type: Type[T], *, unpack: Literal[True] = True path: Path, row_type: Type[T], *, unpack: Literal[True] = True
) -> Generator[T, None, None]: ) -> Generator[T, None, None]:
... ...
def read_imdb_tsv(path, row_type, *, unpack=True): def read_imdb_tsv(path: Path, row_type, *, unpack=True):
with gzip.open(path, "rt", newline="") as f: with gzip.open(path, "rt", newline="") as f:
rows = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_NONE) rows = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
@ -161,7 +161,7 @@ def read_imdb_tsv(path, row_type, *, unpack=True):
raise raise
def read_ratings(path): def read_ratings(path: Path):
mtime = gz_mtime(path) mtime = gz_mtime(path)
rows = read_imdb_tsv(path, RatingRow) rows = read_imdb_tsv(path, RatingRow)
@ -171,19 +171,20 @@ def read_ratings(path):
yield m yield m
def read_ratings_as_mapping(path): def read_ratings_as_mapping(path: Path):
"""Optimized function to quickly load all ratings.""" """Optimized function to quickly load all ratings."""
rows = read_imdb_tsv(path, RatingRow, unpack=False) rows = read_imdb_tsv(path, RatingRow, unpack=False)
return {r[0]: (round(100 * (float(r[1]) - 1) / 9), int(r[2])) for r in rows} return {r[0]: (round(100 * (float(r[1]) - 1) / 9), int(r[2])) for r in rows}
def read_basics(path): def read_basics(path: Path) -> Generator[Movie | None, None, None]:
mtime = gz_mtime(path) mtime = gz_mtime(path)
rows = read_imdb_tsv(path, BasicRow) rows = read_imdb_tsv(path, BasicRow)
for row in rows: for row in rows:
if row.startYear is None: if row.startYear is None:
log.debug("Skipping movie, missing year: %s", row) log.debug("Skipping movie, missing year: %s", row)
yield None
continue continue
m = row.as_movie() m = row.as_movie()
@ -197,20 +198,24 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path):
log.info("💾 Importing movies ...") log.info("💾 Importing movies ...")
total = count_lines(basics_path) total = count_lines(basics_path)
assert total != 0 log.debug("Found %i movies.", total)
if total == 0:
raise RuntimeError(f"No movies found.")
perc_next_report = 0.0 perc_next_report = 0.0
perc_step = 0.1 perc_step = 0.1
chunk = [] chunk = []
for i, m in enumerate(read_basics(basics_path)): for i, m in enumerate(read_basics(basics_path)):
perc = 100 * i / total perc = 100 * i / total
if perc >= perc_next_report: if perc >= perc_next_report:
await db.set_import_progress(perc) await db.set_import_progress(perc)
log.info("⏳ Imported %s%%", round(perc, 1)) log.info("⏳ Imported %s%%", round(perc, 1))
perc_next_report += perc_step perc_next_report += perc_step
if m is None:
continue
if m.media_type not in { if m.media_type not in {
"Movie", "Movie",
"Short", "Short",
@ -235,10 +240,27 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path):
await add_or_update_many_movies(chunk) await add_or_update_many_movies(chunk)
chunk = [] chunk = []
log.info("👍 Imported 100%")
await db.set_import_progress(100) await db.set_import_progress(100)
async def load_from_web(*, force: bool = False): async def download_datasets(*, basics_path: Path, ratings_path: Path) -> None:
"""Download IMDb movie database dumps.
See https://www.imdb.com/interfaces/ and https://datasets.imdbws.com/ for
more information on the IMDb database dumps.
"""
basics_url = "https://datasets.imdbws.com/title.basics.tsv.gz"
ratings_url = "https://datasets.imdbws.com/title.ratings.tsv.gz"
async with request.asession():
await asyncio.gather(
request.adownload(ratings_url, to_path=ratings_path, only_if_newer=True),
request.adownload(basics_url, to_path=basics_path, only_if_newer=True),
)
async def load_from_web(*, force: bool = False) -> None:
"""Refresh the full IMDb movie database. """Refresh the full IMDb movie database.
The latest dumps are first downloaded and then imported into the database. The latest dumps are first downloaded and then imported into the database.
@ -251,17 +273,13 @@ async def load_from_web(*, force: bool = False):
await db.set_import_progress(0) await db.set_import_progress(0)
try: try:
basics_url = "https://datasets.imdbws.com/title.basics.tsv.gz"
ratings_url = "https://datasets.imdbws.com/title.ratings.tsv.gz"
ratings_file = config.datadir / "imdb/title.ratings.tsv.gz" ratings_file = config.datadir / "imdb/title.ratings.tsv.gz"
basics_file = config.datadir / "imdb/title.basics.tsv.gz" basics_file = config.datadir / "imdb/title.basics.tsv.gz"
ratings_mtime = ratings_file.stat().st_mtime if ratings_file.exists() else None ratings_mtime = ratings_file.stat().st_mtime if ratings_file.exists() else None
bastics_mtime = basics_file.stat().st_mtime if basics_file.exists() else None bastics_mtime = basics_file.stat().st_mtime if basics_file.exists() else None
with request.session(): await download_datasets(basics_path=basics_file, ratings_path=ratings_file)
request.download(ratings_url, ratings_file, only_if_newer=True)
request.download(basics_url, basics_file, only_if_newer=True)
is_changed = ( is_changed = (
ratings_mtime != ratings_file.stat().st_mtime ratings_mtime != ratings_file.stat().st_mtime

View file

@ -3,13 +3,14 @@ from dataclasses import dataclass, field
from dataclasses import fields as _fields from dataclasses import fields as _fields
from datetime import datetime, timezone from datetime import datetime, timezone
from functools import partial from functools import partial
from types import UnionType
from typing import ( from typing import (
Annotated, Annotated,
Any, Any,
ClassVar, ClassVar,
Container, Container,
Literal, Literal,
Optional, Mapping,
Type, Type,
TypeVar, TypeVar,
Union, Union,
@ -19,13 +20,13 @@ from typing import (
from .types import ULID from .types import ULID
JSON = Union[int, float, str, None, list["JSON"], dict[str, "JSON"]] JSON = int | float | str | None | list["JSON"] | dict[str, "JSON"]
JSONObject = dict[str, JSON] JSONObject = dict[str, JSON]
T = TypeVar("T") T = TypeVar("T")
def annotations(tp: Type) -> Optional[tuple]: def annotations(tp: Type) -> tuple | None:
return tp.__metadata__ if hasattr(tp, "__metadata__") else None return tp.__metadata__ if hasattr(tp, "__metadata__") else None
@ -42,7 +43,6 @@ def fields(class_or_instance):
# XXX this might be a little slow (not sure), if so, memoize # XXX this might be a little slow (not sure), if so, memoize
for f in _fields(class_or_instance): for f in _fields(class_or_instance):
if f.name == "_is_lazy": if f.name == "_is_lazy":
continue continue
@ -54,21 +54,21 @@ def fields(class_or_instance):
def is_optional(tp: Type) -> bool: def is_optional(tp: Type) -> bool:
"""Return wether the given type is optional.""" """Return wether the given type is optional."""
if get_origin(tp) is not Union: if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
return False return False
args = get_args(tp) args = get_args(tp)
return len(args) == 2 and type(None) in args return len(args) == 2 and type(None) in args
def optional_type(tp: Type) -> Optional[Type]: def optional_type(tp: Type) -> Type | None:
"""Return the wrapped type from an optional type. """Return the wrapped type from an optional type.
For example this will return `int` for `Optional[int]`. For example this will return `int` for `Optional[int]`.
Since they're equivalent this also works for other optioning notations, like Since they're equivalent this also works for other optioning notations, like
`Union[int, None]` and `int | None`. `Union[int, None]` and `int | None`.
""" """
if get_origin(tp) is not Union: if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
return None return None
args = get_args(tp) args = get_args(tp)
@ -92,7 +92,7 @@ def _id(x: T) -> T:
def asplain( def asplain(
o: object, *, filter_fields: Container[str] = None, serialize: bool = False o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return the given model instance as `dict` with JSON compatible plain datatypes. """Return the given model instance as `dict` with JSON compatible plain datatypes.
@ -109,7 +109,6 @@ def asplain(
d: JSONObject = {} d: JSONObject = {}
for f in fields(o): for f in fields(o):
if filter_fields is not None and f.name not in filter_fields: if filter_fields is not None and f.name not in filter_fields:
continue continue
@ -146,7 +145,7 @@ def asplain(
return d return d
def fromplain(cls: Type[T], d: dict[str, Any], *, serialized: bool = False) -> T: def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
"""Return an instance of the given model using the given data. """Return an instance of the given model using the given data.
If `serialized` is `True`, collection types (lists, dicts, etc.) will be If `serialized` is `True`, collection types (lists, dicts, etc.) will be
@ -157,7 +156,6 @@ def fromplain(cls: Type[T], d: dict[str, Any], *, serialized: bool = False) -> T
dd: JSONObject = {} dd: JSONObject = {}
for f in fields(cls): for f in fields(cls):
target = f.type target = f.type
otype = optional_type(f.type) otype = optional_type(f.type)
is_opt = otype is not None is_opt = otype is not None
@ -188,7 +186,8 @@ def validate(o: object) -> None:
vtype = type(getattr(o, f.name)) vtype = type(getattr(o, f.name))
if vtype is not f.type: if vtype is not f.type:
if get_origin(f.type) is vtype or ( if get_origin(f.type) is vtype or (
get_origin(f.type) is Union and vtype in get_args(f.type) (isinstance(f.type, UnionType) or get_origin(f.type) is Union)
and vtype in get_args(f.type)
): ):
continue continue
raise ValueError(f"Invalid value type: {f.name}: {vtype}") raise ValueError(f"Invalid value type: {f.name}: {vtype}")
@ -206,7 +205,7 @@ class Progress:
type: str = None type: str = None
state: str = None state: str = None
started: datetime = field(default_factory=utcnow) started: datetime = field(default_factory=utcnow)
stopped: Optional[str] = None stopped: str | None = None
@property @property
def _state(self) -> dict: def _state(self) -> dict:
@ -243,15 +242,15 @@ class Movie:
id: ULID = field(default_factory=ULID) id: ULID = field(default_factory=ULID)
title: str = None # canonical title (usually English) title: str = None # canonical title (usually English)
original_title: Optional[ original_title: str | None = (
str None # original title (usually transscribed to latin script)
] = None # original title (usually transscribed to latin script) )
release_year: int = None # canonical release date release_year: int = None # canonical release date
media_type: str = None media_type: str = None
imdb_id: str = None imdb_id: str = None
imdb_score: Optional[int] = None # range: [0,100] imdb_score: int | None = None # range: [0,100]
imdb_votes: Optional[int] = None imdb_votes: int | None = None
runtime: Optional[int] = None # minutes runtime: int | None = None # minutes
genres: set[str] = None genres: set[str] = None
created: datetime = field(default_factory=utcnow) created: datetime = field(default_factory=utcnow)
updated: datetime = field(default_factory=utcnow) updated: datetime = field(default_factory=utcnow)
@ -292,7 +291,7 @@ dataclass containing the ID of the linked data.
The contents of the Relation are ignored or discarded when using The contents of the Relation are ignored or discarded when using
`asplain`, `fromplain`, and `validate`. `asplain`, `fromplain`, and `validate`.
""" """
Relation = Annotated[Optional[T], _RelationSentinel] Relation = Annotated[T | None, _RelationSentinel]
@dataclass @dataclass
@ -309,8 +308,8 @@ class Rating:
score: int = None # range: [0,100] score: int = None # range: [0,100]
rating_date: datetime = None rating_date: datetime = None
favorite: Optional[bool] = None favorite: bool | None = None
finished: Optional[bool] = None finished: bool | None = None
def __eq__(self, other): def __eq__(self, other):
"""Return wether two Ratings are equal. """Return wether two Ratings are equal.
@ -342,11 +341,11 @@ class User:
secret: str = None secret: str = None
groups: list[dict[str, str]] = field(default_factory=list) groups: list[dict[str, str]] = field(default_factory=list)
def has_access(self, group_id: Union[ULID, str], access: Access = "r"): def has_access(self, group_id: ULID | str, access: Access = "r"):
group_id = group_id if isinstance(group_id, str) else str(group_id) group_id = group_id if isinstance(group_id, str) else str(group_id)
return any(g["id"] == group_id and access == g["access"] for g in self.groups) return any(g["id"] == group_id and access == g["access"] for g in self.groups)
def set_access(self, group_id: Union[ULID, str], access: Access): def set_access(self, group_id: ULID | str, access: Access):
group_id = group_id if isinstance(group_id, str) else str(group_id) group_id = group_id if isinstance(group_id, str) else str(group_id)
for g in self.groups: for g in self.groups:
if g["id"] == group_id: if g["id"] == group_id:

View file

@ -4,19 +4,17 @@ import logging
import os import os
import tempfile import tempfile
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass, field
from functools import wraps from functools import wraps
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
from random import random from random import random
from time import sleep, time from time import sleep, time
from typing import Callable, Optional, Union from typing import Callable, ParamSpec, TypeVar, cast
import bs4 import bs4
import requests import httpx
from requests.status_codes import codes
from urllib3.util.retry import Retry
from . import config from . import config
@ -26,28 +24,17 @@ if config.debug and config.cachedir:
config.cachedir.mkdir(exist_ok=True) config.cachedir.mkdir(exist_ok=True)
def set_retries(s: requests.Session, n: int, backoff_factor: float = 0.2): _shared_asession = None
retry = (
Retry( _ASession_T = httpx.AsyncClient
total=n, _Response_T = httpx.Response
connect=n,
read=n, _T = TypeVar("_T")
status=n, _P = ParamSpec("_P")
status_forcelist=Retry.RETRY_AFTER_STATUS_CODES,
backoff_factor=backoff_factor,
)
if n
else Retry(0, read=False)
)
for a in s.adapters.values():
a.max_retries = retry
_shared_session = None @asynccontextmanager
async def asession():
@contextmanager
def session():
"""Return the shared request session. """Return the shared request session.
The session is shared by all request functions and provides cookie The session is shared by all request functions and provides cookie
@ -55,38 +42,34 @@ def session():
Opening the session before making a request allows you to set headers Opening the session before making a request allows you to set headers
or change the retry behavior. or change the retry behavior.
""" """
global _shared_session global _shared_asession
if _shared_session: if _shared_asession:
yield _shared_session yield _shared_asession
return return
_shared_session = Session() _shared_asession = _ASession_T()
_shared_asession.headers[
"user-agent"
] = "Mozilla/5.0 Gecko/20100101 unwind/20230203"
try: try:
yield _shared_session async with _shared_asession:
yield _shared_asession
finally: finally:
_shared_session = None _shared_asession = None
def Session() -> requests.Session: def _throttle(
s = requests.Session() times: int, per_seconds: float, jitter: Callable[[], float] | None = None
s.headers["User-Agent"] = "Mozilla/5.0 Gecko/20100101 unwind/20210506" ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
return s calls: deque[float] = deque(maxlen=times)
def throttle(
times: int, per_seconds: float, jitter: Callable[[], float] = None
) -> Callable[[Callable], Callable]:
calls: Deque[float] = deque(maxlen=times)
if jitter is None: if jitter is None:
jitter = lambda: 0.0 jitter = lambda: 0.0
def decorator(func: Callable) -> Callable: def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
@wraps(func) @wraps(func)
def inner(*args, **kwds): def inner(*args: _P.args, **kwds: _P.kwargs):
# clean up # clean up
while calls: while calls:
if calls[0] + per_seconds > time(): if calls[0] + per_seconds > time():
@ -118,23 +101,19 @@ def throttle(
return decorator return decorator
class CachedStr(str):
is_cached = True
@dataclass @dataclass
class CachedResponse: class _CachedResponse:
is_cached = True is_cached = True
status_code: int status_code: int
text: str text: str
url: str url: str
headers: dict[str, str] = None headers: dict[str, str] = field(default_factory=dict)
def json(self): def json(self):
return json.loads(self.text) return json.loads(self.text)
class RedirectError(RuntimeError): class _RedirectError(RuntimeError):
def __init__(self, from_url: str, to_url: str, is_cached=False): def __init__(self, from_url: str, to_url: str, is_cached=False):
self.from_url = from_url self.from_url = from_url
self.to_url = to_url self.to_url = to_url
@ -142,44 +121,51 @@ class RedirectError(RuntimeError):
super().__init__(f"Redirected: {from_url} -> {to_url}") super().__init__(f"Redirected: {from_url} -> {to_url}")
def cache_path(req) -> Optional[Path]: def cache_path(req) -> Path | None:
if not config.cachedir: if not config.cachedir:
return return
sig = repr(req.url) # + repr(sorted(req.headers.items())) sig = repr(req.url) # + repr(sorted(req.headers.items()))
return config.cachedir / md5(sig.encode()).hexdigest() return config.cachedir / md5(sig.encode()).hexdigest()
@throttle(1, 1, random) @_throttle(1, 1, random)
def http_get(s: requests.Session, url: str, *args, **kwds) -> requests.Response: async def _ahttp_get(s: _ASession_T, url: str, *args, **kwds) -> _Response_T:
req = s.build_request(method="GET", url=url, *args, **kwds)
req = s.prepare_request(requests.Request("GET", url, *args, **kwds))
cachefile = cache_path(req) if config.debug else None cachefile = cache_path(req) if config.debug else None
if cachefile: if cachefile:
if cachefile.exists(): if cachefile.exists():
log.debug( log.debug(
f"💾 loading {req.url} ({req.headers!a}) from cache {cachefile} ..." "💾 loading %s (%a) from cache %s ...", req.url, req.headers, cachefile
) )
with cachefile.open() as fp: with cachefile.open() as fp:
resp = CachedResponse(**json.load(fp)) resp = _CachedResponse(**json.load(fp))
if 300 <= resp.status_code <= 399: if 300 <= resp.status_code <= 399:
raise RedirectError( raise _RedirectError(
from_url=resp.url, to_url=resp.headers["location"], is_cached=True from_url=resp.url, to_url=resp.headers["location"], is_cached=True
) )
return resp return cast(_Response_T, resp)
log.debug(f"⚡️ loading {req.url} ({req.headers!a}) ...") log.debug("⚡️ loading %s (%a) ...", req.url, req.headers)
resp = s.send(req, allow_redirects=False, stream=True) resp = await s.send(req, follow_redirects=False, stream=True)
resp.raise_for_status() resp.raise_for_status()
await resp.aread() # Download the response stream to allow `resp.text` access.
if cachefile: if cachefile:
log.debug(
"💾 writing response to cache: %s (%a) -> %s",
req.url,
req.headers,
cachefile,
)
with cachefile.open("w") as fp: with cachefile.open("w") as fp:
json.dump( json.dump(
{ {
"status_code": resp.status_code, "status_code": resp.status_code,
"text": resp.text, "text": resp.text,
"url": resp.url, "url": str(resp.url),
"headers": dict(resp.headers), "headers": dict(resp.headers),
}, },
fp, fp,
@ -187,45 +173,46 @@ def http_get(s: requests.Session, url: str, *args, **kwds) -> requests.Response:
if resp.is_redirect: if resp.is_redirect:
# Redirects could mean trouble, we need to stay on top of that! # Redirects could mean trouble, we need to stay on top of that!
raise RedirectError(from_url=resp.url, to_url=resp.headers["location"]) raise _RedirectError(from_url=str(resp.url), to_url=resp.headers["location"])
return resp return resp
def soup_from_url(url): async def asoup_from_url(url):
"""Return a BeautifulSoup instance from the contents for the given URL.""" """Return a BeautifulSoup instance from the contents for the given URL."""
with session() as s: async with asession() as s:
r = http_get(s, url) r = await _ahttp_get(s, url)
soup = bs4.BeautifulSoup(r.text, "html5lib") soup = bs4.BeautifulSoup(r.text, "html5lib")
return soup return soup
def last_modified_from_response(resp): def _last_modified_from_response(resp: _Response_T) -> float | None:
if last_mod := resp.headers.get("Last-Modified"): if last_mod := resp.headers.get("last-modified"):
try: try:
return email.utils.parsedate_to_datetime(last_mod).timestamp() return email.utils.parsedate_to_datetime(last_mod).timestamp()
except: except:
log.exception("🐛 Received invalid value for Last-Modified: %s", last_mod) log.exception("🐛 Received invalid value for Last-Modified: %s", last_mod)
def last_modified_from_file(path: Path): def _last_modified_from_file(path: Path) -> float:
return path.stat().st_mtime return path.stat().st_mtime
def download( async def adownload(
url: str, url: str,
file_path: Union[Path, str] = None,
*, *,
replace_existing: bool = None, to_path: Path | str | None = None,
replace_existing: bool | None = None,
only_if_newer: bool = False, only_if_newer: bool = False,
timeout: float = None, timeout: float | None = None,
verify_ssl: bool = True,
chunk_callback=None, chunk_callback=None,
response_callback=None, response_callback=None,
): ) -> bytes | None:
"""Download a file. """Download a file.
If `to_path` is `None` return the remote content, otherwise write the
content to the given file path.
Existing files will not be overwritten unless `replace_existing` is set. Existing files will not be overwritten unless `replace_existing` is set.
Setting `only_if_newer` will check if the remote file is newer than the Setting `only_if_newer` will check if the remote file is newer than the
local file, otherwise the download will be aborted. local file, otherwise the download will be aborted.
@ -234,89 +221,103 @@ def download(
replace_existing = only_if_newer replace_existing = only_if_newer
file_exists = None file_exists = None
if file_path is not None: if to_path is not None:
file_path = Path(file_path) to_path = Path(to_path)
file_exists = file_path.exists() and file_path.stat().st_size file_exists = to_path.exists() and to_path.stat().st_size
if file_exists and not replace_existing: if file_exists and not replace_existing:
raise FileExistsError(23, "Would replace existing file", str(file_path)) raise FileExistsError(23, "Would replace existing file", str(to_path))
with session() as s:
async with asession() as s:
headers = {} headers = {}
if file_exists and only_if_newer: if file_exists and only_if_newer:
assert file_path assert to_path
file_lastmod = last_modified_from_file(file_path) file_lastmod = _last_modified_from_file(to_path)
headers["If-Modified-Since"] = email.utils.formatdate( headers["if-modified-since"] = email.utils.formatdate(
file_lastmod, usegmt=True file_lastmod, usegmt=True
) )
req = s.prepare_request(requests.Request("GET", url, headers=headers)) req = s.build_request(method="GET", url=url, headers=headers, timeout=timeout)
log.debug("⚡️ loading %s (%s) ...", req.url, req.headers) log.debug("⚡️ Loading %s (%a) ...", req.url, dict(req.headers))
resp = s.send( resp = await s.send(req, follow_redirects=True, stream=True)
req, allow_redirects=True, stream=True, timeout=timeout, verify=verify_ssl
)
if response_callback is not None: try:
try: if response_callback is not None:
response_callback(resp) try:
except: response_callback(resp)
log.exception("🐛 Error in response callback.") except:
log.exception("🐛 Error in response callback.")
log.debug("☕️ Response status: %s; headers: %s", resp.status_code, resp.headers) log.debug(
"☕️ %s -> status: %s; headers: %a",
req.url,
resp.status_code,
dict(resp.headers),
)
resp.raise_for_status() if resp.status_code == httpx.codes.NOT_MODIFIED:
log.debug(
if resp.status_code == codes.not_modified: "✋ Remote file has not changed, skipping download: %s -> %a",
log.debug("✋ Remote file has not changed, skipping download.") req.url,
return to_path,
)
if file_path is None:
return resp.content
assert replace_existing is True
resp_lastmod = last_modified_from_response(resp)
# Check Last-Modified in case the server ignored If-Modified-Since.
# XXX also check Content-Length?
if file_exists and only_if_newer and resp_lastmod is not None:
assert file_lastmod
if resp_lastmod <= file_lastmod:
log.debug("✋ Local file is newer, skipping download.")
resp.close()
return return
# Create intermediate directories if necessary. resp.raise_for_status()
download_dir = file_path.parent
download_dir.mkdir(parents=True, exist_ok=True) if to_path is None:
await resp.aread() # Download the response stream to allow `resp.content` access.
return resp.content
resp_lastmod = _last_modified_from_response(resp)
# Check Last-Modified in case the server ignored If-Modified-Since.
# XXX also check Content-Length?
if file_exists and only_if_newer and resp_lastmod is not None:
assert file_lastmod
if resp_lastmod <= file_lastmod:
log.debug("✋ Local file is newer, skipping download: %a", req.url)
return
# Create intermediate directories if necessary.
download_dir = to_path.parent
download_dir.mkdir(parents=True, exist_ok=True)
# Write content to temp file.
tempdir = download_dir
tempfd, tempfile_path = tempfile.mkstemp(
dir=tempdir, prefix=f".download-{to_path.name}."
)
one_mb = 2**20
chunk_size = 8 * one_mb
try:
log.debug("💾 Writing to temp file %s ...", tempfile_path)
async for chunk in resp.aiter_bytes(chunk_size):
os.write(tempfd, chunk)
if chunk_callback:
try:
chunk_callback(chunk)
except:
log.exception("🐛 Error in chunk callback.")
finally:
os.close(tempfd)
# Move downloaded file to destination.
if to_path.exists():
log.debug("💾 Replacing existing file: %s", to_path)
else:
log.debug("💾 Move to destination: %s", to_path)
if replace_existing:
Path(tempfile_path).replace(to_path)
else:
Path(tempfile_path).rename(to_path)
# Fix file attributes.
if resp_lastmod is not None:
log.debug("💾 Adjusting file timestamp: %s (%s)", to_path, resp_lastmod)
os.utime(to_path, (resp_lastmod, resp_lastmod))
# Write content to temp file.
tempdir = download_dir
tempfd, tempfile_path = tempfile.mkstemp(
dir=tempdir, prefix=f".download-{file_path.name}."
)
one_mb = 2 ** 20
chunk_size = 8 * one_mb
try:
log.debug("💾 Writing to temp file %s ...", tempfile_path)
for chunk in resp.iter_content(chunk_size=chunk_size, decode_unicode=False):
os.write(tempfd, chunk)
if chunk_callback:
try:
chunk_callback(chunk)
except:
log.exception("🐛 Error in chunk callback.")
finally: finally:
os.close(tempfd) await resp.aclose()
# Move downloaded file to destination.
if file_exists:
log.debug("💾 Replacing existing file: %s", file_path)
Path(tempfile_path).replace(file_path)
# Fix file attributes.
if resp_lastmod is not None:
os.utime(file_path, (resp_lastmod, resp_lastmod))

View file

@ -1,5 +1,5 @@
import re import re
from typing import Union, cast from typing import cast
import ulid import ulid
from ulid.hints import Buffer from ulid.hints import Buffer
@ -16,7 +16,7 @@ class ULID(ulid.ULID):
_pattern = re.compile(r"^[0-9A-HJKMNP-TV-Z]{26}$") _pattern = re.compile(r"^[0-9A-HJKMNP-TV-Z]{26}$")
def __init__(self, buffer: Union[Buffer, ulid.ULID, str, None] = None): def __init__(self, buffer: Buffer | ulid.ULID | str | None = None):
if isinstance(buffer, str): if isinstance(buffer, str):
if not self._pattern.search(buffer): if not self._pattern.search(buffer):
raise ValueError("Invalid ULID.") raise ValueError("Invalid ULID.")

View file

@ -17,7 +17,10 @@ def b64padded(s: str) -> str:
def phc_scrypt( def phc_scrypt(
secret: bytes, *, salt: bytes = None, params: dict[Literal["n", "r", "p"], int] = {} secret: bytes,
*,
salt: bytes | None = None,
params: dict[Literal["n", "r", "p"], int] = {},
) -> str: ) -> str:
"""Return the scrypt expanded secret in PHC string format. """Return the scrypt expanded secret in PHC string format.
@ -30,7 +33,7 @@ def phc_scrypt(
if salt is None: if salt is None:
salt = secrets.token_bytes(16) salt = secrets.token_bytes(16)
n = params.get("n", 2 ** 14) # CPU/Memory cost factor n = params.get("n", 2**14) # CPU/Memory cost factor
r = params.get("r", 8) # block size r = params.get("r", 8) # block size
p = params.get("p", 1) # parallelization factor p = params.get("p", 1) # parallelization factor
# maxmem = 2 * 128 * n * r * p # maxmem = 2 * 128 * n * r * p

View file

@ -1,8 +1,9 @@
import asyncio import asyncio
import contextlib
import logging import logging
import secrets import secrets
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from typing import Literal, Optional, overload from typing import Literal, overload
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.authentication import ( from starlette.authentication import (
@ -85,11 +86,14 @@ def truthy(s: str):
return bool(s) and s.lower() in {"1", "yes", "true"} return bool(s) and s.lower() in {"1", "yes", "true"}
def yearcomp(s: str): _Yearcomp = Literal["<", "=", ">"]
def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
if not s: if not s:
return return
comp: Literal["<", "=", ">"] = "=" comp: _Yearcomp = "="
if (prefix := s[0]) in "<=>": if (prefix := s[0]) in "<=>":
comp = prefix # type: ignore comp = prefix # type: ignore
s = s[len(prefix) :] s = s[len(prefix) :]
@ -97,7 +101,9 @@ def yearcomp(s: str):
return comp, int(s) return comp, int(s)
def as_int(x, *, max: int = None, min: Optional[int] = 1, default: int = None): def as_int(
x, *, max: int | None = None, min: int | None = 1, default: int | None = None
) -> int:
try: try:
if not isinstance(x, int): if not isinstance(x, int):
x = int(x) x = int(x)
@ -135,7 +141,7 @@ async def json_from_body(request, keys: list[str]) -> list:
... ...
async def json_from_body(request, keys: list[str] = None): async def json_from_body(request, keys: list[str] | None = None):
if not await request.body(): if not await request.body():
data = {} data = {}
@ -158,7 +164,7 @@ def is_admin(request):
return "admin" in request.auth.scopes return "admin" in request.auth.scopes
async def auth_user(request) -> Optional[User]: async def auth_user(request) -> User | None:
if not isinstance(request.user, AuthedUser): if not isinstance(request.user, AuthedUser):
return return
@ -176,7 +182,7 @@ async def auth_user(request) -> Optional[User]:
_routes = [] _routes = []
def route(path: str, *, methods: list[str] = None, **kwds): def route(path: str, *, methods: list[str] | None = None, **kwds):
def decorator(func): def decorator(func):
r = Route(path, func, methods=methods, **kwds) r = Route(path, func, methods=methods, **kwds)
_routes.append(r) _routes.append(r)
@ -190,7 +196,6 @@ route.registered = _routes
@route("/groups/{group_id}/ratings") @route("/groups/{group_id}/ratings")
async def get_ratings_for_group(request): async def get_ratings_for_group(request):
group_id = as_ulid(request.path_params["group_id"]) group_id = as_ulid(request.path_params["group_id"])
group = await db.get(Group, id=str(group_id)) group = await db.get(Group, id=str(group_id))
@ -251,7 +256,6 @@ def not_implemented():
@route("/movies") @route("/movies")
@requires(["authenticated"]) @requires(["authenticated"])
async def list_movies(request): async def list_movies(request):
params = request.query_params params = request.query_params
user = await auth_user(request) user = await auth_user(request)
@ -319,7 +323,6 @@ async def list_movies(request):
@route("/movies", methods=["POST"]) @route("/movies", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def add_movie(request): async def add_movie(request):
not_implemented() not_implemented()
@ -361,7 +364,6 @@ _import_lock = asyncio.Lock()
@route("/movies/_reload_imdb", methods=["POST"]) @route("/movies/_reload_imdb", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def load_imdb_movies(request): async def load_imdb_movies(request):
params = request.query_params params = request.query_params
force = truthy(params.get("force")) force = truthy(params.get("force"))
@ -384,7 +386,6 @@ async def load_imdb_movies(request):
@route("/users") @route("/users")
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def list_users(request): async def list_users(request):
users = await db.get_all(User) users = await db.get_all(User)
return JSONResponse([asplain(u) for u in users]) return JSONResponse([asplain(u) for u in users])
@ -393,7 +394,6 @@ async def list_users(request):
@route("/users", methods=["POST"]) @route("/users", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def add_user(request): async def add_user(request):
name, imdb_id = await json_from_body(request, ["name", "imdb_id"]) name, imdb_id = await json_from_body(request, ["name", "imdb_id"])
# XXX restrict name # XXX restrict name
@ -415,7 +415,6 @@ async def add_user(request):
@route("/users/{user_id}") @route("/users/{user_id}")
@requires(["authenticated"]) @requires(["authenticated"])
async def show_user(request): async def show_user(request):
user_id = as_ulid(request.path_params["user_id"]) user_id = as_ulid(request.path_params["user_id"])
if is_admin(request): if is_admin(request):
@ -444,7 +443,6 @@ async def show_user(request):
@route("/users/{user_id}", methods=["DELETE"]) @route("/users/{user_id}", methods=["DELETE"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def remove_user(request): async def remove_user(request):
user_id = as_ulid(request.path_params["user_id"]) user_id = as_ulid(request.path_params["user_id"])
user = await db.get(User, id=str(user_id)) user = await db.get(User, id=str(user_id))
@ -462,7 +460,6 @@ async def remove_user(request):
@route("/users/{user_id}", methods=["PATCH"]) @route("/users/{user_id}", methods=["PATCH"])
@requires(["authenticated"]) @requires(["authenticated"])
async def modify_user(request): async def modify_user(request):
user_id = as_ulid(request.path_params["user_id"]) user_id = as_ulid(request.path_params["user_id"])
if is_admin(request): if is_admin(request):
@ -510,7 +507,6 @@ async def modify_user(request):
@route("/users/{user_id}/groups", methods=["POST"]) @route("/users/{user_id}/groups", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def add_group_to_user(request): async def add_group_to_user(request):
user_id = as_ulid(request.path_params["user_id"]) user_id = as_ulid(request.path_params["user_id"])
user = await db.get(User, id=str(user_id)) user = await db.get(User, id=str(user_id))
@ -535,21 +531,18 @@ async def add_group_to_user(request):
@route("/users/{user_id}/ratings") @route("/users/{user_id}/ratings")
@requires(["private"]) @requires(["private"])
async def ratings_for_user(request): async def ratings_for_user(request):
not_implemented() not_implemented()
@route("/users/{user_id}/ratings", methods=["PUT"]) @route("/users/{user_id}/ratings", methods=["PUT"])
@requires("authenticated") @requires("authenticated")
async def set_rating_for_user(request): async def set_rating_for_user(request):
not_implemented() not_implemented()
@route("/users/_reload_ratings", methods=["POST"]) @route("/users/_reload_ratings", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def load_imdb_user_ratings(request): async def load_imdb_user_ratings(request):
ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()] ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()]
return JSONResponse({"new_ratings": [asplain(r) for r in ratings]}) return JSONResponse({"new_ratings": [asplain(r) for r in ratings]})
@ -558,7 +551,6 @@ async def load_imdb_user_ratings(request):
@route("/groups") @route("/groups")
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def list_groups(request): async def list_groups(request):
groups = await db.get_all(Group) groups = await db.get_all(Group)
return JSONResponse([asplain(g) for g in groups]) return JSONResponse([asplain(g) for g in groups])
@ -567,7 +559,6 @@ async def list_groups(request):
@route("/groups", methods=["POST"]) @route("/groups", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def add_group(request): async def add_group(request):
(name,) = await json_from_body(request, ["name"]) (name,) = await json_from_body(request, ["name"])
# XXX restrict name # XXX restrict name
@ -581,7 +572,6 @@ async def add_group(request):
@route("/groups/{group_id}/users", methods=["POST"]) @route("/groups/{group_id}/users", methods=["POST"])
@requires(["authenticated"]) @requires(["authenticated"])
async def add_user_to_group(request): async def add_user_to_group(request):
group_id = as_ulid(request.path_params["group_id"]) group_id = as_ulid(request.path_params["group_id"])
group = await db.get(Group, id=str(group_id)) group = await db.get(Group, id=str(group_id))
@ -623,6 +613,13 @@ def auth_error(request, err):
return unauthorized(str(err)) return unauthorized(str(err))
@contextlib.asynccontextmanager
async def lifespan(app: Starlette):
await open_connection_pool()
yield
await close_connection_pool()
def create_app(): def create_app():
if config.loglevel == "DEBUG": if config.loglevel == "DEBUG":
logging.basicConfig( logging.basicConfig(
@ -633,8 +630,7 @@ def create_app():
log.debug(f"Log level: {config.loglevel}") log.debug(f"Log level: {config.loglevel}")
return Starlette( return Starlette(
on_startup=[open_connection_pool], lifespan=lifespan,
on_shutdown=[close_connection_pool],
routes=[ routes=[
Mount(f"{config.api_base}v1", routes=route.registered), Mount(f"{config.api_base}v1", routes=route.registered),
], ],

View file

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Container, Iterable, Optional from typing import Container, Iterable
from . import imdb, models from . import imdb, models
@ -10,17 +10,17 @@ Score100 = int # [0, 100]
@dataclass @dataclass
class Rating: class Rating:
canonical_title: str canonical_title: str
imdb_score: Optional[Score100] imdb_score: Score100 | None
imdb_votes: Optional[int] imdb_votes: int | None
media_type: str media_type: str
movie_imdb_id: str movie_imdb_id: str
original_title: Optional[str] original_title: str | None
release_year: int release_year: int
user_id: Optional[str] user_id: str | None
user_score: Optional[Score100] user_score: Score100 | None
@classmethod @classmethod
def from_movie(cls, movie: models.Movie, *, rating: models.Rating = None): def from_movie(cls, movie: models.Movie, *, rating: models.Rating | None = None):
return cls( return cls(
canonical_title=movie.title, canonical_title=movie.title,
imdb_score=movie.imdb_score, imdb_score=movie.imdb_score,
@ -37,11 +37,11 @@ class Rating:
@dataclass @dataclass
class RatingAggregate: class RatingAggregate:
canonical_title: str canonical_title: str
imdb_score: Optional[Score100] imdb_score: Score100 | None
imdb_votes: Optional[int] imdb_votes: int | None
link: URL link: URL
media_type: str media_type: str
original_title: Optional[str] original_title: str | None
user_scores: list[Score100] user_scores: list[Score100]
year: int year: int