Merge branch 'feat/sqlalchemy'

This commit is contained in:
ducklet 2023-11-29 18:01:24 +01:00
commit 4fbdb26d9c
25 changed files with 2107 additions and 2036 deletions

2
.gitignore vendored
View file

@ -2,5 +2,5 @@
*.pyc
/.cache
/.pytest_cache
/build
/data/*
/requirements.txt

1
.python-version Normal file
View file

@ -0,0 +1 @@
3.12

View file

@ -1,4 +1,4 @@
FROM docker.io/library/python:3.11-alpine
FROM docker.io/library/python:3.12-alpine
RUN apk update --no-cache \
&& apk upgrade --no-cache \

748
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,44 +1,46 @@
[tool.poetry]
name = "unwind"
version = "0.1.0"
version = "0"
description = ""
authors = ["ducklet <ducklet@noreply.code.dumpr.org>"]
license = "LOL"
[tool.poetry.dependencies]
python = "^3.11"
python = "^3.12"
beautifulsoup4 = "^4.9.3"
html5lib = "^1.1"
starlette = "^0.26"
starlette = "^0.30"
ulid-py = "^1.1.0"
databases = {extras = ["sqlite"], version = "^0.7.0"}
uvicorn = "^0.21"
httpx = "^0.23.3"
uvicorn = "^0.23"
httpx = "^0.24"
sqlalchemy = {version = "^2.0", extras = ["aiosqlite"]}
[tool.poetry.group.build.dependencies]
# When we run poetry export, typing-extensions is a transient dependency via
# sqlalchemy, but the hash won't be included in the requirements.txt.
# By making it a direct dependency we can fix this issue, otherwise this could
# be removed.
typing-extensions = "*"
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
autoflake = "*"
pytest = "*"
pyright = "*"
black = "*"
isort = "*"
pytest-asyncio = "*"
pytest-cov = "*"
ruff = "*"
honcho = "*"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pyright]
pythonVersion = "3.11"
pythonVersion = "3.12"
[tool.isort]
profile = "black"
[tool.autoflake]
remove-duplicate-keys = true
remove-unused-variables = true
remove-all-unused-imports = true
[tool.ruff]
target-version = "py312"
ignore-init-module-imports = true
ignore-pass-after-docstring = true
select = ["I", "F401", "F601", "F602", "F841"]

View file

@ -2,6 +2,12 @@
cd "$RUN_DIR"
# Make Uvicorn defaults explicit.
: "${API_PORT:=8000}"
: "${API_HOST:=127.0.0.1}"
export API_PORT
export API_HOST
[ -z "${DEBUG:-}" ] || set -x
exec honcho start

View file

@ -4,4 +4,9 @@ cd "$RUN_DIR"
[ -z "${DEBUG:-}" ] || set -x
exec uvicorn unwind:create_app --factory --reload
exec uvicorn \
--host "$API_HOST" \
--port "$API_PORT" \
--reload \
--factory \
unwind:create_app

25
scripts/docker-build Executable file
View file

@ -0,0 +1,25 @@
#!/bin/sh -eu
: "${DOCKER_BIN:=docker}"
cd "$RUN_DIR"
builddir=build
[ -z "${DEBUG:-}" ] || set -x
mkdir -p "$builddir"
poetry export \
--with=build \
--output="$builddir"/requirements.txt
githash=$(git rev-parse --short HEAD)
today=$(date -u '+%Y.%m.%d')
version="${today}+${githash}"
echo "$version" >"$builddir"/version
$DOCKER_BIN build \
--pull \
--tag "code.dumpr.org/ducklet/unwind":"$version" \
.

18
scripts/docker-run Executable file
View file

@ -0,0 +1,18 @@
#!/bin/sh -eu
: "${DOCKER_BIN:=docker}"
cd "$RUN_DIR"
[ -z "${DEBUG:-}" ] || set -x
version=$(cat build/version)
$DOCKER_BIN run \
--init \
-it --rm \
--read-only \
--memory '500m' \
--publish 127.0.0.1:8000:8000 \
--volume "$RUN_DIR"/data:/data \
"code.dumpr.org/ducklet/unwind":"$version"

View file

@ -4,7 +4,7 @@ cd "$RUN_DIR"
[ -z "${DEBUG:-}" ] || set -x
autoflake --quiet --check --recursive unwind tests
isort unwind tests
black unwind tests
ruff check --fix . ||:
ruff format .
pyright

View file

@ -11,4 +11,5 @@ export UNWIND_PORT
exec uvicorn \
--host 0.0.0.0 \
--port "$UNWIND_PORT" \
--factory unwind:create_app
--factory \
unwind:create_app

View file

@ -6,10 +6,10 @@ dbfile="${UNWIND_DATA:-./data}/tests.sqlite"
# Rollback in Databases is currently broken, so we have to rebuild the database
# each time; see https://github.com/encode/databases/issues/403
trap 'rm "$dbfile"' EXIT TERM INT QUIT
trap 'rm "$dbfile" "${dbfile}-shm" "${dbfile}-wal"' EXIT TERM INT QUIT
[ -z "${DEBUG:-}" ] || set -x
SQLALCHEMY_WARN_20=1 \
export SQLALCHEMY_WARN_20=1 # XXX remove when we switched to SQLAlchemy 2.0
UNWIND_STORAGE="$dbfile" \
python -m pytest "$@"
python -m pytest --cov "$@"

View file

@ -17,16 +17,19 @@ def event_loop():
@pytest_asyncio.fixture(scope="session")
async def shared_conn():
c = db.shared_connection()
await c.connect()
"""A database connection, ready to use."""
await db.open_connection_pool()
await db.apply_db_patches(c)
yield c
async with db.new_connection() as c:
db._test_connection = c
yield c
db._test_connection = None
await c.disconnect()
await db.close_connection_pool()
@pytest_asyncio.fixture
async def conn(shared_conn):
async with shared_conn.transaction(force_rollback=True):
async def conn(shared_conn: db.Connection):
"""A transacted database connection, will be rolled back after use."""
async with db.transacted(shared_conn, force_rollback=True):
yield shared_conn

View file

@ -4,155 +4,416 @@ import pytest
from unwind import db, models, web_models
_movie_imdb_id = 1230000
@pytest.mark.asyncio
async def test_add_and_get(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True):
m1 = models.Movie(
title="test movie",
release_year=2013,
media_type="Movie",
imdb_id="tt0000000",
genres={"genre-1"},
)
await db.add(m1)
m2 = models.Movie(
title="test movie",
release_year=2013,
media_type="Movie",
imdb_id="tt0000001",
genres={"genre-1"},
)
await db.add(m2)
assert m1 == await db.get(models.Movie, id=str(m1.id))
assert m2 == await db.get(models.Movie, id=str(m2.id))
def a_movie(**kwds) -> models.Movie:
global _movie_imdb_id
_movie_imdb_id += 1
args = {
"title": "test movie",
"release_year": 2013,
"media_type": "Movie",
"imdb_id": f"tt{_movie_imdb_id}",
"genres": {"genre-1"},
} | kwds
return models.Movie(**args)
@pytest.mark.asyncio
async def test_find_ratings(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True):
m1 = models.Movie(
title="test movie",
release_year=2013,
media_type="Movie",
imdb_id="tt0000000",
genres={"genre-1"},
async def test_current_patch_level(conn: db.Connection):
patch_level = "some-patch-level"
assert patch_level != await db.current_patch_level(conn)
await db.set_current_patch_level(conn, patch_level)
assert patch_level == await db.current_patch_level(conn)
@pytest.mark.asyncio
async def test_get(conn: db.Connection):
m1 = a_movie()
await db.add(conn, m1)
m2 = a_movie(release_year=m1.release_year + 1)
await db.add(conn, m2)
assert None is await db.get(conn, models.Movie)
assert None is await db.get(conn, models.Movie, id="blerp")
assert m1 == await db.get(conn, models.Movie, id=str(m1.id))
assert m2 == await db.get(conn, models.Movie, release_year=m2.release_year)
assert None is await db.get(
conn, models.Movie, id=str(m1.id), release_year=m2.release_year
)
assert m2 == await db.get(
conn, models.Movie, id=str(m2.id), release_year=m2.release_year
)
assert m1 == await db.get(
conn,
models.Movie,
media_type=m1.media_type,
order_by=(models.movies.c.release_year, "asc"),
)
assert m2 == await db.get(
conn,
models.Movie,
media_type=m1.media_type,
order_by=(models.movies.c.release_year, "desc"),
)
@pytest.mark.asyncio
async def test_get_all(conn: db.Connection):
m1 = a_movie()
await db.add(conn, m1)
m2 = a_movie(release_year=m1.release_year)
await db.add(conn, m2)
m3 = a_movie(release_year=m1.release_year + 1)
await db.add(conn, m3)
assert [] == list(await db.get_all(conn, models.Movie, id="blerp"))
assert [m1] == list(await db.get_all(conn, models.Movie, id=str(m1.id)))
assert [m1, m2] == list(
await db.get_all(conn, models.Movie, release_year=m1.release_year)
)
assert [m1, m2, m3] == list(await db.get_all(conn, models.Movie))
@pytest.mark.asyncio
async def test_get_many(conn: db.Connection):
m1 = a_movie()
await db.add(conn, m1)
m2 = a_movie(release_year=m1.release_year)
await db.add(conn, m2)
m3 = a_movie(release_year=m1.release_year + 1)
await db.add(conn, m3)
assert [] == list(await db.get_many(conn, models.Movie)), "selected nothing"
assert [m1] == list(await db.get_many(conn, models.Movie, id=[str(m1.id)]))
assert [m1] == list(await db.get_many(conn, models.Movie, id={str(m1.id)}))
assert [m1, m2] == list(
await db.get_many(conn, models.Movie, release_year=[m1.release_year])
)
assert [m1, m2, m3] == list(
await db.get_many(
conn, models.Movie, release_year=[m1.release_year, m3.release_year]
)
await db.add(m1)
)
m2 = models.Movie(
title="it's anöther Movie, Part 2",
release_year=2015,
media_type="Movie",
imdb_id="tt0000001",
genres={"genre-2"},
@pytest.mark.asyncio
async def test_add_and_get(conn: db.Connection):
m1 = a_movie()
await db.add(conn, m1)
m2 = a_movie()
await db.add(conn, m2)
assert m1 == await db.get(conn, models.Movie, id=str(m1.id))
assert m2 == await db.get(conn, models.Movie, id=str(m2.id))
@pytest.mark.asyncio
async def test_update(conn: db.Connection):
m = a_movie()
await db.add(conn, m)
assert m == await db.get(conn, models.Movie, id=str(m.id))
m.title += "something else"
assert m != await db.get(conn, models.Movie, id=str(m.id))
await db.update(conn, m)
assert m == await db.get(conn, models.Movie, id=str(m.id))
@pytest.mark.asyncio
async def test_remove(conn: db.Connection):
m1 = a_movie()
await db.add(conn, m1)
assert m1 == await db.get(conn, models.Movie, id=str(m1.id))
await db.remove(conn, m1)
assert None is await db.get(conn, models.Movie, id=str(m1.id))
@pytest.mark.asyncio
async def test_find_ratings(conn: db.Connection):
m1 = a_movie(
title="test movie",
release_year=2013,
genres={"genre-1"},
)
await db.add(conn, m1)
m2 = a_movie(
title="it's anöther Movie, Part 2",
release_year=2015,
genres={"genre-2"},
)
await db.add(conn, m2)
m3 = a_movie(
title="movie it's, Part 3",
release_year=m2.release_year,
genres=m2.genres,
)
await db.add(conn, m3)
u1 = models.User(
imdb_id="u00001",
name="User1",
secret="secret1",
)
await db.add(conn, u1)
u2 = models.User(
imdb_id="u00002",
name="User2",
secret="secret2",
)
await db.add(conn, u2)
r1 = models.Rating(
movie_id=m2.id,
movie=m2,
user_id=u1.id,
user=u1,
score=66,
rating_date=datetime.now(),
)
await db.add(conn, r1)
r2 = models.Rating(
movie_id=m2.id,
movie=m2,
user_id=u2.id,
user=u2,
score=77,
rating_date=datetime.now(),
)
await db.add(conn, r2)
# ---
rows = await db.find_ratings(
conn,
title=m1.title,
media_type=m1.media_type,
exact=True,
ignore_tv_episodes=True,
include_unrated=True,
yearcomp=("=", m1.release_year),
limit_rows=3,
user_ids=[],
)
ratings = (web_models.Rating(**r) for r in rows)
assert (web_models.RatingAggregate.from_movie(m1),) == tuple(
web_models.aggregate_ratings(ratings, user_ids=[])
)
rows = await db.find_ratings(conn, title="movie", include_unrated=False)
ratings = tuple(web_models.Rating(**r) for r in rows)
assert (
web_models.Rating.from_movie(m2, rating=r1),
web_models.Rating.from_movie(m2, rating=r2),
) == ratings
rows = await db.find_ratings(conn, title="movie", include_unrated=True)
ratings = tuple(web_models.Rating(**r) for r in rows)
assert (
web_models.Rating.from_movie(m1),
web_models.Rating.from_movie(m2, rating=r1),
web_models.Rating.from_movie(m2, rating=r2),
web_models.Rating.from_movie(m3),
) == ratings
aggr = web_models.aggregate_ratings(ratings, user_ids=[])
assert tuple(
web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3]
) == tuple(aggr)
aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id)])
assert (
web_models.RatingAggregate.from_movie(m1),
web_models.RatingAggregate.from_movie(m2, ratings=[r1]),
web_models.RatingAggregate.from_movie(m3),
) == tuple(aggr)
aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id), str(u2.id)])
assert (
web_models.RatingAggregate.from_movie(m1),
web_models.RatingAggregate.from_movie(m2, ratings=[r1, r2]),
web_models.RatingAggregate.from_movie(m3),
) == tuple(aggr)
rows = await db.find_ratings(conn, title="movie", include_unrated=True)
ratings = (web_models.Rating(**r) for r in rows)
aggr = web_models.aggregate_ratings(ratings, user_ids=[])
assert tuple(
web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3]
) == tuple(aggr)
rows = await db.find_ratings(conn, title="test", include_unrated=True)
ratings = tuple(web_models.Rating(**r) for r in rows)
assert (web_models.Rating.from_movie(m1),) == ratings
@pytest.mark.asyncio
async def test_ratings_for_movies(conn: db.Connection):
m1 = a_movie()
await db.add(conn, m1)
m2 = a_movie()
await db.add(conn, m2)
u1 = models.User(
imdb_id="u00001",
name="User1",
secret="secret1",
)
await db.add(conn, u1)
u2 = models.User(
imdb_id="u00002",
name="User2",
secret="secret2",
)
await db.add(conn, u2)
r1 = models.Rating(
movie_id=m2.id,
movie=m2,
user_id=u1.id,
user=u1,
score=66,
rating_date=datetime.now(),
)
await db.add(conn, r1)
# ---
movie_ids = [m1.id]
user_ids = []
assert tuple() == tuple(
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
)
movie_ids = [m2.id]
user_ids = []
assert (r1,) == tuple(
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
)
movie_ids = [m2.id]
user_ids = [u2.id]
assert tuple() == tuple(
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
)
movie_ids = [m2.id]
user_ids = [u1.id]
assert (r1,) == tuple(
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
)
movie_ids = [m1.id, m2.id]
user_ids = [u1.id, u2.id]
assert (r1,) == tuple(
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
)
@pytest.mark.asyncio
async def test_find_movies(conn: db.Connection):
m1 = a_movie(title="movie one")
await db.add(conn, m1)
m2 = a_movie(title="movie two", imdb_score=33, release_year=m1.release_year + 1)
await db.add(conn, m2)
u1 = models.User(
imdb_id="u00001",
name="User1",
secret="secret1",
)
await db.add(conn, u1)
u2 = models.User(
imdb_id="u00002",
name="User2",
secret="secret2",
)
await db.add(conn, u2)
r1 = models.Rating(
movie_id=m2.id,
movie=m2,
user_id=u1.id,
user=u1,
score=66,
rating_date=datetime.now(),
)
await db.add(conn, r1)
# ---
assert () == tuple(
await db.find_movies(conn, title=m1.title, include_unrated=False)
)
assert ((m1, []),) == tuple(
await db.find_movies(conn, title=m1.title, include_unrated=True)
)
assert ((m1, []),) == tuple(
await db.find_movies(conn, title="mo on", exact=False, include_unrated=True)
)
assert ((m1, []),) == tuple(
await db.find_movies(conn, title="movie one", exact=True, include_unrated=True)
)
assert () == tuple(
await db.find_movies(conn, title="mo on", exact=True, include_unrated=True)
)
assert ((m2, []),) == tuple(
await db.find_movies(conn, title="movie", exact=False, include_unrated=False)
)
assert ((m2, []), (m1, [])) == tuple(
await db.find_movies(conn, title="movie", exact=False, include_unrated=True)
)
assert ((m1, []),) == tuple(
await db.find_movies(
conn, include_unrated=True, yearcomp=("=", m1.release_year)
)
await db.add(m2)
m3 = models.Movie(
title="movie it's, Part 3",
release_year=2015,
media_type="Movie",
imdb_id="tt0000002",
genres={"genre-2"},
)
assert ((m2, []),) == tuple(
await db.find_movies(
conn, include_unrated=True, yearcomp=("=", m2.release_year)
)
await db.add(m3)
u1 = models.User(
imdb_id="u00001",
name="User1",
secret="secret1",
)
assert ((m1, []),) == tuple(
await db.find_movies(
conn, include_unrated=True, yearcomp=("<", m2.release_year)
)
await db.add(u1)
u2 = models.User(
imdb_id="u00002",
name="User2",
secret="secret2",
)
assert ((m2, []),) == tuple(
await db.find_movies(
conn, include_unrated=True, yearcomp=(">", m1.release_year)
)
await db.add(u2)
)
r1 = models.Rating(
movie_id=m2.id,
movie=m2,
user_id=u1.id,
user=u1,
score=66,
rating_date=datetime.now(),
)
await db.add(r1)
assert ((m2, []), (m1, [])) == tuple(
await db.find_movies(conn, include_unrated=True)
)
assert ((m2, []),) == tuple(
await db.find_movies(conn, include_unrated=True, limit_rows=1)
)
assert ((m1, []),) == tuple(
await db.find_movies(conn, include_unrated=True, skip_rows=1)
)
r2 = models.Rating(
movie_id=m2.id,
movie=m2,
user_id=u2.id,
user=u2,
score=77,
rating_date=datetime.now(),
)
await db.add(r2)
# ---
rows = await db.find_ratings(
title=m1.title,
media_type=m1.media_type,
exact=True,
ignore_tv_episodes=True,
include_unrated=True,
yearcomp=("=", m1.release_year),
limit_rows=3,
user_ids=[],
)
ratings = (web_models.Rating(**r) for r in rows)
assert (web_models.RatingAggregate.from_movie(m1),) == tuple(
web_models.aggregate_ratings(ratings, user_ids=[])
)
rows = await db.find_ratings(title="movie", include_unrated=False)
ratings = tuple(web_models.Rating(**r) for r in rows)
assert (
web_models.Rating.from_movie(m2, rating=r1),
web_models.Rating.from_movie(m2, rating=r2),
) == ratings
rows = await db.find_ratings(title="movie", include_unrated=True)
ratings = tuple(web_models.Rating(**r) for r in rows)
assert (
web_models.Rating.from_movie(m1),
web_models.Rating.from_movie(m2, rating=r1),
web_models.Rating.from_movie(m2, rating=r2),
web_models.Rating.from_movie(m3),
) == ratings
aggr = web_models.aggregate_ratings(ratings, user_ids=[])
assert tuple(
web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3]
) == tuple(aggr)
aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id)])
assert (
web_models.RatingAggregate.from_movie(m1),
web_models.RatingAggregate.from_movie(m2, ratings=[r1]),
web_models.RatingAggregate.from_movie(m3),
) == tuple(aggr)
aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id), str(u2.id)])
assert (
web_models.RatingAggregate.from_movie(m1),
web_models.RatingAggregate.from_movie(m2, ratings=[r1, r2]),
web_models.RatingAggregate.from_movie(m3),
) == tuple(aggr)
rows = await db.find_ratings(title="movie", include_unrated=True)
ratings = (web_models.Rating(**r) for r in rows)
aggr = web_models.aggregate_ratings(ratings, user_ids=[])
assert tuple(
web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3]
) == tuple(aggr)
rows = await db.find_ratings(title="test", include_unrated=True)
ratings = tuple(web_models.Rating(**r) for r in rows)
assert (web_models.Rating.from_movie(m1),) == ratings
assert ((m2, [r1]), (m1, [])) == tuple(
await db.find_movies(conn, include_unrated=True, user_ids=[u1.id, u2.id])
)

11
tests/test_models.py Normal file
View file

@ -0,0 +1,11 @@
import pytest
from unwind import models
@pytest.mark.parametrize("mapper", models.mapper_registry.mappers)
def test_fields(mapper):
"""Test that models.fields() matches exactly all table columns."""
dcfields = {f.name for f in models.fields(mapper.class_)}
mfields = {c.name for c in mapper.columns}
assert dcfields == mfields

View file

@ -1,53 +1,243 @@
from datetime import datetime
import pytest
from starlette.testclient import TestClient
from unwind import create_app, db, imdb, models
from unwind import config, create_app, db, imdb, models
app = create_app()
@pytest.fixture(scope="module")
def unauthorized_client() -> TestClient:
# https://www.starlette.io/testclient/
return TestClient(app)
@pytest.fixture(scope="module")
def authorized_client() -> TestClient:
client = TestClient(app)
client.auth = "user1", "secret1"
return client
@pytest.fixture(scope="module")
def admin_client() -> TestClient:
client = TestClient(app)
for token in config.api_credentials.values():
break
else:
raise RuntimeError("No bearer tokens configured.")
client.headers = {"Authorization": f"Bearer {token}"}
return client
@pytest.mark.asyncio
async def test_app(shared_conn: db.Database):
async with shared_conn.transaction(force_rollback=True):
# https://www.starlette.io/testclient/
client = TestClient(app)
response = client.get("/api/v1/movies")
assert response.status_code == 403
async def test_get_ratings_for_group(
conn: db.Connection, unauthorized_client: TestClient
):
user = models.User(
imdb_id="ur12345678",
name="user-1",
secret="secret-1",
groups=[],
)
group = models.Group(
name="group-1",
users=[models.GroupUser(id=str(user.id), name=user.name)],
)
user.groups = [models.UserGroup(id=str(group.id), access="r")]
path = app.url_path_for("get_ratings_for_group", group_id=str(group.id))
client.auth = "user1", "secret1"
resp = unauthorized_client.get(path)
assert resp.status_code == 404, "Group does not exist (yet)"
response = client.get("/api/v1/movies")
assert response.status_code == 200
assert response.json() == []
await db.add(conn, user)
await db.add(conn, group)
m = models.Movie(
title="test movie",
release_year=2013,
media_type="Movie",
imdb_id="tt12345678",
genres={"genre-1"},
)
await db.add(m)
resp = unauthorized_client.get(path)
assert resp.status_code == 200
assert resp.json() == []
response = client.get("/api/v1/movies", params={"include_unrated": 1})
assert response.status_code == 200
assert response.json() == [{**models.asplain(m), "user_scores": []}]
movie = models.Movie(
title="test movie",
release_year=2013,
media_type="Movie",
imdb_id="tt12345678",
genres={"genre-1"},
)
await db.add(conn, movie)
m_plain = {
"canonical_title": m.title,
"imdb_score": m.imdb_score,
"imdb_votes": m.imdb_votes,
"link": imdb.movie_url(m.imdb_id),
"media_type": m.media_type,
"original_title": m.original_title,
"user_scores": [],
"year": m.release_year,
}
rating = models.Rating(
movie_id=movie.id, user_id=user.id, score=66, rating_date=datetime.now()
)
await db.add(conn, rating)
response = client.get("/api/v1/movies", params={"imdb_id": m.imdb_id})
assert response.status_code == 200
assert response.json() == [m_plain]
rating_aggregate = {
"canonical_title": movie.title,
"imdb_score": movie.imdb_score,
"imdb_votes": movie.imdb_votes,
"link": imdb.movie_url(movie.imdb_id),
"media_type": movie.media_type,
"original_title": movie.original_title,
"user_scores": [rating.score],
"year": movie.release_year,
}
response = client.get("/api/v1/movies", params={"unwind_id": str(m.id)})
assert response.status_code == 200
assert response.json() == [m_plain]
resp = unauthorized_client.get(path)
assert resp.status_code == 200
assert resp.json() == [rating_aggregate]
filters = {
"imdb_id": movie.imdb_id,
"unwind_id": str(movie.id),
"title": movie.title,
"media_type": movie.media_type,
"year": movie.release_year,
}
for k, v in filters.items():
resp = unauthorized_client.get(path, params={k: v})
assert resp.status_code == 200
assert resp.json() == [rating_aggregate]
resp = unauthorized_client.get(path, params={"title": "no such thing"})
assert resp.status_code == 200
assert resp.json() == []
# Test "exact" query param.
resp = unauthorized_client.get(
path, params={"title": "test movie", "exact": "true"}
)
assert resp.status_code == 200
assert resp.json() == [rating_aggregate]
resp = unauthorized_client.get(path, params={"title": "te mo", "exact": "false"})
assert resp.status_code == 200
assert resp.json() == [rating_aggregate]
resp = unauthorized_client.get(path, params={"title": "te mo", "exact": "true"})
assert resp.status_code == 200
assert resp.json() == []
# XXX Test "ignore_tv_episodes" query param.
# XXX Test "include_unrated" query param.
# XXX Test "per_page" query param.
@pytest.mark.asyncio
async def test_list_movies(
conn: db.Connection,
unauthorized_client: TestClient,
authorized_client: TestClient,
):
path = app.url_path_for("list_movies")
response = unauthorized_client.get(path)
assert response.status_code == 403
response = authorized_client.get(path)
assert response.status_code == 200
assert response.json() == []
m = models.Movie(
title="test movie",
release_year=2013,
media_type="Movie",
imdb_id="tt12345678",
genres={"genre-1"},
)
await db.add(conn, m)
response = authorized_client.get(path, params={"include_unrated": 1})
assert response.status_code == 200
assert response.json() == [{**models.asplain(m), "user_scores": []}]
m_plain = {
"canonical_title": m.title,
"imdb_score": m.imdb_score,
"imdb_votes": m.imdb_votes,
"link": imdb.movie_url(m.imdb_id),
"media_type": m.media_type,
"original_title": m.original_title,
"user_scores": [],
"year": m.release_year,
}
response = authorized_client.get(path, params={"imdb_id": m.imdb_id})
assert response.status_code == 200
assert response.json() == [m_plain]
response = authorized_client.get(path, params={"unwind_id": str(m.id)})
assert response.status_code == 200
assert response.json() == [m_plain]
@pytest.mark.asyncio
async def test_list_users(
conn: db.Connection,
unauthorized_client: TestClient,
authorized_client: TestClient,
admin_client: TestClient,
):
path = app.url_path_for("list_users")
response = unauthorized_client.get(path)
assert response.status_code == 403
response = authorized_client.get(path)
assert response.status_code == 403
response = admin_client.get(path)
assert response.status_code == 200
assert response.json() == []
m = models.User(
imdb_id="ur12345678",
name="user-1",
secret="secret-1",
groups=[],
)
await db.add(conn, m)
m_plain = {
"groups": m.groups,
"id": m.id,
"imdb_id": m.imdb_id,
"name": m.name,
"secret": m.secret,
}
response = admin_client.get(path)
assert response.status_code == 200
assert response.json() == [m_plain]
@pytest.mark.asyncio
async def test_list_groups(
conn: db.Connection,
unauthorized_client: TestClient,
authorized_client: TestClient,
admin_client: TestClient,
):
path = app.url_path_for("list_groups")
response = unauthorized_client.get(path)
assert response.status_code == 403
response = authorized_client.get(path)
assert response.status_code == 403
response = admin_client.get(path)
assert response.status_code == 200
assert response.json() == []
m = models.Group(
name="group-1",
users=[models.GroupUser(id="123", name="itsa-me")],
)
await db.add(conn, m)
m_plain = {
"users": m.users,
"id": m.id,
"name": m.name,
}
response = admin_client.get(path)
assert response.status_code == 200
assert response.json() == [m_plain]

File diff suppressed because it is too large Load diff

View file

@ -1,13 +1,29 @@
import { defineConfig } from "vite"
import vue from "@vitejs/plugin-vue"
// Vite defaults.
const vite_host = "localhost"
const vite_port = 3000
const base = process.env.BASE_URL || "/"
const proxied_api_url = `http://${vite_host}:${vite_port}/api/`
const real_api_url = `http://${process.env.API_HOST}:${process.env.API_PORT}/api/`
// https://vitejs.dev/config/
export default defineConfig({
base: process.env.BASE_URL || "/",
base,
define: {
"process.env.API_URL": JSON.stringify(
process.env.API_URL || "http://localhost:8000/api/",
),
"process.env.API_URL": JSON.stringify(process.env.API_URL || proxied_api_url),
},
server: {
host: vite_host,
port: vite_port,
proxy: {
[`${base}api`]: {
target: real_api_url,
prependPath: false,
},
},
},
plugins: [vue()],
})

View file

@ -2,20 +2,20 @@ import os
import tomllib
from pathlib import Path
datadir = Path(os.getenv("UNWIND_DATA") or "./data")
cachedir = (
Path(cachedir)
if (cachedir := os.getenv("UNWIND_CACHEDIR", datadir / ".cache"))
else None
datadir: Path = Path(os.getenv("UNWIND_DATA") or "./data")
cachedir: Path = Path(p) if (p := os.getenv("UNWIND_CACHEDIR")) else datadir / ".cache"
debug: bool = os.getenv("DEBUG") == "1"
loglevel: str = os.getenv("UNWIND_LOGLEVEL") or ("DEBUG" if debug else "INFO")
storage_path: Path = (
Path(p) if (p := os.getenv("UNWIND_STORAGE")) else datadir / "db.sqlite"
)
config_path: Path = (
Path(p) if (p := os.getenv("UNWIND_CONFIG")) else datadir / "config.toml"
)
debug = os.getenv("DEBUG") == "1"
loglevel = os.getenv("UNWIND_LOGLEVEL") or ("DEBUG" if debug else "INFO")
storage_path = os.getenv("UNWIND_STORAGE", datadir / "db.sqlite")
config_path = os.getenv("UNWIND_CONFIG", datadir / "config.toml")
with open(config_path, "rb") as fd:
_config = tomllib.load(fd)
api_base = _config["api"].get("base", "/api/")
api_cors = _config["api"].get("cors", "*")
api_credentials = _config["api"].get("credentials", {})
api_base: str = _config["api"].get("base", "/api/")
api_cors: str = _config["api"].get("cors", "*")
api_credentials: dict[str, str] = _config["api"].get("credentials", {})

View file

@ -1,24 +1,27 @@
import asyncio
import contextlib
import logging
import re
import threading
from pathlib import Path
from typing import Any, Iterable, Literal, Type, TypeVar
from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type, TypeVar
import sqlalchemy
from databases import Database
import sqlalchemy as sa
from sqlalchemy.dialects.sqlite import insert
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
from . import config
from .models import (
Model,
Movie,
Progress,
Rating,
User,
asplain,
fields,
db_patches,
fromplain,
metadata,
movies,
optional_fields,
progress,
ratings,
utcnow,
)
from .types import ULID
@ -26,7 +29,9 @@ from .types import ULID
log = logging.getLogger(__name__)
T = TypeVar("T")
_shared_connection: Database | None = None
_engine: AsyncEngine | None = None
type Connection = AsyncConnection
async def open_connection_pool() -> None:
@ -34,10 +39,13 @@ async def open_connection_pool() -> None:
This function needs to be called before any access to the database can happen.
"""
db = shared_connection()
await db.connect()
async with transaction() as conn:
await conn.execute(sa.text("PRAGMA journal_mode=WAL"))
await apply_db_patches(db)
await conn.run_sync(metadata.create_all, tables=[db_patches])
async with new_connection() as conn:
await apply_db_patches(conn)
async def close_connection_pool() -> None:
@ -46,48 +54,33 @@ async def close_connection_pool() -> None:
This function should be called before the app shuts down to ensure all data
has been flushed to the database.
"""
db = shared_connection()
engine = _shared_engine()
# Run automatic ANALYZE prior to closing the db,
# see https://sqlite.com/lang_analyze.html.
await db.execute("PRAGMA analysis_limit=400")
await db.execute("PRAGMA optimize")
async with engine.begin() as conn:
# Run automatic ANALYZE prior to closing the db,
# see https://sqlite.com/lang_analyze.html.
await conn.execute(sa.text("PRAGMA analysis_limit=400"))
await conn.execute(sa.text("PRAGMA optimize"))
await db.disconnect()
await engine.dispose()
async def _create_patch_db(db):
query = """
CREATE TABLE IF NOT EXISTS db_patches (
id INTEGER PRIMARY KEY,
current TEXT
)
"""
await db.execute(query)
async def current_patch_level(db) -> str:
await _create_patch_db(db)
query = "SELECT current FROM db_patches"
current = await db.fetch_val(query)
async def current_patch_level(conn: Connection, /) -> str:
query = sa.select(db_patches.c.current)
current = await conn.scalar(query)
return current or ""
async def set_current_patch_level(db, current: str):
await _create_patch_db(db)
query = """
INSERT INTO db_patches VALUES (1, :current)
ON CONFLICT DO UPDATE SET current=excluded.current
"""
await db.execute(query, values={"current": current})
async def set_current_patch_level(conn: Connection, /, current: str) -> None:
stmt = insert(db_patches).values(id=1, current=current)
stmt = stmt.on_conflict_do_update(set_={"current": stmt.excluded.current})
await conn.execute(stmt)
db_patches_dir = Path(__file__).parent / "sql"
async def apply_db_patches(db: Database):
async def apply_db_patches(conn: Connection, /) -> None:
"""Apply all remaining patches to the database.
Beware that patches will be applied in lexicographical order,
@ -99,7 +92,7 @@ async def apply_db_patches(db: Database):
using two consecutive semi-colons (;).
Failing to do so will result in an error.
"""
applied_lvl = await current_patch_level(db)
applied_lvl = await current_patch_level(conn)
did_patch = False
@ -118,29 +111,52 @@ async def apply_db_patches(db: Database):
)
raise RuntimeError("No statement found.")
async with db.transaction():
async with transacted(conn):
for query in queries:
await db.execute(query)
await conn.execute(sa.text(query))
await set_current_patch_level(db, patch_lvl)
await set_current_patch_level(conn, patch_lvl)
did_patch = True
if did_patch:
await db.execute("vacuum")
await _vacuum(conn)
async def get_import_progress() -> Progress | None:
async def _vacuum(conn: Connection, /) -> None:
"""Vacuum the database.
This function cannot be run on a connection with an open transaction.
"""
# With SQLAlchemy's "autobegin" behavior we need to switch the connection
# to "autocommit" first to keep it from automatically starting a transaction,
# as VACUUM cannot be run inside a transaction for most databases.
await conn.commit()
isolation_level = await conn.get_isolation_level()
log.debug("Previous isolation_level: %a", isolation_level)
await conn.execution_options(isolation_level="AUTOCOMMIT")
try:
await conn.execute(sa.text("vacuum"))
await conn.commit()
finally:
await conn.execution_options(isolation_level=isolation_level)
async def get_import_progress(conn: Connection, /) -> Progress | None:
"""Return the latest import progress."""
return await get(Progress, type="import-imdb-movies", order_by="started DESC")
return await get(
conn, Progress, type="import-imdb-movies", order_by=(progress.c.started, "desc")
)
async def stop_import_progress(*, error: BaseException | None = None):
async def stop_import_progress(
conn: Connection, /, *, error: BaseException | None = None
) -> None:
"""Stop the current import.
If an error is given, it will be logged to the progress state.
"""
current = await get_import_progress()
current = await get_import_progress(conn)
is_running = current and current.stopped is None
if not is_running:
@ -151,17 +167,17 @@ async def stop_import_progress(*, error: BaseException | None = None):
current.error = repr(error)
current.stopped = utcnow().isoformat()
await update(current)
await update(conn, current)
async def set_import_progress(progress: float) -> Progress:
async def set_import_progress(conn: Connection, /, progress: float) -> Progress:
"""Set the current import progress percentage.
If no import is currently running, this will create a new one.
"""
progress = min(max(0.0, progress), 100.0) # clamp to 0 <= progress <= 100
current = await get_import_progress()
current = await get_import_progress(conn)
is_running = current and current.stopped is None
if not is_running:
@ -171,163 +187,211 @@ async def set_import_progress(progress: float) -> Progress:
current.percent = progress
if is_running:
await update(current)
await update(conn, current)
else:
await add(current)
await add(conn, current)
return current
_lock = threading.Lock()
_prelock = threading.Lock()
def _new_engine() -> AsyncEngine:
uri = f"sqlite+aiosqlite:///{config.storage_path}"
return create_async_engine(
uri,
isolation_level="SERIALIZABLE",
)
def _shared_engine() -> AsyncEngine:
global _engine
if _engine is None:
_engine = _new_engine()
return _engine
def _new_connection() -> Connection:
return _shared_engine().connect()
@contextlib.asynccontextmanager
async def single_threaded():
"""Ensure the nested code is run only by a single thread at a time."""
wait = 1e-5 # XXX not sure if there's a better magic value here
async def transaction(
*, force_rollback: bool = False
) -> AsyncGenerator[Connection, None]:
async with new_connection() as conn:
yield conn
# The pre-lock (a lock for the lock) allows for multiple threads to hand of
# the main lock.
# With only a single lock the contending thread will spend most of its time
# in the asyncio.sleep and the reigning thread will have time to finish
# whatever it's doing and simply acquire the lock again before the other
# thread has had a change to try.
# By having another lock (and the same sleep time!) the contending thread
# will always have a chance to acquire the main lock.
while not _prelock.acquire(blocking=False):
await asyncio.sleep(wait)
if not force_rollback:
await conn.commit()
try:
while not _lock.acquire(blocking=False):
await asyncio.sleep(wait)
finally:
_prelock.release()
try:
yield
finally:
_lock.release()
# The _test_connection allows pinning a connection that will be shared across the app.
# This can (and should only) be used when running tests, NOT IN PRODUCTION!
_test_connection: Connection | None = None
@contextlib.asynccontextmanager
async def locked_connection():
async with single_threaded():
yield shared_connection()
async def new_connection() -> AsyncGenerator[Connection, None]:
"""Return a new connection.
Any changes will be rolled back, unless `.commit()` is called on the
connection.
If you want to commit changes, consider using `transaction()` instead.
"""
conn = _test_connection or _new_connection()
# Support reusing the same connection for _test_connection.
is_started = conn.sync_connection is not None
if is_started:
yield conn
return
async with conn:
yield conn
def shared_connection() -> Database:
global _shared_connection
@contextlib.asynccontextmanager
async def transacted(
conn: Connection, /, *, force_rollback: bool = False
) -> AsyncGenerator[None, None]:
transaction = contextlib.nullcontext() if conn.in_transaction() else conn.begin()
if _shared_connection is None:
uri = f"sqlite:///{config.storage_path}"
_shared_connection = Database(uri)
async with transaction:
try:
yield
return _shared_connection
finally:
if force_rollback:
await conn.rollback()
async def add(item):
async def add(conn: Connection, /, item: Model) -> None:
# Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False):
item._lazy_init()
assert hasattr(item, "_lazy_init")
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
table: sa.Table = item.__table__
values = asplain(item, serialize=True)
keys = ", ".join(f"{k}" for k in values)
placeholders = ", ".join(f":{k}" for k in values)
query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})"
async with locked_connection() as conn:
await conn.execute(query=query, values=values)
stmt = table.insert().values(values)
await conn.execute(stmt)
ModelType = TypeVar("ModelType")
async def fetch_all(
conn: Connection, /, query: sa.Executable, values: "dict | None" = None
) -> Sequence[sa.Row]:
result = await conn.execute(query, values)
return result.all()
async def fetch_one(
conn: Connection, /, query: sa.Executable, values: "dict | None" = None
) -> sa.Row | None:
result = await conn.execute(query, values)
return result.first()
ModelType = TypeVar("ModelType", bound=Model)
async def get(
model: Type[ModelType], *, order_by: str | None = None, **kwds
conn: Connection,
/,
model: Type[ModelType],
*,
order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None,
**field_values,
) -> ModelType | None:
"""Load a model instance from the database.
Passing `kwds` allows to filter the instance to load. You have to encode the
Passing `field_values` allows to filter the item to load. You have to encode the
values as the appropriate data type for the database prior to passing them
to this function.
"""
values = {k: v for k, v in kwds.items() if v is not None}
if not values:
if not field_values:
return
fields_ = ", ".join(f.name for f in fields(model))
cond = " AND ".join(f"{k}=:{k}" for k in values)
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
table: sa.Table = model.__table__
query = sa.select(model).where(
*(table.c[k] == v for k, v in field_values.items() if v is not None)
)
if order_by:
query += f" ORDER BY {order_by}"
async with locked_connection() as conn:
row = await conn.fetch_one(query=query, values=values)
order_col, order_dir = order_by
query = query.order_by(
order_col.asc() if order_dir == "asc" else order_col.desc()
)
row = await fetch_one(conn, query)
return fromplain(model, row._mapping, serialized=True) if row else None
async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
keys = {
k: [f"{k}_{i}" for i, _ in enumerate(vs, start=1)] for k, vs in kwds.items()
}
async def get_many(
conn: Connection, /, model: Type[ModelType], **field_sets: set | list
) -> Iterable[ModelType]:
"""Return the items with any values matching all given field sets.
if not keys:
This is similar to `get_all`, but instead of a scalar value a list of values
must be given. If any of the given values is set for that field on an item,
the item is considered a match.
If no field values are given, no items will be returned.
"""
if not field_sets:
return []
values = {n: v for k, vs in kwds.items() for n, v in zip(keys[k], vs)}
table: sa.Table = model.__table__
query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items()))
rows = await fetch_all(conn, query)
return (fromplain(model, row._mapping, serialized=True) for row in rows)
fields_ = ", ".join(f.name for f in fields(model))
cond = " AND ".join(
f"{k} IN ({','.join(':'+n for n in ns)})" for k, ns in keys.items()
async def get_all(
conn: Connection, /, model: Type[ModelType], **field_values
) -> Iterable[ModelType]:
"""Filter all items by comparing all given field values.
If no filters are given, all items will be returned.
"""
table: sa.Table = model.__table__
query = sa.select(model).where(
*(table.c[k] == v for k, v in field_values.items() if v is not None)
)
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
async with locked_connection() as conn:
rows = await conn.fetch_all(query=query, values=values)
rows = await fetch_all(conn, query)
return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
values = {k: v for k, v in kwds.items() if v is not None}
fields_ = ", ".join(f.name for f in fields(model))
cond = " AND ".join(f"{k}=:{k}" for k in values) or "1=1"
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
async with locked_connection() as conn:
rows = await conn.fetch_all(query=query, values=values)
return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def update(item):
async def update(conn: Connection, /, item: Model) -> None:
# Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False):
item._lazy_init()
assert hasattr(item, "_lazy_init")
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
table: sa.Table = item.__table__
values = asplain(item, serialize=True)
keys = ", ".join(f"{k}=:{k}" for k in values if k != "id")
query = f"UPDATE {item._table} SET {keys} WHERE id=:id"
async with locked_connection() as conn:
await conn.execute(query=query, values=values)
stmt = table.update().where(table.c.id == values["id"]).values(values)
await conn.execute(stmt)
async def remove(item):
async def remove(conn: Connection, /, item: Model) -> None:
table: sa.Table = item.__table__
values = asplain(item, filter_fields={"id"}, serialize=True)
query = f"DELETE FROM {item._table} WHERE id=:id"
async with locked_connection() as conn:
await conn.execute(query=query, values=values)
stmt = table.delete().where(table.c.id == values["id"])
await conn.execute(stmt)
async def add_or_update_user(user: User):
db_user = await get(User, imdb_id=user.imdb_id)
async def add_or_update_user(conn: Connection, /, user: User) -> None:
db_user = await get(conn, User, imdb_id=user.imdb_id)
if not db_user:
await add(user)
await add(conn, user)
else:
user.id = db_user.id
if user != db_user:
await update(user)
await update(conn, user)
async def add_or_update_many_movies(movies: list[Movie]):
async def add_or_update_many_movies(conn: Connection, /, movies: list[Movie]) -> None:
"""Add or update Movies in the database.
This is an optimized version of `add_or_update_movie` for the purpose
@ -336,12 +400,13 @@ async def add_or_update_many_movies(movies: list[Movie]):
# for movie in movies:
# await add_or_update_movie(movie)
db_movies = {
m.imdb_id: m for m in await get_many(Movie, imdb_id=[m.imdb_id for m in movies])
m.imdb_id: m
for m in await get_many(conn, Movie, imdb_id=[m.imdb_id for m in movies])
}
for movie in movies:
# XXX optimize bulk add & update as well
if movie.imdb_id not in db_movies:
await add(movie)
await add(conn, movie)
else:
db_movie = db_movies[movie.imdb_id]
movie.id = db_movie.id
@ -354,10 +419,10 @@ async def add_or_update_many_movies(movies: list[Movie]):
if movie.updated <= db_movie.updated:
return
await update(movie)
await update(conn, movie)
async def add_or_update_movie(movie: Movie):
async def add_or_update_movie(conn: Connection, /, movie: Movie) -> None:
"""Add or update a Movie in the database.
This is an upsert operation, but it will also update the Movie you pass
@ -365,9 +430,9 @@ async def add_or_update_movie(movie: Movie):
set all optional values on your Movie that might be unset but exist in the
database. It's a bidirectional sync.
"""
db_movie = await get(Movie, imdb_id=movie.imdb_id)
db_movie = await get(conn, Movie, imdb_id=movie.imdb_id)
if not db_movie:
await add(movie)
await add(conn, movie)
else:
movie.id = db_movie.id
@ -379,33 +444,35 @@ async def add_or_update_movie(movie: Movie):
if movie.updated <= db_movie.updated:
return
await update(movie)
await update(conn, movie)
async def add_or_update_rating(rating: Rating) -> bool:
async def add_or_update_rating(conn: Connection, /, rating: Rating) -> bool:
db_rating = await get(
Rating, movie_id=str(rating.movie_id), user_id=str(rating.user_id)
conn, Rating, movie_id=str(rating.movie_id), user_id=str(rating.user_id)
)
if not db_rating:
await add(rating)
await add(conn, rating)
return True
else:
rating.id = db_rating.id
if rating != db_rating:
await update(rating)
await update(conn, rating)
return True
return False
def sql_escape(s: str, char="#"):
def sql_escape(s: str, char: str = "#") -> str:
return s.replace(char, 2 * char).replace("%", f"{char}%").replace("_", f"{char}_")
async def find_ratings(
conn: Connection,
/,
*,
title: str | None = None,
media_type: str | None = None,
@ -415,163 +482,129 @@ async def find_ratings(
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10,
user_ids: Iterable[str] = [],
):
values: dict[str, int | str] = {
"limit_rows": limit_rows,
}
) -> Iterable[dict[str, Any]]:
conditions = []
if title:
values["escape"] = "#"
escaped_title = sql_escape(title, char=values["escape"])
values["pattern"] = (
escape_char = "#"
escaped_title = sql_escape(title, char=escape_char)
pattern = (
"_".join(escaped_title.split())
if exact
else "%" + "%".join(escaped_title.split()) + "%"
)
conditions.append(
f"""
(
{Movie._table}.title LIKE :pattern ESCAPE :escape
OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape
sa.or_(
movies.c.title.like(pattern, escape=escape_char),
movies.c.original_title.like(pattern, escape=escape_char),
)
"""
)
if yearcomp:
op, year = yearcomp
assert op in "<=>"
values["year"] = year
conditions.append(f"{Movie._table}.release_year{op}:year")
match yearcomp:
case ("<", year):
conditions.append(movies.c.release_year < year)
case ("=", year):
conditions.append(movies.c.release_year == year)
case (">", year):
conditions.append(movies.c.release_year > year)
if media_type:
values["media_type"] = media_type
conditions.append(f"{Movie._table}.media_type=:media_type")
if media_type is not None:
conditions.append(movies.c.media_type == media_type)
if ignore_tv_episodes:
conditions.append(f"{Movie._table}.media_type!='TV Episode'")
conditions.append(movies.c.media_type != "TV Episode")
user_condition = "1=1"
user_condition = []
if user_ids:
uvs = {f"user_id_{i}": v for i, v in enumerate(user_ids, start=1)}
values.update(uvs)
user_condition = f"{Rating._table}.user_id IN ({','.join(':'+n for n in uvs)})"
user_condition.append(ratings.c.user_id.in_(user_ids))
query = f"""
SELECT DISTINCT {Rating._table}.movie_id
FROM {Rating._table}
LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id
WHERE {user_condition}{(' AND ' + ' AND '.join(conditions)) if conditions else ''}
ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC, {Movie._table}.imdb_score DESC
LIMIT :limit_rows
"""
async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values))
movie_ids = tuple(r._mapping["movie_id"] for r in rows)
query = (
sa.select(ratings.c.movie_id)
.distinct()
.outerjoin_from(ratings, movies, movies.c.id == ratings.c.movie_id)
.where(*conditions, *user_condition)
.order_by(
sa.func.length(movies.c.title).asc(),
ratings.c.rating_date.desc(),
movies.c.imdb_score.desc(),
)
.limit(limit_rows)
)
rating_rows: sa.CursorResult[Rating] = await conn.execute(query)
movie_ids = [r.movie_id for r in rating_rows]
if include_unrated and len(movie_ids) < limit_rows:
sqlin, sqlin_vals = sql_in("id", movie_ids, not_=True)
query = f"""
SELECT DISTINCT id AS movie_id
FROM {Movie._table}
WHERE {sqlin}
{('AND ' + ' AND '.join(conditions)) if conditions else ''}
ORDER BY length(title) ASC, imdb_score DESC, release_year DESC
LIMIT :limit_rows
"""
async with locked_connection() as conn:
rows = await conn.fetch_all(
bindparams(
query,
{**values, **sqlin_vals, "limit_rows": limit_rows - len(movie_ids)},
)
query = (
sa.select(movies.c.id)
.distinct()
.where(movies.c.id.not_in(movie_ids), *conditions)
.order_by(
sa.func.length(movies.c.title).asc(),
movies.c.imdb_score.desc(),
movies.c.release_year.desc(),
)
movie_ids += tuple(r._mapping["movie_id"] for r in rows)
.limit(limit_rows - len(movie_ids))
)
movie_rows: sa.CursorResult[Movie] = await conn.execute(query)
movie_ids += [r.id for r in movie_rows]
return await ratings_for_movie_ids(ids=movie_ids)
return await ratings_for_movie_ids(conn, ids=movie_ids)
async def ratings_for_movie_ids(
ids: Iterable[ULID | str] = [], imdb_ids: Iterable[str] = []
conn: Connection,
/,
ids: Iterable[ULID | str] = [],
imdb_ids: Iterable[str] = [],
) -> Iterable[dict[str, Any]]:
conds: list[str] = []
vals: dict[str, str] = {}
conds = []
if ids:
sqlin, sqlin_vals = sql_in(f"{Movie._table}.id", (str(x) for x in ids))
conds.append(sqlin)
vals.update(sqlin_vals)
conds.append(movies.c.id.in_([str(x) for x in ids]))
if imdb_ids:
sqlin, sqlin_vals = sql_in(f"{Movie._table}.imdb_id", imdb_ids)
conds.append(sqlin)
vals.update(sqlin_vals)
conds.append(movies.c.imdb_id.in_(imdb_ids))
if not conds:
return []
query = f"""
SELECT
{Rating._table}.score AS user_score,
{Rating._table}.user_id AS user_id,
{Movie._table}.imdb_score,
{Movie._table}.imdb_votes,
{Movie._table}.imdb_id AS movie_imdb_id,
{Movie._table}.media_type AS media_type,
{Movie._table}.title AS canonical_title,
{Movie._table}.original_title AS original_title,
{Movie._table}.release_year AS release_year
FROM {Movie._table}
LEFT JOIN {Rating._table} ON {Movie._table}.id={Rating._table}.movie_id
WHERE {(' OR '.join(conds))}
"""
async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, vals))
query = (
sa.select(
ratings.c.score.label("user_score"),
ratings.c.user_id.label("user_id"),
movies.c.imdb_score,
movies.c.imdb_votes,
movies.c.imdb_id.label("movie_imdb_id"),
movies.c.media_type.label("media_type"),
movies.c.title.label("canonical_title"),
movies.c.original_title.label("original_title"),
movies.c.release_year.label("release_year"),
)
.outerjoin_from(movies, ratings, movies.c.id == ratings.c.movie_id)
.where(sa.or_(*conds))
)
rows = await fetch_all(conn, query)
return tuple(dict(r._mapping) for r in rows)
def sql_fields(tp: Type):
return (f"{tp._table}.{f.name}" for f in fields(tp))
def sql_in(column: str, values: Iterable[T], not_=False) -> tuple[str, dict[str, T]]:
c = column.replace(".", "___")
value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)}
placeholders = ",".join(":" + k for k in value_map)
if not_:
return f"{column} NOT IN ({placeholders})", value_map
return f"{column} IN ({placeholders})", value_map
async def ratings_for_movies(
movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = []
conn: Connection, /, movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = []
) -> Iterable[Rating]:
values: dict[str, str] = {}
conditions: list[str] = []
q, vm = sql_in("movie_id", [str(m) for m in movie_ids])
conditions.append(q)
values.update(vm)
conditions = [ratings.c.movie_id.in_(str(x) for x in movie_ids)]
if user_ids:
q, vm = sql_in("user_id", [str(m) for m in user_ids])
conditions.append(q)
values.update(vm)
conditions.append(ratings.c.user_id.in_(str(x) for x in user_ids))
query = f"""
SELECT {','.join(sql_fields(Rating))}
FROM {Rating._table}
WHERE {' AND '.join(f'({c})' for c in conditions) if conditions else '1=1'}
"""
query = sa.select(ratings).where(*conditions)
async with locked_connection() as conn:
rows = await conn.fetch_all(query, values)
rows = await fetch_all(conn, query)
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
async def find_movies(
conn: Connection,
/,
*,
title: str | None = None,
media_type: str | None = None,
@ -583,88 +616,63 @@ async def find_movies(
include_unrated: bool = False,
user_ids: list[ULID] = [],
) -> Iterable[tuple[Movie, list[Rating]]]:
values: dict[str, int | str] = {
"limit_rows": limit_rows,
"skip_rows": skip_rows,
}
conditions = []
if title:
values["escape"] = "#"
escaped_title = sql_escape(title, char=values["escape"])
values["pattern"] = (
escape_char = "#"
escaped_title = sql_escape(title, char=escape_char)
pattern = (
"_".join(escaped_title.split())
if exact
else "%" + "%".join(escaped_title.split()) + "%"
)
conditions.append(
f"""
(
{Movie._table}.title LIKE :pattern ESCAPE :escape
OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape
sa.or_(
movies.c.title.like(pattern, escape=escape_char),
movies.c.original_title.like(pattern, escape=escape_char),
)
"""
)
if yearcomp:
op, year = yearcomp
assert op in "<=>"
values["year"] = year
conditions.append(f"{Movie._table}.release_year{op}:year")
match yearcomp:
case ("<", year):
conditions.append(movies.c.release_year < year)
case ("=", year):
conditions.append(movies.c.release_year == year)
case (">", year):
conditions.append(movies.c.release_year > year)
if media_type:
values["media_type"] = media_type
conditions.append(f"{Movie._table}.media_type=:media_type")
if media_type is not None:
conditions.append(movies.c.media_type == media_type)
if ignore_tv_episodes:
conditions.append(f"{Movie._table}.media_type!='TV Episode'")
conditions.append(movies.c.media_type != "TV Episode")
if not include_unrated:
conditions.append(f"{Movie._table}.imdb_score NOTNULL")
conditions.append(movies.c.imdb_score.is_not(None))
query = f"""
SELECT {','.join(sql_fields(Movie))}
FROM {Movie._table}
WHERE {(' AND '.join(conditions)) if conditions else '1=1'}
ORDER BY
length({Movie._table}.title) ASC,
{Movie._table}.imdb_score DESC,
{Movie._table}.release_year DESC
LIMIT :skip_rows, :limit_rows
"""
async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values))
query = (
sa.select(movies)
.where(*conditions)
.order_by(
sa.func.length(movies.c.title).asc(),
movies.c.imdb_score.desc(),
movies.c.release_year.desc(),
)
.limit(limit_rows)
.offset(skip_rows)
)
movies = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
rows = await fetch_all(conn, query)
movies_ = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
if not user_ids:
return ((m, []) for m in movies)
return ((m, []) for m in movies_)
ratings = await ratings_for_movies((m.id for m in movies), user_ids)
ratings = await ratings_for_movies(conn, (m.id for m in movies_), user_ids)
aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies}
aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies_}
for rating in ratings:
aggreg[rating.movie_id][1].append(rating)
return aggreg.values()
def bindparams(query: str, values: dict):
"""Bind values to a query.
This is similar to what SQLAlchemy and Databases do, but it allows to
easily use the same placeholder in multiple places.
"""
pump_vals = {}
pump_keys = {}
def pump(match):
key = match[1]
val = values[key]
pump_keys[key] = 1 + pump_keys.setdefault(key, 0)
pump_key = f"{key}_{pump_keys[key]}"
pump_vals[pump_key] = val
return f":{pump_key}"
pump_query = re.sub(r":(\w+)\b", pump, query)
return sqlalchemy.text(pump_query).bindparams(**pump_vals)

View file

@ -4,6 +4,8 @@ from collections import namedtuple
from datetime import datetime
from urllib.parse import urljoin
import bs4
from . import db
from .models import Movie, Rating, User
from .request import asession, asoup_from_url, cache_path
@ -38,12 +40,14 @@ async def refresh_user_ratings_from_imdb(stop_on_dupe: bool = True):
async with asession() as s:
s.headers["Accept-Language"] = "en-US, en;q=0.5"
for user in await db.get_all(User):
async with db.new_connection() as conn:
users = list(await db.get_all(conn, User))
for user in users:
log.info("⚡️ Loading data for %s ...", user.name)
try:
async for rating, is_updated in load_ratings(user.imdb_id):
assert rating.user.id == user.id
assert rating.user is not None and rating.user.id == user.id
if stop_on_dupe and not is_updated:
break
@ -94,7 +98,7 @@ find_year = re.compile(
find_movie_id = re.compile(r"/title/(?P<id>tt\d+)/").search
def movie_and_rating_from_item(item) -> tuple[Movie, Rating]:
def movie_and_rating_from_item(item: bs4.Tag) -> tuple[Movie, Rating]:
genres = (genre := item.find("span", "genre")) and genre.string or ""
movie = Movie(
title=item.h3.a.string.strip(),
@ -154,13 +158,19 @@ async def parse_page(url: str) -> tuple[list[Rating], str | None]:
soup = await asoup_from_url(url)
meta = soup.find("meta", property="pageId")
headline = soup.h1
assert meta is not None and headline is not None
if (meta := soup.find("meta", property="pageId")) is None:
raise RuntimeError("No pageId found.")
assert isinstance(meta, bs4.Tag)
imdb_id = meta["content"]
user = await db.get(User, imdb_id=imdb_id) or User(
imdb_id=imdb_id, name="", secret=""
)
assert isinstance(imdb_id, str)
async with db.new_connection() as conn:
user = await db.get(conn, User, imdb_id=imdb_id) or User(
imdb_id=imdb_id, name="", secret=""
)
if (headline := soup.h1) is None:
raise RuntimeError("No headline found.")
assert isinstance(headline.string, str)
if match := find_name(headline.string):
user.name = match["name"]
@ -184,9 +194,15 @@ async def parse_page(url: str) -> tuple[list[Rating], str | None]:
ratings.append(rating)
footer = soup.find("div", "footer")
assert footer is not None
next_url = urljoin(url, footer.find(string=re.compile(r"Next")).parent["href"])
next_url = None
if (footer := soup.find("div", "footer")) is None:
raise RuntimeError("No footer found.")
assert isinstance(footer, bs4.Tag)
if (next_link := footer.find("a", string="Next")) is not None:
assert isinstance(next_link, bs4.Tag)
next_href = next_link["href"]
assert isinstance(next_href, str)
next_url = urljoin(url, next_href)
return (ratings, next_url if url != next_url else None)
@ -200,14 +216,15 @@ async def load_ratings(user_id: str):
for i, rating in enumerate(ratings):
assert rating.user and rating.movie
if i == 0:
# All rating objects share the same user.
await db.add_or_update_user(rating.user)
rating.user_id = rating.user.id
async with db.transaction() as conn:
if i == 0:
# All rating objects share the same user.
await db.add_or_update_user(conn, rating.user)
rating.user_id = rating.user.id
await db.add_or_update_movie(rating.movie)
rating.movie_id = rating.movie.id
await db.add_or_update_movie(conn, rating.movie)
rating.movie_id = rating.movie.id
is_updated = await db.add_or_update_rating(rating)
is_updated = await db.add_or_update_rating(conn, rating)
yield rating, is_updated

View file

@ -209,7 +209,8 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path):
for i, m in enumerate(read_basics(basics_path)):
perc = 100 * i / total
if perc >= perc_next_report:
await db.set_import_progress(perc)
async with db.transaction() as conn:
await db.set_import_progress(conn, perc)
log.info("⏳ Imported %s%%", round(perc, 1))
perc_next_report += perc_step
@ -233,15 +234,18 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path):
chunk.append(m)
if len(chunk) > 1000:
await add_or_update_many_movies(chunk)
async with db.transaction() as conn:
await add_or_update_many_movies(conn, chunk)
chunk = []
if chunk:
await add_or_update_many_movies(chunk)
async with db.transaction() as conn:
await add_or_update_many_movies(conn, chunk)
chunk = []
log.info("👍 Imported 100%")
await db.set_import_progress(100)
async with db.transaction() as conn:
await db.set_import_progress(conn, 100)
async def download_datasets(*, basics_path: Path, ratings_path: Path) -> None:
@ -270,7 +274,8 @@ async def load_from_web(*, force: bool = False) -> None:
See https://www.imdb.com/interfaces/ and https://datasets.imdbws.com/ for
more information on the IMDb database dumps.
"""
await db.set_import_progress(0)
async with db.transaction() as conn:
await db.set_import_progress(conn, 0)
try:
ratings_file = config.datadir / "imdb/title.ratings.tsv.gz"
@ -290,8 +295,10 @@ async def load_from_web(*, force: bool = False) -> None:
await import_from_file(basics_path=basics_file, ratings_path=ratings_file)
except BaseException as err:
await db.stop_import_progress(error=err)
async with db.transaction() as conn:
await db.stop_import_progress(conn, error=err)
raise
else:
await db.stop_import_progress()
async with db.transaction() as conn:
await db.stop_import_progress(conn)

View file

@ -11,13 +11,18 @@ from typing import (
Container,
Literal,
Mapping,
Protocol,
Type,
TypedDict,
TypeVar,
Union,
get_args,
get_origin,
)
from sqlalchemy import Column, ForeignKey, Integer, String, Table
from sqlalchemy.orm import registry
from .types import ULID
JSON = int | float | str | None | list["JSON"] | dict[str, "JSON"]
@ -26,8 +31,16 @@ JSONObject = dict[str, JSON]
T = TypeVar("T")
class Model(Protocol):
__table__: ClassVar[Table]
mapper_registry = registry()
metadata = mapper_registry.metadata
def annotations(tp: Type) -> tuple | None:
return tp.__metadata__ if hasattr(tp, "__metadata__") else None
return tp.__metadata__ if hasattr(tp, "__metadata__") else None # type: ignore
def fields(class_or_instance):
@ -112,7 +125,7 @@ def asplain(
if filter_fields is not None and f.name not in filter_fields:
continue
target = f.type
target: Any = f.type
# XXX this doesn't properly support any kind of nested types
if (otype := optional_type(f.type)) is not None:
target = otype
@ -156,7 +169,7 @@ def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
dd: JSONObject = {}
for f in fields(cls):
target = f.type
target: Any = f.type
otype = optional_type(f.type)
is_opt = otype is not None
if is_opt:
@ -194,12 +207,38 @@ def validate(o: object) -> None:
def utcnow():
return datetime.utcnow().replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc)
@mapper_registry.mapped
@dataclass
class DbPatch:
__table__: ClassVar[Table] = Table(
"db_patches",
metadata,
Column("id", Integer, primary_key=True),
Column("current", String),
)
id: int
current: str
db_patches = DbPatch.__table__
@mapper_registry.mapped
@dataclass
class Progress:
_table: ClassVar[str] = "progress"
__table__: ClassVar[Table] = Table(
"progress",
metadata,
Column("id", String, primary_key=True), # ULID
Column("type", String, nullable=False),
Column("state", String, nullable=False), # JSON {"percent": ..., "error": ...}
Column("started", String, nullable=False), # datetime
Column("stopped", String),
)
id: ULID = field(default_factory=ULID)
type: str = None
@ -236,9 +275,28 @@ class Progress:
self._state = state
progress = Progress.__table__
@mapper_registry.mapped
@dataclass
class Movie:
_table: ClassVar[str] = "movies"
__table__: ClassVar[Table] = Table(
"movies",
metadata,
Column("id", String, primary_key=True), # ULID
Column("title", String, nullable=False),
Column("original_title", String),
Column("release_year", Integer, nullable=False),
Column("media_type", String, nullable=False),
Column("imdb_id", String, nullable=False, unique=True),
Column("imdb_score", Integer),
Column("imdb_votes", Integer),
Column("runtime", Integer),
Column("genres", String, nullable=False),
Column("created", String, nullable=False), # datetime
Column("updated", String, nullable=False), # datetime
)
id: ULID = field(default_factory=ULID)
title: str = None # canonical title (usually English)
@ -283,6 +341,8 @@ class Movie:
self._is_lazy = False
movies = Movie.__table__
_RelationSentinel = object()
"""Mark a model field as containing external data.
@ -294,9 +354,65 @@ The contents of the Relation are ignored or discarded when using
Relation = Annotated[T | None, _RelationSentinel]
Access = Literal[
"r", # read
"i", # index
"w", # write
]
class UserGroup(TypedDict):
id: str
access: Access
@mapper_registry.mapped
@dataclass
class User:
__table__: ClassVar[Table] = Table(
"users",
metadata,
Column("id", String, primary_key=True), # ULID
Column("imdb_id", String, nullable=False, unique=True),
Column("name", String, nullable=False),
Column("secret", String, nullable=False),
Column("groups", String, nullable=False), # JSON array
)
id: ULID = field(default_factory=ULID)
imdb_id: str = None
name: str = None # canonical user name
secret: str = None
groups: list[UserGroup] = field(default_factory=list)
def has_access(self, group_id: ULID | str, access: Access = "r"):
group_id = group_id if isinstance(group_id, str) else str(group_id)
return any(g["id"] == group_id and access == g["access"] for g in self.groups)
def set_access(self, group_id: ULID | str, access: Access):
group_id = group_id if isinstance(group_id, str) else str(group_id)
for g in self.groups:
if g["id"] == group_id:
g["access"] = access
break
else:
self.groups.append({"id": group_id, "access": access})
@mapper_registry.mapped
@dataclass
class Rating:
_table: ClassVar[str] = "ratings"
__table__: ClassVar[Table] = Table(
"ratings",
metadata,
Column("id", String, primary_key=True), # ULID
Column("movie_id", ForeignKey("movies.id"), nullable=False), # ULID
Column("user_id", ForeignKey("users.id"), nullable=False), # ULID
Column("score", Integer, nullable=False),
Column("rating_date", String, nullable=False), # datetime
Column("favorite", Integer), # bool
Column("finished", Integer), # bool
)
id: ULID = field(default_factory=ULID)
@ -304,7 +420,7 @@ class Rating:
movie: Relation[Movie] = None
user_id: ULID = None
user: Relation["User"] = None
user: Relation[User] = None
score: int = None # range: [0,100]
rating_date: datetime = None
@ -324,41 +440,25 @@ class Rating:
)
Access = Literal[
"r", # read
"i", # index
"w", # write
]
ratings = Rating.__table__
@dataclass
class User:
_table: ClassVar[str] = "users"
id: ULID = field(default_factory=ULID)
imdb_id: str = None
name: str = None # canonical user name
secret: str = None
groups: list[dict[str, str]] = field(default_factory=list)
def has_access(self, group_id: ULID | str, access: Access = "r"):
group_id = group_id if isinstance(group_id, str) else str(group_id)
return any(g["id"] == group_id and access == g["access"] for g in self.groups)
def set_access(self, group_id: ULID | str, access: Access):
group_id = group_id if isinstance(group_id, str) else str(group_id)
for g in self.groups:
if g["id"] == group_id:
g["access"] = access
break
else:
self.groups.append({"id": group_id, "access": access})
class GroupUser(TypedDict):
id: str
name: str
@mapper_registry.mapped
@dataclass
class Group:
_table: ClassVar[str] = "groups"
__table__: ClassVar[Table] = Table(
"groups",
metadata,
Column("id", String, primary_key=True), # ULID
Column("name", String, nullable=False),
Column("users", String, nullable=False), # JSON array
)
id: ULID = field(default_factory=ULID)
name: str = None
users: list[dict[str, str]] = field(default_factory=list)
users: list[GroupUser] = field(default_factory=list)

View file

@ -11,7 +11,7 @@ from hashlib import md5
from pathlib import Path
from random import random
from time import sleep, time
from typing import Callable, ParamSpec, TypeVar, cast
from typing import Any, Callable, ParamSpec, TypeVar, cast
import bs4
import httpx
@ -190,9 +190,11 @@ async def asoup_from_url(url):
def _last_modified_from_response(resp: _Response_T) -> float | None:
if last_mod := resp.headers.get("last-modified"):
try:
return email.utils.parsedate_to_datetime(last_mod).timestamp()
except:
dt = email.utils.parsedate_to_datetime(last_mod)
except ValueError:
log.exception("🐛 Received invalid value for Last-Modified: %s", last_mod)
else:
return dt.timestamp()
def _last_modified_from_file(path: Path) -> float:
@ -206,8 +208,8 @@ async def adownload(
replace_existing: bool | None = None,
only_if_newer: bool = False,
timeout: float | None = None,
chunk_callback=None,
response_callback=None,
chunk_callback: Callable[[bytes], Any] | None = None,
response_callback: Callable[[_Response_T], Any] | None = None,
) -> bytes | None:
"""Download a file.
@ -246,7 +248,7 @@ async def adownload(
if response_callback is not None:
try:
response_callback(resp)
except:
except BaseException:
log.exception("🐛 Error in response callback.")
log.debug(
@ -267,7 +269,9 @@ async def adownload(
resp.raise_for_status()
if to_path is None:
await resp.aread() # Download the response stream to allow `resp.content` access.
await (
resp.aread()
) # Download the response stream to allow `resp.content` access.
return resp.content
resp_lastmod = _last_modified_from_response(resp)
@ -275,7 +279,7 @@ async def adownload(
# Check Last-Modified in case the server ignored If-Modified-Since.
# XXX also check Content-Length?
if file_exists and only_if_newer and resp_lastmod is not None:
assert file_lastmod
assert file_lastmod # pyright: ignore [reportUnboundVariable]
if resp_lastmod <= file_lastmod:
log.debug("✋ Local file is newer, skipping download: %a", req.url)
@ -299,7 +303,7 @@ async def adownload(
if chunk_callback:
try:
chunk_callback(chunk)
except:
except BaseException:
log.exception("🐛 Error in chunk callback.")
finally:
os.close(tempfd)

View file

@ -168,7 +168,8 @@ async def auth_user(request) -> User | None:
if not isinstance(request.user, AuthedUser):
return
user = await db.get(User, id=request.user.user_id)
async with db.new_connection() as conn:
user = await db.get(conn, User, id=request.user.user_id)
if not user:
return
@ -179,7 +180,7 @@ async def auth_user(request) -> User | None:
return user
_routes = []
_routes: list[Route] = []
def route(path: str, *, methods: list[str] | None = None, **kwds):
@ -191,16 +192,13 @@ def route(path: str, *, methods: list[str] | None = None, **kwds):
return decorator
route.registered = _routes
@route("/groups/{group_id}/ratings")
async def get_ratings_for_group(request):
group_id = as_ulid(request.path_params["group_id"])
group = await db.get(Group, id=str(group_id))
if not group:
return not_found()
async with db.new_connection() as conn:
if (group := await db.get(conn, Group, id=str(group_id))) is None:
return not_found()
user_ids = {u["id"] for u in group.users}
@ -211,22 +209,26 @@ async def get_ratings_for_group(request):
# if (imdb_id or unwind_id) and (movie := await db.get(Movie, id=unwind_id, imdb_id=imdb_id)):
if unwind_id:
rows = await db.ratings_for_movie_ids(ids=[unwind_id])
async with db.new_connection() as conn:
rows = await db.ratings_for_movie_ids(conn, ids=[unwind_id])
elif imdb_id:
rows = await db.ratings_for_movie_ids(imdb_ids=[imdb_id])
async with db.new_connection() as conn:
rows = await db.ratings_for_movie_ids(conn, imdb_ids=[imdb_id])
else:
rows = await find_ratings(
title=params.get("title"),
media_type=params.get("media_type"),
exact=truthy(params.get("exact")),
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
include_unrated=truthy(params.get("include_unrated")),
yearcomp=yearcomp(params["year"]) if "year" in params else None,
limit_rows=as_int(params.get("per_page"), max=10, default=5),
user_ids=user_ids,
)
async with db.new_connection() as conn:
rows = await find_ratings(
conn,
title=params.get("title"),
media_type=params.get("media_type"),
exact=truthy(params.get("exact")),
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
include_unrated=truthy(params.get("include_unrated")),
yearcomp=yearcomp(params["year"]) if "year" in params else None,
limit_rows=as_int(params.get("per_page"), max=10, default=5),
user_ids=user_ids,
)
ratings = (web_models.Rating(**r) for r in rows)
@ -265,7 +267,8 @@ async def list_movies(request):
if group_id := params.get("group_id"):
group_id = as_ulid(group_id)
group = await db.get(Group, id=str(group_id))
async with db.new_connection() as conn:
group = await db.get(conn, Group, id=str(group_id))
if not group:
return not_found("Group not found.")
@ -290,26 +293,31 @@ async def list_movies(request):
if imdb_id or unwind_id:
# XXX missing support for user_ids and user_scores
movies = (
[m] if (m := await db.get(Movie, id=unwind_id, imdb_id=imdb_id)) else []
)
async with db.new_connection() as conn:
movies = (
[m]
if (m := await db.get(conn, Movie, id=unwind_id, imdb_id=imdb_id))
else []
)
resp = [asplain(web_models.RatingAggregate.from_movie(m)) for m in movies]
else:
per_page = as_int(params.get("per_page"), max=1000, default=5)
page = as_int(params.get("page"), min=1, default=1)
movieratings = await find_movies(
title=params.get("title"),
media_type=params.get("media_type"),
exact=truthy(params.get("exact")),
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
include_unrated=truthy(params.get("include_unrated")),
yearcomp=yearcomp(params["year"]) if "year" in params else None,
limit_rows=per_page,
skip_rows=(page - 1) * per_page,
user_ids=list(user_ids),
)
async with db.new_connection() as conn:
movieratings = await find_movies(
conn,
title=params.get("title"),
media_type=params.get("media_type"),
exact=truthy(params.get("exact")),
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
include_unrated=truthy(params.get("include_unrated")),
yearcomp=yearcomp(params["year"]) if "year" in params else None,
limit_rows=per_page,
skip_rows=(page - 1) * per_page,
user_ids=list(user_ids),
)
resp = []
for movie, ratings in movieratings:
@ -329,7 +337,8 @@ async def add_movie(request):
@route("/movies/_reload_imdb", methods=["GET"])
@requires(["authenticated", "admin"])
async def progress_for_load_imdb_movies(request):
progress = await db.get_import_progress()
async with db.new_connection() as conn:
progress = await db.get_import_progress(conn)
if not progress:
return JSONResponse({"status": "No import exists."}, status_code=404)
@ -368,14 +377,16 @@ async def load_imdb_movies(request):
force = truthy(params.get("force"))
async with _import_lock:
progress = await db.get_import_progress()
async with db.new_connection() as conn:
progress = await db.get_import_progress(conn)
if progress and not progress.stopped:
return JSONResponse(
{"status": "Import is running.", "progress": progress.percent},
status_code=409,
)
await db.set_import_progress(0)
async with db.transaction() as conn:
await db.set_import_progress(conn, 0)
task = BackgroundTask(imdb_import.load_from_web, force=force)
return JSONResponse(
@ -386,7 +397,8 @@ async def load_imdb_movies(request):
@route("/users")
@requires(["authenticated", "admin"])
async def list_users(request):
users = await db.get_all(User)
async with db.new_connection() as conn:
users = await db.get_all(conn, User)
return JSONResponse([asplain(u) for u in users])
@ -402,7 +414,8 @@ async def add_user(request):
secret = secrets.token_bytes()
user = User(name=name, imdb_id=imdb_id, secret=phc_scrypt(secret))
await db.add(user)
async with db.transaction() as conn:
await db.add(conn, user)
return JSONResponse(
{
@ -418,7 +431,8 @@ async def show_user(request):
user_id = as_ulid(request.path_params["user_id"])
if is_admin(request):
user = await db.get(User, id=str(user_id))
async with db.new_connection() as conn:
user = await db.get(conn, User, id=str(user_id))
else:
user = await auth_user(request)
@ -445,14 +459,15 @@ async def show_user(request):
async def remove_user(request):
user_id = as_ulid(request.path_params["user_id"])
user = await db.get(User, id=str(user_id))
async with db.new_connection() as conn:
user = await db.get(conn, User, id=str(user_id))
if not user:
return not_found()
async with db.shared_connection().transaction():
async with db.transaction() as conn:
# XXX remove user refs from groups and ratings
await db.remove(user)
await db.remove(conn, user)
return JSONResponse(asplain(user))
@ -463,7 +478,8 @@ async def modify_user(request):
user_id = as_ulid(request.path_params["user_id"])
if is_admin(request):
user = await db.get(User, id=str(user_id))
async with db.new_connection() as conn:
user = await db.get(conn, User, id=str(user_id))
else:
user = await auth_user(request)
@ -499,7 +515,8 @@ async def modify_user(request):
user.secret = phc_scrypt(secret)
await db.update(user)
async with db.transaction() as conn:
await db.update(conn, user)
return JSONResponse(asplain(user))
@ -509,13 +526,15 @@ async def modify_user(request):
async def add_group_to_user(request):
user_id = as_ulid(request.path_params["user_id"])
user = await db.get(User, id=str(user_id))
async with db.new_connection() as conn:
user = await db.get(conn, User, id=str(user_id))
if not user:
return not_found("User not found")
(group_id, access) = await json_from_body(request, ["group", "access"])
group = await db.get(Group, id=str(group_id))
async with db.new_connection() as conn:
group = await db.get(conn, Group, id=str(group_id))
if not group:
return not_found("Group not found")
@ -523,7 +542,8 @@ async def add_group_to_user(request):
raise HTTPException(422, f"Invalid access level.")
user.set_access(group_id, access)
await db.update(user)
async with db.transaction() as conn:
await db.update(conn, user)
return JSONResponse(asplain(user))
@ -551,7 +571,8 @@ async def load_imdb_user_ratings(request):
@route("/groups")
@requires(["authenticated", "admin"])
async def list_groups(request):
groups = await db.get_all(Group)
async with db.new_connection() as conn:
groups = await db.get_all(conn, Group)
return JSONResponse([asplain(g) for g in groups])
@ -564,7 +585,8 @@ async def add_group(request):
# XXX restrict name
group = Group(name=name)
await db.add(group)
async with db.transaction() as conn:
await db.add(conn, group)
return JSONResponse(asplain(group))
@ -573,7 +595,8 @@ async def add_group(request):
@requires(["authenticated"])
async def add_user_to_group(request):
group_id = as_ulid(request.path_params["group_id"])
group = await db.get(Group, id=str(group_id))
async with db.new_connection() as conn:
group = await db.get(conn, Group, id=str(group_id))
if not group:
return not_found()
@ -600,7 +623,8 @@ async def add_user_to_group(request):
else:
group.users.append({"name": name, "id": user_id})
await db.update(group)
async with db.transaction() as conn:
await db.update(conn, group)
return JSONResponse(asplain(group))
@ -632,7 +656,7 @@ def create_app():
return Starlette(
lifespan=lifespan,
routes=[
Mount(f"{config.api_base}v1", routes=route.registered),
Mount(f"{config.api_base}v1", routes=_routes),
],
middleware=[
Middleware(ResponseTimeMiddleware, header_name="Unwind-Elapsed"),