Merge branch 'feat/py311'
This commit is contained in:
commit
a020d972f8
23 changed files with 783 additions and 748 deletions
4
.git-blame-ignore-revs
Normal file
4
.git-blame-ignore-revs
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
# Apply Black v23.1.0 formatting changes.
|
||||||
|
8a8bfce89de23d987386a35b659532bbac373788
|
||||||
|
# Apply auto-formatting to tests.
|
||||||
|
9ffcc5357150cecde26f5e6f8fccceaf92411efb
|
||||||
16
Dockerfile
16
Dockerfile
|
|
@ -1,4 +1,4 @@
|
||||||
FROM docker.io/library/python:3.10-alpine
|
FROM docker.io/library/python:3.11-alpine
|
||||||
|
|
||||||
RUN apk update --no-cache \
|
RUN apk update --no-cache \
|
||||||
&& apk upgrade --no-cache \
|
&& apk upgrade --no-cache \
|
||||||
|
|
@ -11,20 +11,18 @@ WORKDIR /var/app
|
||||||
|
|
||||||
COPY requirements.txt ./
|
COPY requirements.txt ./
|
||||||
|
|
||||||
# Required to build greenlet on Alpine, dependency of SQLAlchemy 1.4.
|
RUN pip install --no-cache-dir --upgrade \
|
||||||
RUN apk add --no-cache \
|
--requirement requirements.txt
|
||||||
--virtual .build-deps \
|
|
||||||
g++ gcc musl-dev \
|
|
||||||
&& pip install --no-cache-dir --upgrade \
|
|
||||||
--requirement requirements.txt \
|
|
||||||
&& apk del .build-deps
|
|
||||||
|
|
||||||
USER 10000:10001
|
USER 10000:10001
|
||||||
|
|
||||||
COPY . ./
|
COPY . ./
|
||||||
|
|
||||||
ENV UNWIND_DATA="/data"
|
ENV UNWIND_DATA="/data"
|
||||||
VOLUME ["/data"]
|
VOLUME $UNWIND_DATA
|
||||||
|
|
||||||
|
ENV UNWIND_PORT=8097
|
||||||
|
EXPOSE $UNWIND_PORT
|
||||||
|
|
||||||
ENTRYPOINT ["/var/app/run"]
|
ENTRYPOINT ["/var/app/run"]
|
||||||
CMD ["server"]
|
CMD ["server"]
|
||||||
|
|
|
||||||
817
poetry.lock
generated
817
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -6,26 +6,20 @@ authors = ["ducklet <ducklet@noreply.code.dumpr.org>"]
|
||||||
license = "LOL"
|
license = "LOL"
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.10"
|
python = "^3.11"
|
||||||
requests = "^2.25.1"
|
|
||||||
beautifulsoup4 = "^4.9.3"
|
beautifulsoup4 = "^4.9.3"
|
||||||
html5lib = "^1.1"
|
html5lib = "^1.1"
|
||||||
starlette = "^0.17.0"
|
starlette = "^0.26"
|
||||||
ulid-py = "^1.1.0"
|
ulid-py = "^1.1.0"
|
||||||
databases = {extras = ["sqlite"], version = "^0.6.1"}
|
databases = {extras = ["sqlite"], version = "^0.7.0"}
|
||||||
toml = "^0.10.2"
|
uvicorn = "^0.21"
|
||||||
uvicorn = "^0.19.0"
|
httpx = "^0.23.3"
|
||||||
|
|
||||||
[tool.poetry.group.fixes.dependencies]
|
|
||||||
# `databases` is having issues with new versions of SQLAlchemy 1.4,
|
|
||||||
# and `greenlet` is also always a pain.
|
|
||||||
SQLAlchemy = "1.4.25"
|
|
||||||
greenlet = "1.1.2"
|
|
||||||
|
|
||||||
[tool.poetry.group.dev]
|
[tool.poetry.group.dev]
|
||||||
optional = true
|
optional = true
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
autoflake = "*"
|
||||||
pytest = "*"
|
pytest = "*"
|
||||||
pyright = "*"
|
pyright = "*"
|
||||||
black = "*"
|
black = "*"
|
||||||
|
|
@ -37,4 +31,14 @@ requires = ["poetry-core>=1.0.0"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.pyright]
|
[tool.pyright]
|
||||||
pythonVersion = "3.10"
|
pythonVersion = "3.11"
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
profile = "black"
|
||||||
|
|
||||||
|
[tool.autoflake]
|
||||||
|
remove-duplicate-keys = true
|
||||||
|
remove-unused-variables = true
|
||||||
|
remove-all-unused-imports = true
|
||||||
|
ignore-init-module-imports = true
|
||||||
|
ignore-pass-after-docstring = true
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ cd "$RUN_DIR"
|
||||||
|
|
||||||
[ -z "${DEBUG:-}" ] || set -x
|
[ -z "${DEBUG:-}" ] || set -x
|
||||||
|
|
||||||
isort --profile black unwind
|
autoflake --quiet --check --recursive unwind tests
|
||||||
black unwind
|
isort unwind tests
|
||||||
|
black unwind tests
|
||||||
pyright
|
pyright
|
||||||
|
|
|
||||||
13
scripts/profile
Executable file
13
scripts/profile
Executable file
|
|
@ -0,0 +1,13 @@
|
||||||
|
#!/bin/sh -eu
|
||||||
|
|
||||||
|
cd "$RUN_DIR"
|
||||||
|
|
||||||
|
outfile="profile-$(date '+%Y%m%d-%H%M%S').txt"
|
||||||
|
|
||||||
|
[ -z "${DEBUG:-}" ] || set -x
|
||||||
|
|
||||||
|
echo "# Writing profiler stats to: $outfile"
|
||||||
|
python -m cProfile -o "$outfile" -m unwind "$@"
|
||||||
|
|
||||||
|
echo "# Loading stats file: $outfile"
|
||||||
|
python -m pstats "$outfile"
|
||||||
|
|
@ -1,7 +1,14 @@
|
||||||
#!/bin/sh -eu
|
#!/bin/sh -eu
|
||||||
|
|
||||||
|
: "${UNWIND_PORT:=8097}"
|
||||||
|
|
||||||
cd "$RUN_DIR"
|
cd "$RUN_DIR"
|
||||||
|
|
||||||
[ -z "${DEBUG:-}" ] || set -x
|
[ -z "${DEBUG:-}" ] || set -x
|
||||||
|
|
||||||
exec uvicorn --host 0.0.0.0 --factory unwind:create_app
|
export UNWIND_PORT
|
||||||
|
|
||||||
|
exec uvicorn \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port "$UNWIND_PORT" \
|
||||||
|
--factory unwind:create_app
|
||||||
|
|
|
||||||
|
|
@ -10,5 +10,6 @@ trap 'rm "$dbfile"' EXIT TERM INT QUIT
|
||||||
|
|
||||||
[ -z "${DEBUG:-}" ] || set -x
|
[ -z "${DEBUG:-}" ] || set -x
|
||||||
|
|
||||||
|
SQLALCHEMY_WARN_20=1 \
|
||||||
UNWIND_STORAGE="$dbfile" \
|
UNWIND_STORAGE="$dbfile" \
|
||||||
python -m pytest "$@"
|
python -m pytest "$@"
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
from unwind import db
|
from unwind import db
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -13,7 +15,7 @@ def event_loop():
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def shared_conn():
|
async def shared_conn():
|
||||||
c = db.shared_connection()
|
c = db.shared_connection()
|
||||||
await c.connect()
|
await c.connect()
|
||||||
|
|
@ -24,7 +26,7 @@ async def shared_conn():
|
||||||
await c.disconnect()
|
await c.disconnect()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture
|
||||||
async def conn(shared_conn):
|
async def conn(shared_conn):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
async with shared_conn.transaction(force_rollback=True):
|
||||||
yield shared_conn
|
yield shared_conn
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,13 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from unwind import db, models, web_models
|
from unwind import db, models, web_models
|
||||||
|
|
||||||
pytestmark = pytest.mark.asyncio
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
async def test_add_and_get(shared_conn):
|
async def test_add_and_get(shared_conn: db.Database):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
async with shared_conn.transaction(force_rollback=True):
|
||||||
|
|
||||||
m1 = models.Movie(
|
m1 = models.Movie(
|
||||||
title="test movie",
|
title="test movie",
|
||||||
release_year=2013,
|
release_year=2013,
|
||||||
|
|
@ -31,9 +30,9 @@ async def test_add_and_get(shared_conn):
|
||||||
assert m2 == await db.get(models.Movie, id=str(m2.id))
|
assert m2 == await db.get(models.Movie, id=str(m2.id))
|
||||||
|
|
||||||
|
|
||||||
async def test_find_ratings(shared_conn):
|
@pytest.mark.asyncio
|
||||||
|
async def test_find_ratings(shared_conn: db.Database):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
async with shared_conn.transaction(force_rollback=True):
|
||||||
|
|
||||||
m1 = models.Movie(
|
m1 = models.Movie(
|
||||||
title="test movie",
|
title="test movie",
|
||||||
release_year=2013,
|
release_year=2013,
|
||||||
|
|
@ -157,4 +156,3 @@ async def test_find_ratings(shared_conn):
|
||||||
rows = await db.find_ratings(title="test", include_unrated=True)
|
rows = await db.find_ratings(title="test", include_unrated=True)
|
||||||
ratings = tuple(web_models.Rating(**r) for r in rows)
|
ratings = tuple(web_models.Rating(**r) for r in rows)
|
||||||
assert (web_models.Rating.from_movie(m1),) == ratings
|
assert (web_models.Rating.from_movie(m1),) == ratings
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,15 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from unwind.imdb import imdb_rating_from_score, score_from_imdb_rating
|
from unwind.imdb import imdb_rating_from_score, score_from_imdb_rating
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("rating", (x / 10 for x in range(10, 101)))
|
@pytest.mark.parametrize("rating", (x / 10 for x in range(10, 101)))
|
||||||
def test_rating_conversion(rating):
|
def test_rating_conversion(rating: float):
|
||||||
assert rating == imdb_rating_from_score(score_from_imdb_rating(rating))
|
assert rating == imdb_rating_from_score(score_from_imdb_rating(rating))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("score", range(0, 101))
|
@pytest.mark.parametrize("score", range(0, 101))
|
||||||
def test_score_conversion(score):
|
def test_score_conversion(score: int):
|
||||||
# Because our score covers 101 discrete values and IMDb's rating only 91
|
# Because our score covers 101 discrete values and IMDb's rating only 91
|
||||||
# discrete values, the mapping is non-injective, i.e. 10 values can't be
|
# discrete values, the mapping is non-injective, i.e. 10 values can't be
|
||||||
# mapped uniquely.
|
# mapped uniquely.
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,14 @@
|
||||||
from starlette.testclient import TestClient
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
from unwind import create_app
|
from unwind import create_app, db, imdb, models
|
||||||
from unwind import db, models, imdb
|
|
||||||
|
|
||||||
# https://pypi.org/project/pytest-asyncio/
|
|
||||||
pytestmark = pytest.mark.asyncio
|
|
||||||
|
|
||||||
app = create_app()
|
app = create_app()
|
||||||
|
|
||||||
|
|
||||||
async def test_app(shared_conn):
|
@pytest.mark.asyncio
|
||||||
|
async def test_app(shared_conn: db.Database):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
async with shared_conn.transaction(force_rollback=True):
|
||||||
|
|
||||||
# https://www.starlette.io/testclient/
|
# https://www.starlette.io/testclient/
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
response = client.get("/api/v1/movies")
|
response = client.get("/api/v1/movies")
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
||||||
from . import config
|
from . import config
|
||||||
from .db import close_connection_pool, open_connection_pool
|
from .db import close_connection_pool, open_connection_pool
|
||||||
from .imdb import refresh_user_ratings_from_imdb
|
from .imdb import refresh_user_ratings_from_imdb
|
||||||
from .imdb_import import import_from_file
|
from .imdb_import import download_datasets, import_from_file
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -15,7 +15,7 @@ async def run_load_user_ratings_from_imdb():
|
||||||
await open_connection_pool()
|
await open_connection_pool()
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
async for rating in refresh_user_ratings_from_imdb():
|
async for _ in refresh_user_ratings_from_imdb():
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
log.info("✨ Imported %s new ratings.", i)
|
log.info("✨ Imported %s new ratings.", i)
|
||||||
|
|
@ -31,6 +31,10 @@ async def run_import_imdb_dataset(basics_path: Path, ratings_path: Path):
|
||||||
await close_connection_pool()
|
await close_connection_pool()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_download_imdb_dataset(basics_path: Path, ratings_path: Path):
|
||||||
|
await download_datasets(basics_path=basics_path, ratings_path=ratings_path)
|
||||||
|
|
||||||
|
|
||||||
def getargs():
|
def getargs():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
commands = parser.add_subparsers(required=True)
|
commands = parser.add_subparsers(required=True)
|
||||||
|
|
@ -55,6 +59,25 @@ def getargs():
|
||||||
"--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True
|
"--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser_download_imdb_dataset = commands.add_parser(
|
||||||
|
"download-imdb-dataset",
|
||||||
|
help="Download IMDb datasets.",
|
||||||
|
description="""
|
||||||
|
Download IMDb datasets.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
parser_download_imdb_dataset.add_argument(
|
||||||
|
dest="mode",
|
||||||
|
action="store_const",
|
||||||
|
const="download-imdb-dataset",
|
||||||
|
)
|
||||||
|
parser_download_imdb_dataset.add_argument(
|
||||||
|
"--basics", metavar="basics_file.tsv.gz", type=Path, required=True
|
||||||
|
)
|
||||||
|
parser_download_imdb_dataset.add_argument(
|
||||||
|
"--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True
|
||||||
|
)
|
||||||
|
|
||||||
parser_load_user_ratings_from_imdb = commands.add_parser(
|
parser_load_user_ratings_from_imdb = commands.add_parser(
|
||||||
"load-user-ratings-from-imdb",
|
"load-user-ratings-from-imdb",
|
||||||
help="Load user ratings from imdb.com.",
|
help="Load user ratings from imdb.com.",
|
||||||
|
|
@ -94,6 +117,8 @@ def main():
|
||||||
asyncio.run(run_load_user_ratings_from_imdb())
|
asyncio.run(run_load_user_ratings_from_imdb())
|
||||||
elif args.mode == "import-imdb-dataset":
|
elif args.mode == "import-imdb-dataset":
|
||||||
asyncio.run(run_import_imdb_dataset(args.basics, args.ratings))
|
asyncio.run(run_import_imdb_dataset(args.basics, args.ratings))
|
||||||
|
elif args.mode == "download-imdb-dataset":
|
||||||
|
asyncio.run(run_download_imdb_dataset(args.basics, args.ratings))
|
||||||
|
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
import os
|
import os
|
||||||
|
import tomllib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import toml
|
|
||||||
|
|
||||||
datadir = Path(os.getenv("UNWIND_DATA") or "./data")
|
datadir = Path(os.getenv("UNWIND_DATA") or "./data")
|
||||||
cachedir = (
|
cachedir = (
|
||||||
Path(cachedir)
|
Path(cachedir)
|
||||||
|
|
@ -14,7 +13,8 @@ loglevel = os.getenv("UNWIND_LOGLEVEL") or ("DEBUG" if debug else "INFO")
|
||||||
storage_path = os.getenv("UNWIND_STORAGE", datadir / "db.sqlite")
|
storage_path = os.getenv("UNWIND_STORAGE", datadir / "db.sqlite")
|
||||||
config_path = os.getenv("UNWIND_CONFIG", datadir / "config.toml")
|
config_path = os.getenv("UNWIND_CONFIG", datadir / "config.toml")
|
||||||
|
|
||||||
_config = toml.load(config_path)
|
with open(config_path, "rb") as fd:
|
||||||
|
_config = tomllib.load(fd)
|
||||||
|
|
||||||
api_base = _config["api"].get("base", "/api/")
|
api_base = _config["api"].get("base", "/api/")
|
||||||
api_cors = _config["api"].get("cors", "*")
|
api_cors = _config["api"].get("cors", "*")
|
||||||
|
|
|
||||||
63
unwind/db.py
63
unwind/db.py
|
|
@ -4,7 +4,7 @@ import logging
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Iterable, Literal, Optional, Type, TypeVar, Union
|
from typing import Any, Iterable, Literal, Type, TypeVar
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from databases import Database
|
from databases import Database
|
||||||
|
|
@ -26,7 +26,7 @@ from .types import ULID
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
_shared_connection: Optional[Database] = None
|
_shared_connection: Database | None = None
|
||||||
|
|
||||||
|
|
||||||
async def open_connection_pool() -> None:
|
async def open_connection_pool() -> None:
|
||||||
|
|
@ -119,7 +119,6 @@ async def apply_db_patches(db: Database):
|
||||||
raise RuntimeError("No statement found.")
|
raise RuntimeError("No statement found.")
|
||||||
|
|
||||||
async with db.transaction():
|
async with db.transaction():
|
||||||
|
|
||||||
for query in queries:
|
for query in queries:
|
||||||
await db.execute(query)
|
await db.execute(query)
|
||||||
|
|
||||||
|
|
@ -131,12 +130,12 @@ async def apply_db_patches(db: Database):
|
||||||
await db.execute("vacuum")
|
await db.execute("vacuum")
|
||||||
|
|
||||||
|
|
||||||
async def get_import_progress() -> Optional[Progress]:
|
async def get_import_progress() -> Progress | None:
|
||||||
"""Return the latest import progress."""
|
"""Return the latest import progress."""
|
||||||
return await get(Progress, type="import-imdb-movies", order_by="started DESC")
|
return await get(Progress, type="import-imdb-movies", order_by="started DESC")
|
||||||
|
|
||||||
|
|
||||||
async def stop_import_progress(*, error: BaseException = None):
|
async def stop_import_progress(*, error: BaseException | None = None):
|
||||||
"""Stop the current import.
|
"""Stop the current import.
|
||||||
|
|
||||||
If an error is given, it will be logged to the progress state.
|
If an error is given, it will be logged to the progress state.
|
||||||
|
|
@ -176,6 +175,8 @@ async def set_import_progress(progress: float) -> Progress:
|
||||||
else:
|
else:
|
||||||
await add(current)
|
await add(current)
|
||||||
|
|
||||||
|
return current
|
||||||
|
|
||||||
|
|
||||||
_lock = threading.Lock()
|
_lock = threading.Lock()
|
||||||
_prelock = threading.Lock()
|
_prelock = threading.Lock()
|
||||||
|
|
@ -243,8 +244,8 @@ ModelType = TypeVar("ModelType")
|
||||||
|
|
||||||
|
|
||||||
async def get(
|
async def get(
|
||||||
model: Type[ModelType], *, order_by: str = None, **kwds
|
model: Type[ModelType], *, order_by: str | None = None, **kwds
|
||||||
) -> Optional[ModelType]:
|
) -> ModelType | None:
|
||||||
"""Load a model instance from the database.
|
"""Load a model instance from the database.
|
||||||
|
|
||||||
Passing `kwds` allows to filter the instance to load. You have to encode the
|
Passing `kwds` allows to filter the instance to load. You have to encode the
|
||||||
|
|
@ -262,7 +263,7 @@ async def get(
|
||||||
query += f" ORDER BY {order_by}"
|
query += f" ORDER BY {order_by}"
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
row = await conn.fetch_one(query=query, values=values)
|
row = await conn.fetch_one(query=query, values=values)
|
||||||
return fromplain(model, row, serialized=True) if row else None
|
return fromplain(model, row._mapping, serialized=True) if row else None
|
||||||
|
|
||||||
|
|
||||||
async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
||||||
|
|
@ -282,7 +283,7 @@ async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
||||||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(query=query, values=values)
|
rows = await conn.fetch_all(query=query, values=values)
|
||||||
return (fromplain(model, row, serialized=True) for row in rows)
|
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
||||||
async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
||||||
|
|
@ -293,7 +294,7 @@ async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
||||||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(query=query, values=values)
|
rows = await conn.fetch_all(query=query, values=values)
|
||||||
return (fromplain(model, row, serialized=True) for row in rows)
|
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
||||||
async def update(item):
|
async def update(item):
|
||||||
|
|
@ -406,16 +407,16 @@ def sql_escape(s: str, char="#"):
|
||||||
|
|
||||||
async def find_ratings(
|
async def find_ratings(
|
||||||
*,
|
*,
|
||||||
title: str = None,
|
title: str | None = None,
|
||||||
media_type: str = None,
|
media_type: str | None = None,
|
||||||
exact: bool = False,
|
exact: bool = False,
|
||||||
ignore_tv_episodes: bool = False,
|
ignore_tv_episodes: bool = False,
|
||||||
include_unrated: bool = False,
|
include_unrated: bool = False,
|
||||||
yearcomp: tuple[Literal["<", "=", ">"], int] = None,
|
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
|
||||||
limit_rows: int = 10,
|
limit_rows: int = 10,
|
||||||
user_ids: Iterable[str] = [],
|
user_ids: Iterable[str] = [],
|
||||||
):
|
):
|
||||||
values: dict[str, Union[int, str]] = {
|
values: dict[str, int | str] = {
|
||||||
"limit_rows": limit_rows,
|
"limit_rows": limit_rows,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -466,7 +467,7 @@ async def find_ratings(
|
||||||
"""
|
"""
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(bindparams(query, values))
|
rows = await conn.fetch_all(bindparams(query, values))
|
||||||
movie_ids = tuple(r["movie_id"] for r in rows)
|
movie_ids = tuple(r._mapping["movie_id"] for r in rows)
|
||||||
|
|
||||||
if include_unrated and len(movie_ids) < limit_rows:
|
if include_unrated and len(movie_ids) < limit_rows:
|
||||||
sqlin, sqlin_vals = sql_in("id", movie_ids, not_=True)
|
sqlin, sqlin_vals = sql_in("id", movie_ids, not_=True)
|
||||||
|
|
@ -485,7 +486,7 @@ async def find_ratings(
|
||||||
{**values, **sqlin_vals, "limit_rows": limit_rows - len(movie_ids)},
|
{**values, **sqlin_vals, "limit_rows": limit_rows - len(movie_ids)},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
movie_ids += tuple(r["movie_id"] for r in rows)
|
movie_ids += tuple(r._mapping["movie_id"] for r in rows)
|
||||||
|
|
||||||
return await ratings_for_movie_ids(ids=movie_ids)
|
return await ratings_for_movie_ids(ids=movie_ids)
|
||||||
|
|
||||||
|
|
@ -527,29 +528,13 @@ async def ratings_for_movie_ids(
|
||||||
|
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(bindparams(query, vals))
|
rows = await conn.fetch_all(bindparams(query, vals))
|
||||||
return tuple(dict(r) for r in rows)
|
return tuple(dict(r._mapping) for r in rows)
|
||||||
|
|
||||||
|
|
||||||
def sql_fields(tp: Type):
|
def sql_fields(tp: Type):
|
||||||
return (f"{tp._table}.{f.name}" for f in fields(tp))
|
return (f"{tp._table}.{f.name}" for f in fields(tp))
|
||||||
|
|
||||||
|
|
||||||
def sql_fieldmap(tp: Type):
|
|
||||||
"""-> {alias: (table, field_name)}"""
|
|
||||||
return {f"{tp._table}_{f.name}": (tp._table, f.name) for f in fields(tp)}
|
|
||||||
|
|
||||||
|
|
||||||
def mux(*tps: Type):
|
|
||||||
return ", ".join(
|
|
||||||
f"{t}.{n} AS {k}" for tp in tps for k, (t, n) in sql_fieldmap(tp).items()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def demux(tp: Type[ModelType], row) -> ModelType:
|
|
||||||
d = {n: row[k] for k, (_, n) in sql_fieldmap(tp).items()}
|
|
||||||
return fromplain(tp, d, serialized=True)
|
|
||||||
|
|
||||||
|
|
||||||
def sql_in(column: str, values: Iterable[T], not_=False) -> tuple[str, dict[str, T]]:
|
def sql_in(column: str, values: Iterable[T], not_=False) -> tuple[str, dict[str, T]]:
|
||||||
c = column.replace(".", "___")
|
c = column.replace(".", "___")
|
||||||
value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)}
|
value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)}
|
||||||
|
|
@ -583,22 +568,22 @@ async def ratings_for_movies(
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(query, values)
|
rows = await conn.fetch_all(query, values)
|
||||||
|
|
||||||
return (fromplain(Rating, row, serialized=True) for row in rows)
|
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
||||||
async def find_movies(
|
async def find_movies(
|
||||||
*,
|
*,
|
||||||
title: str = None,
|
title: str | None = None,
|
||||||
media_type: str = None,
|
media_type: str | None = None,
|
||||||
exact: bool = False,
|
exact: bool = False,
|
||||||
ignore_tv_episodes: bool = False,
|
ignore_tv_episodes: bool = False,
|
||||||
yearcomp: tuple[Literal["<", "=", ">"], int] = None,
|
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
|
||||||
limit_rows: int = 10,
|
limit_rows: int = 10,
|
||||||
skip_rows: int = 0,
|
skip_rows: int = 0,
|
||||||
include_unrated: bool = False,
|
include_unrated: bool = False,
|
||||||
user_ids: list[ULID] = [],
|
user_ids: list[ULID] = [],
|
||||||
) -> Iterable[tuple[Movie, list[Rating]]]:
|
) -> Iterable[tuple[Movie, list[Rating]]]:
|
||||||
values: dict[str, Union[int, str]] = {
|
values: dict[str, int | str] = {
|
||||||
"limit_rows": limit_rows,
|
"limit_rows": limit_rows,
|
||||||
"skip_rows": skip_rows,
|
"skip_rows": skip_rows,
|
||||||
}
|
}
|
||||||
|
|
@ -650,7 +635,7 @@ async def find_movies(
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(bindparams(query, values))
|
rows = await conn.fetch_all(bindparams(query, values))
|
||||||
|
|
||||||
movies = [fromplain(Movie, row, serialized=True) for row in rows]
|
movies = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
|
||||||
|
|
||||||
if not user_ids:
|
if not user_ids:
|
||||||
return ((m, []) for m in movies)
|
return ((m, []) for m in movies)
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,11 @@ import logging
|
||||||
import re
|
import re
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Tuple
|
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
from . import db
|
from . import db
|
||||||
from .models import Movie, Rating, User
|
from .models import Movie, Rating, User
|
||||||
from .request import cache_path, session, soup_from_url
|
from .request import asession, asoup_from_url, cache_path
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -35,13 +34,11 @@ log = logging.getLogger(__name__)
|
||||||
# p.text-muted.text-small span[name=nv] [data-value]
|
# p.text-muted.text-small span[name=nv] [data-value]
|
||||||
|
|
||||||
|
|
||||||
async def refresh_user_ratings_from_imdb(stop_on_dupe=True):
|
async def refresh_user_ratings_from_imdb(stop_on_dupe: bool = True):
|
||||||
|
async with asession() as s:
|
||||||
with session() as s:
|
|
||||||
s.headers["Accept-Language"] = "en-US, en;q=0.5"
|
s.headers["Accept-Language"] = "en-US, en;q=0.5"
|
||||||
|
|
||||||
for user in await db.get_all(User):
|
for user in await db.get_all(User):
|
||||||
|
|
||||||
log.info("⚡️ Loading data for %s ...", user.name)
|
log.info("⚡️ Loading data for %s ...", user.name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -98,7 +95,6 @@ find_movie_id = re.compile(r"/title/(?P<id>tt\d+)/").search
|
||||||
|
|
||||||
|
|
||||||
def movie_and_rating_from_item(item) -> tuple[Movie, Rating]:
|
def movie_and_rating_from_item(item) -> tuple[Movie, Rating]:
|
||||||
|
|
||||||
genres = (genre := item.find("span", "genre")) and genre.string or ""
|
genres = (genre := item.find("span", "genre")) and genre.string or ""
|
||||||
movie = Movie(
|
movie = Movie(
|
||||||
title=item.h3.a.string.strip(),
|
title=item.h3.a.string.strip(),
|
||||||
|
|
@ -153,10 +149,10 @@ def movie_and_rating_from_item(item) -> tuple[Movie, Rating]:
|
||||||
ForgedRequest = namedtuple("ForgedRequest", "url headers")
|
ForgedRequest = namedtuple("ForgedRequest", "url headers")
|
||||||
|
|
||||||
|
|
||||||
async def parse_page(url) -> Tuple[list[Rating], Optional[str]]:
|
async def parse_page(url: str) -> tuple[list[Rating], str | None]:
|
||||||
ratings = []
|
ratings = []
|
||||||
|
|
||||||
soup = soup_from_url(url)
|
soup = await asoup_from_url(url)
|
||||||
|
|
||||||
meta = soup.find("meta", property="pageId")
|
meta = soup.find("meta", property="pageId")
|
||||||
headline = soup.h1
|
headline = soup.h1
|
||||||
|
|
@ -170,7 +166,6 @@ async def parse_page(url) -> Tuple[list[Rating], Optional[str]]:
|
||||||
|
|
||||||
items = soup.find_all("div", "lister-item-content")
|
items = soup.find_all("div", "lister-item-content")
|
||||||
for i, item in enumerate(items):
|
for i, item in enumerate(items):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
movie, rating = movie_and_rating_from_item(item)
|
movie, rating = movie_and_rating_from_item(item)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
|
|
@ -196,11 +191,10 @@ async def parse_page(url) -> Tuple[list[Rating], Optional[str]]:
|
||||||
return (ratings, next_url if url != next_url else None)
|
return (ratings, next_url if url != next_url else None)
|
||||||
|
|
||||||
|
|
||||||
async def load_ratings(user_id):
|
async def load_ratings(user_id: str):
|
||||||
next_url = user_ratings_url(user_id)
|
next_url = user_ratings_url(user_id)
|
||||||
|
|
||||||
while next_url:
|
while next_url:
|
||||||
|
|
||||||
ratings, next_url = await parse_page(next_url)
|
ratings, next_url = await parse_page(next_url)
|
||||||
|
|
||||||
for i, rating in enumerate(ratings):
|
for i, rating in enumerate(ratings):
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
|
import asyncio
|
||||||
import csv
|
import csv
|
||||||
import gzip
|
import gzip
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator, Literal, Optional, Type, TypeVar, overload
|
from typing import Generator, Literal, Type, TypeVar, overload
|
||||||
|
|
||||||
from . import config, db, request
|
from . import config, db, request
|
||||||
from .db import add_or_update_many_movies
|
from .db import add_or_update_many_movies
|
||||||
|
|
@ -27,10 +28,10 @@ class BasicRow:
|
||||||
primaryTitle: str
|
primaryTitle: str
|
||||||
originalTitle: str
|
originalTitle: str
|
||||||
isAdult: bool
|
isAdult: bool
|
||||||
startYear: Optional[int]
|
startYear: int | None
|
||||||
endYear: Optional[int]
|
endYear: int | None
|
||||||
runtimeMinutes: Optional[int]
|
runtimeMinutes: int | None
|
||||||
genres: Optional[set[str]]
|
genres: set[str] | None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_row(cls, row):
|
def from_row(cls, row):
|
||||||
|
|
@ -100,7 +101,7 @@ title_types = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def gz_mtime(path) -> datetime:
|
def gz_mtime(path: Path) -> datetime:
|
||||||
"""Return the timestamp of the compressed file."""
|
"""Return the timestamp of the compressed file."""
|
||||||
g = gzip.GzipFile(path, "rb")
|
g = gzip.GzipFile(path, "rb")
|
||||||
g.peek(1) # start reading the file to fill the timestamp field
|
g.peek(1) # start reading the file to fill the timestamp field
|
||||||
|
|
@ -108,14 +109,13 @@ def gz_mtime(path) -> datetime:
|
||||||
return datetime.fromtimestamp(g.mtime).replace(tzinfo=timezone.utc)
|
return datetime.fromtimestamp(g.mtime).replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
def count_lines(path) -> int:
|
def count_lines(path: Path) -> int:
|
||||||
i = 0
|
i = 0
|
||||||
|
|
||||||
one_mb = 2**20
|
one_mb = 2**20
|
||||||
buf_size = 8 * one_mb # 8 MiB seems to give a good read/process performance.
|
buf_size = 8 * one_mb # 8 MiB seems to give a good read/process performance.
|
||||||
|
|
||||||
with gzip.open(path, "rt") as f:
|
with gzip.open(path, "rt") as f:
|
||||||
|
|
||||||
while buf := f.read(buf_size):
|
while buf := f.read(buf_size):
|
||||||
i += buf.count("\n")
|
i += buf.count("\n")
|
||||||
|
|
||||||
|
|
@ -124,19 +124,19 @@ def count_lines(path) -> int:
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def read_imdb_tsv(
|
def read_imdb_tsv(
|
||||||
path, row_type, *, unpack: Literal[False]
|
path: Path, row_type, *, unpack: Literal[False]
|
||||||
) -> Generator[list[str], None, None]:
|
) -> Generator[list[str], None, None]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def read_imdb_tsv(
|
def read_imdb_tsv(
|
||||||
path, row_type: Type[T], *, unpack: Literal[True] = True
|
path: Path, row_type: Type[T], *, unpack: Literal[True] = True
|
||||||
) -> Generator[T, None, None]:
|
) -> Generator[T, None, None]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
def read_imdb_tsv(path, row_type, *, unpack=True):
|
def read_imdb_tsv(path: Path, row_type, *, unpack=True):
|
||||||
with gzip.open(path, "rt", newline="") as f:
|
with gzip.open(path, "rt", newline="") as f:
|
||||||
rows = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
|
rows = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
|
||||||
|
|
||||||
|
|
@ -161,7 +161,7 @@ def read_imdb_tsv(path, row_type, *, unpack=True):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def read_ratings(path):
|
def read_ratings(path: Path):
|
||||||
mtime = gz_mtime(path)
|
mtime = gz_mtime(path)
|
||||||
rows = read_imdb_tsv(path, RatingRow)
|
rows = read_imdb_tsv(path, RatingRow)
|
||||||
|
|
||||||
|
|
@ -171,19 +171,20 @@ def read_ratings(path):
|
||||||
yield m
|
yield m
|
||||||
|
|
||||||
|
|
||||||
def read_ratings_as_mapping(path):
|
def read_ratings_as_mapping(path: Path):
|
||||||
"""Optimized function to quickly load all ratings."""
|
"""Optimized function to quickly load all ratings."""
|
||||||
rows = read_imdb_tsv(path, RatingRow, unpack=False)
|
rows = read_imdb_tsv(path, RatingRow, unpack=False)
|
||||||
return {r[0]: (round(100 * (float(r[1]) - 1) / 9), int(r[2])) for r in rows}
|
return {r[0]: (round(100 * (float(r[1]) - 1) / 9), int(r[2])) for r in rows}
|
||||||
|
|
||||||
|
|
||||||
def read_basics(path):
|
def read_basics(path: Path) -> Generator[Movie | None, None, None]:
|
||||||
mtime = gz_mtime(path)
|
mtime = gz_mtime(path)
|
||||||
rows = read_imdb_tsv(path, BasicRow)
|
rows = read_imdb_tsv(path, BasicRow)
|
||||||
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
if row.startYear is None:
|
if row.startYear is None:
|
||||||
log.debug("Skipping movie, missing year: %s", row)
|
log.debug("Skipping movie, missing year: %s", row)
|
||||||
|
yield None
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = row.as_movie()
|
m = row.as_movie()
|
||||||
|
|
@ -197,20 +198,24 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path):
|
||||||
|
|
||||||
log.info("💾 Importing movies ...")
|
log.info("💾 Importing movies ...")
|
||||||
total = count_lines(basics_path)
|
total = count_lines(basics_path)
|
||||||
assert total != 0
|
log.debug("Found %i movies.", total)
|
||||||
|
if total == 0:
|
||||||
|
raise RuntimeError(f"No movies found.")
|
||||||
perc_next_report = 0.0
|
perc_next_report = 0.0
|
||||||
perc_step = 0.1
|
perc_step = 0.1
|
||||||
|
|
||||||
chunk = []
|
chunk = []
|
||||||
|
|
||||||
for i, m in enumerate(read_basics(basics_path)):
|
for i, m in enumerate(read_basics(basics_path)):
|
||||||
|
|
||||||
perc = 100 * i / total
|
perc = 100 * i / total
|
||||||
if perc >= perc_next_report:
|
if perc >= perc_next_report:
|
||||||
await db.set_import_progress(perc)
|
await db.set_import_progress(perc)
|
||||||
log.info("⏳ Imported %s%%", round(perc, 1))
|
log.info("⏳ Imported %s%%", round(perc, 1))
|
||||||
perc_next_report += perc_step
|
perc_next_report += perc_step
|
||||||
|
|
||||||
|
if m is None:
|
||||||
|
continue
|
||||||
|
|
||||||
if m.media_type not in {
|
if m.media_type not in {
|
||||||
"Movie",
|
"Movie",
|
||||||
"Short",
|
"Short",
|
||||||
|
|
@ -235,10 +240,27 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path):
|
||||||
await add_or_update_many_movies(chunk)
|
await add_or_update_many_movies(chunk)
|
||||||
chunk = []
|
chunk = []
|
||||||
|
|
||||||
|
log.info("👍 Imported 100%")
|
||||||
await db.set_import_progress(100)
|
await db.set_import_progress(100)
|
||||||
|
|
||||||
|
|
||||||
async def load_from_web(*, force: bool = False):
|
async def download_datasets(*, basics_path: Path, ratings_path: Path) -> None:
|
||||||
|
"""Download IMDb movie database dumps.
|
||||||
|
|
||||||
|
See https://www.imdb.com/interfaces/ and https://datasets.imdbws.com/ for
|
||||||
|
more information on the IMDb database dumps.
|
||||||
|
"""
|
||||||
|
basics_url = "https://datasets.imdbws.com/title.basics.tsv.gz"
|
||||||
|
ratings_url = "https://datasets.imdbws.com/title.ratings.tsv.gz"
|
||||||
|
|
||||||
|
async with request.asession():
|
||||||
|
await asyncio.gather(
|
||||||
|
request.adownload(ratings_url, to_path=ratings_path, only_if_newer=True),
|
||||||
|
request.adownload(basics_url, to_path=basics_path, only_if_newer=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def load_from_web(*, force: bool = False) -> None:
|
||||||
"""Refresh the full IMDb movie database.
|
"""Refresh the full IMDb movie database.
|
||||||
|
|
||||||
The latest dumps are first downloaded and then imported into the database.
|
The latest dumps are first downloaded and then imported into the database.
|
||||||
|
|
@ -251,17 +273,13 @@ async def load_from_web(*, force: bool = False):
|
||||||
await db.set_import_progress(0)
|
await db.set_import_progress(0)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
basics_url = "https://datasets.imdbws.com/title.basics.tsv.gz"
|
|
||||||
ratings_url = "https://datasets.imdbws.com/title.ratings.tsv.gz"
|
|
||||||
ratings_file = config.datadir / "imdb/title.ratings.tsv.gz"
|
ratings_file = config.datadir / "imdb/title.ratings.tsv.gz"
|
||||||
basics_file = config.datadir / "imdb/title.basics.tsv.gz"
|
basics_file = config.datadir / "imdb/title.basics.tsv.gz"
|
||||||
|
|
||||||
ratings_mtime = ratings_file.stat().st_mtime if ratings_file.exists() else None
|
ratings_mtime = ratings_file.stat().st_mtime if ratings_file.exists() else None
|
||||||
bastics_mtime = basics_file.stat().st_mtime if basics_file.exists() else None
|
bastics_mtime = basics_file.stat().st_mtime if basics_file.exists() else None
|
||||||
|
|
||||||
with request.session():
|
await download_datasets(basics_path=basics_file, ratings_path=ratings_file)
|
||||||
request.download(ratings_url, ratings_file, only_if_newer=True)
|
|
||||||
request.download(basics_url, basics_file, only_if_newer=True)
|
|
||||||
|
|
||||||
is_changed = (
|
is_changed = (
|
||||||
ratings_mtime != ratings_file.stat().st_mtime
|
ratings_mtime != ratings_file.stat().st_mtime
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,14 @@ from dataclasses import dataclass, field
|
||||||
from dataclasses import fields as _fields
|
from dataclasses import fields as _fields
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from types import UnionType
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Container,
|
Container,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Mapping,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
|
|
@ -19,13 +20,13 @@ from typing import (
|
||||||
|
|
||||||
from .types import ULID
|
from .types import ULID
|
||||||
|
|
||||||
JSON = Union[int, float, str, None, list["JSON"], dict[str, "JSON"]]
|
JSON = int | float | str | None | list["JSON"] | dict[str, "JSON"]
|
||||||
JSONObject = dict[str, JSON]
|
JSONObject = dict[str, JSON]
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def annotations(tp: Type) -> Optional[tuple]:
|
def annotations(tp: Type) -> tuple | None:
|
||||||
return tp.__metadata__ if hasattr(tp, "__metadata__") else None
|
return tp.__metadata__ if hasattr(tp, "__metadata__") else None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -42,7 +43,6 @@ def fields(class_or_instance):
|
||||||
# XXX this might be a little slow (not sure), if so, memoize
|
# XXX this might be a little slow (not sure), if so, memoize
|
||||||
|
|
||||||
for f in _fields(class_or_instance):
|
for f in _fields(class_or_instance):
|
||||||
|
|
||||||
if f.name == "_is_lazy":
|
if f.name == "_is_lazy":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -54,21 +54,21 @@ def fields(class_or_instance):
|
||||||
|
|
||||||
def is_optional(tp: Type) -> bool:
|
def is_optional(tp: Type) -> bool:
|
||||||
"""Return wether the given type is optional."""
|
"""Return wether the given type is optional."""
|
||||||
if get_origin(tp) is not Union:
|
if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
args = get_args(tp)
|
args = get_args(tp)
|
||||||
return len(args) == 2 and type(None) in args
|
return len(args) == 2 and type(None) in args
|
||||||
|
|
||||||
|
|
||||||
def optional_type(tp: Type) -> Optional[Type]:
|
def optional_type(tp: Type) -> Type | None:
|
||||||
"""Return the wrapped type from an optional type.
|
"""Return the wrapped type from an optional type.
|
||||||
|
|
||||||
For example this will return `int` for `Optional[int]`.
|
For example this will return `int` for `Optional[int]`.
|
||||||
Since they're equivalent this also works for other optioning notations, like
|
Since they're equivalent this also works for other optioning notations, like
|
||||||
`Union[int, None]` and `int | None`.
|
`Union[int, None]` and `int | None`.
|
||||||
"""
|
"""
|
||||||
if get_origin(tp) is not Union:
|
if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
args = get_args(tp)
|
args = get_args(tp)
|
||||||
|
|
@ -92,7 +92,7 @@ def _id(x: T) -> T:
|
||||||
|
|
||||||
|
|
||||||
def asplain(
|
def asplain(
|
||||||
o: object, *, filter_fields: Container[str] = None, serialize: bool = False
|
o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Return the given model instance as `dict` with JSON compatible plain datatypes.
|
"""Return the given model instance as `dict` with JSON compatible plain datatypes.
|
||||||
|
|
||||||
|
|
@ -109,7 +109,6 @@ def asplain(
|
||||||
|
|
||||||
d: JSONObject = {}
|
d: JSONObject = {}
|
||||||
for f in fields(o):
|
for f in fields(o):
|
||||||
|
|
||||||
if filter_fields is not None and f.name not in filter_fields:
|
if filter_fields is not None and f.name not in filter_fields:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -146,7 +145,7 @@ def asplain(
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def fromplain(cls: Type[T], d: dict[str, Any], *, serialized: bool = False) -> T:
|
def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
|
||||||
"""Return an instance of the given model using the given data.
|
"""Return an instance of the given model using the given data.
|
||||||
|
|
||||||
If `serialized` is `True`, collection types (lists, dicts, etc.) will be
|
If `serialized` is `True`, collection types (lists, dicts, etc.) will be
|
||||||
|
|
@ -157,7 +156,6 @@ def fromplain(cls: Type[T], d: dict[str, Any], *, serialized: bool = False) -> T
|
||||||
|
|
||||||
dd: JSONObject = {}
|
dd: JSONObject = {}
|
||||||
for f in fields(cls):
|
for f in fields(cls):
|
||||||
|
|
||||||
target = f.type
|
target = f.type
|
||||||
otype = optional_type(f.type)
|
otype = optional_type(f.type)
|
||||||
is_opt = otype is not None
|
is_opt = otype is not None
|
||||||
|
|
@ -188,7 +186,8 @@ def validate(o: object) -> None:
|
||||||
vtype = type(getattr(o, f.name))
|
vtype = type(getattr(o, f.name))
|
||||||
if vtype is not f.type:
|
if vtype is not f.type:
|
||||||
if get_origin(f.type) is vtype or (
|
if get_origin(f.type) is vtype or (
|
||||||
get_origin(f.type) is Union and vtype in get_args(f.type)
|
(isinstance(f.type, UnionType) or get_origin(f.type) is Union)
|
||||||
|
and vtype in get_args(f.type)
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
raise ValueError(f"Invalid value type: {f.name}: {vtype}")
|
raise ValueError(f"Invalid value type: {f.name}: {vtype}")
|
||||||
|
|
@ -206,7 +205,7 @@ class Progress:
|
||||||
type: str = None
|
type: str = None
|
||||||
state: str = None
|
state: str = None
|
||||||
started: datetime = field(default_factory=utcnow)
|
started: datetime = field(default_factory=utcnow)
|
||||||
stopped: Optional[str] = None
|
stopped: str | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _state(self) -> dict:
|
def _state(self) -> dict:
|
||||||
|
|
@ -243,15 +242,15 @@ class Movie:
|
||||||
|
|
||||||
id: ULID = field(default_factory=ULID)
|
id: ULID = field(default_factory=ULID)
|
||||||
title: str = None # canonical title (usually English)
|
title: str = None # canonical title (usually English)
|
||||||
original_title: Optional[
|
original_title: str | None = (
|
||||||
str
|
None # original title (usually transscribed to latin script)
|
||||||
] = None # original title (usually transscribed to latin script)
|
)
|
||||||
release_year: int = None # canonical release date
|
release_year: int = None # canonical release date
|
||||||
media_type: str = None
|
media_type: str = None
|
||||||
imdb_id: str = None
|
imdb_id: str = None
|
||||||
imdb_score: Optional[int] = None # range: [0,100]
|
imdb_score: int | None = None # range: [0,100]
|
||||||
imdb_votes: Optional[int] = None
|
imdb_votes: int | None = None
|
||||||
runtime: Optional[int] = None # minutes
|
runtime: int | None = None # minutes
|
||||||
genres: set[str] = None
|
genres: set[str] = None
|
||||||
created: datetime = field(default_factory=utcnow)
|
created: datetime = field(default_factory=utcnow)
|
||||||
updated: datetime = field(default_factory=utcnow)
|
updated: datetime = field(default_factory=utcnow)
|
||||||
|
|
@ -292,7 +291,7 @@ dataclass containing the ID of the linked data.
|
||||||
The contents of the Relation are ignored or discarded when using
|
The contents of the Relation are ignored or discarded when using
|
||||||
`asplain`, `fromplain`, and `validate`.
|
`asplain`, `fromplain`, and `validate`.
|
||||||
"""
|
"""
|
||||||
Relation = Annotated[Optional[T], _RelationSentinel]
|
Relation = Annotated[T | None, _RelationSentinel]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -309,8 +308,8 @@ class Rating:
|
||||||
|
|
||||||
score: int = None # range: [0,100]
|
score: int = None # range: [0,100]
|
||||||
rating_date: datetime = None
|
rating_date: datetime = None
|
||||||
favorite: Optional[bool] = None
|
favorite: bool | None = None
|
||||||
finished: Optional[bool] = None
|
finished: bool | None = None
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
"""Return wether two Ratings are equal.
|
"""Return wether two Ratings are equal.
|
||||||
|
|
@ -342,11 +341,11 @@ class User:
|
||||||
secret: str = None
|
secret: str = None
|
||||||
groups: list[dict[str, str]] = field(default_factory=list)
|
groups: list[dict[str, str]] = field(default_factory=list)
|
||||||
|
|
||||||
def has_access(self, group_id: Union[ULID, str], access: Access = "r"):
|
def has_access(self, group_id: ULID | str, access: Access = "r"):
|
||||||
group_id = group_id if isinstance(group_id, str) else str(group_id)
|
group_id = group_id if isinstance(group_id, str) else str(group_id)
|
||||||
return any(g["id"] == group_id and access == g["access"] for g in self.groups)
|
return any(g["id"] == group_id and access == g["access"] for g in self.groups)
|
||||||
|
|
||||||
def set_access(self, group_id: Union[ULID, str], access: Access):
|
def set_access(self, group_id: ULID | str, access: Access):
|
||||||
group_id = group_id if isinstance(group_id, str) else str(group_id)
|
group_id = group_id if isinstance(group_id, str) else str(group_id)
|
||||||
for g in self.groups:
|
for g in self.groups:
|
||||||
if g["id"] == group_id:
|
if g["id"] == group_id:
|
||||||
|
|
|
||||||
|
|
@ -4,19 +4,17 @@ import logging
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import contextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from random import random
|
from random import random
|
||||||
from time import sleep, time
|
from time import sleep, time
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, ParamSpec, TypeVar, cast
|
||||||
|
|
||||||
import bs4
|
import bs4
|
||||||
import requests
|
import httpx
|
||||||
from requests.status_codes import codes
|
|
||||||
from urllib3.util.retry import Retry
|
|
||||||
|
|
||||||
from . import config
|
from . import config
|
||||||
|
|
||||||
|
|
@ -26,28 +24,17 @@ if config.debug and config.cachedir:
|
||||||
config.cachedir.mkdir(exist_ok=True)
|
config.cachedir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
def set_retries(s: requests.Session, n: int, backoff_factor: float = 0.2):
|
_shared_asession = None
|
||||||
retry = (
|
|
||||||
Retry(
|
_ASession_T = httpx.AsyncClient
|
||||||
total=n,
|
_Response_T = httpx.Response
|
||||||
connect=n,
|
|
||||||
read=n,
|
_T = TypeVar("_T")
|
||||||
status=n,
|
_P = ParamSpec("_P")
|
||||||
status_forcelist=Retry.RETRY_AFTER_STATUS_CODES,
|
|
||||||
backoff_factor=backoff_factor,
|
|
||||||
)
|
|
||||||
if n
|
|
||||||
else Retry(0, read=False)
|
|
||||||
)
|
|
||||||
for a in s.adapters.values():
|
|
||||||
a.max_retries = retry
|
|
||||||
|
|
||||||
|
|
||||||
_shared_session = None
|
@asynccontextmanager
|
||||||
|
async def asession():
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def session():
|
|
||||||
"""Return the shared request session.
|
"""Return the shared request session.
|
||||||
|
|
||||||
The session is shared by all request functions and provides cookie
|
The session is shared by all request functions and provides cookie
|
||||||
|
|
@ -55,38 +42,34 @@ def session():
|
||||||
Opening the session before making a request allows you to set headers
|
Opening the session before making a request allows you to set headers
|
||||||
or change the retry behavior.
|
or change the retry behavior.
|
||||||
"""
|
"""
|
||||||
global _shared_session
|
global _shared_asession
|
||||||
|
|
||||||
if _shared_session:
|
if _shared_asession:
|
||||||
yield _shared_session
|
yield _shared_asession
|
||||||
return
|
return
|
||||||
|
|
||||||
_shared_session = Session()
|
_shared_asession = _ASession_T()
|
||||||
|
_shared_asession.headers[
|
||||||
|
"user-agent"
|
||||||
|
] = "Mozilla/5.0 Gecko/20100101 unwind/20230203"
|
||||||
try:
|
try:
|
||||||
yield _shared_session
|
async with _shared_asession:
|
||||||
|
yield _shared_asession
|
||||||
finally:
|
finally:
|
||||||
_shared_session = None
|
_shared_asession = None
|
||||||
|
|
||||||
|
|
||||||
def Session() -> requests.Session:
|
def _throttle(
|
||||||
s = requests.Session()
|
times: int, per_seconds: float, jitter: Callable[[], float] | None = None
|
||||||
s.headers["User-Agent"] = "Mozilla/5.0 Gecko/20100101 unwind/20210506"
|
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||||
return s
|
calls: deque[float] = deque(maxlen=times)
|
||||||
|
|
||||||
|
|
||||||
def throttle(
|
|
||||||
times: int, per_seconds: float, jitter: Callable[[], float] = None
|
|
||||||
) -> Callable[[Callable], Callable]:
|
|
||||||
|
|
||||||
calls: Deque[float] = deque(maxlen=times)
|
|
||||||
|
|
||||||
if jitter is None:
|
if jitter is None:
|
||||||
jitter = lambda: 0.0
|
jitter = lambda: 0.0
|
||||||
|
|
||||||
def decorator(func: Callable) -> Callable:
|
def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def inner(*args, **kwds):
|
def inner(*args: _P.args, **kwds: _P.kwargs):
|
||||||
|
|
||||||
# clean up
|
# clean up
|
||||||
while calls:
|
while calls:
|
||||||
if calls[0] + per_seconds > time():
|
if calls[0] + per_seconds > time():
|
||||||
|
|
@ -118,23 +101,19 @@ def throttle(
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
class CachedStr(str):
|
|
||||||
is_cached = True
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CachedResponse:
|
class _CachedResponse:
|
||||||
is_cached = True
|
is_cached = True
|
||||||
status_code: int
|
status_code: int
|
||||||
text: str
|
text: str
|
||||||
url: str
|
url: str
|
||||||
headers: dict[str, str] = None
|
headers: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
def json(self):
|
def json(self):
|
||||||
return json.loads(self.text)
|
return json.loads(self.text)
|
||||||
|
|
||||||
|
|
||||||
class RedirectError(RuntimeError):
|
class _RedirectError(RuntimeError):
|
||||||
def __init__(self, from_url: str, to_url: str, is_cached=False):
|
def __init__(self, from_url: str, to_url: str, is_cached=False):
|
||||||
self.from_url = from_url
|
self.from_url = from_url
|
||||||
self.to_url = to_url
|
self.to_url = to_url
|
||||||
|
|
@ -142,44 +121,51 @@ class RedirectError(RuntimeError):
|
||||||
super().__init__(f"Redirected: {from_url} -> {to_url}")
|
super().__init__(f"Redirected: {from_url} -> {to_url}")
|
||||||
|
|
||||||
|
|
||||||
def cache_path(req) -> Optional[Path]:
|
def cache_path(req) -> Path | None:
|
||||||
if not config.cachedir:
|
if not config.cachedir:
|
||||||
return
|
return
|
||||||
sig = repr(req.url) # + repr(sorted(req.headers.items()))
|
sig = repr(req.url) # + repr(sorted(req.headers.items()))
|
||||||
return config.cachedir / md5(sig.encode()).hexdigest()
|
return config.cachedir / md5(sig.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
@throttle(1, 1, random)
|
@_throttle(1, 1, random)
|
||||||
def http_get(s: requests.Session, url: str, *args, **kwds) -> requests.Response:
|
async def _ahttp_get(s: _ASession_T, url: str, *args, **kwds) -> _Response_T:
|
||||||
|
req = s.build_request(method="GET", url=url, *args, **kwds)
|
||||||
req = s.prepare_request(requests.Request("GET", url, *args, **kwds))
|
|
||||||
|
|
||||||
cachefile = cache_path(req) if config.debug else None
|
cachefile = cache_path(req) if config.debug else None
|
||||||
|
|
||||||
if cachefile:
|
if cachefile:
|
||||||
if cachefile.exists():
|
if cachefile.exists():
|
||||||
log.debug(
|
log.debug(
|
||||||
f"💾 loading {req.url} ({req.headers!a}) from cache {cachefile} ..."
|
"💾 loading %s (%a) from cache %s ...", req.url, req.headers, cachefile
|
||||||
)
|
)
|
||||||
with cachefile.open() as fp:
|
with cachefile.open() as fp:
|
||||||
resp = CachedResponse(**json.load(fp))
|
resp = _CachedResponse(**json.load(fp))
|
||||||
if 300 <= resp.status_code <= 399:
|
if 300 <= resp.status_code <= 399:
|
||||||
raise RedirectError(
|
raise _RedirectError(
|
||||||
from_url=resp.url, to_url=resp.headers["location"], is_cached=True
|
from_url=resp.url, to_url=resp.headers["location"], is_cached=True
|
||||||
)
|
)
|
||||||
return resp
|
return cast(_Response_T, resp)
|
||||||
|
|
||||||
log.debug(f"⚡️ loading {req.url} ({req.headers!a}) ...")
|
log.debug("⚡️ loading %s (%a) ...", req.url, req.headers)
|
||||||
resp = s.send(req, allow_redirects=False, stream=True)
|
resp = await s.send(req, follow_redirects=False, stream=True)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
await resp.aread() # Download the response stream to allow `resp.text` access.
|
||||||
|
|
||||||
if cachefile:
|
if cachefile:
|
||||||
|
log.debug(
|
||||||
|
"💾 writing response to cache: %s (%a) -> %s",
|
||||||
|
req.url,
|
||||||
|
req.headers,
|
||||||
|
cachefile,
|
||||||
|
)
|
||||||
with cachefile.open("w") as fp:
|
with cachefile.open("w") as fp:
|
||||||
json.dump(
|
json.dump(
|
||||||
{
|
{
|
||||||
"status_code": resp.status_code,
|
"status_code": resp.status_code,
|
||||||
"text": resp.text,
|
"text": resp.text,
|
||||||
"url": resp.url,
|
"url": str(resp.url),
|
||||||
"headers": dict(resp.headers),
|
"headers": dict(resp.headers),
|
||||||
},
|
},
|
||||||
fp,
|
fp,
|
||||||
|
|
@ -187,45 +173,46 @@ def http_get(s: requests.Session, url: str, *args, **kwds) -> requests.Response:
|
||||||
|
|
||||||
if resp.is_redirect:
|
if resp.is_redirect:
|
||||||
# Redirects could mean trouble, we need to stay on top of that!
|
# Redirects could mean trouble, we need to stay on top of that!
|
||||||
raise RedirectError(from_url=resp.url, to_url=resp.headers["location"])
|
raise _RedirectError(from_url=str(resp.url), to_url=resp.headers["location"])
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
def soup_from_url(url):
|
async def asoup_from_url(url):
|
||||||
"""Return a BeautifulSoup instance from the contents for the given URL."""
|
"""Return a BeautifulSoup instance from the contents for the given URL."""
|
||||||
with session() as s:
|
async with asession() as s:
|
||||||
r = http_get(s, url)
|
r = await _ahttp_get(s, url)
|
||||||
|
|
||||||
soup = bs4.BeautifulSoup(r.text, "html5lib")
|
soup = bs4.BeautifulSoup(r.text, "html5lib")
|
||||||
return soup
|
return soup
|
||||||
|
|
||||||
|
|
||||||
def last_modified_from_response(resp):
|
def _last_modified_from_response(resp: _Response_T) -> float | None:
|
||||||
if last_mod := resp.headers.get("Last-Modified"):
|
if last_mod := resp.headers.get("last-modified"):
|
||||||
try:
|
try:
|
||||||
return email.utils.parsedate_to_datetime(last_mod).timestamp()
|
return email.utils.parsedate_to_datetime(last_mod).timestamp()
|
||||||
except:
|
except:
|
||||||
log.exception("🐛 Received invalid value for Last-Modified: %s", last_mod)
|
log.exception("🐛 Received invalid value for Last-Modified: %s", last_mod)
|
||||||
|
|
||||||
|
|
||||||
def last_modified_from_file(path: Path):
|
def _last_modified_from_file(path: Path) -> float:
|
||||||
return path.stat().st_mtime
|
return path.stat().st_mtime
|
||||||
|
|
||||||
|
|
||||||
def download(
|
async def adownload(
|
||||||
url: str,
|
url: str,
|
||||||
file_path: Union[Path, str] = None,
|
|
||||||
*,
|
*,
|
||||||
replace_existing: bool = None,
|
to_path: Path | str | None = None,
|
||||||
|
replace_existing: bool | None = None,
|
||||||
only_if_newer: bool = False,
|
only_if_newer: bool = False,
|
||||||
timeout: float = None,
|
timeout: float | None = None,
|
||||||
verify_ssl: bool = True,
|
|
||||||
chunk_callback=None,
|
chunk_callback=None,
|
||||||
response_callback=None,
|
response_callback=None,
|
||||||
):
|
) -> bytes | None:
|
||||||
"""Download a file.
|
"""Download a file.
|
||||||
|
|
||||||
|
If `to_path` is `None` return the remote content, otherwise write the
|
||||||
|
content to the given file path.
|
||||||
Existing files will not be overwritten unless `replace_existing` is set.
|
Existing files will not be overwritten unless `replace_existing` is set.
|
||||||
Setting `only_if_newer` will check if the remote file is newer than the
|
Setting `only_if_newer` will check if the remote file is newer than the
|
||||||
local file, otherwise the download will be aborted.
|
local file, otherwise the download will be aborted.
|
||||||
|
|
@ -234,50 +221,56 @@ def download(
|
||||||
replace_existing = only_if_newer
|
replace_existing = only_if_newer
|
||||||
|
|
||||||
file_exists = None
|
file_exists = None
|
||||||
if file_path is not None:
|
if to_path is not None:
|
||||||
file_path = Path(file_path)
|
to_path = Path(to_path)
|
||||||
|
|
||||||
file_exists = file_path.exists() and file_path.stat().st_size
|
file_exists = to_path.exists() and to_path.stat().st_size
|
||||||
if file_exists and not replace_existing:
|
if file_exists and not replace_existing:
|
||||||
raise FileExistsError(23, "Would replace existing file", str(file_path))
|
raise FileExistsError(23, "Would replace existing file", str(to_path))
|
||||||
|
|
||||||
with session() as s:
|
|
||||||
|
|
||||||
|
async with asession() as s:
|
||||||
headers = {}
|
headers = {}
|
||||||
if file_exists and only_if_newer:
|
if file_exists and only_if_newer:
|
||||||
assert file_path
|
assert to_path
|
||||||
file_lastmod = last_modified_from_file(file_path)
|
file_lastmod = _last_modified_from_file(to_path)
|
||||||
headers["If-Modified-Since"] = email.utils.formatdate(
|
headers["if-modified-since"] = email.utils.formatdate(
|
||||||
file_lastmod, usegmt=True
|
file_lastmod, usegmt=True
|
||||||
)
|
)
|
||||||
|
|
||||||
req = s.prepare_request(requests.Request("GET", url, headers=headers))
|
req = s.build_request(method="GET", url=url, headers=headers, timeout=timeout)
|
||||||
|
|
||||||
log.debug("⚡️ loading %s (%s) ...", req.url, req.headers)
|
log.debug("⚡️ Loading %s (%a) ...", req.url, dict(req.headers))
|
||||||
resp = s.send(
|
resp = await s.send(req, follow_redirects=True, stream=True)
|
||||||
req, allow_redirects=True, stream=True, timeout=timeout, verify=verify_ssl
|
|
||||||
)
|
|
||||||
|
|
||||||
|
try:
|
||||||
if response_callback is not None:
|
if response_callback is not None:
|
||||||
try:
|
try:
|
||||||
response_callback(resp)
|
response_callback(resp)
|
||||||
except:
|
except:
|
||||||
log.exception("🐛 Error in response callback.")
|
log.exception("🐛 Error in response callback.")
|
||||||
|
|
||||||
log.debug("☕️ Response status: %s; headers: %s", resp.status_code, resp.headers)
|
log.debug(
|
||||||
|
"☕️ %s -> status: %s; headers: %a",
|
||||||
|
req.url,
|
||||||
|
resp.status_code,
|
||||||
|
dict(resp.headers),
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == httpx.codes.NOT_MODIFIED:
|
||||||
|
log.debug(
|
||||||
|
"✋ Remote file has not changed, skipping download: %s -> %a",
|
||||||
|
req.url,
|
||||||
|
to_path,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
|
|
||||||
if resp.status_code == codes.not_modified:
|
if to_path is None:
|
||||||
log.debug("✋ Remote file has not changed, skipping download.")
|
await resp.aread() # Download the response stream to allow `resp.content` access.
|
||||||
return
|
|
||||||
|
|
||||||
if file_path is None:
|
|
||||||
return resp.content
|
return resp.content
|
||||||
|
|
||||||
assert replace_existing is True
|
resp_lastmod = _last_modified_from_response(resp)
|
||||||
|
|
||||||
resp_lastmod = last_modified_from_response(resp)
|
|
||||||
|
|
||||||
# Check Last-Modified in case the server ignored If-Modified-Since.
|
# Check Last-Modified in case the server ignored If-Modified-Since.
|
||||||
# XXX also check Content-Length?
|
# XXX also check Content-Length?
|
||||||
|
|
@ -285,24 +278,23 @@ def download(
|
||||||
assert file_lastmod
|
assert file_lastmod
|
||||||
|
|
||||||
if resp_lastmod <= file_lastmod:
|
if resp_lastmod <= file_lastmod:
|
||||||
log.debug("✋ Local file is newer, skipping download.")
|
log.debug("✋ Local file is newer, skipping download: %a", req.url)
|
||||||
resp.close()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create intermediate directories if necessary.
|
# Create intermediate directories if necessary.
|
||||||
download_dir = file_path.parent
|
download_dir = to_path.parent
|
||||||
download_dir.mkdir(parents=True, exist_ok=True)
|
download_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Write content to temp file.
|
# Write content to temp file.
|
||||||
tempdir = download_dir
|
tempdir = download_dir
|
||||||
tempfd, tempfile_path = tempfile.mkstemp(
|
tempfd, tempfile_path = tempfile.mkstemp(
|
||||||
dir=tempdir, prefix=f".download-{file_path.name}."
|
dir=tempdir, prefix=f".download-{to_path.name}."
|
||||||
)
|
)
|
||||||
one_mb = 2**20
|
one_mb = 2**20
|
||||||
chunk_size = 8 * one_mb
|
chunk_size = 8 * one_mb
|
||||||
try:
|
try:
|
||||||
log.debug("💾 Writing to temp file %s ...", tempfile_path)
|
log.debug("💾 Writing to temp file %s ...", tempfile_path)
|
||||||
for chunk in resp.iter_content(chunk_size=chunk_size, decode_unicode=False):
|
async for chunk in resp.aiter_bytes(chunk_size):
|
||||||
os.write(tempfd, chunk)
|
os.write(tempfd, chunk)
|
||||||
if chunk_callback:
|
if chunk_callback:
|
||||||
try:
|
try:
|
||||||
|
|
@ -313,10 +305,19 @@ def download(
|
||||||
os.close(tempfd)
|
os.close(tempfd)
|
||||||
|
|
||||||
# Move downloaded file to destination.
|
# Move downloaded file to destination.
|
||||||
if file_exists:
|
if to_path.exists():
|
||||||
log.debug("💾 Replacing existing file: %s", file_path)
|
log.debug("💾 Replacing existing file: %s", to_path)
|
||||||
Path(tempfile_path).replace(file_path)
|
else:
|
||||||
|
log.debug("💾 Move to destination: %s", to_path)
|
||||||
|
if replace_existing:
|
||||||
|
Path(tempfile_path).replace(to_path)
|
||||||
|
else:
|
||||||
|
Path(tempfile_path).rename(to_path)
|
||||||
|
|
||||||
# Fix file attributes.
|
# Fix file attributes.
|
||||||
if resp_lastmod is not None:
|
if resp_lastmod is not None:
|
||||||
os.utime(file_path, (resp_lastmod, resp_lastmod))
|
log.debug("💾 Adjusting file timestamp: %s (%s)", to_path, resp_lastmod)
|
||||||
|
os.utime(to_path, (resp_lastmod, resp_lastmod))
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await resp.aclose()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import re
|
import re
|
||||||
from typing import Union, cast
|
from typing import cast
|
||||||
|
|
||||||
import ulid
|
import ulid
|
||||||
from ulid.hints import Buffer
|
from ulid.hints import Buffer
|
||||||
|
|
@ -16,7 +16,7 @@ class ULID(ulid.ULID):
|
||||||
|
|
||||||
_pattern = re.compile(r"^[0-9A-HJKMNP-TV-Z]{26}$")
|
_pattern = re.compile(r"^[0-9A-HJKMNP-TV-Z]{26}$")
|
||||||
|
|
||||||
def __init__(self, buffer: Union[Buffer, ulid.ULID, str, None] = None):
|
def __init__(self, buffer: Buffer | ulid.ULID | str | None = None):
|
||||||
if isinstance(buffer, str):
|
if isinstance(buffer, str):
|
||||||
if not self._pattern.search(buffer):
|
if not self._pattern.search(buffer):
|
||||||
raise ValueError("Invalid ULID.")
|
raise ValueError("Invalid ULID.")
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,10 @@ def b64padded(s: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def phc_scrypt(
|
def phc_scrypt(
|
||||||
secret: bytes, *, salt: bytes = None, params: dict[Literal["n", "r", "p"], int] = {}
|
secret: bytes,
|
||||||
|
*,
|
||||||
|
salt: bytes | None = None,
|
||||||
|
params: dict[Literal["n", "r", "p"], int] = {},
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Return the scrypt expanded secret in PHC string format.
|
"""Return the scrypt expanded secret in PHC string format.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
from json.decoder import JSONDecodeError
|
from json.decoder import JSONDecodeError
|
||||||
from typing import Literal, Optional, overload
|
from typing import Literal, overload
|
||||||
|
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.authentication import (
|
from starlette.authentication import (
|
||||||
|
|
@ -85,11 +86,14 @@ def truthy(s: str):
|
||||||
return bool(s) and s.lower() in {"1", "yes", "true"}
|
return bool(s) and s.lower() in {"1", "yes", "true"}
|
||||||
|
|
||||||
|
|
||||||
def yearcomp(s: str):
|
_Yearcomp = Literal["<", "=", ">"]
|
||||||
|
|
||||||
|
|
||||||
|
def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
|
||||||
if not s:
|
if not s:
|
||||||
return
|
return
|
||||||
|
|
||||||
comp: Literal["<", "=", ">"] = "="
|
comp: _Yearcomp = "="
|
||||||
if (prefix := s[0]) in "<=>":
|
if (prefix := s[0]) in "<=>":
|
||||||
comp = prefix # type: ignore
|
comp = prefix # type: ignore
|
||||||
s = s[len(prefix) :]
|
s = s[len(prefix) :]
|
||||||
|
|
@ -97,7 +101,9 @@ def yearcomp(s: str):
|
||||||
return comp, int(s)
|
return comp, int(s)
|
||||||
|
|
||||||
|
|
||||||
def as_int(x, *, max: int = None, min: Optional[int] = 1, default: int = None):
|
def as_int(
|
||||||
|
x, *, max: int | None = None, min: int | None = 1, default: int | None = None
|
||||||
|
) -> int:
|
||||||
try:
|
try:
|
||||||
if not isinstance(x, int):
|
if not isinstance(x, int):
|
||||||
x = int(x)
|
x = int(x)
|
||||||
|
|
@ -135,7 +141,7 @@ async def json_from_body(request, keys: list[str]) -> list:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
async def json_from_body(request, keys: list[str] = None):
|
async def json_from_body(request, keys: list[str] | None = None):
|
||||||
if not await request.body():
|
if not await request.body():
|
||||||
data = {}
|
data = {}
|
||||||
|
|
||||||
|
|
@ -158,7 +164,7 @@ def is_admin(request):
|
||||||
return "admin" in request.auth.scopes
|
return "admin" in request.auth.scopes
|
||||||
|
|
||||||
|
|
||||||
async def auth_user(request) -> Optional[User]:
|
async def auth_user(request) -> User | None:
|
||||||
if not isinstance(request.user, AuthedUser):
|
if not isinstance(request.user, AuthedUser):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -176,7 +182,7 @@ async def auth_user(request) -> Optional[User]:
|
||||||
_routes = []
|
_routes = []
|
||||||
|
|
||||||
|
|
||||||
def route(path: str, *, methods: list[str] = None, **kwds):
|
def route(path: str, *, methods: list[str] | None = None, **kwds):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
r = Route(path, func, methods=methods, **kwds)
|
r = Route(path, func, methods=methods, **kwds)
|
||||||
_routes.append(r)
|
_routes.append(r)
|
||||||
|
|
@ -190,7 +196,6 @@ route.registered = _routes
|
||||||
|
|
||||||
@route("/groups/{group_id}/ratings")
|
@route("/groups/{group_id}/ratings")
|
||||||
async def get_ratings_for_group(request):
|
async def get_ratings_for_group(request):
|
||||||
|
|
||||||
group_id = as_ulid(request.path_params["group_id"])
|
group_id = as_ulid(request.path_params["group_id"])
|
||||||
group = await db.get(Group, id=str(group_id))
|
group = await db.get(Group, id=str(group_id))
|
||||||
|
|
||||||
|
|
@ -251,7 +256,6 @@ def not_implemented():
|
||||||
@route("/movies")
|
@route("/movies")
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def list_movies(request):
|
async def list_movies(request):
|
||||||
|
|
||||||
params = request.query_params
|
params = request.query_params
|
||||||
|
|
||||||
user = await auth_user(request)
|
user = await auth_user(request)
|
||||||
|
|
@ -319,7 +323,6 @@ async def list_movies(request):
|
||||||
@route("/movies", methods=["POST"])
|
@route("/movies", methods=["POST"])
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def add_movie(request):
|
async def add_movie(request):
|
||||||
|
|
||||||
not_implemented()
|
not_implemented()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -361,7 +364,6 @@ _import_lock = asyncio.Lock()
|
||||||
@route("/movies/_reload_imdb", methods=["POST"])
|
@route("/movies/_reload_imdb", methods=["POST"])
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def load_imdb_movies(request):
|
async def load_imdb_movies(request):
|
||||||
|
|
||||||
params = request.query_params
|
params = request.query_params
|
||||||
force = truthy(params.get("force"))
|
force = truthy(params.get("force"))
|
||||||
|
|
||||||
|
|
@ -384,7 +386,6 @@ async def load_imdb_movies(request):
|
||||||
@route("/users")
|
@route("/users")
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def list_users(request):
|
async def list_users(request):
|
||||||
|
|
||||||
users = await db.get_all(User)
|
users = await db.get_all(User)
|
||||||
|
|
||||||
return JSONResponse([asplain(u) for u in users])
|
return JSONResponse([asplain(u) for u in users])
|
||||||
|
|
@ -393,7 +394,6 @@ async def list_users(request):
|
||||||
@route("/users", methods=["POST"])
|
@route("/users", methods=["POST"])
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def add_user(request):
|
async def add_user(request):
|
||||||
|
|
||||||
name, imdb_id = await json_from_body(request, ["name", "imdb_id"])
|
name, imdb_id = await json_from_body(request, ["name", "imdb_id"])
|
||||||
|
|
||||||
# XXX restrict name
|
# XXX restrict name
|
||||||
|
|
@ -415,7 +415,6 @@ async def add_user(request):
|
||||||
@route("/users/{user_id}")
|
@route("/users/{user_id}")
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def show_user(request):
|
async def show_user(request):
|
||||||
|
|
||||||
user_id = as_ulid(request.path_params["user_id"])
|
user_id = as_ulid(request.path_params["user_id"])
|
||||||
|
|
||||||
if is_admin(request):
|
if is_admin(request):
|
||||||
|
|
@ -444,7 +443,6 @@ async def show_user(request):
|
||||||
@route("/users/{user_id}", methods=["DELETE"])
|
@route("/users/{user_id}", methods=["DELETE"])
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def remove_user(request):
|
async def remove_user(request):
|
||||||
|
|
||||||
user_id = as_ulid(request.path_params["user_id"])
|
user_id = as_ulid(request.path_params["user_id"])
|
||||||
|
|
||||||
user = await db.get(User, id=str(user_id))
|
user = await db.get(User, id=str(user_id))
|
||||||
|
|
@ -462,7 +460,6 @@ async def remove_user(request):
|
||||||
@route("/users/{user_id}", methods=["PATCH"])
|
@route("/users/{user_id}", methods=["PATCH"])
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def modify_user(request):
|
async def modify_user(request):
|
||||||
|
|
||||||
user_id = as_ulid(request.path_params["user_id"])
|
user_id = as_ulid(request.path_params["user_id"])
|
||||||
|
|
||||||
if is_admin(request):
|
if is_admin(request):
|
||||||
|
|
@ -510,7 +507,6 @@ async def modify_user(request):
|
||||||
@route("/users/{user_id}/groups", methods=["POST"])
|
@route("/users/{user_id}/groups", methods=["POST"])
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def add_group_to_user(request):
|
async def add_group_to_user(request):
|
||||||
|
|
||||||
user_id = as_ulid(request.path_params["user_id"])
|
user_id = as_ulid(request.path_params["user_id"])
|
||||||
|
|
||||||
user = await db.get(User, id=str(user_id))
|
user = await db.get(User, id=str(user_id))
|
||||||
|
|
@ -535,21 +531,18 @@ async def add_group_to_user(request):
|
||||||
@route("/users/{user_id}/ratings")
|
@route("/users/{user_id}/ratings")
|
||||||
@requires(["private"])
|
@requires(["private"])
|
||||||
async def ratings_for_user(request):
|
async def ratings_for_user(request):
|
||||||
|
|
||||||
not_implemented()
|
not_implemented()
|
||||||
|
|
||||||
|
|
||||||
@route("/users/{user_id}/ratings", methods=["PUT"])
|
@route("/users/{user_id}/ratings", methods=["PUT"])
|
||||||
@requires("authenticated")
|
@requires("authenticated")
|
||||||
async def set_rating_for_user(request):
|
async def set_rating_for_user(request):
|
||||||
|
|
||||||
not_implemented()
|
not_implemented()
|
||||||
|
|
||||||
|
|
||||||
@route("/users/_reload_ratings", methods=["POST"])
|
@route("/users/_reload_ratings", methods=["POST"])
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def load_imdb_user_ratings(request):
|
async def load_imdb_user_ratings(request):
|
||||||
|
|
||||||
ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()]
|
ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()]
|
||||||
|
|
||||||
return JSONResponse({"new_ratings": [asplain(r) for r in ratings]})
|
return JSONResponse({"new_ratings": [asplain(r) for r in ratings]})
|
||||||
|
|
@ -558,7 +551,6 @@ async def load_imdb_user_ratings(request):
|
||||||
@route("/groups")
|
@route("/groups")
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def list_groups(request):
|
async def list_groups(request):
|
||||||
|
|
||||||
groups = await db.get_all(Group)
|
groups = await db.get_all(Group)
|
||||||
|
|
||||||
return JSONResponse([asplain(g) for g in groups])
|
return JSONResponse([asplain(g) for g in groups])
|
||||||
|
|
@ -567,7 +559,6 @@ async def list_groups(request):
|
||||||
@route("/groups", methods=["POST"])
|
@route("/groups", methods=["POST"])
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def add_group(request):
|
async def add_group(request):
|
||||||
|
|
||||||
(name,) = await json_from_body(request, ["name"])
|
(name,) = await json_from_body(request, ["name"])
|
||||||
|
|
||||||
# XXX restrict name
|
# XXX restrict name
|
||||||
|
|
@ -581,7 +572,6 @@ async def add_group(request):
|
||||||
@route("/groups/{group_id}/users", methods=["POST"])
|
@route("/groups/{group_id}/users", methods=["POST"])
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def add_user_to_group(request):
|
async def add_user_to_group(request):
|
||||||
|
|
||||||
group_id = as_ulid(request.path_params["group_id"])
|
group_id = as_ulid(request.path_params["group_id"])
|
||||||
group = await db.get(Group, id=str(group_id))
|
group = await db.get(Group, id=str(group_id))
|
||||||
|
|
||||||
|
|
@ -623,6 +613,13 @@ def auth_error(request, err):
|
||||||
return unauthorized(str(err))
|
return unauthorized(str(err))
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def lifespan(app: Starlette):
|
||||||
|
await open_connection_pool()
|
||||||
|
yield
|
||||||
|
await close_connection_pool()
|
||||||
|
|
||||||
|
|
||||||
def create_app():
|
def create_app():
|
||||||
if config.loglevel == "DEBUG":
|
if config.loglevel == "DEBUG":
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
|
|
@ -633,8 +630,7 @@ def create_app():
|
||||||
log.debug(f"Log level: {config.loglevel}")
|
log.debug(f"Log level: {config.loglevel}")
|
||||||
|
|
||||||
return Starlette(
|
return Starlette(
|
||||||
on_startup=[open_connection_pool],
|
lifespan=lifespan,
|
||||||
on_shutdown=[close_connection_pool],
|
|
||||||
routes=[
|
routes=[
|
||||||
Mount(f"{config.api_base}v1", routes=route.registered),
|
Mount(f"{config.api_base}v1", routes=route.registered),
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Container, Iterable, Optional
|
from typing import Container, Iterable
|
||||||
|
|
||||||
from . import imdb, models
|
from . import imdb, models
|
||||||
|
|
||||||
|
|
@ -10,17 +10,17 @@ Score100 = int # [0, 100]
|
||||||
@dataclass
|
@dataclass
|
||||||
class Rating:
|
class Rating:
|
||||||
canonical_title: str
|
canonical_title: str
|
||||||
imdb_score: Optional[Score100]
|
imdb_score: Score100 | None
|
||||||
imdb_votes: Optional[int]
|
imdb_votes: int | None
|
||||||
media_type: str
|
media_type: str
|
||||||
movie_imdb_id: str
|
movie_imdb_id: str
|
||||||
original_title: Optional[str]
|
original_title: str | None
|
||||||
release_year: int
|
release_year: int
|
||||||
user_id: Optional[str]
|
user_id: str | None
|
||||||
user_score: Optional[Score100]
|
user_score: Score100 | None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_movie(cls, movie: models.Movie, *, rating: models.Rating = None):
|
def from_movie(cls, movie: models.Movie, *, rating: models.Rating | None = None):
|
||||||
return cls(
|
return cls(
|
||||||
canonical_title=movie.title,
|
canonical_title=movie.title,
|
||||||
imdb_score=movie.imdb_score,
|
imdb_score=movie.imdb_score,
|
||||||
|
|
@ -37,11 +37,11 @@ class Rating:
|
||||||
@dataclass
|
@dataclass
|
||||||
class RatingAggregate:
|
class RatingAggregate:
|
||||||
canonical_title: str
|
canonical_title: str
|
||||||
imdb_score: Optional[Score100]
|
imdb_score: Score100 | None
|
||||||
imdb_votes: Optional[int]
|
imdb_votes: int | None
|
||||||
link: URL
|
link: URL
|
||||||
media_type: str
|
media_type: str
|
||||||
original_title: Optional[str]
|
original_title: str | None
|
||||||
user_scores: list[Score100]
|
user_scores: list[Score100]
|
||||||
year: int
|
year: int
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue