Merge branch 'feat/sqlalchemy'
This commit is contained in:
commit
4fbdb26d9c
25 changed files with 2107 additions and 2036 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -2,5 +2,5 @@
|
|||
*.pyc
|
||||
/.cache
|
||||
/.pytest_cache
|
||||
/build
|
||||
/data/*
|
||||
/requirements.txt
|
||||
|
|
|
|||
1
.python-version
Normal file
1
.python-version
Normal file
|
|
@ -0,0 +1 @@
|
|||
3.12
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
FROM docker.io/library/python:3.11-alpine
|
||||
FROM docker.io/library/python:3.12-alpine
|
||||
|
||||
RUN apk update --no-cache \
|
||||
&& apk upgrade --no-cache \
|
||||
|
|
|
|||
748
poetry.lock
generated
748
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,44 +1,46 @@
|
|||
[tool.poetry]
|
||||
name = "unwind"
|
||||
version = "0.1.0"
|
||||
version = "0"
|
||||
description = ""
|
||||
authors = ["ducklet <ducklet@noreply.code.dumpr.org>"]
|
||||
license = "LOL"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
python = "^3.12"
|
||||
beautifulsoup4 = "^4.9.3"
|
||||
html5lib = "^1.1"
|
||||
starlette = "^0.26"
|
||||
starlette = "^0.30"
|
||||
ulid-py = "^1.1.0"
|
||||
databases = {extras = ["sqlite"], version = "^0.7.0"}
|
||||
uvicorn = "^0.21"
|
||||
httpx = "^0.23.3"
|
||||
uvicorn = "^0.23"
|
||||
httpx = "^0.24"
|
||||
sqlalchemy = {version = "^2.0", extras = ["aiosqlite"]}
|
||||
|
||||
[tool.poetry.group.build.dependencies]
|
||||
# When we run poetry export, typing-extensions is a transient dependency via
|
||||
# sqlalchemy, but the hash won't be included in the requirements.txt.
|
||||
# By making it a direct dependency we can fix this issue, otherwise this could
|
||||
# be removed.
|
||||
typing-extensions = "*"
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
autoflake = "*"
|
||||
pytest = "*"
|
||||
pyright = "*"
|
||||
black = "*"
|
||||
isort = "*"
|
||||
pytest-asyncio = "*"
|
||||
pytest-cov = "*"
|
||||
ruff = "*"
|
||||
honcho = "*"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pyright]
|
||||
pythonVersion = "3.11"
|
||||
pythonVersion = "3.12"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
||||
[tool.autoflake]
|
||||
remove-duplicate-keys = true
|
||||
remove-unused-variables = true
|
||||
remove-all-unused-imports = true
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
ignore-init-module-imports = true
|
||||
ignore-pass-after-docstring = true
|
||||
select = ["I", "F401", "F601", "F602", "F841"]
|
||||
|
|
|
|||
|
|
@ -2,6 +2,12 @@
|
|||
|
||||
cd "$RUN_DIR"
|
||||
|
||||
# Make Uvicorn defaults explicit.
|
||||
: "${API_PORT:=8000}"
|
||||
: "${API_HOST:=127.0.0.1}"
|
||||
export API_PORT
|
||||
export API_HOST
|
||||
|
||||
[ -z "${DEBUG:-}" ] || set -x
|
||||
|
||||
exec honcho start
|
||||
|
|
|
|||
|
|
@ -4,4 +4,9 @@ cd "$RUN_DIR"
|
|||
|
||||
[ -z "${DEBUG:-}" ] || set -x
|
||||
|
||||
exec uvicorn unwind:create_app --factory --reload
|
||||
exec uvicorn \
|
||||
--host "$API_HOST" \
|
||||
--port "$API_PORT" \
|
||||
--reload \
|
||||
--factory \
|
||||
unwind:create_app
|
||||
|
|
|
|||
25
scripts/docker-build
Executable file
25
scripts/docker-build
Executable file
|
|
@ -0,0 +1,25 @@
|
|||
#!/bin/sh -eu
|
||||
|
||||
: "${DOCKER_BIN:=docker}"
|
||||
|
||||
cd "$RUN_DIR"
|
||||
|
||||
builddir=build
|
||||
|
||||
[ -z "${DEBUG:-}" ] || set -x
|
||||
|
||||
mkdir -p "$builddir"
|
||||
|
||||
poetry export \
|
||||
--with=build \
|
||||
--output="$builddir"/requirements.txt
|
||||
|
||||
githash=$(git rev-parse --short HEAD)
|
||||
today=$(date -u '+%Y.%m.%d')
|
||||
version="${today}+${githash}"
|
||||
echo "$version" >"$builddir"/version
|
||||
|
||||
$DOCKER_BIN build \
|
||||
--pull \
|
||||
--tag "code.dumpr.org/ducklet/unwind":"$version" \
|
||||
.
|
||||
18
scripts/docker-run
Executable file
18
scripts/docker-run
Executable file
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/sh -eu
|
||||
|
||||
: "${DOCKER_BIN:=docker}"
|
||||
|
||||
cd "$RUN_DIR"
|
||||
|
||||
[ -z "${DEBUG:-}" ] || set -x
|
||||
|
||||
version=$(cat build/version)
|
||||
|
||||
$DOCKER_BIN run \
|
||||
--init \
|
||||
-it --rm \
|
||||
--read-only \
|
||||
--memory '500m' \
|
||||
--publish 127.0.0.1:8000:8000 \
|
||||
--volume "$RUN_DIR"/data:/data \
|
||||
"code.dumpr.org/ducklet/unwind":"$version"
|
||||
|
|
@ -4,7 +4,7 @@ cd "$RUN_DIR"
|
|||
|
||||
[ -z "${DEBUG:-}" ] || set -x
|
||||
|
||||
autoflake --quiet --check --recursive unwind tests
|
||||
isort unwind tests
|
||||
black unwind tests
|
||||
ruff check --fix . ||:
|
||||
ruff format .
|
||||
|
||||
pyright
|
||||
|
|
|
|||
|
|
@ -11,4 +11,5 @@ export UNWIND_PORT
|
|||
exec uvicorn \
|
||||
--host 0.0.0.0 \
|
||||
--port "$UNWIND_PORT" \
|
||||
--factory unwind:create_app
|
||||
--factory \
|
||||
unwind:create_app
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ dbfile="${UNWIND_DATA:-./data}/tests.sqlite"
|
|||
|
||||
# Rollback in Databases is currently broken, so we have to rebuild the database
|
||||
# each time; see https://github.com/encode/databases/issues/403
|
||||
trap 'rm "$dbfile"' EXIT TERM INT QUIT
|
||||
trap 'rm "$dbfile" "${dbfile}-shm" "${dbfile}-wal"' EXIT TERM INT QUIT
|
||||
|
||||
[ -z "${DEBUG:-}" ] || set -x
|
||||
|
||||
SQLALCHEMY_WARN_20=1 \
|
||||
export SQLALCHEMY_WARN_20=1 # XXX remove when we switched to SQLAlchemy 2.0
|
||||
UNWIND_STORAGE="$dbfile" \
|
||||
python -m pytest "$@"
|
||||
python -m pytest --cov "$@"
|
||||
|
|
|
|||
|
|
@ -17,16 +17,19 @@ def event_loop():
|
|||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def shared_conn():
|
||||
c = db.shared_connection()
|
||||
await c.connect()
|
||||
"""A database connection, ready to use."""
|
||||
await db.open_connection_pool()
|
||||
|
||||
await db.apply_db_patches(c)
|
||||
async with db.new_connection() as c:
|
||||
db._test_connection = c
|
||||
yield c
|
||||
db._test_connection = None
|
||||
|
||||
await c.disconnect()
|
||||
await db.close_connection_pool()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def conn(shared_conn):
|
||||
async with shared_conn.transaction(force_rollback=True):
|
||||
async def conn(shared_conn: db.Connection):
|
||||
"""A transacted database connection, will be rolled back after use."""
|
||||
async with db.transacted(shared_conn, force_rollback=True):
|
||||
yield shared_conn
|
||||
|
|
|
|||
353
tests/test_db.py
353
tests/test_db.py
|
|
@ -4,75 +4,176 @@ import pytest
|
|||
|
||||
from unwind import db, models, web_models
|
||||
|
||||
_movie_imdb_id = 1230000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_and_get(shared_conn: db.Database):
|
||||
async with shared_conn.transaction(force_rollback=True):
|
||||
m1 = models.Movie(
|
||||
title="test movie",
|
||||
release_year=2013,
|
||||
media_type="Movie",
|
||||
imdb_id="tt0000000",
|
||||
genres={"genre-1"},
|
||||
)
|
||||
await db.add(m1)
|
||||
|
||||
m2 = models.Movie(
|
||||
title="test movie",
|
||||
release_year=2013,
|
||||
media_type="Movie",
|
||||
imdb_id="tt0000001",
|
||||
genres={"genre-1"},
|
||||
)
|
||||
await db.add(m2)
|
||||
|
||||
assert m1 == await db.get(models.Movie, id=str(m1.id))
|
||||
assert m2 == await db.get(models.Movie, id=str(m2.id))
|
||||
def a_movie(**kwds) -> models.Movie:
|
||||
global _movie_imdb_id
|
||||
_movie_imdb_id += 1
|
||||
args = {
|
||||
"title": "test movie",
|
||||
"release_year": 2013,
|
||||
"media_type": "Movie",
|
||||
"imdb_id": f"tt{_movie_imdb_id}",
|
||||
"genres": {"genre-1"},
|
||||
} | kwds
|
||||
return models.Movie(**args)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_ratings(shared_conn: db.Database):
|
||||
async with shared_conn.transaction(force_rollback=True):
|
||||
m1 = models.Movie(
|
||||
async def test_current_patch_level(conn: db.Connection):
|
||||
patch_level = "some-patch-level"
|
||||
assert patch_level != await db.current_patch_level(conn)
|
||||
await db.set_current_patch_level(conn, patch_level)
|
||||
assert patch_level == await db.current_patch_level(conn)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get(conn: db.Connection):
|
||||
m1 = a_movie()
|
||||
await db.add(conn, m1)
|
||||
|
||||
m2 = a_movie(release_year=m1.release_year + 1)
|
||||
await db.add(conn, m2)
|
||||
|
||||
assert None is await db.get(conn, models.Movie)
|
||||
assert None is await db.get(conn, models.Movie, id="blerp")
|
||||
assert m1 == await db.get(conn, models.Movie, id=str(m1.id))
|
||||
assert m2 == await db.get(conn, models.Movie, release_year=m2.release_year)
|
||||
assert None is await db.get(
|
||||
conn, models.Movie, id=str(m1.id), release_year=m2.release_year
|
||||
)
|
||||
assert m2 == await db.get(
|
||||
conn, models.Movie, id=str(m2.id), release_year=m2.release_year
|
||||
)
|
||||
assert m1 == await db.get(
|
||||
conn,
|
||||
models.Movie,
|
||||
media_type=m1.media_type,
|
||||
order_by=(models.movies.c.release_year, "asc"),
|
||||
)
|
||||
assert m2 == await db.get(
|
||||
conn,
|
||||
models.Movie,
|
||||
media_type=m1.media_type,
|
||||
order_by=(models.movies.c.release_year, "desc"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all(conn: db.Connection):
|
||||
m1 = a_movie()
|
||||
await db.add(conn, m1)
|
||||
|
||||
m2 = a_movie(release_year=m1.release_year)
|
||||
await db.add(conn, m2)
|
||||
|
||||
m3 = a_movie(release_year=m1.release_year + 1)
|
||||
await db.add(conn, m3)
|
||||
|
||||
assert [] == list(await db.get_all(conn, models.Movie, id="blerp"))
|
||||
assert [m1] == list(await db.get_all(conn, models.Movie, id=str(m1.id)))
|
||||
assert [m1, m2] == list(
|
||||
await db.get_all(conn, models.Movie, release_year=m1.release_year)
|
||||
)
|
||||
assert [m1, m2, m3] == list(await db.get_all(conn, models.Movie))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_many(conn: db.Connection):
|
||||
m1 = a_movie()
|
||||
await db.add(conn, m1)
|
||||
|
||||
m2 = a_movie(release_year=m1.release_year)
|
||||
await db.add(conn, m2)
|
||||
|
||||
m3 = a_movie(release_year=m1.release_year + 1)
|
||||
await db.add(conn, m3)
|
||||
|
||||
assert [] == list(await db.get_many(conn, models.Movie)), "selected nothing"
|
||||
assert [m1] == list(await db.get_many(conn, models.Movie, id=[str(m1.id)]))
|
||||
assert [m1] == list(await db.get_many(conn, models.Movie, id={str(m1.id)}))
|
||||
assert [m1, m2] == list(
|
||||
await db.get_many(conn, models.Movie, release_year=[m1.release_year])
|
||||
)
|
||||
assert [m1, m2, m3] == list(
|
||||
await db.get_many(
|
||||
conn, models.Movie, release_year=[m1.release_year, m3.release_year]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_and_get(conn: db.Connection):
|
||||
m1 = a_movie()
|
||||
await db.add(conn, m1)
|
||||
|
||||
m2 = a_movie()
|
||||
await db.add(conn, m2)
|
||||
|
||||
assert m1 == await db.get(conn, models.Movie, id=str(m1.id))
|
||||
assert m2 == await db.get(conn, models.Movie, id=str(m2.id))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update(conn: db.Connection):
|
||||
m = a_movie()
|
||||
await db.add(conn, m)
|
||||
|
||||
assert m == await db.get(conn, models.Movie, id=str(m.id))
|
||||
m.title += "something else"
|
||||
assert m != await db.get(conn, models.Movie, id=str(m.id))
|
||||
|
||||
await db.update(conn, m)
|
||||
assert m == await db.get(conn, models.Movie, id=str(m.id))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove(conn: db.Connection):
|
||||
m1 = a_movie()
|
||||
await db.add(conn, m1)
|
||||
assert m1 == await db.get(conn, models.Movie, id=str(m1.id))
|
||||
|
||||
await db.remove(conn, m1)
|
||||
assert None is await db.get(conn, models.Movie, id=str(m1.id))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_ratings(conn: db.Connection):
|
||||
m1 = a_movie(
|
||||
title="test movie",
|
||||
release_year=2013,
|
||||
media_type="Movie",
|
||||
imdb_id="tt0000000",
|
||||
genres={"genre-1"},
|
||||
)
|
||||
await db.add(m1)
|
||||
await db.add(conn, m1)
|
||||
|
||||
m2 = models.Movie(
|
||||
m2 = a_movie(
|
||||
title="it's anöther Movie, Part 2",
|
||||
release_year=2015,
|
||||
media_type="Movie",
|
||||
imdb_id="tt0000001",
|
||||
genres={"genre-2"},
|
||||
)
|
||||
await db.add(m2)
|
||||
await db.add(conn, m2)
|
||||
|
||||
m3 = models.Movie(
|
||||
m3 = a_movie(
|
||||
title="movie it's, Part 3",
|
||||
release_year=2015,
|
||||
media_type="Movie",
|
||||
imdb_id="tt0000002",
|
||||
genres={"genre-2"},
|
||||
release_year=m2.release_year,
|
||||
genres=m2.genres,
|
||||
)
|
||||
await db.add(m3)
|
||||
await db.add(conn, m3)
|
||||
|
||||
u1 = models.User(
|
||||
imdb_id="u00001",
|
||||
name="User1",
|
||||
secret="secret1",
|
||||
)
|
||||
await db.add(u1)
|
||||
await db.add(conn, u1)
|
||||
|
||||
u2 = models.User(
|
||||
imdb_id="u00002",
|
||||
name="User2",
|
||||
secret="secret2",
|
||||
)
|
||||
await db.add(u2)
|
||||
await db.add(conn, u2)
|
||||
|
||||
r1 = models.Rating(
|
||||
movie_id=m2.id,
|
||||
|
|
@ -82,7 +183,7 @@ async def test_find_ratings(shared_conn: db.Database):
|
|||
score=66,
|
||||
rating_date=datetime.now(),
|
||||
)
|
||||
await db.add(r1)
|
||||
await db.add(conn, r1)
|
||||
|
||||
r2 = models.Rating(
|
||||
movie_id=m2.id,
|
||||
|
|
@ -92,11 +193,12 @@ async def test_find_ratings(shared_conn: db.Database):
|
|||
score=77,
|
||||
rating_date=datetime.now(),
|
||||
)
|
||||
await db.add(r2)
|
||||
await db.add(conn, r2)
|
||||
|
||||
# ---
|
||||
|
||||
rows = await db.find_ratings(
|
||||
conn,
|
||||
title=m1.title,
|
||||
media_type=m1.media_type,
|
||||
exact=True,
|
||||
|
|
@ -111,14 +213,14 @@ async def test_find_ratings(shared_conn: db.Database):
|
|||
web_models.aggregate_ratings(ratings, user_ids=[])
|
||||
)
|
||||
|
||||
rows = await db.find_ratings(title="movie", include_unrated=False)
|
||||
rows = await db.find_ratings(conn, title="movie", include_unrated=False)
|
||||
ratings = tuple(web_models.Rating(**r) for r in rows)
|
||||
assert (
|
||||
web_models.Rating.from_movie(m2, rating=r1),
|
||||
web_models.Rating.from_movie(m2, rating=r2),
|
||||
) == ratings
|
||||
|
||||
rows = await db.find_ratings(title="movie", include_unrated=True)
|
||||
rows = await db.find_ratings(conn, title="movie", include_unrated=True)
|
||||
ratings = tuple(web_models.Rating(**r) for r in rows)
|
||||
assert (
|
||||
web_models.Rating.from_movie(m1),
|
||||
|
|
@ -146,13 +248,172 @@ async def test_find_ratings(shared_conn: db.Database):
|
|||
web_models.RatingAggregate.from_movie(m3),
|
||||
) == tuple(aggr)
|
||||
|
||||
rows = await db.find_ratings(title="movie", include_unrated=True)
|
||||
rows = await db.find_ratings(conn, title="movie", include_unrated=True)
|
||||
ratings = (web_models.Rating(**r) for r in rows)
|
||||
aggr = web_models.aggregate_ratings(ratings, user_ids=[])
|
||||
assert tuple(
|
||||
web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3]
|
||||
) == tuple(aggr)
|
||||
|
||||
rows = await db.find_ratings(title="test", include_unrated=True)
|
||||
rows = await db.find_ratings(conn, title="test", include_unrated=True)
|
||||
ratings = tuple(web_models.Rating(**r) for r in rows)
|
||||
assert (web_models.Rating.from_movie(m1),) == ratings
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ratings_for_movies(conn: db.Connection):
|
||||
m1 = a_movie()
|
||||
await db.add(conn, m1)
|
||||
|
||||
m2 = a_movie()
|
||||
await db.add(conn, m2)
|
||||
|
||||
u1 = models.User(
|
||||
imdb_id="u00001",
|
||||
name="User1",
|
||||
secret="secret1",
|
||||
)
|
||||
await db.add(conn, u1)
|
||||
|
||||
u2 = models.User(
|
||||
imdb_id="u00002",
|
||||
name="User2",
|
||||
secret="secret2",
|
||||
)
|
||||
await db.add(conn, u2)
|
||||
|
||||
r1 = models.Rating(
|
||||
movie_id=m2.id,
|
||||
movie=m2,
|
||||
user_id=u1.id,
|
||||
user=u1,
|
||||
score=66,
|
||||
rating_date=datetime.now(),
|
||||
)
|
||||
await db.add(conn, r1)
|
||||
|
||||
# ---
|
||||
|
||||
movie_ids = [m1.id]
|
||||
user_ids = []
|
||||
assert tuple() == tuple(
|
||||
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
||||
)
|
||||
|
||||
movie_ids = [m2.id]
|
||||
user_ids = []
|
||||
assert (r1,) == tuple(
|
||||
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
||||
)
|
||||
|
||||
movie_ids = [m2.id]
|
||||
user_ids = [u2.id]
|
||||
assert tuple() == tuple(
|
||||
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
||||
)
|
||||
|
||||
movie_ids = [m2.id]
|
||||
user_ids = [u1.id]
|
||||
assert (r1,) == tuple(
|
||||
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
||||
)
|
||||
|
||||
movie_ids = [m1.id, m2.id]
|
||||
user_ids = [u1.id, u2.id]
|
||||
assert (r1,) == tuple(
|
||||
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_movies(conn: db.Connection):
|
||||
m1 = a_movie(title="movie one")
|
||||
await db.add(conn, m1)
|
||||
|
||||
m2 = a_movie(title="movie two", imdb_score=33, release_year=m1.release_year + 1)
|
||||
await db.add(conn, m2)
|
||||
|
||||
u1 = models.User(
|
||||
imdb_id="u00001",
|
||||
name="User1",
|
||||
secret="secret1",
|
||||
)
|
||||
await db.add(conn, u1)
|
||||
|
||||
u2 = models.User(
|
||||
imdb_id="u00002",
|
||||
name="User2",
|
||||
secret="secret2",
|
||||
)
|
||||
await db.add(conn, u2)
|
||||
|
||||
r1 = models.Rating(
|
||||
movie_id=m2.id,
|
||||
movie=m2,
|
||||
user_id=u1.id,
|
||||
user=u1,
|
||||
score=66,
|
||||
rating_date=datetime.now(),
|
||||
)
|
||||
await db.add(conn, r1)
|
||||
|
||||
# ---
|
||||
|
||||
assert () == tuple(
|
||||
await db.find_movies(conn, title=m1.title, include_unrated=False)
|
||||
)
|
||||
assert ((m1, []),) == tuple(
|
||||
await db.find_movies(conn, title=m1.title, include_unrated=True)
|
||||
)
|
||||
|
||||
assert ((m1, []),) == tuple(
|
||||
await db.find_movies(conn, title="mo on", exact=False, include_unrated=True)
|
||||
)
|
||||
assert ((m1, []),) == tuple(
|
||||
await db.find_movies(conn, title="movie one", exact=True, include_unrated=True)
|
||||
)
|
||||
assert () == tuple(
|
||||
await db.find_movies(conn, title="mo on", exact=True, include_unrated=True)
|
||||
)
|
||||
|
||||
assert ((m2, []),) == tuple(
|
||||
await db.find_movies(conn, title="movie", exact=False, include_unrated=False)
|
||||
)
|
||||
assert ((m2, []), (m1, [])) == tuple(
|
||||
await db.find_movies(conn, title="movie", exact=False, include_unrated=True)
|
||||
)
|
||||
|
||||
assert ((m1, []),) == tuple(
|
||||
await db.find_movies(
|
||||
conn, include_unrated=True, yearcomp=("=", m1.release_year)
|
||||
)
|
||||
)
|
||||
assert ((m2, []),) == tuple(
|
||||
await db.find_movies(
|
||||
conn, include_unrated=True, yearcomp=("=", m2.release_year)
|
||||
)
|
||||
)
|
||||
assert ((m1, []),) == tuple(
|
||||
await db.find_movies(
|
||||
conn, include_unrated=True, yearcomp=("<", m2.release_year)
|
||||
)
|
||||
)
|
||||
assert ((m2, []),) == tuple(
|
||||
await db.find_movies(
|
||||
conn, include_unrated=True, yearcomp=(">", m1.release_year)
|
||||
)
|
||||
)
|
||||
|
||||
assert ((m2, []), (m1, [])) == tuple(
|
||||
await db.find_movies(conn, include_unrated=True)
|
||||
)
|
||||
assert ((m2, []),) == tuple(
|
||||
await db.find_movies(conn, include_unrated=True, limit_rows=1)
|
||||
)
|
||||
assert ((m1, []),) == tuple(
|
||||
await db.find_movies(conn, include_unrated=True, skip_rows=1)
|
||||
)
|
||||
|
||||
assert ((m2, [r1]), (m1, [])) == tuple(
|
||||
await db.find_movies(conn, include_unrated=True, user_ids=[u1.id, u2.id])
|
||||
)
|
||||
|
|
|
|||
11
tests/test_models.py
Normal file
11
tests/test_models.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
import pytest
|
||||
|
||||
from unwind import models
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mapper", models.mapper_registry.mappers)
|
||||
def test_fields(mapper):
|
||||
"""Test that models.fields() matches exactly all table columns."""
|
||||
dcfields = {f.name for f in models.fields(mapper.class_)}
|
||||
mfields = {c.name for c in mapper.columns}
|
||||
assert dcfields == mfields
|
||||
|
|
@ -1,22 +1,138 @@
|
|||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from unwind import create_app, db, imdb, models
|
||||
from unwind import config, create_app, db, imdb, models
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_app(shared_conn: db.Database):
|
||||
async with shared_conn.transaction(force_rollback=True):
|
||||
@pytest.fixture(scope="module")
|
||||
def unauthorized_client() -> TestClient:
|
||||
# https://www.starlette.io/testclient/
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def authorized_client() -> TestClient:
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/movies")
|
||||
client.auth = "user1", "secret1"
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def admin_client() -> TestClient:
|
||||
client = TestClient(app)
|
||||
for token in config.api_credentials.values():
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("No bearer tokens configured.")
|
||||
client.headers = {"Authorization": f"Bearer {token}"}
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ratings_for_group(
|
||||
conn: db.Connection, unauthorized_client: TestClient
|
||||
):
|
||||
user = models.User(
|
||||
imdb_id="ur12345678",
|
||||
name="user-1",
|
||||
secret="secret-1",
|
||||
groups=[],
|
||||
)
|
||||
group = models.Group(
|
||||
name="group-1",
|
||||
users=[models.GroupUser(id=str(user.id), name=user.name)],
|
||||
)
|
||||
user.groups = [models.UserGroup(id=str(group.id), access="r")]
|
||||
path = app.url_path_for("get_ratings_for_group", group_id=str(group.id))
|
||||
|
||||
resp = unauthorized_client.get(path)
|
||||
assert resp.status_code == 404, "Group does not exist (yet)"
|
||||
|
||||
await db.add(conn, user)
|
||||
await db.add(conn, group)
|
||||
|
||||
resp = unauthorized_client.get(path)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
movie = models.Movie(
|
||||
title="test movie",
|
||||
release_year=2013,
|
||||
media_type="Movie",
|
||||
imdb_id="tt12345678",
|
||||
genres={"genre-1"},
|
||||
)
|
||||
await db.add(conn, movie)
|
||||
|
||||
rating = models.Rating(
|
||||
movie_id=movie.id, user_id=user.id, score=66, rating_date=datetime.now()
|
||||
)
|
||||
await db.add(conn, rating)
|
||||
|
||||
rating_aggregate = {
|
||||
"canonical_title": movie.title,
|
||||
"imdb_score": movie.imdb_score,
|
||||
"imdb_votes": movie.imdb_votes,
|
||||
"link": imdb.movie_url(movie.imdb_id),
|
||||
"media_type": movie.media_type,
|
||||
"original_title": movie.original_title,
|
||||
"user_scores": [rating.score],
|
||||
"year": movie.release_year,
|
||||
}
|
||||
|
||||
resp = unauthorized_client.get(path)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == [rating_aggregate]
|
||||
|
||||
filters = {
|
||||
"imdb_id": movie.imdb_id,
|
||||
"unwind_id": str(movie.id),
|
||||
"title": movie.title,
|
||||
"media_type": movie.media_type,
|
||||
"year": movie.release_year,
|
||||
}
|
||||
for k, v in filters.items():
|
||||
resp = unauthorized_client.get(path, params={k: v})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == [rating_aggregate]
|
||||
|
||||
resp = unauthorized_client.get(path, params={"title": "no such thing"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
# Test "exact" query param.
|
||||
resp = unauthorized_client.get(
|
||||
path, params={"title": "test movie", "exact": "true"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == [rating_aggregate]
|
||||
resp = unauthorized_client.get(path, params={"title": "te mo", "exact": "false"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == [rating_aggregate]
|
||||
resp = unauthorized_client.get(path, params={"title": "te mo", "exact": "true"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
# XXX Test "ignore_tv_episodes" query param.
|
||||
# XXX Test "include_unrated" query param.
|
||||
# XXX Test "per_page" query param.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_movies(
|
||||
conn: db.Connection,
|
||||
unauthorized_client: TestClient,
|
||||
authorized_client: TestClient,
|
||||
):
|
||||
path = app.url_path_for("list_movies")
|
||||
response = unauthorized_client.get(path)
|
||||
assert response.status_code == 403
|
||||
|
||||
client.auth = "user1", "secret1"
|
||||
|
||||
response = client.get("/api/v1/movies")
|
||||
response = authorized_client.get(path)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
|
|
@ -27,9 +143,9 @@ async def test_app(shared_conn: db.Database):
|
|||
imdb_id="tt12345678",
|
||||
genres={"genre-1"},
|
||||
)
|
||||
await db.add(m)
|
||||
await db.add(conn, m)
|
||||
|
||||
response = client.get("/api/v1/movies", params={"include_unrated": 1})
|
||||
response = authorized_client.get(path, params={"include_unrated": 1})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [{**models.asplain(m), "user_scores": []}]
|
||||
|
||||
|
|
@ -44,10 +160,84 @@ async def test_app(shared_conn: db.Database):
|
|||
"year": m.release_year,
|
||||
}
|
||||
|
||||
response = client.get("/api/v1/movies", params={"imdb_id": m.imdb_id})
|
||||
response = authorized_client.get(path, params={"imdb_id": m.imdb_id})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [m_plain]
|
||||
|
||||
response = client.get("/api/v1/movies", params={"unwind_id": str(m.id)})
|
||||
response = authorized_client.get(path, params={"unwind_id": str(m.id)})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [m_plain]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users(
|
||||
conn: db.Connection,
|
||||
unauthorized_client: TestClient,
|
||||
authorized_client: TestClient,
|
||||
admin_client: TestClient,
|
||||
):
|
||||
path = app.url_path_for("list_users")
|
||||
response = unauthorized_client.get(path)
|
||||
assert response.status_code == 403
|
||||
|
||||
response = authorized_client.get(path)
|
||||
assert response.status_code == 403
|
||||
|
||||
response = admin_client.get(path)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
m = models.User(
|
||||
imdb_id="ur12345678",
|
||||
name="user-1",
|
||||
secret="secret-1",
|
||||
groups=[],
|
||||
)
|
||||
await db.add(conn, m)
|
||||
|
||||
m_plain = {
|
||||
"groups": m.groups,
|
||||
"id": m.id,
|
||||
"imdb_id": m.imdb_id,
|
||||
"name": m.name,
|
||||
"secret": m.secret,
|
||||
}
|
||||
|
||||
response = admin_client.get(path)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [m_plain]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_groups(
|
||||
conn: db.Connection,
|
||||
unauthorized_client: TestClient,
|
||||
authorized_client: TestClient,
|
||||
admin_client: TestClient,
|
||||
):
|
||||
path = app.url_path_for("list_groups")
|
||||
response = unauthorized_client.get(path)
|
||||
assert response.status_code == 403
|
||||
|
||||
response = authorized_client.get(path)
|
||||
assert response.status_code == 403
|
||||
|
||||
response = admin_client.get(path)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
m = models.Group(
|
||||
name="group-1",
|
||||
users=[models.GroupUser(id="123", name="itsa-me")],
|
||||
)
|
||||
await db.add(conn, m)
|
||||
|
||||
m_plain = {
|
||||
"users": m.users,
|
||||
"id": m.id,
|
||||
"name": m.name,
|
||||
}
|
||||
|
||||
response = admin_client.get(path)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [m_plain]
|
||||
|
|
|
|||
1332
unwind-ui/package-lock.json
generated
1332
unwind-ui/package-lock.json
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,13 +1,29 @@
|
|||
import { defineConfig } from "vite"
|
||||
import vue from "@vitejs/plugin-vue"
|
||||
|
||||
// Vite defaults.
|
||||
const vite_host = "localhost"
|
||||
const vite_port = 3000
|
||||
|
||||
const base = process.env.BASE_URL || "/"
|
||||
const proxied_api_url = `http://${vite_host}:${vite_port}/api/`
|
||||
const real_api_url = `http://${process.env.API_HOST}:${process.env.API_PORT}/api/`
|
||||
|
||||
// https://vitejs.dev/config/
|
||||
export default defineConfig({
|
||||
base: process.env.BASE_URL || "/",
|
||||
base,
|
||||
define: {
|
||||
"process.env.API_URL": JSON.stringify(
|
||||
process.env.API_URL || "http://localhost:8000/api/",
|
||||
),
|
||||
"process.env.API_URL": JSON.stringify(process.env.API_URL || proxied_api_url),
|
||||
},
|
||||
server: {
|
||||
host: vite_host,
|
||||
port: vite_port,
|
||||
proxy: {
|
||||
[`${base}api`]: {
|
||||
target: real_api_url,
|
||||
prependPath: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
plugins: [vue()],
|
||||
})
|
||||
|
|
|
|||
|
|
@ -2,20 +2,20 @@ import os
|
|||
import tomllib
|
||||
from pathlib import Path
|
||||
|
||||
datadir = Path(os.getenv("UNWIND_DATA") or "./data")
|
||||
cachedir = (
|
||||
Path(cachedir)
|
||||
if (cachedir := os.getenv("UNWIND_CACHEDIR", datadir / ".cache"))
|
||||
else None
|
||||
datadir: Path = Path(os.getenv("UNWIND_DATA") or "./data")
|
||||
cachedir: Path = Path(p) if (p := os.getenv("UNWIND_CACHEDIR")) else datadir / ".cache"
|
||||
debug: bool = os.getenv("DEBUG") == "1"
|
||||
loglevel: str = os.getenv("UNWIND_LOGLEVEL") or ("DEBUG" if debug else "INFO")
|
||||
storage_path: Path = (
|
||||
Path(p) if (p := os.getenv("UNWIND_STORAGE")) else datadir / "db.sqlite"
|
||||
)
|
||||
config_path: Path = (
|
||||
Path(p) if (p := os.getenv("UNWIND_CONFIG")) else datadir / "config.toml"
|
||||
)
|
||||
debug = os.getenv("DEBUG") == "1"
|
||||
loglevel = os.getenv("UNWIND_LOGLEVEL") or ("DEBUG" if debug else "INFO")
|
||||
storage_path = os.getenv("UNWIND_STORAGE", datadir / "db.sqlite")
|
||||
config_path = os.getenv("UNWIND_CONFIG", datadir / "config.toml")
|
||||
|
||||
with open(config_path, "rb") as fd:
|
||||
_config = tomllib.load(fd)
|
||||
|
||||
api_base = _config["api"].get("base", "/api/")
|
||||
api_cors = _config["api"].get("cors", "*")
|
||||
api_credentials = _config["api"].get("credentials", {})
|
||||
api_base: str = _config["api"].get("base", "/api/")
|
||||
api_cors: str = _config["api"].get("cors", "*")
|
||||
api_credentials: dict[str, str] = _config["api"].get("credentials", {})
|
||||
|
|
|
|||
664
unwind/db.py
664
unwind/db.py
|
|
@ -1,24 +1,27 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Literal, Type, TypeVar
|
||||
from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type, TypeVar
|
||||
|
||||
import sqlalchemy
|
||||
from databases import Database
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
|
||||
|
||||
from . import config
|
||||
from .models import (
|
||||
Model,
|
||||
Movie,
|
||||
Progress,
|
||||
Rating,
|
||||
User,
|
||||
asplain,
|
||||
fields,
|
||||
db_patches,
|
||||
fromplain,
|
||||
metadata,
|
||||
movies,
|
||||
optional_fields,
|
||||
progress,
|
||||
ratings,
|
||||
utcnow,
|
||||
)
|
||||
from .types import ULID
|
||||
|
|
@ -26,7 +29,9 @@ from .types import ULID
|
|||
log = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
|
||||
_shared_connection: Database | None = None
|
||||
_engine: AsyncEngine | None = None
|
||||
|
||||
type Connection = AsyncConnection
|
||||
|
||||
|
||||
async def open_connection_pool() -> None:
|
||||
|
|
@ -34,10 +39,13 @@ async def open_connection_pool() -> None:
|
|||
|
||||
This function needs to be called before any access to the database can happen.
|
||||
"""
|
||||
db = shared_connection()
|
||||
await db.connect()
|
||||
async with transaction() as conn:
|
||||
await conn.execute(sa.text("PRAGMA journal_mode=WAL"))
|
||||
|
||||
await apply_db_patches(db)
|
||||
await conn.run_sync(metadata.create_all, tables=[db_patches])
|
||||
|
||||
async with new_connection() as conn:
|
||||
await apply_db_patches(conn)
|
||||
|
||||
|
||||
async def close_connection_pool() -> None:
|
||||
|
|
@ -46,48 +54,33 @@ async def close_connection_pool() -> None:
|
|||
This function should be called before the app shuts down to ensure all data
|
||||
has been flushed to the database.
|
||||
"""
|
||||
db = shared_connection()
|
||||
engine = _shared_engine()
|
||||
|
||||
async with engine.begin() as conn:
|
||||
# Run automatic ANALYZE prior to closing the db,
|
||||
# see https://sqlite.com/lang_analyze.html.
|
||||
await db.execute("PRAGMA analysis_limit=400")
|
||||
await db.execute("PRAGMA optimize")
|
||||
await conn.execute(sa.text("PRAGMA analysis_limit=400"))
|
||||
await conn.execute(sa.text("PRAGMA optimize"))
|
||||
|
||||
await db.disconnect()
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def _create_patch_db(db):
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS db_patches (
|
||||
id INTEGER PRIMARY KEY,
|
||||
current TEXT
|
||||
)
|
||||
"""
|
||||
await db.execute(query)
|
||||
|
||||
|
||||
async def current_patch_level(db) -> str:
|
||||
await _create_patch_db(db)
|
||||
|
||||
query = "SELECT current FROM db_patches"
|
||||
current = await db.fetch_val(query)
|
||||
async def current_patch_level(conn: Connection, /) -> str:
|
||||
query = sa.select(db_patches.c.current)
|
||||
current = await conn.scalar(query)
|
||||
return current or ""
|
||||
|
||||
|
||||
async def set_current_patch_level(db, current: str):
|
||||
await _create_patch_db(db)
|
||||
|
||||
query = """
|
||||
INSERT INTO db_patches VALUES (1, :current)
|
||||
ON CONFLICT DO UPDATE SET current=excluded.current
|
||||
"""
|
||||
await db.execute(query, values={"current": current})
|
||||
async def set_current_patch_level(conn: Connection, /, current: str) -> None:
|
||||
stmt = insert(db_patches).values(id=1, current=current)
|
||||
stmt = stmt.on_conflict_do_update(set_={"current": stmt.excluded.current})
|
||||
await conn.execute(stmt)
|
||||
|
||||
|
||||
db_patches_dir = Path(__file__).parent / "sql"
|
||||
|
||||
|
||||
async def apply_db_patches(db: Database):
|
||||
async def apply_db_patches(conn: Connection, /) -> None:
|
||||
"""Apply all remaining patches to the database.
|
||||
|
||||
Beware that patches will be applied in lexicographical order,
|
||||
|
|
@ -99,7 +92,7 @@ async def apply_db_patches(db: Database):
|
|||
using two consecutive semi-colons (;).
|
||||
Failing to do so will result in an error.
|
||||
"""
|
||||
applied_lvl = await current_patch_level(db)
|
||||
applied_lvl = await current_patch_level(conn)
|
||||
|
||||
did_patch = False
|
||||
|
||||
|
|
@ -118,29 +111,52 @@ async def apply_db_patches(db: Database):
|
|||
)
|
||||
raise RuntimeError("No statement found.")
|
||||
|
||||
async with db.transaction():
|
||||
async with transacted(conn):
|
||||
for query in queries:
|
||||
await db.execute(query)
|
||||
await conn.execute(sa.text(query))
|
||||
|
||||
await set_current_patch_level(db, patch_lvl)
|
||||
await set_current_patch_level(conn, patch_lvl)
|
||||
|
||||
did_patch = True
|
||||
|
||||
if did_patch:
|
||||
await db.execute("vacuum")
|
||||
await _vacuum(conn)
|
||||
|
||||
|
||||
async def get_import_progress() -> Progress | None:
|
||||
async def _vacuum(conn: Connection, /) -> None:
|
||||
"""Vacuum the database.
|
||||
|
||||
This function cannot be run on a connection with an open transaction.
|
||||
"""
|
||||
# With SQLAlchemy's "autobegin" behavior we need to switch the connection
|
||||
# to "autocommit" first to keep it from automatically starting a transaction,
|
||||
# as VACUUM cannot be run inside a transaction for most databases.
|
||||
await conn.commit()
|
||||
isolation_level = await conn.get_isolation_level()
|
||||
log.debug("Previous isolation_level: %a", isolation_level)
|
||||
await conn.execution_options(isolation_level="AUTOCOMMIT")
|
||||
try:
|
||||
await conn.execute(sa.text("vacuum"))
|
||||
await conn.commit()
|
||||
finally:
|
||||
await conn.execution_options(isolation_level=isolation_level)
|
||||
|
||||
|
||||
async def get_import_progress(conn: Connection, /) -> Progress | None:
|
||||
"""Return the latest import progress."""
|
||||
return await get(Progress, type="import-imdb-movies", order_by="started DESC")
|
||||
return await get(
|
||||
conn, Progress, type="import-imdb-movies", order_by=(progress.c.started, "desc")
|
||||
)
|
||||
|
||||
|
||||
async def stop_import_progress(*, error: BaseException | None = None):
|
||||
async def stop_import_progress(
|
||||
conn: Connection, /, *, error: BaseException | None = None
|
||||
) -> None:
|
||||
"""Stop the current import.
|
||||
|
||||
If an error is given, it will be logged to the progress state.
|
||||
"""
|
||||
current = await get_import_progress()
|
||||
current = await get_import_progress(conn)
|
||||
is_running = current and current.stopped is None
|
||||
|
||||
if not is_running:
|
||||
|
|
@ -151,17 +167,17 @@ async def stop_import_progress(*, error: BaseException | None = None):
|
|||
current.error = repr(error)
|
||||
current.stopped = utcnow().isoformat()
|
||||
|
||||
await update(current)
|
||||
await update(conn, current)
|
||||
|
||||
|
||||
async def set_import_progress(progress: float) -> Progress:
|
||||
async def set_import_progress(conn: Connection, /, progress: float) -> Progress:
|
||||
"""Set the current import progress percentage.
|
||||
|
||||
If no import is currently running, this will create a new one.
|
||||
"""
|
||||
progress = min(max(0.0, progress), 100.0) # clamp to 0 <= progress <= 100
|
||||
|
||||
current = await get_import_progress()
|
||||
current = await get_import_progress(conn)
|
||||
is_running = current and current.stopped is None
|
||||
|
||||
if not is_running:
|
||||
|
|
@ -171,163 +187,211 @@ async def set_import_progress(progress: float) -> Progress:
|
|||
current.percent = progress
|
||||
|
||||
if is_running:
|
||||
await update(current)
|
||||
await update(conn, current)
|
||||
else:
|
||||
await add(current)
|
||||
await add(conn, current)
|
||||
|
||||
return current
|
||||
|
||||
|
||||
_lock = threading.Lock()
|
||||
_prelock = threading.Lock()
|
||||
def _new_engine() -> AsyncEngine:
|
||||
uri = f"sqlite+aiosqlite:///{config.storage_path}"
|
||||
|
||||
return create_async_engine(
|
||||
uri,
|
||||
isolation_level="SERIALIZABLE",
|
||||
)
|
||||
|
||||
|
||||
def _shared_engine() -> AsyncEngine:
|
||||
global _engine
|
||||
|
||||
if _engine is None:
|
||||
_engine = _new_engine()
|
||||
|
||||
return _engine
|
||||
|
||||
|
||||
def _new_connection() -> Connection:
|
||||
return _shared_engine().connect()
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def single_threaded():
|
||||
"""Ensure the nested code is run only by a single thread at a time."""
|
||||
wait = 1e-5 # XXX not sure if there's a better magic value here
|
||||
async def transaction(
|
||||
*, force_rollback: bool = False
|
||||
) -> AsyncGenerator[Connection, None]:
|
||||
async with new_connection() as conn:
|
||||
yield conn
|
||||
|
||||
# The pre-lock (a lock for the lock) allows for multiple threads to hand of
|
||||
# the main lock.
|
||||
# With only a single lock the contending thread will spend most of its time
|
||||
# in the asyncio.sleep and the reigning thread will have time to finish
|
||||
# whatever it's doing and simply acquire the lock again before the other
|
||||
# thread has had a change to try.
|
||||
# By having another lock (and the same sleep time!) the contending thread
|
||||
# will always have a chance to acquire the main lock.
|
||||
while not _prelock.acquire(blocking=False):
|
||||
await asyncio.sleep(wait)
|
||||
if not force_rollback:
|
||||
await conn.commit()
|
||||
|
||||
try:
|
||||
while not _lock.acquire(blocking=False):
|
||||
await asyncio.sleep(wait)
|
||||
finally:
|
||||
_prelock.release()
|
||||
|
||||
# The _test_connection allows pinning a connection that will be shared across the app.
|
||||
# This can (and should only) be used when running tests, NOT IN PRODUCTION!
|
||||
_test_connection: Connection | None = None
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def new_connection() -> AsyncGenerator[Connection, None]:
|
||||
"""Return a new connection.
|
||||
|
||||
Any changes will be rolled back, unless `.commit()` is called on the
|
||||
connection.
|
||||
|
||||
If you want to commit changes, consider using `transaction()` instead.
|
||||
"""
|
||||
conn = _test_connection or _new_connection()
|
||||
|
||||
# Support reusing the same connection for _test_connection.
|
||||
is_started = conn.sync_connection is not None
|
||||
if is_started:
|
||||
yield conn
|
||||
return
|
||||
|
||||
async with conn:
|
||||
yield conn
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def transacted(
|
||||
conn: Connection, /, *, force_rollback: bool = False
|
||||
) -> AsyncGenerator[None, None]:
|
||||
transaction = contextlib.nullcontext() if conn.in_transaction() else conn.begin()
|
||||
|
||||
async with transaction:
|
||||
try:
|
||||
yield
|
||||
|
||||
finally:
|
||||
_lock.release()
|
||||
if force_rollback:
|
||||
await conn.rollback()
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def locked_connection():
|
||||
async with single_threaded():
|
||||
yield shared_connection()
|
||||
|
||||
|
||||
def shared_connection() -> Database:
|
||||
global _shared_connection
|
||||
|
||||
if _shared_connection is None:
|
||||
uri = f"sqlite:///{config.storage_path}"
|
||||
_shared_connection = Database(uri)
|
||||
|
||||
return _shared_connection
|
||||
|
||||
|
||||
async def add(item):
|
||||
async def add(conn: Connection, /, item: Model) -> None:
|
||||
# Support late initializing - used for optimization.
|
||||
if getattr(item, "_is_lazy", False):
|
||||
item._lazy_init()
|
||||
assert hasattr(item, "_lazy_init")
|
||||
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
|
||||
|
||||
table: sa.Table = item.__table__
|
||||
values = asplain(item, serialize=True)
|
||||
keys = ", ".join(f"{k}" for k in values)
|
||||
placeholders = ", ".join(f":{k}" for k in values)
|
||||
query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})"
|
||||
async with locked_connection() as conn:
|
||||
await conn.execute(query=query, values=values)
|
||||
stmt = table.insert().values(values)
|
||||
await conn.execute(stmt)
|
||||
|
||||
|
||||
ModelType = TypeVar("ModelType")
|
||||
async def fetch_all(
|
||||
conn: Connection, /, query: sa.Executable, values: "dict | None" = None
|
||||
) -> Sequence[sa.Row]:
|
||||
result = await conn.execute(query, values)
|
||||
return result.all()
|
||||
|
||||
|
||||
async def fetch_one(
|
||||
conn: Connection, /, query: sa.Executable, values: "dict | None" = None
|
||||
) -> sa.Row | None:
|
||||
result = await conn.execute(query, values)
|
||||
return result.first()
|
||||
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Model)
|
||||
|
||||
|
||||
async def get(
|
||||
model: Type[ModelType], *, order_by: str | None = None, **kwds
|
||||
conn: Connection,
|
||||
/,
|
||||
model: Type[ModelType],
|
||||
*,
|
||||
order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None,
|
||||
**field_values,
|
||||
) -> ModelType | None:
|
||||
"""Load a model instance from the database.
|
||||
|
||||
Passing `kwds` allows to filter the instance to load. You have to encode the
|
||||
Passing `field_values` allows to filter the item to load. You have to encode the
|
||||
values as the appropriate data type for the database prior to passing them
|
||||
to this function.
|
||||
"""
|
||||
values = {k: v for k, v in kwds.items() if v is not None}
|
||||
if not values:
|
||||
if not field_values:
|
||||
return
|
||||
|
||||
fields_ = ", ".join(f.name for f in fields(model))
|
||||
cond = " AND ".join(f"{k}=:{k}" for k in values)
|
||||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
||||
table: sa.Table = model.__table__
|
||||
query = sa.select(model).where(
|
||||
*(table.c[k] == v for k, v in field_values.items() if v is not None)
|
||||
)
|
||||
if order_by:
|
||||
query += f" ORDER BY {order_by}"
|
||||
async with locked_connection() as conn:
|
||||
row = await conn.fetch_one(query=query, values=values)
|
||||
order_col, order_dir = order_by
|
||||
query = query.order_by(
|
||||
order_col.asc() if order_dir == "asc" else order_col.desc()
|
||||
)
|
||||
row = await fetch_one(conn, query)
|
||||
return fromplain(model, row._mapping, serialized=True) if row else None
|
||||
|
||||
|
||||
async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
||||
keys = {
|
||||
k: [f"{k}_{i}" for i, _ in enumerate(vs, start=1)] for k, vs in kwds.items()
|
||||
}
|
||||
async def get_many(
|
||||
conn: Connection, /, model: Type[ModelType], **field_sets: set | list
|
||||
) -> Iterable[ModelType]:
|
||||
"""Return the items with any values matching all given field sets.
|
||||
|
||||
if not keys:
|
||||
This is similar to `get_all`, but instead of a scalar value a list of values
|
||||
must be given. If any of the given values is set for that field on an item,
|
||||
the item is considered a match.
|
||||
If no field values are given, no items will be returned.
|
||||
"""
|
||||
if not field_sets:
|
||||
return []
|
||||
|
||||
values = {n: v for k, vs in kwds.items() for n, v in zip(keys[k], vs)}
|
||||
table: sa.Table = model.__table__
|
||||
query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items()))
|
||||
rows = await fetch_all(conn, query)
|
||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||
|
||||
fields_ = ", ".join(f.name for f in fields(model))
|
||||
cond = " AND ".join(
|
||||
f"{k} IN ({','.join(':'+n for n in ns)})" for k, ns in keys.items()
|
||||
|
||||
async def get_all(
|
||||
conn: Connection, /, model: Type[ModelType], **field_values
|
||||
) -> Iterable[ModelType]:
|
||||
"""Filter all items by comparing all given field values.
|
||||
|
||||
If no filters are given, all items will be returned.
|
||||
"""
|
||||
table: sa.Table = model.__table__
|
||||
query = sa.select(model).where(
|
||||
*(table.c[k] == v for k, v in field_values.items() if v is not None)
|
||||
)
|
||||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
||||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(query=query, values=values)
|
||||
rows = await fetch_all(conn, query)
|
||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||
|
||||
|
||||
async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
||||
values = {k: v for k, v in kwds.items() if v is not None}
|
||||
|
||||
fields_ = ", ".join(f.name for f in fields(model))
|
||||
cond = " AND ".join(f"{k}=:{k}" for k in values) or "1=1"
|
||||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
||||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(query=query, values=values)
|
||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||
|
||||
|
||||
async def update(item):
|
||||
async def update(conn: Connection, /, item: Model) -> None:
|
||||
# Support late initializing - used for optimization.
|
||||
if getattr(item, "_is_lazy", False):
|
||||
item._lazy_init()
|
||||
assert hasattr(item, "_lazy_init")
|
||||
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
|
||||
|
||||
table: sa.Table = item.__table__
|
||||
values = asplain(item, serialize=True)
|
||||
keys = ", ".join(f"{k}=:{k}" for k in values if k != "id")
|
||||
query = f"UPDATE {item._table} SET {keys} WHERE id=:id"
|
||||
async with locked_connection() as conn:
|
||||
await conn.execute(query=query, values=values)
|
||||
stmt = table.update().where(table.c.id == values["id"]).values(values)
|
||||
await conn.execute(stmt)
|
||||
|
||||
|
||||
async def remove(item):
|
||||
async def remove(conn: Connection, /, item: Model) -> None:
|
||||
table: sa.Table = item.__table__
|
||||
values = asplain(item, filter_fields={"id"}, serialize=True)
|
||||
query = f"DELETE FROM {item._table} WHERE id=:id"
|
||||
async with locked_connection() as conn:
|
||||
await conn.execute(query=query, values=values)
|
||||
stmt = table.delete().where(table.c.id == values["id"])
|
||||
await conn.execute(stmt)
|
||||
|
||||
|
||||
async def add_or_update_user(user: User):
|
||||
db_user = await get(User, imdb_id=user.imdb_id)
|
||||
async def add_or_update_user(conn: Connection, /, user: User) -> None:
|
||||
db_user = await get(conn, User, imdb_id=user.imdb_id)
|
||||
if not db_user:
|
||||
await add(user)
|
||||
await add(conn, user)
|
||||
else:
|
||||
user.id = db_user.id
|
||||
|
||||
if user != db_user:
|
||||
await update(user)
|
||||
await update(conn, user)
|
||||
|
||||
|
||||
async def add_or_update_many_movies(movies: list[Movie]):
|
||||
async def add_or_update_many_movies(conn: Connection, /, movies: list[Movie]) -> None:
|
||||
"""Add or update Movies in the database.
|
||||
|
||||
This is an optimized version of `add_or_update_movie` for the purpose
|
||||
|
|
@ -336,12 +400,13 @@ async def add_or_update_many_movies(movies: list[Movie]):
|
|||
# for movie in movies:
|
||||
# await add_or_update_movie(movie)
|
||||
db_movies = {
|
||||
m.imdb_id: m for m in await get_many(Movie, imdb_id=[m.imdb_id for m in movies])
|
||||
m.imdb_id: m
|
||||
for m in await get_many(conn, Movie, imdb_id=[m.imdb_id for m in movies])
|
||||
}
|
||||
for movie in movies:
|
||||
# XXX optimize bulk add & update as well
|
||||
if movie.imdb_id not in db_movies:
|
||||
await add(movie)
|
||||
await add(conn, movie)
|
||||
else:
|
||||
db_movie = db_movies[movie.imdb_id]
|
||||
movie.id = db_movie.id
|
||||
|
|
@ -354,10 +419,10 @@ async def add_or_update_many_movies(movies: list[Movie]):
|
|||
if movie.updated <= db_movie.updated:
|
||||
return
|
||||
|
||||
await update(movie)
|
||||
await update(conn, movie)
|
||||
|
||||
|
||||
async def add_or_update_movie(movie: Movie):
|
||||
async def add_or_update_movie(conn: Connection, /, movie: Movie) -> None:
|
||||
"""Add or update a Movie in the database.
|
||||
|
||||
This is an upsert operation, but it will also update the Movie you pass
|
||||
|
|
@ -365,9 +430,9 @@ async def add_or_update_movie(movie: Movie):
|
|||
set all optional values on your Movie that might be unset but exist in the
|
||||
database. It's a bidirectional sync.
|
||||
"""
|
||||
db_movie = await get(Movie, imdb_id=movie.imdb_id)
|
||||
db_movie = await get(conn, Movie, imdb_id=movie.imdb_id)
|
||||
if not db_movie:
|
||||
await add(movie)
|
||||
await add(conn, movie)
|
||||
else:
|
||||
movie.id = db_movie.id
|
||||
|
||||
|
|
@ -379,33 +444,35 @@ async def add_or_update_movie(movie: Movie):
|
|||
if movie.updated <= db_movie.updated:
|
||||
return
|
||||
|
||||
await update(movie)
|
||||
await update(conn, movie)
|
||||
|
||||
|
||||
async def add_or_update_rating(rating: Rating) -> bool:
|
||||
async def add_or_update_rating(conn: Connection, /, rating: Rating) -> bool:
|
||||
db_rating = await get(
|
||||
Rating, movie_id=str(rating.movie_id), user_id=str(rating.user_id)
|
||||
conn, Rating, movie_id=str(rating.movie_id), user_id=str(rating.user_id)
|
||||
)
|
||||
|
||||
if not db_rating:
|
||||
await add(rating)
|
||||
await add(conn, rating)
|
||||
return True
|
||||
|
||||
else:
|
||||
rating.id = db_rating.id
|
||||
|
||||
if rating != db_rating:
|
||||
await update(rating)
|
||||
await update(conn, rating)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def sql_escape(s: str, char="#"):
|
||||
def sql_escape(s: str, char: str = "#") -> str:
|
||||
return s.replace(char, 2 * char).replace("%", f"{char}%").replace("_", f"{char}_")
|
||||
|
||||
|
||||
async def find_ratings(
|
||||
conn: Connection,
|
||||
/,
|
||||
*,
|
||||
title: str | None = None,
|
||||
media_type: str | None = None,
|
||||
|
|
@ -415,163 +482,129 @@ async def find_ratings(
|
|||
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
|
||||
limit_rows: int = 10,
|
||||
user_ids: Iterable[str] = [],
|
||||
):
|
||||
values: dict[str, int | str] = {
|
||||
"limit_rows": limit_rows,
|
||||
}
|
||||
|
||||
) -> Iterable[dict[str, Any]]:
|
||||
conditions = []
|
||||
|
||||
if title:
|
||||
values["escape"] = "#"
|
||||
escaped_title = sql_escape(title, char=values["escape"])
|
||||
values["pattern"] = (
|
||||
escape_char = "#"
|
||||
escaped_title = sql_escape(title, char=escape_char)
|
||||
pattern = (
|
||||
"_".join(escaped_title.split())
|
||||
if exact
|
||||
else "%" + "%".join(escaped_title.split()) + "%"
|
||||
)
|
||||
conditions.append(
|
||||
f"""
|
||||
(
|
||||
{Movie._table}.title LIKE :pattern ESCAPE :escape
|
||||
OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape
|
||||
sa.or_(
|
||||
movies.c.title.like(pattern, escape=escape_char),
|
||||
movies.c.original_title.like(pattern, escape=escape_char),
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
if yearcomp:
|
||||
op, year = yearcomp
|
||||
assert op in "<=>"
|
||||
values["year"] = year
|
||||
conditions.append(f"{Movie._table}.release_year{op}:year")
|
||||
match yearcomp:
|
||||
case ("<", year):
|
||||
conditions.append(movies.c.release_year < year)
|
||||
case ("=", year):
|
||||
conditions.append(movies.c.release_year == year)
|
||||
case (">", year):
|
||||
conditions.append(movies.c.release_year > year)
|
||||
|
||||
if media_type:
|
||||
values["media_type"] = media_type
|
||||
conditions.append(f"{Movie._table}.media_type=:media_type")
|
||||
if media_type is not None:
|
||||
conditions.append(movies.c.media_type == media_type)
|
||||
|
||||
if ignore_tv_episodes:
|
||||
conditions.append(f"{Movie._table}.media_type!='TV Episode'")
|
||||
conditions.append(movies.c.media_type != "TV Episode")
|
||||
|
||||
user_condition = "1=1"
|
||||
user_condition = []
|
||||
if user_ids:
|
||||
uvs = {f"user_id_{i}": v for i, v in enumerate(user_ids, start=1)}
|
||||
values.update(uvs)
|
||||
user_condition = f"{Rating._table}.user_id IN ({','.join(':'+n for n in uvs)})"
|
||||
user_condition.append(ratings.c.user_id.in_(user_ids))
|
||||
|
||||
query = f"""
|
||||
SELECT DISTINCT {Rating._table}.movie_id
|
||||
FROM {Rating._table}
|
||||
LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id
|
||||
WHERE {user_condition}{(' AND ' + ' AND '.join(conditions)) if conditions else ''}
|
||||
ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC, {Movie._table}.imdb_score DESC
|
||||
LIMIT :limit_rows
|
||||
"""
|
||||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(bindparams(query, values))
|
||||
movie_ids = tuple(r._mapping["movie_id"] for r in rows)
|
||||
query = (
|
||||
sa.select(ratings.c.movie_id)
|
||||
.distinct()
|
||||
.outerjoin_from(ratings, movies, movies.c.id == ratings.c.movie_id)
|
||||
.where(*conditions, *user_condition)
|
||||
.order_by(
|
||||
sa.func.length(movies.c.title).asc(),
|
||||
ratings.c.rating_date.desc(),
|
||||
movies.c.imdb_score.desc(),
|
||||
)
|
||||
.limit(limit_rows)
|
||||
)
|
||||
rating_rows: sa.CursorResult[Rating] = await conn.execute(query)
|
||||
movie_ids = [r.movie_id for r in rating_rows]
|
||||
|
||||
if include_unrated and len(movie_ids) < limit_rows:
|
||||
sqlin, sqlin_vals = sql_in("id", movie_ids, not_=True)
|
||||
query = f"""
|
||||
SELECT DISTINCT id AS movie_id
|
||||
FROM {Movie._table}
|
||||
WHERE {sqlin}
|
||||
{('AND ' + ' AND '.join(conditions)) if conditions else ''}
|
||||
ORDER BY length(title) ASC, imdb_score DESC, release_year DESC
|
||||
LIMIT :limit_rows
|
||||
"""
|
||||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(
|
||||
bindparams(
|
||||
query,
|
||||
{**values, **sqlin_vals, "limit_rows": limit_rows - len(movie_ids)},
|
||||
query = (
|
||||
sa.select(movies.c.id)
|
||||
.distinct()
|
||||
.where(movies.c.id.not_in(movie_ids), *conditions)
|
||||
.order_by(
|
||||
sa.func.length(movies.c.title).asc(),
|
||||
movies.c.imdb_score.desc(),
|
||||
movies.c.release_year.desc(),
|
||||
)
|
||||
.limit(limit_rows - len(movie_ids))
|
||||
)
|
||||
movie_ids += tuple(r._mapping["movie_id"] for r in rows)
|
||||
movie_rows: sa.CursorResult[Movie] = await conn.execute(query)
|
||||
movie_ids += [r.id for r in movie_rows]
|
||||
|
||||
return await ratings_for_movie_ids(ids=movie_ids)
|
||||
return await ratings_for_movie_ids(conn, ids=movie_ids)
|
||||
|
||||
|
||||
async def ratings_for_movie_ids(
|
||||
ids: Iterable[ULID | str] = [], imdb_ids: Iterable[str] = []
|
||||
conn: Connection,
|
||||
/,
|
||||
ids: Iterable[ULID | str] = [],
|
||||
imdb_ids: Iterable[str] = [],
|
||||
) -> Iterable[dict[str, Any]]:
|
||||
conds: list[str] = []
|
||||
vals: dict[str, str] = {}
|
||||
conds = []
|
||||
|
||||
if ids:
|
||||
sqlin, sqlin_vals = sql_in(f"{Movie._table}.id", (str(x) for x in ids))
|
||||
conds.append(sqlin)
|
||||
vals.update(sqlin_vals)
|
||||
conds.append(movies.c.id.in_([str(x) for x in ids]))
|
||||
|
||||
if imdb_ids:
|
||||
sqlin, sqlin_vals = sql_in(f"{Movie._table}.imdb_id", imdb_ids)
|
||||
conds.append(sqlin)
|
||||
vals.update(sqlin_vals)
|
||||
conds.append(movies.c.imdb_id.in_(imdb_ids))
|
||||
|
||||
if not conds:
|
||||
return []
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
{Rating._table}.score AS user_score,
|
||||
{Rating._table}.user_id AS user_id,
|
||||
{Movie._table}.imdb_score,
|
||||
{Movie._table}.imdb_votes,
|
||||
{Movie._table}.imdb_id AS movie_imdb_id,
|
||||
{Movie._table}.media_type AS media_type,
|
||||
{Movie._table}.title AS canonical_title,
|
||||
{Movie._table}.original_title AS original_title,
|
||||
{Movie._table}.release_year AS release_year
|
||||
FROM {Movie._table}
|
||||
LEFT JOIN {Rating._table} ON {Movie._table}.id={Rating._table}.movie_id
|
||||
WHERE {(' OR '.join(conds))}
|
||||
"""
|
||||
|
||||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(bindparams(query, vals))
|
||||
query = (
|
||||
sa.select(
|
||||
ratings.c.score.label("user_score"),
|
||||
ratings.c.user_id.label("user_id"),
|
||||
movies.c.imdb_score,
|
||||
movies.c.imdb_votes,
|
||||
movies.c.imdb_id.label("movie_imdb_id"),
|
||||
movies.c.media_type.label("media_type"),
|
||||
movies.c.title.label("canonical_title"),
|
||||
movies.c.original_title.label("original_title"),
|
||||
movies.c.release_year.label("release_year"),
|
||||
)
|
||||
.outerjoin_from(movies, ratings, movies.c.id == ratings.c.movie_id)
|
||||
.where(sa.or_(*conds))
|
||||
)
|
||||
rows = await fetch_all(conn, query)
|
||||
return tuple(dict(r._mapping) for r in rows)
|
||||
|
||||
|
||||
def sql_fields(tp: Type):
|
||||
return (f"{tp._table}.{f.name}" for f in fields(tp))
|
||||
|
||||
|
||||
def sql_in(column: str, values: Iterable[T], not_=False) -> tuple[str, dict[str, T]]:
|
||||
c = column.replace(".", "___")
|
||||
value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)}
|
||||
placeholders = ",".join(":" + k for k in value_map)
|
||||
if not_:
|
||||
return f"{column} NOT IN ({placeholders})", value_map
|
||||
return f"{column} IN ({placeholders})", value_map
|
||||
|
||||
|
||||
async def ratings_for_movies(
|
||||
movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = []
|
||||
conn: Connection, /, movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = []
|
||||
) -> Iterable[Rating]:
|
||||
values: dict[str, str] = {}
|
||||
conditions: list[str] = []
|
||||
|
||||
q, vm = sql_in("movie_id", [str(m) for m in movie_ids])
|
||||
conditions.append(q)
|
||||
values.update(vm)
|
||||
conditions = [ratings.c.movie_id.in_(str(x) for x in movie_ids)]
|
||||
|
||||
if user_ids:
|
||||
q, vm = sql_in("user_id", [str(m) for m in user_ids])
|
||||
conditions.append(q)
|
||||
values.update(vm)
|
||||
conditions.append(ratings.c.user_id.in_(str(x) for x in user_ids))
|
||||
|
||||
query = f"""
|
||||
SELECT {','.join(sql_fields(Rating))}
|
||||
FROM {Rating._table}
|
||||
WHERE {' AND '.join(f'({c})' for c in conditions) if conditions else '1=1'}
|
||||
"""
|
||||
query = sa.select(ratings).where(*conditions)
|
||||
|
||||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(query, values)
|
||||
rows = await fetch_all(conn, query)
|
||||
|
||||
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
|
||||
|
||||
|
||||
async def find_movies(
|
||||
conn: Connection,
|
||||
/,
|
||||
*,
|
||||
title: str | None = None,
|
||||
media_type: str | None = None,
|
||||
|
|
@ -583,88 +616,63 @@ async def find_movies(
|
|||
include_unrated: bool = False,
|
||||
user_ids: list[ULID] = [],
|
||||
) -> Iterable[tuple[Movie, list[Rating]]]:
|
||||
values: dict[str, int | str] = {
|
||||
"limit_rows": limit_rows,
|
||||
"skip_rows": skip_rows,
|
||||
}
|
||||
|
||||
conditions = []
|
||||
|
||||
if title:
|
||||
values["escape"] = "#"
|
||||
escaped_title = sql_escape(title, char=values["escape"])
|
||||
values["pattern"] = (
|
||||
escape_char = "#"
|
||||
escaped_title = sql_escape(title, char=escape_char)
|
||||
pattern = (
|
||||
"_".join(escaped_title.split())
|
||||
if exact
|
||||
else "%" + "%".join(escaped_title.split()) + "%"
|
||||
)
|
||||
conditions.append(
|
||||
f"""
|
||||
(
|
||||
{Movie._table}.title LIKE :pattern ESCAPE :escape
|
||||
OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape
|
||||
sa.or_(
|
||||
movies.c.title.like(pattern, escape=escape_char),
|
||||
movies.c.original_title.like(pattern, escape=escape_char),
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
if yearcomp:
|
||||
op, year = yearcomp
|
||||
assert op in "<=>"
|
||||
values["year"] = year
|
||||
conditions.append(f"{Movie._table}.release_year{op}:year")
|
||||
match yearcomp:
|
||||
case ("<", year):
|
||||
conditions.append(movies.c.release_year < year)
|
||||
case ("=", year):
|
||||
conditions.append(movies.c.release_year == year)
|
||||
case (">", year):
|
||||
conditions.append(movies.c.release_year > year)
|
||||
|
||||
if media_type:
|
||||
values["media_type"] = media_type
|
||||
conditions.append(f"{Movie._table}.media_type=:media_type")
|
||||
if media_type is not None:
|
||||
conditions.append(movies.c.media_type == media_type)
|
||||
|
||||
if ignore_tv_episodes:
|
||||
conditions.append(f"{Movie._table}.media_type!='TV Episode'")
|
||||
conditions.append(movies.c.media_type != "TV Episode")
|
||||
|
||||
if not include_unrated:
|
||||
conditions.append(f"{Movie._table}.imdb_score NOTNULL")
|
||||
conditions.append(movies.c.imdb_score.is_not(None))
|
||||
|
||||
query = f"""
|
||||
SELECT {','.join(sql_fields(Movie))}
|
||||
FROM {Movie._table}
|
||||
WHERE {(' AND '.join(conditions)) if conditions else '1=1'}
|
||||
ORDER BY
|
||||
length({Movie._table}.title) ASC,
|
||||
{Movie._table}.imdb_score DESC,
|
||||
{Movie._table}.release_year DESC
|
||||
LIMIT :skip_rows, :limit_rows
|
||||
"""
|
||||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(bindparams(query, values))
|
||||
query = (
|
||||
sa.select(movies)
|
||||
.where(*conditions)
|
||||
.order_by(
|
||||
sa.func.length(movies.c.title).asc(),
|
||||
movies.c.imdb_score.desc(),
|
||||
movies.c.release_year.desc(),
|
||||
)
|
||||
.limit(limit_rows)
|
||||
.offset(skip_rows)
|
||||
)
|
||||
|
||||
movies = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
|
||||
rows = await fetch_all(conn, query)
|
||||
|
||||
movies_ = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
|
||||
|
||||
if not user_ids:
|
||||
return ((m, []) for m in movies)
|
||||
return ((m, []) for m in movies_)
|
||||
|
||||
ratings = await ratings_for_movies((m.id for m in movies), user_ids)
|
||||
ratings = await ratings_for_movies(conn, (m.id for m in movies_), user_ids)
|
||||
|
||||
aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies}
|
||||
aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies_}
|
||||
for rating in ratings:
|
||||
aggreg[rating.movie_id][1].append(rating)
|
||||
|
||||
return aggreg.values()
|
||||
|
||||
|
||||
def bindparams(query: str, values: dict):
|
||||
"""Bind values to a query.
|
||||
|
||||
This is similar to what SQLAlchemy and Databases do, but it allows to
|
||||
easily use the same placeholder in multiple places.
|
||||
"""
|
||||
pump_vals = {}
|
||||
pump_keys = {}
|
||||
|
||||
def pump(match):
|
||||
key = match[1]
|
||||
val = values[key]
|
||||
pump_keys[key] = 1 + pump_keys.setdefault(key, 0)
|
||||
pump_key = f"{key}_{pump_keys[key]}"
|
||||
pump_vals[pump_key] = val
|
||||
return f":{pump_key}"
|
||||
|
||||
pump_query = re.sub(r":(\w+)\b", pump, query)
|
||||
return sqlalchemy.text(pump_query).bindparams(**pump_vals)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ from collections import namedtuple
|
|||
from datetime import datetime
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import bs4
|
||||
|
||||
from . import db
|
||||
from .models import Movie, Rating, User
|
||||
from .request import asession, asoup_from_url, cache_path
|
||||
|
|
@ -38,12 +40,14 @@ async def refresh_user_ratings_from_imdb(stop_on_dupe: bool = True):
|
|||
async with asession() as s:
|
||||
s.headers["Accept-Language"] = "en-US, en;q=0.5"
|
||||
|
||||
for user in await db.get_all(User):
|
||||
async with db.new_connection() as conn:
|
||||
users = list(await db.get_all(conn, User))
|
||||
for user in users:
|
||||
log.info("⚡️ Loading data for %s ...", user.name)
|
||||
|
||||
try:
|
||||
async for rating, is_updated in load_ratings(user.imdb_id):
|
||||
assert rating.user.id == user.id
|
||||
assert rating.user is not None and rating.user.id == user.id
|
||||
|
||||
if stop_on_dupe and not is_updated:
|
||||
break
|
||||
|
|
@ -94,7 +98,7 @@ find_year = re.compile(
|
|||
find_movie_id = re.compile(r"/title/(?P<id>tt\d+)/").search
|
||||
|
||||
|
||||
def movie_and_rating_from_item(item) -> tuple[Movie, Rating]:
|
||||
def movie_and_rating_from_item(item: bs4.Tag) -> tuple[Movie, Rating]:
|
||||
genres = (genre := item.find("span", "genre")) and genre.string or ""
|
||||
movie = Movie(
|
||||
title=item.h3.a.string.strip(),
|
||||
|
|
@ -154,13 +158,19 @@ async def parse_page(url: str) -> tuple[list[Rating], str | None]:
|
|||
|
||||
soup = await asoup_from_url(url)
|
||||
|
||||
meta = soup.find("meta", property="pageId")
|
||||
headline = soup.h1
|
||||
assert meta is not None and headline is not None
|
||||
if (meta := soup.find("meta", property="pageId")) is None:
|
||||
raise RuntimeError("No pageId found.")
|
||||
assert isinstance(meta, bs4.Tag)
|
||||
imdb_id = meta["content"]
|
||||
user = await db.get(User, imdb_id=imdb_id) or User(
|
||||
assert isinstance(imdb_id, str)
|
||||
async with db.new_connection() as conn:
|
||||
user = await db.get(conn, User, imdb_id=imdb_id) or User(
|
||||
imdb_id=imdb_id, name="", secret=""
|
||||
)
|
||||
|
||||
if (headline := soup.h1) is None:
|
||||
raise RuntimeError("No headline found.")
|
||||
assert isinstance(headline.string, str)
|
||||
if match := find_name(headline.string):
|
||||
user.name = match["name"]
|
||||
|
||||
|
|
@ -184,9 +194,15 @@ async def parse_page(url: str) -> tuple[list[Rating], str | None]:
|
|||
|
||||
ratings.append(rating)
|
||||
|
||||
footer = soup.find("div", "footer")
|
||||
assert footer is not None
|
||||
next_url = urljoin(url, footer.find(string=re.compile(r"Next")).parent["href"])
|
||||
next_url = None
|
||||
if (footer := soup.find("div", "footer")) is None:
|
||||
raise RuntimeError("No footer found.")
|
||||
assert isinstance(footer, bs4.Tag)
|
||||
if (next_link := footer.find("a", string="Next")) is not None:
|
||||
assert isinstance(next_link, bs4.Tag)
|
||||
next_href = next_link["href"]
|
||||
assert isinstance(next_href, str)
|
||||
next_url = urljoin(url, next_href)
|
||||
|
||||
return (ratings, next_url if url != next_url else None)
|
||||
|
||||
|
|
@ -200,14 +216,15 @@ async def load_ratings(user_id: str):
|
|||
for i, rating in enumerate(ratings):
|
||||
assert rating.user and rating.movie
|
||||
|
||||
async with db.transaction() as conn:
|
||||
if i == 0:
|
||||
# All rating objects share the same user.
|
||||
await db.add_or_update_user(rating.user)
|
||||
await db.add_or_update_user(conn, rating.user)
|
||||
rating.user_id = rating.user.id
|
||||
|
||||
await db.add_or_update_movie(rating.movie)
|
||||
await db.add_or_update_movie(conn, rating.movie)
|
||||
rating.movie_id = rating.movie.id
|
||||
|
||||
is_updated = await db.add_or_update_rating(rating)
|
||||
is_updated = await db.add_or_update_rating(conn, rating)
|
||||
|
||||
yield rating, is_updated
|
||||
|
|
|
|||
|
|
@ -209,7 +209,8 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path):
|
|||
for i, m in enumerate(read_basics(basics_path)):
|
||||
perc = 100 * i / total
|
||||
if perc >= perc_next_report:
|
||||
await db.set_import_progress(perc)
|
||||
async with db.transaction() as conn:
|
||||
await db.set_import_progress(conn, perc)
|
||||
log.info("⏳ Imported %s%%", round(perc, 1))
|
||||
perc_next_report += perc_step
|
||||
|
||||
|
|
@ -233,15 +234,18 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path):
|
|||
chunk.append(m)
|
||||
|
||||
if len(chunk) > 1000:
|
||||
await add_or_update_many_movies(chunk)
|
||||
async with db.transaction() as conn:
|
||||
await add_or_update_many_movies(conn, chunk)
|
||||
chunk = []
|
||||
|
||||
if chunk:
|
||||
await add_or_update_many_movies(chunk)
|
||||
async with db.transaction() as conn:
|
||||
await add_or_update_many_movies(conn, chunk)
|
||||
chunk = []
|
||||
|
||||
log.info("👍 Imported 100%")
|
||||
await db.set_import_progress(100)
|
||||
async with db.transaction() as conn:
|
||||
await db.set_import_progress(conn, 100)
|
||||
|
||||
|
||||
async def download_datasets(*, basics_path: Path, ratings_path: Path) -> None:
|
||||
|
|
@ -270,7 +274,8 @@ async def load_from_web(*, force: bool = False) -> None:
|
|||
See https://www.imdb.com/interfaces/ and https://datasets.imdbws.com/ for
|
||||
more information on the IMDb database dumps.
|
||||
"""
|
||||
await db.set_import_progress(0)
|
||||
async with db.transaction() as conn:
|
||||
await db.set_import_progress(conn, 0)
|
||||
|
||||
try:
|
||||
ratings_file = config.datadir / "imdb/title.ratings.tsv.gz"
|
||||
|
|
@ -290,8 +295,10 @@ async def load_from_web(*, force: bool = False) -> None:
|
|||
await import_from_file(basics_path=basics_file, ratings_path=ratings_file)
|
||||
|
||||
except BaseException as err:
|
||||
await db.stop_import_progress(error=err)
|
||||
async with db.transaction() as conn:
|
||||
await db.stop_import_progress(conn, error=err)
|
||||
raise
|
||||
|
||||
else:
|
||||
await db.stop_import_progress()
|
||||
async with db.transaction() as conn:
|
||||
await db.stop_import_progress(conn)
|
||||
|
|
|
|||
174
unwind/models.py
174
unwind/models.py
|
|
@ -11,13 +11,18 @@ from typing import (
|
|||
Container,
|
||||
Literal,
|
||||
Mapping,
|
||||
Protocol,
|
||||
Type,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, Integer, String, Table
|
||||
from sqlalchemy.orm import registry
|
||||
|
||||
from .types import ULID
|
||||
|
||||
JSON = int | float | str | None | list["JSON"] | dict[str, "JSON"]
|
||||
|
|
@ -26,8 +31,16 @@ JSONObject = dict[str, JSON]
|
|||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Model(Protocol):
|
||||
__table__: ClassVar[Table]
|
||||
|
||||
|
||||
mapper_registry = registry()
|
||||
metadata = mapper_registry.metadata
|
||||
|
||||
|
||||
def annotations(tp: Type) -> tuple | None:
|
||||
return tp.__metadata__ if hasattr(tp, "__metadata__") else None
|
||||
return tp.__metadata__ if hasattr(tp, "__metadata__") else None # type: ignore
|
||||
|
||||
|
||||
def fields(class_or_instance):
|
||||
|
|
@ -112,7 +125,7 @@ def asplain(
|
|||
if filter_fields is not None and f.name not in filter_fields:
|
||||
continue
|
||||
|
||||
target = f.type
|
||||
target: Any = f.type
|
||||
# XXX this doesn't properly support any kind of nested types
|
||||
if (otype := optional_type(f.type)) is not None:
|
||||
target = otype
|
||||
|
|
@ -156,7 +169,7 @@ def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
|
|||
|
||||
dd: JSONObject = {}
|
||||
for f in fields(cls):
|
||||
target = f.type
|
||||
target: Any = f.type
|
||||
otype = optional_type(f.type)
|
||||
is_opt = otype is not None
|
||||
if is_opt:
|
||||
|
|
@ -194,12 +207,38 @@ def validate(o: object) -> None:
|
|||
|
||||
|
||||
def utcnow():
|
||||
return datetime.utcnow().replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
@mapper_registry.mapped
|
||||
@dataclass
|
||||
class DbPatch:
|
||||
__table__: ClassVar[Table] = Table(
|
||||
"db_patches",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("current", String),
|
||||
)
|
||||
|
||||
id: int
|
||||
current: str
|
||||
|
||||
|
||||
db_patches = DbPatch.__table__
|
||||
|
||||
|
||||
@mapper_registry.mapped
|
||||
@dataclass
|
||||
class Progress:
|
||||
_table: ClassVar[str] = "progress"
|
||||
__table__: ClassVar[Table] = Table(
|
||||
"progress",
|
||||
metadata,
|
||||
Column("id", String, primary_key=True), # ULID
|
||||
Column("type", String, nullable=False),
|
||||
Column("state", String, nullable=False), # JSON {"percent": ..., "error": ...}
|
||||
Column("started", String, nullable=False), # datetime
|
||||
Column("stopped", String),
|
||||
)
|
||||
|
||||
id: ULID = field(default_factory=ULID)
|
||||
type: str = None
|
||||
|
|
@ -236,9 +275,28 @@ class Progress:
|
|||
self._state = state
|
||||
|
||||
|
||||
progress = Progress.__table__
|
||||
|
||||
|
||||
@mapper_registry.mapped
|
||||
@dataclass
|
||||
class Movie:
|
||||
_table: ClassVar[str] = "movies"
|
||||
__table__: ClassVar[Table] = Table(
|
||||
"movies",
|
||||
metadata,
|
||||
Column("id", String, primary_key=True), # ULID
|
||||
Column("title", String, nullable=False),
|
||||
Column("original_title", String),
|
||||
Column("release_year", Integer, nullable=False),
|
||||
Column("media_type", String, nullable=False),
|
||||
Column("imdb_id", String, nullable=False, unique=True),
|
||||
Column("imdb_score", Integer),
|
||||
Column("imdb_votes", Integer),
|
||||
Column("runtime", Integer),
|
||||
Column("genres", String, nullable=False),
|
||||
Column("created", String, nullable=False), # datetime
|
||||
Column("updated", String, nullable=False), # datetime
|
||||
)
|
||||
|
||||
id: ULID = field(default_factory=ULID)
|
||||
title: str = None # canonical title (usually English)
|
||||
|
|
@ -283,6 +341,8 @@ class Movie:
|
|||
self._is_lazy = False
|
||||
|
||||
|
||||
movies = Movie.__table__
|
||||
|
||||
_RelationSentinel = object()
|
||||
"""Mark a model field as containing external data.
|
||||
|
||||
|
|
@ -294,9 +354,65 @@ The contents of the Relation are ignored or discarded when using
|
|||
Relation = Annotated[T | None, _RelationSentinel]
|
||||
|
||||
|
||||
Access = Literal[
|
||||
"r", # read
|
||||
"i", # index
|
||||
"w", # write
|
||||
]
|
||||
|
||||
|
||||
class UserGroup(TypedDict):
|
||||
id: str
|
||||
access: Access
|
||||
|
||||
|
||||
@mapper_registry.mapped
|
||||
@dataclass
|
||||
class User:
|
||||
__table__: ClassVar[Table] = Table(
|
||||
"users",
|
||||
metadata,
|
||||
Column("id", String, primary_key=True), # ULID
|
||||
Column("imdb_id", String, nullable=False, unique=True),
|
||||
Column("name", String, nullable=False),
|
||||
Column("secret", String, nullable=False),
|
||||
Column("groups", String, nullable=False), # JSON array
|
||||
)
|
||||
|
||||
id: ULID = field(default_factory=ULID)
|
||||
imdb_id: str = None
|
||||
name: str = None # canonical user name
|
||||
secret: str = None
|
||||
groups: list[UserGroup] = field(default_factory=list)
|
||||
|
||||
def has_access(self, group_id: ULID | str, access: Access = "r"):
|
||||
group_id = group_id if isinstance(group_id, str) else str(group_id)
|
||||
return any(g["id"] == group_id and access == g["access"] for g in self.groups)
|
||||
|
||||
def set_access(self, group_id: ULID | str, access: Access):
|
||||
group_id = group_id if isinstance(group_id, str) else str(group_id)
|
||||
for g in self.groups:
|
||||
if g["id"] == group_id:
|
||||
g["access"] = access
|
||||
break
|
||||
else:
|
||||
self.groups.append({"id": group_id, "access": access})
|
||||
|
||||
|
||||
@mapper_registry.mapped
|
||||
@dataclass
|
||||
class Rating:
|
||||
_table: ClassVar[str] = "ratings"
|
||||
__table__: ClassVar[Table] = Table(
|
||||
"ratings",
|
||||
metadata,
|
||||
Column("id", String, primary_key=True), # ULID
|
||||
Column("movie_id", ForeignKey("movies.id"), nullable=False), # ULID
|
||||
Column("user_id", ForeignKey("users.id"), nullable=False), # ULID
|
||||
Column("score", Integer, nullable=False),
|
||||
Column("rating_date", String, nullable=False), # datetime
|
||||
Column("favorite", Integer), # bool
|
||||
Column("finished", Integer), # bool
|
||||
)
|
||||
|
||||
id: ULID = field(default_factory=ULID)
|
||||
|
||||
|
|
@ -304,7 +420,7 @@ class Rating:
|
|||
movie: Relation[Movie] = None
|
||||
|
||||
user_id: ULID = None
|
||||
user: Relation["User"] = None
|
||||
user: Relation[User] = None
|
||||
|
||||
score: int = None # range: [0,100]
|
||||
rating_date: datetime = None
|
||||
|
|
@ -324,41 +440,25 @@ class Rating:
|
|||
)
|
||||
|
||||
|
||||
Access = Literal[
|
||||
"r", # read
|
||||
"i", # index
|
||||
"w", # write
|
||||
]
|
||||
ratings = Rating.__table__
|
||||
|
||||
|
||||
@dataclass
|
||||
class User:
|
||||
_table: ClassVar[str] = "users"
|
||||
|
||||
id: ULID = field(default_factory=ULID)
|
||||
imdb_id: str = None
|
||||
name: str = None # canonical user name
|
||||
secret: str = None
|
||||
groups: list[dict[str, str]] = field(default_factory=list)
|
||||
|
||||
def has_access(self, group_id: ULID | str, access: Access = "r"):
|
||||
group_id = group_id if isinstance(group_id, str) else str(group_id)
|
||||
return any(g["id"] == group_id and access == g["access"] for g in self.groups)
|
||||
|
||||
def set_access(self, group_id: ULID | str, access: Access):
|
||||
group_id = group_id if isinstance(group_id, str) else str(group_id)
|
||||
for g in self.groups:
|
||||
if g["id"] == group_id:
|
||||
g["access"] = access
|
||||
break
|
||||
else:
|
||||
self.groups.append({"id": group_id, "access": access})
|
||||
class GroupUser(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
@mapper_registry.mapped
|
||||
@dataclass
|
||||
class Group:
|
||||
_table: ClassVar[str] = "groups"
|
||||
__table__: ClassVar[Table] = Table(
|
||||
"groups",
|
||||
metadata,
|
||||
Column("id", String, primary_key=True), # ULID
|
||||
Column("name", String, nullable=False),
|
||||
Column("users", String, nullable=False), # JSON array
|
||||
)
|
||||
|
||||
id: ULID = field(default_factory=ULID)
|
||||
name: str = None
|
||||
users: list[dict[str, str]] = field(default_factory=list)
|
||||
users: list[GroupUser] = field(default_factory=list)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from hashlib import md5
|
|||
from pathlib import Path
|
||||
from random import random
|
||||
from time import sleep, time
|
||||
from typing import Callable, ParamSpec, TypeVar, cast
|
||||
from typing import Any, Callable, ParamSpec, TypeVar, cast
|
||||
|
||||
import bs4
|
||||
import httpx
|
||||
|
|
@ -190,9 +190,11 @@ async def asoup_from_url(url):
|
|||
def _last_modified_from_response(resp: _Response_T) -> float | None:
|
||||
if last_mod := resp.headers.get("last-modified"):
|
||||
try:
|
||||
return email.utils.parsedate_to_datetime(last_mod).timestamp()
|
||||
except:
|
||||
dt = email.utils.parsedate_to_datetime(last_mod)
|
||||
except ValueError:
|
||||
log.exception("🐛 Received invalid value for Last-Modified: %s", last_mod)
|
||||
else:
|
||||
return dt.timestamp()
|
||||
|
||||
|
||||
def _last_modified_from_file(path: Path) -> float:
|
||||
|
|
@ -206,8 +208,8 @@ async def adownload(
|
|||
replace_existing: bool | None = None,
|
||||
only_if_newer: bool = False,
|
||||
timeout: float | None = None,
|
||||
chunk_callback=None,
|
||||
response_callback=None,
|
||||
chunk_callback: Callable[[bytes], Any] | None = None,
|
||||
response_callback: Callable[[_Response_T], Any] | None = None,
|
||||
) -> bytes | None:
|
||||
"""Download a file.
|
||||
|
||||
|
|
@ -246,7 +248,7 @@ async def adownload(
|
|||
if response_callback is not None:
|
||||
try:
|
||||
response_callback(resp)
|
||||
except:
|
||||
except BaseException:
|
||||
log.exception("🐛 Error in response callback.")
|
||||
|
||||
log.debug(
|
||||
|
|
@ -267,7 +269,9 @@ async def adownload(
|
|||
resp.raise_for_status()
|
||||
|
||||
if to_path is None:
|
||||
await resp.aread() # Download the response stream to allow `resp.content` access.
|
||||
await (
|
||||
resp.aread()
|
||||
) # Download the response stream to allow `resp.content` access.
|
||||
return resp.content
|
||||
|
||||
resp_lastmod = _last_modified_from_response(resp)
|
||||
|
|
@ -275,7 +279,7 @@ async def adownload(
|
|||
# Check Last-Modified in case the server ignored If-Modified-Since.
|
||||
# XXX also check Content-Length?
|
||||
if file_exists and only_if_newer and resp_lastmod is not None:
|
||||
assert file_lastmod
|
||||
assert file_lastmod # pyright: ignore [reportUnboundVariable]
|
||||
|
||||
if resp_lastmod <= file_lastmod:
|
||||
log.debug("✋ Local file is newer, skipping download: %a", req.url)
|
||||
|
|
@ -299,7 +303,7 @@ async def adownload(
|
|||
if chunk_callback:
|
||||
try:
|
||||
chunk_callback(chunk)
|
||||
except:
|
||||
except BaseException:
|
||||
log.exception("🐛 Error in chunk callback.")
|
||||
finally:
|
||||
os.close(tempfd)
|
||||
|
|
|
|||
|
|
@ -168,7 +168,8 @@ async def auth_user(request) -> User | None:
|
|||
if not isinstance(request.user, AuthedUser):
|
||||
return
|
||||
|
||||
user = await db.get(User, id=request.user.user_id)
|
||||
async with db.new_connection() as conn:
|
||||
user = await db.get(conn, User, id=request.user.user_id)
|
||||
if not user:
|
||||
return
|
||||
|
||||
|
|
@ -179,7 +180,7 @@ async def auth_user(request) -> User | None:
|
|||
return user
|
||||
|
||||
|
||||
_routes = []
|
||||
_routes: list[Route] = []
|
||||
|
||||
|
||||
def route(path: str, *, methods: list[str] | None = None, **kwds):
|
||||
|
|
@ -191,15 +192,12 @@ def route(path: str, *, methods: list[str] | None = None, **kwds):
|
|||
return decorator
|
||||
|
||||
|
||||
route.registered = _routes
|
||||
|
||||
|
||||
@route("/groups/{group_id}/ratings")
|
||||
async def get_ratings_for_group(request):
|
||||
group_id = as_ulid(request.path_params["group_id"])
|
||||
group = await db.get(Group, id=str(group_id))
|
||||
|
||||
if not group:
|
||||
async with db.new_connection() as conn:
|
||||
if (group := await db.get(conn, Group, id=str(group_id))) is None:
|
||||
return not_found()
|
||||
|
||||
user_ids = {u["id"] for u in group.users}
|
||||
|
|
@ -211,13 +209,17 @@ async def get_ratings_for_group(request):
|
|||
|
||||
# if (imdb_id or unwind_id) and (movie := await db.get(Movie, id=unwind_id, imdb_id=imdb_id)):
|
||||
if unwind_id:
|
||||
rows = await db.ratings_for_movie_ids(ids=[unwind_id])
|
||||
async with db.new_connection() as conn:
|
||||
rows = await db.ratings_for_movie_ids(conn, ids=[unwind_id])
|
||||
|
||||
elif imdb_id:
|
||||
rows = await db.ratings_for_movie_ids(imdb_ids=[imdb_id])
|
||||
async with db.new_connection() as conn:
|
||||
rows = await db.ratings_for_movie_ids(conn, imdb_ids=[imdb_id])
|
||||
|
||||
else:
|
||||
async with db.new_connection() as conn:
|
||||
rows = await find_ratings(
|
||||
conn,
|
||||
title=params.get("title"),
|
||||
media_type=params.get("media_type"),
|
||||
exact=truthy(params.get("exact")),
|
||||
|
|
@ -265,7 +267,8 @@ async def list_movies(request):
|
|||
if group_id := params.get("group_id"):
|
||||
group_id = as_ulid(group_id)
|
||||
|
||||
group = await db.get(Group, id=str(group_id))
|
||||
async with db.new_connection() as conn:
|
||||
group = await db.get(conn, Group, id=str(group_id))
|
||||
if not group:
|
||||
return not_found("Group not found.")
|
||||
|
||||
|
|
@ -290,8 +293,11 @@ async def list_movies(request):
|
|||
|
||||
if imdb_id or unwind_id:
|
||||
# XXX missing support for user_ids and user_scores
|
||||
async with db.new_connection() as conn:
|
||||
movies = (
|
||||
[m] if (m := await db.get(Movie, id=unwind_id, imdb_id=imdb_id)) else []
|
||||
[m]
|
||||
if (m := await db.get(conn, Movie, id=unwind_id, imdb_id=imdb_id))
|
||||
else []
|
||||
)
|
||||
|
||||
resp = [asplain(web_models.RatingAggregate.from_movie(m)) for m in movies]
|
||||
|
|
@ -299,7 +305,9 @@ async def list_movies(request):
|
|||
else:
|
||||
per_page = as_int(params.get("per_page"), max=1000, default=5)
|
||||
page = as_int(params.get("page"), min=1, default=1)
|
||||
async with db.new_connection() as conn:
|
||||
movieratings = await find_movies(
|
||||
conn,
|
||||
title=params.get("title"),
|
||||
media_type=params.get("media_type"),
|
||||
exact=truthy(params.get("exact")),
|
||||
|
|
@ -329,7 +337,8 @@ async def add_movie(request):
|
|||
@route("/movies/_reload_imdb", methods=["GET"])
|
||||
@requires(["authenticated", "admin"])
|
||||
async def progress_for_load_imdb_movies(request):
|
||||
progress = await db.get_import_progress()
|
||||
async with db.new_connection() as conn:
|
||||
progress = await db.get_import_progress(conn)
|
||||
if not progress:
|
||||
return JSONResponse({"status": "No import exists."}, status_code=404)
|
||||
|
||||
|
|
@ -368,14 +377,16 @@ async def load_imdb_movies(request):
|
|||
force = truthy(params.get("force"))
|
||||
|
||||
async with _import_lock:
|
||||
progress = await db.get_import_progress()
|
||||
async with db.new_connection() as conn:
|
||||
progress = await db.get_import_progress(conn)
|
||||
if progress and not progress.stopped:
|
||||
return JSONResponse(
|
||||
{"status": "Import is running.", "progress": progress.percent},
|
||||
status_code=409,
|
||||
)
|
||||
|
||||
await db.set_import_progress(0)
|
||||
async with db.transaction() as conn:
|
||||
await db.set_import_progress(conn, 0)
|
||||
|
||||
task = BackgroundTask(imdb_import.load_from_web, force=force)
|
||||
return JSONResponse(
|
||||
|
|
@ -386,7 +397,8 @@ async def load_imdb_movies(request):
|
|||
@route("/users")
|
||||
@requires(["authenticated", "admin"])
|
||||
async def list_users(request):
|
||||
users = await db.get_all(User)
|
||||
async with db.new_connection() as conn:
|
||||
users = await db.get_all(conn, User)
|
||||
|
||||
return JSONResponse([asplain(u) for u in users])
|
||||
|
||||
|
|
@ -402,7 +414,8 @@ async def add_user(request):
|
|||
secret = secrets.token_bytes()
|
||||
|
||||
user = User(name=name, imdb_id=imdb_id, secret=phc_scrypt(secret))
|
||||
await db.add(user)
|
||||
async with db.transaction() as conn:
|
||||
await db.add(conn, user)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
|
|
@ -418,7 +431,8 @@ async def show_user(request):
|
|||
user_id = as_ulid(request.path_params["user_id"])
|
||||
|
||||
if is_admin(request):
|
||||
user = await db.get(User, id=str(user_id))
|
||||
async with db.new_connection() as conn:
|
||||
user = await db.get(conn, User, id=str(user_id))
|
||||
|
||||
else:
|
||||
user = await auth_user(request)
|
||||
|
|
@ -445,14 +459,15 @@ async def show_user(request):
|
|||
async def remove_user(request):
|
||||
user_id = as_ulid(request.path_params["user_id"])
|
||||
|
||||
user = await db.get(User, id=str(user_id))
|
||||
async with db.new_connection() as conn:
|
||||
user = await db.get(conn, User, id=str(user_id))
|
||||
if not user:
|
||||
return not_found()
|
||||
|
||||
async with db.shared_connection().transaction():
|
||||
async with db.transaction() as conn:
|
||||
# XXX remove user refs from groups and ratings
|
||||
|
||||
await db.remove(user)
|
||||
await db.remove(conn, user)
|
||||
|
||||
return JSONResponse(asplain(user))
|
||||
|
||||
|
|
@ -463,7 +478,8 @@ async def modify_user(request):
|
|||
user_id = as_ulid(request.path_params["user_id"])
|
||||
|
||||
if is_admin(request):
|
||||
user = await db.get(User, id=str(user_id))
|
||||
async with db.new_connection() as conn:
|
||||
user = await db.get(conn, User, id=str(user_id))
|
||||
|
||||
else:
|
||||
user = await auth_user(request)
|
||||
|
|
@ -499,7 +515,8 @@ async def modify_user(request):
|
|||
|
||||
user.secret = phc_scrypt(secret)
|
||||
|
||||
await db.update(user)
|
||||
async with db.transaction() as conn:
|
||||
await db.update(conn, user)
|
||||
|
||||
return JSONResponse(asplain(user))
|
||||
|
||||
|
|
@ -509,13 +526,15 @@ async def modify_user(request):
|
|||
async def add_group_to_user(request):
|
||||
user_id = as_ulid(request.path_params["user_id"])
|
||||
|
||||
user = await db.get(User, id=str(user_id))
|
||||
async with db.new_connection() as conn:
|
||||
user = await db.get(conn, User, id=str(user_id))
|
||||
if not user:
|
||||
return not_found("User not found")
|
||||
|
||||
(group_id, access) = await json_from_body(request, ["group", "access"])
|
||||
|
||||
group = await db.get(Group, id=str(group_id))
|
||||
async with db.new_connection() as conn:
|
||||
group = await db.get(conn, Group, id=str(group_id))
|
||||
if not group:
|
||||
return not_found("Group not found")
|
||||
|
||||
|
|
@ -523,7 +542,8 @@ async def add_group_to_user(request):
|
|||
raise HTTPException(422, f"Invalid access level.")
|
||||
|
||||
user.set_access(group_id, access)
|
||||
await db.update(user)
|
||||
async with db.transaction() as conn:
|
||||
await db.update(conn, user)
|
||||
|
||||
return JSONResponse(asplain(user))
|
||||
|
||||
|
|
@ -551,7 +571,8 @@ async def load_imdb_user_ratings(request):
|
|||
@route("/groups")
|
||||
@requires(["authenticated", "admin"])
|
||||
async def list_groups(request):
|
||||
groups = await db.get_all(Group)
|
||||
async with db.new_connection() as conn:
|
||||
groups = await db.get_all(conn, Group)
|
||||
|
||||
return JSONResponse([asplain(g) for g in groups])
|
||||
|
||||
|
|
@ -564,7 +585,8 @@ async def add_group(request):
|
|||
# XXX restrict name
|
||||
|
||||
group = Group(name=name)
|
||||
await db.add(group)
|
||||
async with db.transaction() as conn:
|
||||
await db.add(conn, group)
|
||||
|
||||
return JSONResponse(asplain(group))
|
||||
|
||||
|
|
@ -573,7 +595,8 @@ async def add_group(request):
|
|||
@requires(["authenticated"])
|
||||
async def add_user_to_group(request):
|
||||
group_id = as_ulid(request.path_params["group_id"])
|
||||
group = await db.get(Group, id=str(group_id))
|
||||
async with db.new_connection() as conn:
|
||||
group = await db.get(conn, Group, id=str(group_id))
|
||||
|
||||
if not group:
|
||||
return not_found()
|
||||
|
|
@ -600,7 +623,8 @@ async def add_user_to_group(request):
|
|||
else:
|
||||
group.users.append({"name": name, "id": user_id})
|
||||
|
||||
await db.update(group)
|
||||
async with db.transaction() as conn:
|
||||
await db.update(conn, group)
|
||||
|
||||
return JSONResponse(asplain(group))
|
||||
|
||||
|
|
@ -632,7 +656,7 @@ def create_app():
|
|||
return Starlette(
|
||||
lifespan=lifespan,
|
||||
routes=[
|
||||
Mount(f"{config.api_base}v1", routes=route.registered),
|
||||
Mount(f"{config.api_base}v1", routes=_routes),
|
||||
],
|
||||
middleware=[
|
||||
Middleware(ResponseTimeMiddleware, header_name="Unwind-Elapsed"),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue