Merge branch 'feat/py311'
This commit is contained in:
commit
a020d972f8
23 changed files with 783 additions and 748 deletions
4
.git-blame-ignore-revs
Normal file
4
.git-blame-ignore-revs
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
# Apply Black v23.1.0 formatting changes.
|
||||
8a8bfce89de23d987386a35b659532bbac373788
|
||||
# Apply auto-formatting to tests.
|
||||
9ffcc5357150cecde26f5e6f8fccceaf92411efb
|
||||
16
Dockerfile
16
Dockerfile
|
|
@ -1,4 +1,4 @@
|
|||
FROM docker.io/library/python:3.10-alpine
|
||||
FROM docker.io/library/python:3.11-alpine
|
||||
|
||||
RUN apk update --no-cache \
|
||||
&& apk upgrade --no-cache \
|
||||
|
|
@ -11,20 +11,18 @@ WORKDIR /var/app
|
|||
|
||||
COPY requirements.txt ./
|
||||
|
||||
# Required to build greenlet on Alpine, dependency of SQLAlchemy 1.4.
|
||||
RUN apk add --no-cache \
|
||||
--virtual .build-deps \
|
||||
g++ gcc musl-dev \
|
||||
&& pip install --no-cache-dir --upgrade \
|
||||
--requirement requirements.txt \
|
||||
&& apk del .build-deps
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
--requirement requirements.txt
|
||||
|
||||
USER 10000:10001
|
||||
|
||||
COPY . ./
|
||||
|
||||
ENV UNWIND_DATA="/data"
|
||||
VOLUME ["/data"]
|
||||
VOLUME $UNWIND_DATA
|
||||
|
||||
ENV UNWIND_PORT=8097
|
||||
EXPOSE $UNWIND_PORT
|
||||
|
||||
ENTRYPOINT ["/var/app/run"]
|
||||
CMD ["server"]
|
||||
|
|
|
|||
817
poetry.lock
generated
817
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -6,26 +6,20 @@ authors = ["ducklet <ducklet@noreply.code.dumpr.org>"]
|
|||
license = "LOL"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
requests = "^2.25.1"
|
||||
python = "^3.11"
|
||||
beautifulsoup4 = "^4.9.3"
|
||||
html5lib = "^1.1"
|
||||
starlette = "^0.17.0"
|
||||
starlette = "^0.26"
|
||||
ulid-py = "^1.1.0"
|
||||
databases = {extras = ["sqlite"], version = "^0.6.1"}
|
||||
toml = "^0.10.2"
|
||||
uvicorn = "^0.19.0"
|
||||
|
||||
[tool.poetry.group.fixes.dependencies]
|
||||
# `databases` is having issues with new versions of SQLAlchemy 1.4,
|
||||
# and `greenlet` is also always a pain.
|
||||
SQLAlchemy = "1.4.25"
|
||||
greenlet = "1.1.2"
|
||||
databases = {extras = ["sqlite"], version = "^0.7.0"}
|
||||
uvicorn = "^0.21"
|
||||
httpx = "^0.23.3"
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
autoflake = "*"
|
||||
pytest = "*"
|
||||
pyright = "*"
|
||||
black = "*"
|
||||
|
|
@ -37,4 +31,14 @@ requires = ["poetry-core>=1.0.0"]
|
|||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pyright]
|
||||
pythonVersion = "3.10"
|
||||
pythonVersion = "3.11"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
||||
[tool.autoflake]
|
||||
remove-duplicate-keys = true
|
||||
remove-unused-variables = true
|
||||
remove-all-unused-imports = true
|
||||
ignore-init-module-imports = true
|
||||
ignore-pass-after-docstring = true
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ cd "$RUN_DIR"
|
|||
|
||||
[ -z "${DEBUG:-}" ] || set -x
|
||||
|
||||
isort --profile black unwind
|
||||
black unwind
|
||||
autoflake --quiet --check --recursive unwind tests
|
||||
isort unwind tests
|
||||
black unwind tests
|
||||
pyright
|
||||
|
|
|
|||
13
scripts/profile
Executable file
13
scripts/profile
Executable file
|
|
@ -0,0 +1,13 @@
|
|||
#!/bin/sh -eu
|
||||
|
||||
cd "$RUN_DIR"
|
||||
|
||||
outfile="profile-$(date '+%Y%m%d-%H%M%S').txt"
|
||||
|
||||
[ -z "${DEBUG:-}" ] || set -x
|
||||
|
||||
echo "# Writing profiler stats to: $outfile"
|
||||
python -m cProfile -o "$outfile" -m unwind "$@"
|
||||
|
||||
echo "# Loading stats file: $outfile"
|
||||
python -m pstats "$outfile"
|
||||
|
|
@ -1,7 +1,14 @@
|
|||
#!/bin/sh -eu
|
||||
|
||||
: "${UNWIND_PORT:=8097}"
|
||||
|
||||
cd "$RUN_DIR"
|
||||
|
||||
[ -z "${DEBUG:-}" ] || set -x
|
||||
|
||||
exec uvicorn --host 0.0.0.0 --factory unwind:create_app
|
||||
export UNWIND_PORT
|
||||
|
||||
exec uvicorn \
|
||||
--host 0.0.0.0 \
|
||||
--port "$UNWIND_PORT" \
|
||||
--factory unwind:create_app
|
||||
|
|
|
|||
|
|
@ -10,5 +10,6 @@ trap 'rm "$dbfile"' EXIT TERM INT QUIT
|
|||
|
||||
[ -z "${DEBUG:-}" ] || set -x
|
||||
|
||||
SQLALCHEMY_WARN_20=1 \
|
||||
UNWIND_STORAGE="$dbfile" \
|
||||
python -m pytest "$@"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from unwind import db
|
||||
|
||||
|
||||
|
|
@ -13,7 +15,7 @@ def event_loop():
|
|||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def shared_conn():
|
||||
c = db.shared_connection()
|
||||
await c.connect()
|
||||
|
|
@ -24,7 +26,7 @@ async def shared_conn():
|
|||
await c.disconnect()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest_asyncio.fixture
|
||||
async def conn(shared_conn):
|
||||
async with shared_conn.transaction(force_rollback=True):
|
||||
yield shared_conn
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from unwind import db, models, web_models
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def test_add_and_get(shared_conn):
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_and_get(shared_conn: db.Database):
|
||||
async with shared_conn.transaction(force_rollback=True):
|
||||
|
||||
m1 = models.Movie(
|
||||
title="test movie",
|
||||
release_year=2013,
|
||||
|
|
@ -31,9 +30,9 @@ async def test_add_and_get(shared_conn):
|
|||
assert m2 == await db.get(models.Movie, id=str(m2.id))
|
||||
|
||||
|
||||
async def test_find_ratings(shared_conn):
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_ratings(shared_conn: db.Database):
|
||||
async with shared_conn.transaction(force_rollback=True):
|
||||
|
||||
m1 = models.Movie(
|
||||
title="test movie",
|
||||
release_year=2013,
|
||||
|
|
@ -157,4 +156,3 @@ async def test_find_ratings(shared_conn):
|
|||
rows = await db.find_ratings(title="test", include_unrated=True)
|
||||
ratings = tuple(web_models.Rating(**r) for r in rows)
|
||||
assert (web_models.Rating.from_movie(m1),) == ratings
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
import pytest
|
||||
|
||||
from unwind.imdb import imdb_rating_from_score, score_from_imdb_rating
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rating", (x / 10 for x in range(10, 101)))
|
||||
def test_rating_conversion(rating):
|
||||
def test_rating_conversion(rating: float):
|
||||
assert rating == imdb_rating_from_score(score_from_imdb_rating(rating))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("score", range(0, 101))
|
||||
def test_score_conversion(score):
|
||||
def test_score_conversion(score: int):
|
||||
# Because our score covers 101 discrete values and IMDb's rating only 91
|
||||
# discrete values, the mapping is non-injective, i.e. 10 values can't be
|
||||
# mapped uniquely.
|
||||
|
|
|
|||
|
|
@ -1,18 +1,14 @@
|
|||
from starlette.testclient import TestClient
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from unwind import create_app
|
||||
from unwind import db, models, imdb
|
||||
|
||||
# https://pypi.org/project/pytest-asyncio/
|
||||
pytestmark = pytest.mark.asyncio
|
||||
from unwind import create_app, db, imdb, models
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
async def test_app(shared_conn):
|
||||
@pytest.mark.asyncio
|
||||
async def test_app(shared_conn: db.Database):
|
||||
async with shared_conn.transaction(force_rollback=True):
|
||||
|
||||
# https://www.starlette.io/testclient/
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/movies")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
|||
from . import config
|
||||
from .db import close_connection_pool, open_connection_pool
|
||||
from .imdb import refresh_user_ratings_from_imdb
|
||||
from .imdb_import import import_from_file
|
||||
from .imdb_import import download_datasets, import_from_file
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -15,7 +15,7 @@ async def run_load_user_ratings_from_imdb():
|
|||
await open_connection_pool()
|
||||
|
||||
i = 0
|
||||
async for rating in refresh_user_ratings_from_imdb():
|
||||
async for _ in refresh_user_ratings_from_imdb():
|
||||
i += 1
|
||||
|
||||
log.info("✨ Imported %s new ratings.", i)
|
||||
|
|
@ -31,6 +31,10 @@ async def run_import_imdb_dataset(basics_path: Path, ratings_path: Path):
|
|||
await close_connection_pool()
|
||||
|
||||
|
||||
async def run_download_imdb_dataset(basics_path: Path, ratings_path: Path):
|
||||
await download_datasets(basics_path=basics_path, ratings_path=ratings_path)
|
||||
|
||||
|
||||
def getargs():
|
||||
parser = argparse.ArgumentParser()
|
||||
commands = parser.add_subparsers(required=True)
|
||||
|
|
@ -55,6 +59,25 @@ def getargs():
|
|||
"--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True
|
||||
)
|
||||
|
||||
parser_download_imdb_dataset = commands.add_parser(
|
||||
"download-imdb-dataset",
|
||||
help="Download IMDb datasets.",
|
||||
description="""
|
||||
Download IMDb datasets.
|
||||
""",
|
||||
)
|
||||
parser_download_imdb_dataset.add_argument(
|
||||
dest="mode",
|
||||
action="store_const",
|
||||
const="download-imdb-dataset",
|
||||
)
|
||||
parser_download_imdb_dataset.add_argument(
|
||||
"--basics", metavar="basics_file.tsv.gz", type=Path, required=True
|
||||
)
|
||||
parser_download_imdb_dataset.add_argument(
|
||||
"--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True
|
||||
)
|
||||
|
||||
parser_load_user_ratings_from_imdb = commands.add_parser(
|
||||
"load-user-ratings-from-imdb",
|
||||
help="Load user ratings from imdb.com.",
|
||||
|
|
@ -94,6 +117,8 @@ def main():
|
|||
asyncio.run(run_load_user_ratings_from_imdb())
|
||||
elif args.mode == "import-imdb-dataset":
|
||||
asyncio.run(run_import_imdb_dataset(args.basics, args.ratings))
|
||||
elif args.mode == "download-imdb-dataset":
|
||||
asyncio.run(run_download_imdb_dataset(args.basics, args.ratings))
|
||||
|
||||
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
import os
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
|
||||
import toml
|
||||
|
||||
datadir = Path(os.getenv("UNWIND_DATA") or "./data")
|
||||
cachedir = (
|
||||
Path(cachedir)
|
||||
|
|
@ -14,7 +13,8 @@ loglevel = os.getenv("UNWIND_LOGLEVEL") or ("DEBUG" if debug else "INFO")
|
|||
storage_path = os.getenv("UNWIND_STORAGE", datadir / "db.sqlite")
|
||||
config_path = os.getenv("UNWIND_CONFIG", datadir / "config.toml")
|
||||
|
||||
_config = toml.load(config_path)
|
||||
with open(config_path, "rb") as fd:
|
||||
_config = tomllib.load(fd)
|
||||
|
||||
api_base = _config["api"].get("base", "/api/")
|
||||
api_cors = _config["api"].get("cors", "*")
|
||||
|
|
|
|||
63
unwind/db.py
63
unwind/db.py
|
|
@ -4,7 +4,7 @@ import logging
|
|||
import re
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Iterable, Literal, Type, TypeVar
|
||||
|
||||
import sqlalchemy
|
||||
from databases import Database
|
||||
|
|
@ -26,7 +26,7 @@ from .types import ULID
|
|||
log = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
|
||||
_shared_connection: Optional[Database] = None
|
||||
_shared_connection: Database | None = None
|
||||
|
||||
|
||||
async def open_connection_pool() -> None:
|
||||
|
|
@ -119,7 +119,6 @@ async def apply_db_patches(db: Database):
|
|||
raise RuntimeError("No statement found.")
|
||||
|
||||
async with db.transaction():
|
||||
|
||||
for query in queries:
|
||||
await db.execute(query)
|
||||
|
||||
|
|
@ -131,12 +130,12 @@ async def apply_db_patches(db: Database):
|
|||
await db.execute("vacuum")
|
||||
|
||||
|
||||
async def get_import_progress() -> Optional[Progress]:
|
||||
async def get_import_progress() -> Progress | None:
|
||||
"""Return the latest import progress."""
|
||||
return await get(Progress, type="import-imdb-movies", order_by="started DESC")
|
||||
|
||||
|
||||
async def stop_import_progress(*, error: BaseException = None):
|
||||
async def stop_import_progress(*, error: BaseException | None = None):
|
||||
"""Stop the current import.
|
||||
|
||||
If an error is given, it will be logged to the progress state.
|
||||
|
|
@ -176,6 +175,8 @@ async def set_import_progress(progress: float) -> Progress:
|
|||
else:
|
||||
await add(current)
|
||||
|
||||
return current
|
||||
|
||||
|
||||
_lock = threading.Lock()
|
||||
_prelock = threading.Lock()
|
||||
|
|
@ -243,8 +244,8 @@ ModelType = TypeVar("ModelType")
|
|||
|
||||
|
||||
async def get(
|
||||
model: Type[ModelType], *, order_by: str = None, **kwds
|
||||
) -> Optional[ModelType]:
|
||||
model: Type[ModelType], *, order_by: str | None = None, **kwds
|
||||
) -> ModelType | None:
|
||||
"""Load a model instance from the database.
|
||||
|
||||
Passing `kwds` allows to filter the instance to load. You have to encode the
|
||||
|
|
@ -262,7 +263,7 @@ async def get(
|
|||
query += f" ORDER BY {order_by}"
|
||||
async with locked_connection() as conn:
|
||||
row = await conn.fetch_one(query=query, values=values)
|
||||
return fromplain(model, row, serialized=True) if row else None
|
||||
return fromplain(model, row._mapping, serialized=True) if row else None
|
||||
|
||||
|
||||
async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
||||
|
|
@ -282,7 +283,7 @@ async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
|||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
||||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(query=query, values=values)
|
||||
return (fromplain(model, row, serialized=True) for row in rows)
|
||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||
|
||||
|
||||
async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
||||
|
|
@ -293,7 +294,7 @@ async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
|||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
||||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(query=query, values=values)
|
||||
return (fromplain(model, row, serialized=True) for row in rows)
|
||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||
|
||||
|
||||
async def update(item):
|
||||
|
|
@ -406,16 +407,16 @@ def sql_escape(s: str, char="#"):
|
|||
|
||||
async def find_ratings(
|
||||
*,
|
||||
title: str = None,
|
||||
media_type: str = None,
|
||||
title: str | None = None,
|
||||
media_type: str | None = None,
|
||||
exact: bool = False,
|
||||
ignore_tv_episodes: bool = False,
|
||||
include_unrated: bool = False,
|
||||
yearcomp: tuple[Literal["<", "=", ">"], int] = None,
|
||||
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
|
||||
limit_rows: int = 10,
|
||||
user_ids: Iterable[str] = [],
|
||||
):
|
||||
values: dict[str, Union[int, str]] = {
|
||||
values: dict[str, int | str] = {
|
||||
"limit_rows": limit_rows,
|
||||
}
|
||||
|
||||
|
|
@ -466,7 +467,7 @@ async def find_ratings(
|
|||
"""
|
||||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(bindparams(query, values))
|
||||
movie_ids = tuple(r["movie_id"] for r in rows)
|
||||
movie_ids = tuple(r._mapping["movie_id"] for r in rows)
|
||||
|
||||
if include_unrated and len(movie_ids) < limit_rows:
|
||||
sqlin, sqlin_vals = sql_in("id", movie_ids, not_=True)
|
||||
|
|
@ -485,7 +486,7 @@ async def find_ratings(
|
|||
{**values, **sqlin_vals, "limit_rows": limit_rows - len(movie_ids)},
|
||||
)
|
||||
)
|
||||
movie_ids += tuple(r["movie_id"] for r in rows)
|
||||
movie_ids += tuple(r._mapping["movie_id"] for r in rows)
|
||||
|
||||
return await ratings_for_movie_ids(ids=movie_ids)
|
||||
|
||||
|
|
@ -527,29 +528,13 @@ async def ratings_for_movie_ids(
|
|||
|
||||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(bindparams(query, vals))
|
||||
return tuple(dict(r) for r in rows)
|
||||
return tuple(dict(r._mapping) for r in rows)
|
||||
|
||||
|
||||
def sql_fields(tp: Type):
|
||||
return (f"{tp._table}.{f.name}" for f in fields(tp))
|
||||
|
||||
|
||||
def sql_fieldmap(tp: Type):
|
||||
"""-> {alias: (table, field_name)}"""
|
||||
return {f"{tp._table}_{f.name}": (tp._table, f.name) for f in fields(tp)}
|
||||
|
||||
|
||||
def mux(*tps: Type):
|
||||
return ", ".join(
|
||||
f"{t}.{n} AS {k}" for tp in tps for k, (t, n) in sql_fieldmap(tp).items()
|
||||
)
|
||||
|
||||
|
||||
def demux(tp: Type[ModelType], row) -> ModelType:
|
||||
d = {n: row[k] for k, (_, n) in sql_fieldmap(tp).items()}
|
||||
return fromplain(tp, d, serialized=True)
|
||||
|
||||
|
||||
def sql_in(column: str, values: Iterable[T], not_=False) -> tuple[str, dict[str, T]]:
|
||||
c = column.replace(".", "___")
|
||||
value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)}
|
||||
|
|
@ -583,22 +568,22 @@ async def ratings_for_movies(
|
|||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(query, values)
|
||||
|
||||
return (fromplain(Rating, row, serialized=True) for row in rows)
|
||||
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
|
||||
|
||||
|
||||
async def find_movies(
|
||||
*,
|
||||
title: str = None,
|
||||
media_type: str = None,
|
||||
title: str | None = None,
|
||||
media_type: str | None = None,
|
||||
exact: bool = False,
|
||||
ignore_tv_episodes: bool = False,
|
||||
yearcomp: tuple[Literal["<", "=", ">"], int] = None,
|
||||
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
|
||||
limit_rows: int = 10,
|
||||
skip_rows: int = 0,
|
||||
include_unrated: bool = False,
|
||||
user_ids: list[ULID] = [],
|
||||
) -> Iterable[tuple[Movie, list[Rating]]]:
|
||||
values: dict[str, Union[int, str]] = {
|
||||
values: dict[str, int | str] = {
|
||||
"limit_rows": limit_rows,
|
||||
"skip_rows": skip_rows,
|
||||
}
|
||||
|
|
@ -650,7 +635,7 @@ async def find_movies(
|
|||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(bindparams(query, values))
|
||||
|
||||
movies = [fromplain(Movie, row, serialized=True) for row in rows]
|
||||
movies = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
|
||||
|
||||
if not user_ids:
|
||||
return ((m, []) for m in movies)
|
||||
|
|
|
|||
|
|
@ -2,12 +2,11 @@ import logging
|
|||
import re
|
||||
from collections import namedtuple
|
||||
from datetime import datetime
|
||||
from typing import Optional, Tuple
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from . import db
|
||||
from .models import Movie, Rating, User
|
||||
from .request import cache_path, session, soup_from_url
|
||||
from .request import asession, asoup_from_url, cache_path
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -35,13 +34,11 @@ log = logging.getLogger(__name__)
|
|||
# p.text-muted.text-small span[name=nv] [data-value]
|
||||
|
||||
|
||||
async def refresh_user_ratings_from_imdb(stop_on_dupe=True):
|
||||
|
||||
with session() as s:
|
||||
async def refresh_user_ratings_from_imdb(stop_on_dupe: bool = True):
|
||||
async with asession() as s:
|
||||
s.headers["Accept-Language"] = "en-US, en;q=0.5"
|
||||
|
||||
for user in await db.get_all(User):
|
||||
|
||||
log.info("⚡️ Loading data for %s ...", user.name)
|
||||
|
||||
try:
|
||||
|
|
@ -98,7 +95,6 @@ find_movie_id = re.compile(r"/title/(?P<id>tt\d+)/").search
|
|||
|
||||
|
||||
def movie_and_rating_from_item(item) -> tuple[Movie, Rating]:
|
||||
|
||||
genres = (genre := item.find("span", "genre")) and genre.string or ""
|
||||
movie = Movie(
|
||||
title=item.h3.a.string.strip(),
|
||||
|
|
@ -153,10 +149,10 @@ def movie_and_rating_from_item(item) -> tuple[Movie, Rating]:
|
|||
ForgedRequest = namedtuple("ForgedRequest", "url headers")
|
||||
|
||||
|
||||
async def parse_page(url) -> Tuple[list[Rating], Optional[str]]:
|
||||
async def parse_page(url: str) -> tuple[list[Rating], str | None]:
|
||||
ratings = []
|
||||
|
||||
soup = soup_from_url(url)
|
||||
soup = await asoup_from_url(url)
|
||||
|
||||
meta = soup.find("meta", property="pageId")
|
||||
headline = soup.h1
|
||||
|
|
@ -170,7 +166,6 @@ async def parse_page(url) -> Tuple[list[Rating], Optional[str]]:
|
|||
|
||||
items = soup.find_all("div", "lister-item-content")
|
||||
for i, item in enumerate(items):
|
||||
|
||||
try:
|
||||
movie, rating = movie_and_rating_from_item(item)
|
||||
except Exception as err:
|
||||
|
|
@ -196,11 +191,10 @@ async def parse_page(url) -> Tuple[list[Rating], Optional[str]]:
|
|||
return (ratings, next_url if url != next_url else None)
|
||||
|
||||
|
||||
async def load_ratings(user_id):
|
||||
async def load_ratings(user_id: str):
|
||||
next_url = user_ratings_url(user_id)
|
||||
|
||||
while next_url:
|
||||
|
||||
ratings, next_url = await parse_page(next_url)
|
||||
|
||||
for i, rating in enumerate(ratings):
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import asyncio
|
||||
import csv
|
||||
import gzip
|
||||
import logging
|
||||
from dataclasses import dataclass, fields
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Generator, Literal, Optional, Type, TypeVar, overload
|
||||
from typing import Generator, Literal, Type, TypeVar, overload
|
||||
|
||||
from . import config, db, request
|
||||
from .db import add_or_update_many_movies
|
||||
|
|
@ -27,10 +28,10 @@ class BasicRow:
|
|||
primaryTitle: str
|
||||
originalTitle: str
|
||||
isAdult: bool
|
||||
startYear: Optional[int]
|
||||
endYear: Optional[int]
|
||||
runtimeMinutes: Optional[int]
|
||||
genres: Optional[set[str]]
|
||||
startYear: int | None
|
||||
endYear: int | None
|
||||
runtimeMinutes: int | None
|
||||
genres: set[str] | None
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row):
|
||||
|
|
@ -100,7 +101,7 @@ title_types = {
|
|||
}
|
||||
|
||||
|
||||
def gz_mtime(path) -> datetime:
|
||||
def gz_mtime(path: Path) -> datetime:
|
||||
"""Return the timestamp of the compressed file."""
|
||||
g = gzip.GzipFile(path, "rb")
|
||||
g.peek(1) # start reading the file to fill the timestamp field
|
||||
|
|
@ -108,14 +109,13 @@ def gz_mtime(path) -> datetime:
|
|||
return datetime.fromtimestamp(g.mtime).replace(tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def count_lines(path) -> int:
|
||||
def count_lines(path: Path) -> int:
|
||||
i = 0
|
||||
|
||||
one_mb = 2 ** 20
|
||||
one_mb = 2**20
|
||||
buf_size = 8 * one_mb # 8 MiB seems to give a good read/process performance.
|
||||
|
||||
with gzip.open(path, "rt") as f:
|
||||
|
||||
while buf := f.read(buf_size):
|
||||
i += buf.count("\n")
|
||||
|
||||
|
|
@ -124,19 +124,19 @@ def count_lines(path) -> int:
|
|||
|
||||
@overload
|
||||
def read_imdb_tsv(
|
||||
path, row_type, *, unpack: Literal[False]
|
||||
path: Path, row_type, *, unpack: Literal[False]
|
||||
) -> Generator[list[str], None, None]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def read_imdb_tsv(
|
||||
path, row_type: Type[T], *, unpack: Literal[True] = True
|
||||
path: Path, row_type: Type[T], *, unpack: Literal[True] = True
|
||||
) -> Generator[T, None, None]:
|
||||
...
|
||||
|
||||
|
||||
def read_imdb_tsv(path, row_type, *, unpack=True):
|
||||
def read_imdb_tsv(path: Path, row_type, *, unpack=True):
|
||||
with gzip.open(path, "rt", newline="") as f:
|
||||
rows = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
|
||||
|
||||
|
|
@ -161,7 +161,7 @@ def read_imdb_tsv(path, row_type, *, unpack=True):
|
|||
raise
|
||||
|
||||
|
||||
def read_ratings(path):
|
||||
def read_ratings(path: Path):
|
||||
mtime = gz_mtime(path)
|
||||
rows = read_imdb_tsv(path, RatingRow)
|
||||
|
||||
|
|
@ -171,19 +171,20 @@ def read_ratings(path):
|
|||
yield m
|
||||
|
||||
|
||||
def read_ratings_as_mapping(path):
|
||||
def read_ratings_as_mapping(path: Path):
|
||||
"""Optimized function to quickly load all ratings."""
|
||||
rows = read_imdb_tsv(path, RatingRow, unpack=False)
|
||||
return {r[0]: (round(100 * (float(r[1]) - 1) / 9), int(r[2])) for r in rows}
|
||||
|
||||
|
||||
def read_basics(path):
|
||||
def read_basics(path: Path) -> Generator[Movie | None, None, None]:
|
||||
mtime = gz_mtime(path)
|
||||
rows = read_imdb_tsv(path, BasicRow)
|
||||
|
||||
for row in rows:
|
||||
if row.startYear is None:
|
||||
log.debug("Skipping movie, missing year: %s", row)
|
||||
yield None
|
||||
continue
|
||||
|
||||
m = row.as_movie()
|
||||
|
|
@ -197,20 +198,24 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path):
|
|||
|
||||
log.info("💾 Importing movies ...")
|
||||
total = count_lines(basics_path)
|
||||
assert total != 0
|
||||
log.debug("Found %i movies.", total)
|
||||
if total == 0:
|
||||
raise RuntimeError(f"No movies found.")
|
||||
perc_next_report = 0.0
|
||||
perc_step = 0.1
|
||||
|
||||
chunk = []
|
||||
|
||||
for i, m in enumerate(read_basics(basics_path)):
|
||||
|
||||
perc = 100 * i / total
|
||||
if perc >= perc_next_report:
|
||||
await db.set_import_progress(perc)
|
||||
log.info("⏳ Imported %s%%", round(perc, 1))
|
||||
perc_next_report += perc_step
|
||||
|
||||
if m is None:
|
||||
continue
|
||||
|
||||
if m.media_type not in {
|
||||
"Movie",
|
||||
"Short",
|
||||
|
|
@ -235,10 +240,27 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path):
|
|||
await add_or_update_many_movies(chunk)
|
||||
chunk = []
|
||||
|
||||
log.info("👍 Imported 100%")
|
||||
await db.set_import_progress(100)
|
||||
|
||||
|
||||
async def load_from_web(*, force: bool = False):
|
||||
async def download_datasets(*, basics_path: Path, ratings_path: Path) -> None:
|
||||
"""Download IMDb movie database dumps.
|
||||
|
||||
See https://www.imdb.com/interfaces/ and https://datasets.imdbws.com/ for
|
||||
more information on the IMDb database dumps.
|
||||
"""
|
||||
basics_url = "https://datasets.imdbws.com/title.basics.tsv.gz"
|
||||
ratings_url = "https://datasets.imdbws.com/title.ratings.tsv.gz"
|
||||
|
||||
async with request.asession():
|
||||
await asyncio.gather(
|
||||
request.adownload(ratings_url, to_path=ratings_path, only_if_newer=True),
|
||||
request.adownload(basics_url, to_path=basics_path, only_if_newer=True),
|
||||
)
|
||||
|
||||
|
||||
async def load_from_web(*, force: bool = False) -> None:
|
||||
"""Refresh the full IMDb movie database.
|
||||
|
||||
The latest dumps are first downloaded and then imported into the database.
|
||||
|
|
@ -251,17 +273,13 @@ async def load_from_web(*, force: bool = False):
|
|||
await db.set_import_progress(0)
|
||||
|
||||
try:
|
||||
basics_url = "https://datasets.imdbws.com/title.basics.tsv.gz"
|
||||
ratings_url = "https://datasets.imdbws.com/title.ratings.tsv.gz"
|
||||
ratings_file = config.datadir / "imdb/title.ratings.tsv.gz"
|
||||
basics_file = config.datadir / "imdb/title.basics.tsv.gz"
|
||||
|
||||
ratings_mtime = ratings_file.stat().st_mtime if ratings_file.exists() else None
|
||||
bastics_mtime = basics_file.stat().st_mtime if basics_file.exists() else None
|
||||
|
||||
with request.session():
|
||||
request.download(ratings_url, ratings_file, only_if_newer=True)
|
||||
request.download(basics_url, basics_file, only_if_newer=True)
|
||||
await download_datasets(basics_path=basics_file, ratings_path=ratings_file)
|
||||
|
||||
is_changed = (
|
||||
ratings_mtime != ratings_file.stat().st_mtime
|
||||
|
|
|
|||
|
|
@ -3,13 +3,14 @@ from dataclasses import dataclass, field
|
|||
from dataclasses import fields as _fields
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from types import UnionType
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
ClassVar,
|
||||
Container,
|
||||
Literal,
|
||||
Optional,
|
||||
Mapping,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
|
|
@ -19,13 +20,13 @@ from typing import (
|
|||
|
||||
from .types import ULID
|
||||
|
||||
JSON = Union[int, float, str, None, list["JSON"], dict[str, "JSON"]]
|
||||
JSON = int | float | str | None | list["JSON"] | dict[str, "JSON"]
|
||||
JSONObject = dict[str, JSON]
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def annotations(tp: Type) -> Optional[tuple]:
|
||||
def annotations(tp: Type) -> tuple | None:
|
||||
return tp.__metadata__ if hasattr(tp, "__metadata__") else None
|
||||
|
||||
|
||||
|
|
@ -42,7 +43,6 @@ def fields(class_or_instance):
|
|||
# XXX this might be a little slow (not sure), if so, memoize
|
||||
|
||||
for f in _fields(class_or_instance):
|
||||
|
||||
if f.name == "_is_lazy":
|
||||
continue
|
||||
|
||||
|
|
@ -54,21 +54,21 @@ def fields(class_or_instance):
|
|||
|
||||
def is_optional(tp: Type) -> bool:
|
||||
"""Return wether the given type is optional."""
|
||||
if get_origin(tp) is not Union:
|
||||
if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
|
||||
return False
|
||||
|
||||
args = get_args(tp)
|
||||
return len(args) == 2 and type(None) in args
|
||||
|
||||
|
||||
def optional_type(tp: Type) -> Optional[Type]:
|
||||
def optional_type(tp: Type) -> Type | None:
|
||||
"""Return the wrapped type from an optional type.
|
||||
|
||||
For example this will return `int` for `Optional[int]`.
|
||||
Since they're equivalent this also works for other optioning notations, like
|
||||
`Union[int, None]` and `int | None`.
|
||||
"""
|
||||
if get_origin(tp) is not Union:
|
||||
if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
|
||||
return None
|
||||
|
||||
args = get_args(tp)
|
||||
|
|
@ -92,7 +92,7 @@ def _id(x: T) -> T:
|
|||
|
||||
|
||||
def asplain(
|
||||
o: object, *, filter_fields: Container[str] = None, serialize: bool = False
|
||||
o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""Return the given model instance as `dict` with JSON compatible plain datatypes.
|
||||
|
||||
|
|
@ -109,7 +109,6 @@ def asplain(
|
|||
|
||||
d: JSONObject = {}
|
||||
for f in fields(o):
|
||||
|
||||
if filter_fields is not None and f.name not in filter_fields:
|
||||
continue
|
||||
|
||||
|
|
@ -146,7 +145,7 @@ def asplain(
|
|||
return d
|
||||
|
||||
|
||||
def fromplain(cls: Type[T], d: dict[str, Any], *, serialized: bool = False) -> T:
|
||||
def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
|
||||
"""Return an instance of the given model using the given data.
|
||||
|
||||
If `serialized` is `True`, collection types (lists, dicts, etc.) will be
|
||||
|
|
@ -157,7 +156,6 @@ def fromplain(cls: Type[T], d: dict[str, Any], *, serialized: bool = False) -> T
|
|||
|
||||
dd: JSONObject = {}
|
||||
for f in fields(cls):
|
||||
|
||||
target = f.type
|
||||
otype = optional_type(f.type)
|
||||
is_opt = otype is not None
|
||||
|
|
@ -188,7 +186,8 @@ def validate(o: object) -> None:
|
|||
vtype = type(getattr(o, f.name))
|
||||
if vtype is not f.type:
|
||||
if get_origin(f.type) is vtype or (
|
||||
get_origin(f.type) is Union and vtype in get_args(f.type)
|
||||
(isinstance(f.type, UnionType) or get_origin(f.type) is Union)
|
||||
and vtype in get_args(f.type)
|
||||
):
|
||||
continue
|
||||
raise ValueError(f"Invalid value type: {f.name}: {vtype}")
|
||||
|
|
@ -206,7 +205,7 @@ class Progress:
|
|||
type: str = None
|
||||
state: str = None
|
||||
started: datetime = field(default_factory=utcnow)
|
||||
stopped: Optional[str] = None
|
||||
stopped: str | None = None
|
||||
|
||||
@property
|
||||
def _state(self) -> dict:
|
||||
|
|
@ -243,15 +242,15 @@ class Movie:
|
|||
|
||||
id: ULID = field(default_factory=ULID)
|
||||
title: str = None # canonical title (usually English)
|
||||
original_title: Optional[
|
||||
str
|
||||
] = None # original title (usually transscribed to latin script)
|
||||
original_title: str | None = (
|
||||
None # original title (usually transscribed to latin script)
|
||||
)
|
||||
release_year: int = None # canonical release date
|
||||
media_type: str = None
|
||||
imdb_id: str = None
|
||||
imdb_score: Optional[int] = None # range: [0,100]
|
||||
imdb_votes: Optional[int] = None
|
||||
runtime: Optional[int] = None # minutes
|
||||
imdb_score: int | None = None # range: [0,100]
|
||||
imdb_votes: int | None = None
|
||||
runtime: int | None = None # minutes
|
||||
genres: set[str] = None
|
||||
created: datetime = field(default_factory=utcnow)
|
||||
updated: datetime = field(default_factory=utcnow)
|
||||
|
|
@ -292,7 +291,7 @@ dataclass containing the ID of the linked data.
|
|||
The contents of the Relation are ignored or discarded when using
|
||||
`asplain`, `fromplain`, and `validate`.
|
||||
"""
|
||||
Relation = Annotated[Optional[T], _RelationSentinel]
|
||||
Relation = Annotated[T | None, _RelationSentinel]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -309,8 +308,8 @@ class Rating:
|
|||
|
||||
score: int = None # range: [0,100]
|
||||
rating_date: datetime = None
|
||||
favorite: Optional[bool] = None
|
||||
finished: Optional[bool] = None
|
||||
favorite: bool | None = None
|
||||
finished: bool | None = None
|
||||
|
||||
def __eq__(self, other):
|
||||
"""Return wether two Ratings are equal.
|
||||
|
|
@ -342,11 +341,11 @@ class User:
|
|||
secret: str = None
|
||||
groups: list[dict[str, str]] = field(default_factory=list)
|
||||
|
||||
def has_access(self, group_id: Union[ULID, str], access: Access = "r"):
|
||||
def has_access(self, group_id: ULID | str, access: Access = "r"):
|
||||
group_id = group_id if isinstance(group_id, str) else str(group_id)
|
||||
return any(g["id"] == group_id and access == g["access"] for g in self.groups)
|
||||
|
||||
def set_access(self, group_id: Union[ULID, str], access: Access):
|
||||
def set_access(self, group_id: ULID | str, access: Access):
|
||||
group_id = group_id if isinstance(group_id, str) else str(group_id)
|
||||
for g in self.groups:
|
||||
if g["id"] == group_id:
|
||||
|
|
|
|||
|
|
@ -4,19 +4,17 @@ import logging
|
|||
import os
|
||||
import tempfile
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from random import random
|
||||
from time import sleep, time
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Callable, ParamSpec, TypeVar, cast
|
||||
|
||||
import bs4
|
||||
import requests
|
||||
from requests.status_codes import codes
|
||||
from urllib3.util.retry import Retry
|
||||
import httpx
|
||||
|
||||
from . import config
|
||||
|
||||
|
|
@ -26,28 +24,17 @@ if config.debug and config.cachedir:
|
|||
config.cachedir.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
def set_retries(s: requests.Session, n: int, backoff_factor: float = 0.2):
|
||||
retry = (
|
||||
Retry(
|
||||
total=n,
|
||||
connect=n,
|
||||
read=n,
|
||||
status=n,
|
||||
status_forcelist=Retry.RETRY_AFTER_STATUS_CODES,
|
||||
backoff_factor=backoff_factor,
|
||||
)
|
||||
if n
|
||||
else Retry(0, read=False)
|
||||
)
|
||||
for a in s.adapters.values():
|
||||
a.max_retries = retry
|
||||
_shared_asession = None
|
||||
|
||||
_ASession_T = httpx.AsyncClient
|
||||
_Response_T = httpx.Response
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
_shared_session = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session():
|
||||
@asynccontextmanager
|
||||
async def asession():
|
||||
"""Return the shared request session.
|
||||
|
||||
The session is shared by all request functions and provides cookie
|
||||
|
|
@ -55,38 +42,34 @@ def session():
|
|||
Opening the session before making a request allows you to set headers
|
||||
or change the retry behavior.
|
||||
"""
|
||||
global _shared_session
|
||||
global _shared_asession
|
||||
|
||||
if _shared_session:
|
||||
yield _shared_session
|
||||
if _shared_asession:
|
||||
yield _shared_asession
|
||||
return
|
||||
|
||||
_shared_session = Session()
|
||||
_shared_asession = _ASession_T()
|
||||
_shared_asession.headers[
|
||||
"user-agent"
|
||||
] = "Mozilla/5.0 Gecko/20100101 unwind/20230203"
|
||||
try:
|
||||
yield _shared_session
|
||||
async with _shared_asession:
|
||||
yield _shared_asession
|
||||
finally:
|
||||
_shared_session = None
|
||||
_shared_asession = None
|
||||
|
||||
|
||||
def Session() -> requests.Session:
|
||||
s = requests.Session()
|
||||
s.headers["User-Agent"] = "Mozilla/5.0 Gecko/20100101 unwind/20210506"
|
||||
return s
|
||||
|
||||
|
||||
def throttle(
|
||||
times: int, per_seconds: float, jitter: Callable[[], float] = None
|
||||
) -> Callable[[Callable], Callable]:
|
||||
|
||||
calls: Deque[float] = deque(maxlen=times)
|
||||
def _throttle(
|
||||
times: int, per_seconds: float, jitter: Callable[[], float] | None = None
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
calls: deque[float] = deque(maxlen=times)
|
||||
|
||||
if jitter is None:
|
||||
jitter = lambda: 0.0
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
@wraps(func)
|
||||
def inner(*args, **kwds):
|
||||
|
||||
def inner(*args: _P.args, **kwds: _P.kwargs):
|
||||
# clean up
|
||||
while calls:
|
||||
if calls[0] + per_seconds > time():
|
||||
|
|
@ -118,23 +101,19 @@ def throttle(
|
|||
return decorator
|
||||
|
||||
|
||||
class CachedStr(str):
|
||||
is_cached = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedResponse:
|
||||
class _CachedResponse:
|
||||
is_cached = True
|
||||
status_code: int
|
||||
text: str
|
||||
url: str
|
||||
headers: dict[str, str] = None
|
||||
headers: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def json(self):
|
||||
return json.loads(self.text)
|
||||
|
||||
|
||||
class RedirectError(RuntimeError):
|
||||
class _RedirectError(RuntimeError):
|
||||
def __init__(self, from_url: str, to_url: str, is_cached=False):
|
||||
self.from_url = from_url
|
||||
self.to_url = to_url
|
||||
|
|
@ -142,44 +121,51 @@ class RedirectError(RuntimeError):
|
|||
super().__init__(f"Redirected: {from_url} -> {to_url}")
|
||||
|
||||
|
||||
def cache_path(req) -> Optional[Path]:
|
||||
def cache_path(req) -> Path | None:
|
||||
if not config.cachedir:
|
||||
return
|
||||
sig = repr(req.url) # + repr(sorted(req.headers.items()))
|
||||
return config.cachedir / md5(sig.encode()).hexdigest()
|
||||
|
||||
|
||||
@throttle(1, 1, random)
|
||||
def http_get(s: requests.Session, url: str, *args, **kwds) -> requests.Response:
|
||||
|
||||
req = s.prepare_request(requests.Request("GET", url, *args, **kwds))
|
||||
@_throttle(1, 1, random)
|
||||
async def _ahttp_get(s: _ASession_T, url: str, *args, **kwds) -> _Response_T:
|
||||
req = s.build_request(method="GET", url=url, *args, **kwds)
|
||||
|
||||
cachefile = cache_path(req) if config.debug else None
|
||||
|
||||
if cachefile:
|
||||
if cachefile.exists():
|
||||
log.debug(
|
||||
f"💾 loading {req.url} ({req.headers!a}) from cache {cachefile} ..."
|
||||
"💾 loading %s (%a) from cache %s ...", req.url, req.headers, cachefile
|
||||
)
|
||||
with cachefile.open() as fp:
|
||||
resp = CachedResponse(**json.load(fp))
|
||||
resp = _CachedResponse(**json.load(fp))
|
||||
if 300 <= resp.status_code <= 399:
|
||||
raise RedirectError(
|
||||
raise _RedirectError(
|
||||
from_url=resp.url, to_url=resp.headers["location"], is_cached=True
|
||||
)
|
||||
return resp
|
||||
return cast(_Response_T, resp)
|
||||
|
||||
log.debug(f"⚡️ loading {req.url} ({req.headers!a}) ...")
|
||||
resp = s.send(req, allow_redirects=False, stream=True)
|
||||
log.debug("⚡️ loading %s (%a) ...", req.url, req.headers)
|
||||
resp = await s.send(req, follow_redirects=False, stream=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
await resp.aread() # Download the response stream to allow `resp.text` access.
|
||||
|
||||
if cachefile:
|
||||
log.debug(
|
||||
"💾 writing response to cache: %s (%a) -> %s",
|
||||
req.url,
|
||||
req.headers,
|
||||
cachefile,
|
||||
)
|
||||
with cachefile.open("w") as fp:
|
||||
json.dump(
|
||||
{
|
||||
"status_code": resp.status_code,
|
||||
"text": resp.text,
|
||||
"url": resp.url,
|
||||
"url": str(resp.url),
|
||||
"headers": dict(resp.headers),
|
||||
},
|
||||
fp,
|
||||
|
|
@ -187,45 +173,46 @@ def http_get(s: requests.Session, url: str, *args, **kwds) -> requests.Response:
|
|||
|
||||
if resp.is_redirect:
|
||||
# Redirects could mean trouble, we need to stay on top of that!
|
||||
raise RedirectError(from_url=resp.url, to_url=resp.headers["location"])
|
||||
raise _RedirectError(from_url=str(resp.url), to_url=resp.headers["location"])
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
def soup_from_url(url):
|
||||
async def asoup_from_url(url):
|
||||
"""Return a BeautifulSoup instance from the contents for the given URL."""
|
||||
with session() as s:
|
||||
r = http_get(s, url)
|
||||
async with asession() as s:
|
||||
r = await _ahttp_get(s, url)
|
||||
|
||||
soup = bs4.BeautifulSoup(r.text, "html5lib")
|
||||
return soup
|
||||
|
||||
|
||||
def last_modified_from_response(resp):
|
||||
if last_mod := resp.headers.get("Last-Modified"):
|
||||
def _last_modified_from_response(resp: _Response_T) -> float | None:
|
||||
if last_mod := resp.headers.get("last-modified"):
|
||||
try:
|
||||
return email.utils.parsedate_to_datetime(last_mod).timestamp()
|
||||
except:
|
||||
log.exception("🐛 Received invalid value for Last-Modified: %s", last_mod)
|
||||
|
||||
|
||||
def last_modified_from_file(path: Path):
|
||||
def _last_modified_from_file(path: Path) -> float:
|
||||
return path.stat().st_mtime
|
||||
|
||||
|
||||
def download(
|
||||
async def adownload(
|
||||
url: str,
|
||||
file_path: Union[Path, str] = None,
|
||||
*,
|
||||
replace_existing: bool = None,
|
||||
to_path: Path | str | None = None,
|
||||
replace_existing: bool | None = None,
|
||||
only_if_newer: bool = False,
|
||||
timeout: float = None,
|
||||
verify_ssl: bool = True,
|
||||
timeout: float | None = None,
|
||||
chunk_callback=None,
|
||||
response_callback=None,
|
||||
):
|
||||
) -> bytes | None:
|
||||
"""Download a file.
|
||||
|
||||
If `to_path` is `None` return the remote content, otherwise write the
|
||||
content to the given file path.
|
||||
Existing files will not be overwritten unless `replace_existing` is set.
|
||||
Setting `only_if_newer` will check if the remote file is newer than the
|
||||
local file, otherwise the download will be aborted.
|
||||
|
|
@ -234,50 +221,56 @@ def download(
|
|||
replace_existing = only_if_newer
|
||||
|
||||
file_exists = None
|
||||
if file_path is not None:
|
||||
file_path = Path(file_path)
|
||||
if to_path is not None:
|
||||
to_path = Path(to_path)
|
||||
|
||||
file_exists = file_path.exists() and file_path.stat().st_size
|
||||
file_exists = to_path.exists() and to_path.stat().st_size
|
||||
if file_exists and not replace_existing:
|
||||
raise FileExistsError(23, "Would replace existing file", str(file_path))
|
||||
|
||||
with session() as s:
|
||||
raise FileExistsError(23, "Would replace existing file", str(to_path))
|
||||
|
||||
async with asession() as s:
|
||||
headers = {}
|
||||
if file_exists and only_if_newer:
|
||||
assert file_path
|
||||
file_lastmod = last_modified_from_file(file_path)
|
||||
headers["If-Modified-Since"] = email.utils.formatdate(
|
||||
assert to_path
|
||||
file_lastmod = _last_modified_from_file(to_path)
|
||||
headers["if-modified-since"] = email.utils.formatdate(
|
||||
file_lastmod, usegmt=True
|
||||
)
|
||||
|
||||
req = s.prepare_request(requests.Request("GET", url, headers=headers))
|
||||
req = s.build_request(method="GET", url=url, headers=headers, timeout=timeout)
|
||||
|
||||
log.debug("⚡️ loading %s (%s) ...", req.url, req.headers)
|
||||
resp = s.send(
|
||||
req, allow_redirects=True, stream=True, timeout=timeout, verify=verify_ssl
|
||||
)
|
||||
log.debug("⚡️ Loading %s (%a) ...", req.url, dict(req.headers))
|
||||
resp = await s.send(req, follow_redirects=True, stream=True)
|
||||
|
||||
try:
|
||||
if response_callback is not None:
|
||||
try:
|
||||
response_callback(resp)
|
||||
except:
|
||||
log.exception("🐛 Error in response callback.")
|
||||
|
||||
log.debug("☕️ Response status: %s; headers: %s", resp.status_code, resp.headers)
|
||||
log.debug(
|
||||
"☕️ %s -> status: %s; headers: %a",
|
||||
req.url,
|
||||
resp.status_code,
|
||||
dict(resp.headers),
|
||||
)
|
||||
|
||||
if resp.status_code == httpx.codes.NOT_MODIFIED:
|
||||
log.debug(
|
||||
"✋ Remote file has not changed, skipping download: %s -> %a",
|
||||
req.url,
|
||||
to_path,
|
||||
)
|
||||
return
|
||||
|
||||
resp.raise_for_status()
|
||||
|
||||
if resp.status_code == codes.not_modified:
|
||||
log.debug("✋ Remote file has not changed, skipping download.")
|
||||
return
|
||||
|
||||
if file_path is None:
|
||||
if to_path is None:
|
||||
await resp.aread() # Download the response stream to allow `resp.content` access.
|
||||
return resp.content
|
||||
|
||||
assert replace_existing is True
|
||||
|
||||
resp_lastmod = last_modified_from_response(resp)
|
||||
resp_lastmod = _last_modified_from_response(resp)
|
||||
|
||||
# Check Last-Modified in case the server ignored If-Modified-Since.
|
||||
# XXX also check Content-Length?
|
||||
|
|
@ -285,24 +278,23 @@ def download(
|
|||
assert file_lastmod
|
||||
|
||||
if resp_lastmod <= file_lastmod:
|
||||
log.debug("✋ Local file is newer, skipping download.")
|
||||
resp.close()
|
||||
log.debug("✋ Local file is newer, skipping download: %a", req.url)
|
||||
return
|
||||
|
||||
# Create intermediate directories if necessary.
|
||||
download_dir = file_path.parent
|
||||
download_dir = to_path.parent
|
||||
download_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write content to temp file.
|
||||
tempdir = download_dir
|
||||
tempfd, tempfile_path = tempfile.mkstemp(
|
||||
dir=tempdir, prefix=f".download-{file_path.name}."
|
||||
dir=tempdir, prefix=f".download-{to_path.name}."
|
||||
)
|
||||
one_mb = 2 ** 20
|
||||
one_mb = 2**20
|
||||
chunk_size = 8 * one_mb
|
||||
try:
|
||||
log.debug("💾 Writing to temp file %s ...", tempfile_path)
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size, decode_unicode=False):
|
||||
async for chunk in resp.aiter_bytes(chunk_size):
|
||||
os.write(tempfd, chunk)
|
||||
if chunk_callback:
|
||||
try:
|
||||
|
|
@ -313,10 +305,19 @@ def download(
|
|||
os.close(tempfd)
|
||||
|
||||
# Move downloaded file to destination.
|
||||
if file_exists:
|
||||
log.debug("💾 Replacing existing file: %s", file_path)
|
||||
Path(tempfile_path).replace(file_path)
|
||||
if to_path.exists():
|
||||
log.debug("💾 Replacing existing file: %s", to_path)
|
||||
else:
|
||||
log.debug("💾 Move to destination: %s", to_path)
|
||||
if replace_existing:
|
||||
Path(tempfile_path).replace(to_path)
|
||||
else:
|
||||
Path(tempfile_path).rename(to_path)
|
||||
|
||||
# Fix file attributes.
|
||||
if resp_lastmod is not None:
|
||||
os.utime(file_path, (resp_lastmod, resp_lastmod))
|
||||
log.debug("💾 Adjusting file timestamp: %s (%s)", to_path, resp_lastmod)
|
||||
os.utime(to_path, (resp_lastmod, resp_lastmod))
|
||||
|
||||
finally:
|
||||
await resp.aclose()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import re
|
||||
from typing import Union, cast
|
||||
from typing import cast
|
||||
|
||||
import ulid
|
||||
from ulid.hints import Buffer
|
||||
|
|
@ -16,7 +16,7 @@ class ULID(ulid.ULID):
|
|||
|
||||
_pattern = re.compile(r"^[0-9A-HJKMNP-TV-Z]{26}$")
|
||||
|
||||
def __init__(self, buffer: Union[Buffer, ulid.ULID, str, None] = None):
|
||||
def __init__(self, buffer: Buffer | ulid.ULID | str | None = None):
|
||||
if isinstance(buffer, str):
|
||||
if not self._pattern.search(buffer):
|
||||
raise ValueError("Invalid ULID.")
|
||||
|
|
|
|||
|
|
@ -17,7 +17,10 @@ def b64padded(s: str) -> str:
|
|||
|
||||
|
||||
def phc_scrypt(
|
||||
secret: bytes, *, salt: bytes = None, params: dict[Literal["n", "r", "p"], int] = {}
|
||||
secret: bytes,
|
||||
*,
|
||||
salt: bytes | None = None,
|
||||
params: dict[Literal["n", "r", "p"], int] = {},
|
||||
) -> str:
|
||||
"""Return the scrypt expanded secret in PHC string format.
|
||||
|
||||
|
|
@ -30,7 +33,7 @@ def phc_scrypt(
|
|||
if salt is None:
|
||||
salt = secrets.token_bytes(16)
|
||||
|
||||
n = params.get("n", 2 ** 14) # CPU/Memory cost factor
|
||||
n = params.get("n", 2**14) # CPU/Memory cost factor
|
||||
r = params.get("r", 8) # block size
|
||||
p = params.get("p", 1) # parallelization factor
|
||||
# maxmem = 2 * 128 * n * r * p
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import secrets
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Literal, Optional, overload
|
||||
from typing import Literal, overload
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.authentication import (
|
||||
|
|
@ -85,11 +86,14 @@ def truthy(s: str):
|
|||
return bool(s) and s.lower() in {"1", "yes", "true"}
|
||||
|
||||
|
||||
def yearcomp(s: str):
|
||||
_Yearcomp = Literal["<", "=", ">"]
|
||||
|
||||
|
||||
def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
|
||||
if not s:
|
||||
return
|
||||
|
||||
comp: Literal["<", "=", ">"] = "="
|
||||
comp: _Yearcomp = "="
|
||||
if (prefix := s[0]) in "<=>":
|
||||
comp = prefix # type: ignore
|
||||
s = s[len(prefix) :]
|
||||
|
|
@ -97,7 +101,9 @@ def yearcomp(s: str):
|
|||
return comp, int(s)
|
||||
|
||||
|
||||
def as_int(x, *, max: int = None, min: Optional[int] = 1, default: int = None):
|
||||
def as_int(
|
||||
x, *, max: int | None = None, min: int | None = 1, default: int | None = None
|
||||
) -> int:
|
||||
try:
|
||||
if not isinstance(x, int):
|
||||
x = int(x)
|
||||
|
|
@ -135,7 +141,7 @@ async def json_from_body(request, keys: list[str]) -> list:
|
|||
...
|
||||
|
||||
|
||||
async def json_from_body(request, keys: list[str] = None):
|
||||
async def json_from_body(request, keys: list[str] | None = None):
|
||||
if not await request.body():
|
||||
data = {}
|
||||
|
||||
|
|
@ -158,7 +164,7 @@ def is_admin(request):
|
|||
return "admin" in request.auth.scopes
|
||||
|
||||
|
||||
async def auth_user(request) -> Optional[User]:
|
||||
async def auth_user(request) -> User | None:
|
||||
if not isinstance(request.user, AuthedUser):
|
||||
return
|
||||
|
||||
|
|
@ -176,7 +182,7 @@ async def auth_user(request) -> Optional[User]:
|
|||
_routes = []
|
||||
|
||||
|
||||
def route(path: str, *, methods: list[str] = None, **kwds):
|
||||
def route(path: str, *, methods: list[str] | None = None, **kwds):
|
||||
def decorator(func):
|
||||
r = Route(path, func, methods=methods, **kwds)
|
||||
_routes.append(r)
|
||||
|
|
@ -190,7 +196,6 @@ route.registered = _routes
|
|||
|
||||
@route("/groups/{group_id}/ratings")
|
||||
async def get_ratings_for_group(request):
|
||||
|
||||
group_id = as_ulid(request.path_params["group_id"])
|
||||
group = await db.get(Group, id=str(group_id))
|
||||
|
||||
|
|
@ -251,7 +256,6 @@ def not_implemented():
|
|||
@route("/movies")
|
||||
@requires(["authenticated"])
|
||||
async def list_movies(request):
|
||||
|
||||
params = request.query_params
|
||||
|
||||
user = await auth_user(request)
|
||||
|
|
@ -319,7 +323,6 @@ async def list_movies(request):
|
|||
@route("/movies", methods=["POST"])
|
||||
@requires(["authenticated", "admin"])
|
||||
async def add_movie(request):
|
||||
|
||||
not_implemented()
|
||||
|
||||
|
||||
|
|
@ -361,7 +364,6 @@ _import_lock = asyncio.Lock()
|
|||
@route("/movies/_reload_imdb", methods=["POST"])
|
||||
@requires(["authenticated", "admin"])
|
||||
async def load_imdb_movies(request):
|
||||
|
||||
params = request.query_params
|
||||
force = truthy(params.get("force"))
|
||||
|
||||
|
|
@ -384,7 +386,6 @@ async def load_imdb_movies(request):
|
|||
@route("/users")
|
||||
@requires(["authenticated", "admin"])
|
||||
async def list_users(request):
|
||||
|
||||
users = await db.get_all(User)
|
||||
|
||||
return JSONResponse([asplain(u) for u in users])
|
||||
|
|
@ -393,7 +394,6 @@ async def list_users(request):
|
|||
@route("/users", methods=["POST"])
|
||||
@requires(["authenticated", "admin"])
|
||||
async def add_user(request):
|
||||
|
||||
name, imdb_id = await json_from_body(request, ["name", "imdb_id"])
|
||||
|
||||
# XXX restrict name
|
||||
|
|
@ -415,7 +415,6 @@ async def add_user(request):
|
|||
@route("/users/{user_id}")
|
||||
@requires(["authenticated"])
|
||||
async def show_user(request):
|
||||
|
||||
user_id = as_ulid(request.path_params["user_id"])
|
||||
|
||||
if is_admin(request):
|
||||
|
|
@ -444,7 +443,6 @@ async def show_user(request):
|
|||
@route("/users/{user_id}", methods=["DELETE"])
|
||||
@requires(["authenticated", "admin"])
|
||||
async def remove_user(request):
|
||||
|
||||
user_id = as_ulid(request.path_params["user_id"])
|
||||
|
||||
user = await db.get(User, id=str(user_id))
|
||||
|
|
@ -462,7 +460,6 @@ async def remove_user(request):
|
|||
@route("/users/{user_id}", methods=["PATCH"])
|
||||
@requires(["authenticated"])
|
||||
async def modify_user(request):
|
||||
|
||||
user_id = as_ulid(request.path_params["user_id"])
|
||||
|
||||
if is_admin(request):
|
||||
|
|
@ -510,7 +507,6 @@ async def modify_user(request):
|
|||
@route("/users/{user_id}/groups", methods=["POST"])
|
||||
@requires(["authenticated", "admin"])
|
||||
async def add_group_to_user(request):
|
||||
|
||||
user_id = as_ulid(request.path_params["user_id"])
|
||||
|
||||
user = await db.get(User, id=str(user_id))
|
||||
|
|
@ -535,21 +531,18 @@ async def add_group_to_user(request):
|
|||
@route("/users/{user_id}/ratings")
|
||||
@requires(["private"])
|
||||
async def ratings_for_user(request):
|
||||
|
||||
not_implemented()
|
||||
|
||||
|
||||
@route("/users/{user_id}/ratings", methods=["PUT"])
|
||||
@requires("authenticated")
|
||||
async def set_rating_for_user(request):
|
||||
|
||||
not_implemented()
|
||||
|
||||
|
||||
@route("/users/_reload_ratings", methods=["POST"])
|
||||
@requires(["authenticated", "admin"])
|
||||
async def load_imdb_user_ratings(request):
|
||||
|
||||
ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()]
|
||||
|
||||
return JSONResponse({"new_ratings": [asplain(r) for r in ratings]})
|
||||
|
|
@ -558,7 +551,6 @@ async def load_imdb_user_ratings(request):
|
|||
@route("/groups")
|
||||
@requires(["authenticated", "admin"])
|
||||
async def list_groups(request):
|
||||
|
||||
groups = await db.get_all(Group)
|
||||
|
||||
return JSONResponse([asplain(g) for g in groups])
|
||||
|
|
@ -567,7 +559,6 @@ async def list_groups(request):
|
|||
@route("/groups", methods=["POST"])
|
||||
@requires(["authenticated", "admin"])
|
||||
async def add_group(request):
|
||||
|
||||
(name,) = await json_from_body(request, ["name"])
|
||||
|
||||
# XXX restrict name
|
||||
|
|
@ -581,7 +572,6 @@ async def add_group(request):
|
|||
@route("/groups/{group_id}/users", methods=["POST"])
|
||||
@requires(["authenticated"])
|
||||
async def add_user_to_group(request):
|
||||
|
||||
group_id = as_ulid(request.path_params["group_id"])
|
||||
group = await db.get(Group, id=str(group_id))
|
||||
|
||||
|
|
@ -623,6 +613,13 @@ def auth_error(request, err):
|
|||
return unauthorized(str(err))
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: Starlette):
|
||||
await open_connection_pool()
|
||||
yield
|
||||
await close_connection_pool()
|
||||
|
||||
|
||||
def create_app():
|
||||
if config.loglevel == "DEBUG":
|
||||
logging.basicConfig(
|
||||
|
|
@ -633,8 +630,7 @@ def create_app():
|
|||
log.debug(f"Log level: {config.loglevel}")
|
||||
|
||||
return Starlette(
|
||||
on_startup=[open_connection_pool],
|
||||
on_shutdown=[close_connection_pool],
|
||||
lifespan=lifespan,
|
||||
routes=[
|
||||
Mount(f"{config.api_base}v1", routes=route.registered),
|
||||
],
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Container, Iterable, Optional
|
||||
from typing import Container, Iterable
|
||||
|
||||
from . import imdb, models
|
||||
|
||||
|
|
@ -10,17 +10,17 @@ Score100 = int # [0, 100]
|
|||
@dataclass
|
||||
class Rating:
|
||||
canonical_title: str
|
||||
imdb_score: Optional[Score100]
|
||||
imdb_votes: Optional[int]
|
||||
imdb_score: Score100 | None
|
||||
imdb_votes: int | None
|
||||
media_type: str
|
||||
movie_imdb_id: str
|
||||
original_title: Optional[str]
|
||||
original_title: str | None
|
||||
release_year: int
|
||||
user_id: Optional[str]
|
||||
user_score: Optional[Score100]
|
||||
user_id: str | None
|
||||
user_score: Score100 | None
|
||||
|
||||
@classmethod
|
||||
def from_movie(cls, movie: models.Movie, *, rating: models.Rating = None):
|
||||
def from_movie(cls, movie: models.Movie, *, rating: models.Rating | None = None):
|
||||
return cls(
|
||||
canonical_title=movie.title,
|
||||
imdb_score=movie.imdb_score,
|
||||
|
|
@ -37,11 +37,11 @@ class Rating:
|
|||
@dataclass
|
||||
class RatingAggregate:
|
||||
canonical_title: str
|
||||
imdb_score: Optional[Score100]
|
||||
imdb_votes: Optional[int]
|
||||
imdb_score: Score100 | None
|
||||
imdb_votes: int | None
|
||||
link: URL
|
||||
media_type: str
|
||||
original_title: Optional[str]
|
||||
original_title: str | None
|
||||
user_scores: list[Score100]
|
||||
year: int
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue