Merge branch 'feat/sqlalchemy'
This commit is contained in:
commit
4fbdb26d9c
25 changed files with 2107 additions and 2036 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -2,5 +2,5 @@
|
||||||
*.pyc
|
*.pyc
|
||||||
/.cache
|
/.cache
|
||||||
/.pytest_cache
|
/.pytest_cache
|
||||||
|
/build
|
||||||
/data/*
|
/data/*
|
||||||
/requirements.txt
|
|
||||||
|
|
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
3.12
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
FROM docker.io/library/python:3.11-alpine
|
FROM docker.io/library/python:3.12-alpine
|
||||||
|
|
||||||
RUN apk update --no-cache \
|
RUN apk update --no-cache \
|
||||||
&& apk upgrade --no-cache \
|
&& apk upgrade --no-cache \
|
||||||
|
|
|
||||||
748
poetry.lock
generated
748
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,44 +1,46 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "unwind"
|
name = "unwind"
|
||||||
version = "0.1.0"
|
version = "0"
|
||||||
description = ""
|
description = ""
|
||||||
authors = ["ducklet <ducklet@noreply.code.dumpr.org>"]
|
authors = ["ducklet <ducklet@noreply.code.dumpr.org>"]
|
||||||
license = "LOL"
|
license = "LOL"
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.11"
|
python = "^3.12"
|
||||||
beautifulsoup4 = "^4.9.3"
|
beautifulsoup4 = "^4.9.3"
|
||||||
html5lib = "^1.1"
|
html5lib = "^1.1"
|
||||||
starlette = "^0.26"
|
starlette = "^0.30"
|
||||||
ulid-py = "^1.1.0"
|
ulid-py = "^1.1.0"
|
||||||
databases = {extras = ["sqlite"], version = "^0.7.0"}
|
uvicorn = "^0.23"
|
||||||
uvicorn = "^0.21"
|
httpx = "^0.24"
|
||||||
httpx = "^0.23.3"
|
sqlalchemy = {version = "^2.0", extras = ["aiosqlite"]}
|
||||||
|
|
||||||
|
[tool.poetry.group.build.dependencies]
|
||||||
|
# When we run poetry export, typing-extensions is a transient dependency via
|
||||||
|
# sqlalchemy, but the hash won't be included in the requirements.txt.
|
||||||
|
# By making it a direct dependency we can fix this issue, otherwise this could
|
||||||
|
# be removed.
|
||||||
|
typing-extensions = "*"
|
||||||
|
|
||||||
[tool.poetry.group.dev]
|
[tool.poetry.group.dev]
|
||||||
optional = true
|
optional = true
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
autoflake = "*"
|
|
||||||
pytest = "*"
|
pytest = "*"
|
||||||
pyright = "*"
|
pyright = "*"
|
||||||
black = "*"
|
|
||||||
isort = "*"
|
|
||||||
pytest-asyncio = "*"
|
pytest-asyncio = "*"
|
||||||
|
pytest-cov = "*"
|
||||||
|
ruff = "*"
|
||||||
|
honcho = "*"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core>=1.0.0"]
|
requires = ["poetry-core>=1.0.0"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.pyright]
|
[tool.pyright]
|
||||||
pythonVersion = "3.11"
|
pythonVersion = "3.12"
|
||||||
|
|
||||||
[tool.isort]
|
[tool.ruff]
|
||||||
profile = "black"
|
target-version = "py312"
|
||||||
|
|
||||||
[tool.autoflake]
|
|
||||||
remove-duplicate-keys = true
|
|
||||||
remove-unused-variables = true
|
|
||||||
remove-all-unused-imports = true
|
|
||||||
ignore-init-module-imports = true
|
ignore-init-module-imports = true
|
||||||
ignore-pass-after-docstring = true
|
select = ["I", "F401", "F601", "F602", "F841"]
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,12 @@
|
||||||
|
|
||||||
cd "$RUN_DIR"
|
cd "$RUN_DIR"
|
||||||
|
|
||||||
|
# Make Uvicorn defaults explicit.
|
||||||
|
: "${API_PORT:=8000}"
|
||||||
|
: "${API_HOST:=127.0.0.1}"
|
||||||
|
export API_PORT
|
||||||
|
export API_HOST
|
||||||
|
|
||||||
[ -z "${DEBUG:-}" ] || set -x
|
[ -z "${DEBUG:-}" ] || set -x
|
||||||
|
|
||||||
exec honcho start
|
exec honcho start
|
||||||
|
|
|
||||||
|
|
@ -4,4 +4,9 @@ cd "$RUN_DIR"
|
||||||
|
|
||||||
[ -z "${DEBUG:-}" ] || set -x
|
[ -z "${DEBUG:-}" ] || set -x
|
||||||
|
|
||||||
exec uvicorn unwind:create_app --factory --reload
|
exec uvicorn \
|
||||||
|
--host "$API_HOST" \
|
||||||
|
--port "$API_PORT" \
|
||||||
|
--reload \
|
||||||
|
--factory \
|
||||||
|
unwind:create_app
|
||||||
|
|
|
||||||
25
scripts/docker-build
Executable file
25
scripts/docker-build
Executable file
|
|
@ -0,0 +1,25 @@
|
||||||
|
#!/bin/sh -eu
|
||||||
|
|
||||||
|
: "${DOCKER_BIN:=docker}"
|
||||||
|
|
||||||
|
cd "$RUN_DIR"
|
||||||
|
|
||||||
|
builddir=build
|
||||||
|
|
||||||
|
[ -z "${DEBUG:-}" ] || set -x
|
||||||
|
|
||||||
|
mkdir -p "$builddir"
|
||||||
|
|
||||||
|
poetry export \
|
||||||
|
--with=build \
|
||||||
|
--output="$builddir"/requirements.txt
|
||||||
|
|
||||||
|
githash=$(git rev-parse --short HEAD)
|
||||||
|
today=$(date -u '+%Y.%m.%d')
|
||||||
|
version="${today}+${githash}"
|
||||||
|
echo "$version" >"$builddir"/version
|
||||||
|
|
||||||
|
$DOCKER_BIN build \
|
||||||
|
--pull \
|
||||||
|
--tag "code.dumpr.org/ducklet/unwind":"$version" \
|
||||||
|
.
|
||||||
18
scripts/docker-run
Executable file
18
scripts/docker-run
Executable file
|
|
@ -0,0 +1,18 @@
|
||||||
|
#!/bin/sh -eu
|
||||||
|
|
||||||
|
: "${DOCKER_BIN:=docker}"
|
||||||
|
|
||||||
|
cd "$RUN_DIR"
|
||||||
|
|
||||||
|
[ -z "${DEBUG:-}" ] || set -x
|
||||||
|
|
||||||
|
version=$(cat build/version)
|
||||||
|
|
||||||
|
$DOCKER_BIN run \
|
||||||
|
--init \
|
||||||
|
-it --rm \
|
||||||
|
--read-only \
|
||||||
|
--memory '500m' \
|
||||||
|
--publish 127.0.0.1:8000:8000 \
|
||||||
|
--volume "$RUN_DIR"/data:/data \
|
||||||
|
"code.dumpr.org/ducklet/unwind":"$version"
|
||||||
|
|
@ -4,7 +4,7 @@ cd "$RUN_DIR"
|
||||||
|
|
||||||
[ -z "${DEBUG:-}" ] || set -x
|
[ -z "${DEBUG:-}" ] || set -x
|
||||||
|
|
||||||
autoflake --quiet --check --recursive unwind tests
|
ruff check --fix . ||:
|
||||||
isort unwind tests
|
ruff format .
|
||||||
black unwind tests
|
|
||||||
pyright
|
pyright
|
||||||
|
|
|
||||||
|
|
@ -11,4 +11,5 @@ export UNWIND_PORT
|
||||||
exec uvicorn \
|
exec uvicorn \
|
||||||
--host 0.0.0.0 \
|
--host 0.0.0.0 \
|
||||||
--port "$UNWIND_PORT" \
|
--port "$UNWIND_PORT" \
|
||||||
--factory unwind:create_app
|
--factory \
|
||||||
|
unwind:create_app
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,10 @@ dbfile="${UNWIND_DATA:-./data}/tests.sqlite"
|
||||||
|
|
||||||
# Rollback in Databases is currently broken, so we have to rebuild the database
|
# Rollback in Databases is currently broken, so we have to rebuild the database
|
||||||
# each time; see https://github.com/encode/databases/issues/403
|
# each time; see https://github.com/encode/databases/issues/403
|
||||||
trap 'rm "$dbfile"' EXIT TERM INT QUIT
|
trap 'rm "$dbfile" "${dbfile}-shm" "${dbfile}-wal"' EXIT TERM INT QUIT
|
||||||
|
|
||||||
[ -z "${DEBUG:-}" ] || set -x
|
[ -z "${DEBUG:-}" ] || set -x
|
||||||
|
|
||||||
SQLALCHEMY_WARN_20=1 \
|
export SQLALCHEMY_WARN_20=1 # XXX remove when we switched to SQLAlchemy 2.0
|
||||||
UNWIND_STORAGE="$dbfile" \
|
UNWIND_STORAGE="$dbfile" \
|
||||||
python -m pytest "$@"
|
python -m pytest --cov "$@"
|
||||||
|
|
|
||||||
|
|
@ -17,16 +17,19 @@ def event_loop():
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def shared_conn():
|
async def shared_conn():
|
||||||
c = db.shared_connection()
|
"""A database connection, ready to use."""
|
||||||
await c.connect()
|
await db.open_connection_pool()
|
||||||
|
|
||||||
await db.apply_db_patches(c)
|
async with db.new_connection() as c:
|
||||||
yield c
|
db._test_connection = c
|
||||||
|
yield c
|
||||||
|
db._test_connection = None
|
||||||
|
|
||||||
await c.disconnect()
|
await db.close_connection_pool()
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
async def conn(shared_conn):
|
async def conn(shared_conn: db.Connection):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
"""A transacted database connection, will be rolled back after use."""
|
||||||
|
async with db.transacted(shared_conn, force_rollback=True):
|
||||||
yield shared_conn
|
yield shared_conn
|
||||||
|
|
|
||||||
539
tests/test_db.py
539
tests/test_db.py
|
|
@ -4,155 +4,416 @@ import pytest
|
||||||
|
|
||||||
from unwind import db, models, web_models
|
from unwind import db, models, web_models
|
||||||
|
|
||||||
|
_movie_imdb_id = 1230000
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_add_and_get(shared_conn: db.Database):
|
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
|
||||||
m1 = models.Movie(
|
|
||||||
title="test movie",
|
|
||||||
release_year=2013,
|
|
||||||
media_type="Movie",
|
|
||||||
imdb_id="tt0000000",
|
|
||||||
genres={"genre-1"},
|
|
||||||
)
|
|
||||||
await db.add(m1)
|
|
||||||
|
|
||||||
m2 = models.Movie(
|
def a_movie(**kwds) -> models.Movie:
|
||||||
title="test movie",
|
global _movie_imdb_id
|
||||||
release_year=2013,
|
_movie_imdb_id += 1
|
||||||
media_type="Movie",
|
args = {
|
||||||
imdb_id="tt0000001",
|
"title": "test movie",
|
||||||
genres={"genre-1"},
|
"release_year": 2013,
|
||||||
)
|
"media_type": "Movie",
|
||||||
await db.add(m2)
|
"imdb_id": f"tt{_movie_imdb_id}",
|
||||||
|
"genres": {"genre-1"},
|
||||||
assert m1 == await db.get(models.Movie, id=str(m1.id))
|
} | kwds
|
||||||
assert m2 == await db.get(models.Movie, id=str(m2.id))
|
return models.Movie(**args)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_find_ratings(shared_conn: db.Database):
|
async def test_current_patch_level(conn: db.Connection):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
patch_level = "some-patch-level"
|
||||||
m1 = models.Movie(
|
assert patch_level != await db.current_patch_level(conn)
|
||||||
title="test movie",
|
await db.set_current_patch_level(conn, patch_level)
|
||||||
release_year=2013,
|
assert patch_level == await db.current_patch_level(conn)
|
||||||
media_type="Movie",
|
|
||||||
imdb_id="tt0000000",
|
|
||||||
genres={"genre-1"},
|
@pytest.mark.asyncio
|
||||||
|
async def test_get(conn: db.Connection):
|
||||||
|
m1 = a_movie()
|
||||||
|
await db.add(conn, m1)
|
||||||
|
|
||||||
|
m2 = a_movie(release_year=m1.release_year + 1)
|
||||||
|
await db.add(conn, m2)
|
||||||
|
|
||||||
|
assert None is await db.get(conn, models.Movie)
|
||||||
|
assert None is await db.get(conn, models.Movie, id="blerp")
|
||||||
|
assert m1 == await db.get(conn, models.Movie, id=str(m1.id))
|
||||||
|
assert m2 == await db.get(conn, models.Movie, release_year=m2.release_year)
|
||||||
|
assert None is await db.get(
|
||||||
|
conn, models.Movie, id=str(m1.id), release_year=m2.release_year
|
||||||
|
)
|
||||||
|
assert m2 == await db.get(
|
||||||
|
conn, models.Movie, id=str(m2.id), release_year=m2.release_year
|
||||||
|
)
|
||||||
|
assert m1 == await db.get(
|
||||||
|
conn,
|
||||||
|
models.Movie,
|
||||||
|
media_type=m1.media_type,
|
||||||
|
order_by=(models.movies.c.release_year, "asc"),
|
||||||
|
)
|
||||||
|
assert m2 == await db.get(
|
||||||
|
conn,
|
||||||
|
models.Movie,
|
||||||
|
media_type=m1.media_type,
|
||||||
|
order_by=(models.movies.c.release_year, "desc"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_all(conn: db.Connection):
|
||||||
|
m1 = a_movie()
|
||||||
|
await db.add(conn, m1)
|
||||||
|
|
||||||
|
m2 = a_movie(release_year=m1.release_year)
|
||||||
|
await db.add(conn, m2)
|
||||||
|
|
||||||
|
m3 = a_movie(release_year=m1.release_year + 1)
|
||||||
|
await db.add(conn, m3)
|
||||||
|
|
||||||
|
assert [] == list(await db.get_all(conn, models.Movie, id="blerp"))
|
||||||
|
assert [m1] == list(await db.get_all(conn, models.Movie, id=str(m1.id)))
|
||||||
|
assert [m1, m2] == list(
|
||||||
|
await db.get_all(conn, models.Movie, release_year=m1.release_year)
|
||||||
|
)
|
||||||
|
assert [m1, m2, m3] == list(await db.get_all(conn, models.Movie))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_many(conn: db.Connection):
|
||||||
|
m1 = a_movie()
|
||||||
|
await db.add(conn, m1)
|
||||||
|
|
||||||
|
m2 = a_movie(release_year=m1.release_year)
|
||||||
|
await db.add(conn, m2)
|
||||||
|
|
||||||
|
m3 = a_movie(release_year=m1.release_year + 1)
|
||||||
|
await db.add(conn, m3)
|
||||||
|
|
||||||
|
assert [] == list(await db.get_many(conn, models.Movie)), "selected nothing"
|
||||||
|
assert [m1] == list(await db.get_many(conn, models.Movie, id=[str(m1.id)]))
|
||||||
|
assert [m1] == list(await db.get_many(conn, models.Movie, id={str(m1.id)}))
|
||||||
|
assert [m1, m2] == list(
|
||||||
|
await db.get_many(conn, models.Movie, release_year=[m1.release_year])
|
||||||
|
)
|
||||||
|
assert [m1, m2, m3] == list(
|
||||||
|
await db.get_many(
|
||||||
|
conn, models.Movie, release_year=[m1.release_year, m3.release_year]
|
||||||
)
|
)
|
||||||
await db.add(m1)
|
)
|
||||||
|
|
||||||
m2 = models.Movie(
|
|
||||||
title="it's anöther Movie, Part 2",
|
@pytest.mark.asyncio
|
||||||
release_year=2015,
|
async def test_add_and_get(conn: db.Connection):
|
||||||
media_type="Movie",
|
m1 = a_movie()
|
||||||
imdb_id="tt0000001",
|
await db.add(conn, m1)
|
||||||
genres={"genre-2"},
|
|
||||||
|
m2 = a_movie()
|
||||||
|
await db.add(conn, m2)
|
||||||
|
|
||||||
|
assert m1 == await db.get(conn, models.Movie, id=str(m1.id))
|
||||||
|
assert m2 == await db.get(conn, models.Movie, id=str(m2.id))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update(conn: db.Connection):
|
||||||
|
m = a_movie()
|
||||||
|
await db.add(conn, m)
|
||||||
|
|
||||||
|
assert m == await db.get(conn, models.Movie, id=str(m.id))
|
||||||
|
m.title += "something else"
|
||||||
|
assert m != await db.get(conn, models.Movie, id=str(m.id))
|
||||||
|
|
||||||
|
await db.update(conn, m)
|
||||||
|
assert m == await db.get(conn, models.Movie, id=str(m.id))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove(conn: db.Connection):
|
||||||
|
m1 = a_movie()
|
||||||
|
await db.add(conn, m1)
|
||||||
|
assert m1 == await db.get(conn, models.Movie, id=str(m1.id))
|
||||||
|
|
||||||
|
await db.remove(conn, m1)
|
||||||
|
assert None is await db.get(conn, models.Movie, id=str(m1.id))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_find_ratings(conn: db.Connection):
|
||||||
|
m1 = a_movie(
|
||||||
|
title="test movie",
|
||||||
|
release_year=2013,
|
||||||
|
genres={"genre-1"},
|
||||||
|
)
|
||||||
|
await db.add(conn, m1)
|
||||||
|
|
||||||
|
m2 = a_movie(
|
||||||
|
title="it's anöther Movie, Part 2",
|
||||||
|
release_year=2015,
|
||||||
|
genres={"genre-2"},
|
||||||
|
)
|
||||||
|
await db.add(conn, m2)
|
||||||
|
|
||||||
|
m3 = a_movie(
|
||||||
|
title="movie it's, Part 3",
|
||||||
|
release_year=m2.release_year,
|
||||||
|
genres=m2.genres,
|
||||||
|
)
|
||||||
|
await db.add(conn, m3)
|
||||||
|
|
||||||
|
u1 = models.User(
|
||||||
|
imdb_id="u00001",
|
||||||
|
name="User1",
|
||||||
|
secret="secret1",
|
||||||
|
)
|
||||||
|
await db.add(conn, u1)
|
||||||
|
|
||||||
|
u2 = models.User(
|
||||||
|
imdb_id="u00002",
|
||||||
|
name="User2",
|
||||||
|
secret="secret2",
|
||||||
|
)
|
||||||
|
await db.add(conn, u2)
|
||||||
|
|
||||||
|
r1 = models.Rating(
|
||||||
|
movie_id=m2.id,
|
||||||
|
movie=m2,
|
||||||
|
user_id=u1.id,
|
||||||
|
user=u1,
|
||||||
|
score=66,
|
||||||
|
rating_date=datetime.now(),
|
||||||
|
)
|
||||||
|
await db.add(conn, r1)
|
||||||
|
|
||||||
|
r2 = models.Rating(
|
||||||
|
movie_id=m2.id,
|
||||||
|
movie=m2,
|
||||||
|
user_id=u2.id,
|
||||||
|
user=u2,
|
||||||
|
score=77,
|
||||||
|
rating_date=datetime.now(),
|
||||||
|
)
|
||||||
|
await db.add(conn, r2)
|
||||||
|
|
||||||
|
# ---
|
||||||
|
|
||||||
|
rows = await db.find_ratings(
|
||||||
|
conn,
|
||||||
|
title=m1.title,
|
||||||
|
media_type=m1.media_type,
|
||||||
|
exact=True,
|
||||||
|
ignore_tv_episodes=True,
|
||||||
|
include_unrated=True,
|
||||||
|
yearcomp=("=", m1.release_year),
|
||||||
|
limit_rows=3,
|
||||||
|
user_ids=[],
|
||||||
|
)
|
||||||
|
ratings = (web_models.Rating(**r) for r in rows)
|
||||||
|
assert (web_models.RatingAggregate.from_movie(m1),) == tuple(
|
||||||
|
web_models.aggregate_ratings(ratings, user_ids=[])
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = await db.find_ratings(conn, title="movie", include_unrated=False)
|
||||||
|
ratings = tuple(web_models.Rating(**r) for r in rows)
|
||||||
|
assert (
|
||||||
|
web_models.Rating.from_movie(m2, rating=r1),
|
||||||
|
web_models.Rating.from_movie(m2, rating=r2),
|
||||||
|
) == ratings
|
||||||
|
|
||||||
|
rows = await db.find_ratings(conn, title="movie", include_unrated=True)
|
||||||
|
ratings = tuple(web_models.Rating(**r) for r in rows)
|
||||||
|
assert (
|
||||||
|
web_models.Rating.from_movie(m1),
|
||||||
|
web_models.Rating.from_movie(m2, rating=r1),
|
||||||
|
web_models.Rating.from_movie(m2, rating=r2),
|
||||||
|
web_models.Rating.from_movie(m3),
|
||||||
|
) == ratings
|
||||||
|
|
||||||
|
aggr = web_models.aggregate_ratings(ratings, user_ids=[])
|
||||||
|
assert tuple(
|
||||||
|
web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3]
|
||||||
|
) == tuple(aggr)
|
||||||
|
|
||||||
|
aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id)])
|
||||||
|
assert (
|
||||||
|
web_models.RatingAggregate.from_movie(m1),
|
||||||
|
web_models.RatingAggregate.from_movie(m2, ratings=[r1]),
|
||||||
|
web_models.RatingAggregate.from_movie(m3),
|
||||||
|
) == tuple(aggr)
|
||||||
|
|
||||||
|
aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id), str(u2.id)])
|
||||||
|
assert (
|
||||||
|
web_models.RatingAggregate.from_movie(m1),
|
||||||
|
web_models.RatingAggregate.from_movie(m2, ratings=[r1, r2]),
|
||||||
|
web_models.RatingAggregate.from_movie(m3),
|
||||||
|
) == tuple(aggr)
|
||||||
|
|
||||||
|
rows = await db.find_ratings(conn, title="movie", include_unrated=True)
|
||||||
|
ratings = (web_models.Rating(**r) for r in rows)
|
||||||
|
aggr = web_models.aggregate_ratings(ratings, user_ids=[])
|
||||||
|
assert tuple(
|
||||||
|
web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3]
|
||||||
|
) == tuple(aggr)
|
||||||
|
|
||||||
|
rows = await db.find_ratings(conn, title="test", include_unrated=True)
|
||||||
|
ratings = tuple(web_models.Rating(**r) for r in rows)
|
||||||
|
assert (web_models.Rating.from_movie(m1),) == ratings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ratings_for_movies(conn: db.Connection):
|
||||||
|
m1 = a_movie()
|
||||||
|
await db.add(conn, m1)
|
||||||
|
|
||||||
|
m2 = a_movie()
|
||||||
|
await db.add(conn, m2)
|
||||||
|
|
||||||
|
u1 = models.User(
|
||||||
|
imdb_id="u00001",
|
||||||
|
name="User1",
|
||||||
|
secret="secret1",
|
||||||
|
)
|
||||||
|
await db.add(conn, u1)
|
||||||
|
|
||||||
|
u2 = models.User(
|
||||||
|
imdb_id="u00002",
|
||||||
|
name="User2",
|
||||||
|
secret="secret2",
|
||||||
|
)
|
||||||
|
await db.add(conn, u2)
|
||||||
|
|
||||||
|
r1 = models.Rating(
|
||||||
|
movie_id=m2.id,
|
||||||
|
movie=m2,
|
||||||
|
user_id=u1.id,
|
||||||
|
user=u1,
|
||||||
|
score=66,
|
||||||
|
rating_date=datetime.now(),
|
||||||
|
)
|
||||||
|
await db.add(conn, r1)
|
||||||
|
|
||||||
|
# ---
|
||||||
|
|
||||||
|
movie_ids = [m1.id]
|
||||||
|
user_ids = []
|
||||||
|
assert tuple() == tuple(
|
||||||
|
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
movie_ids = [m2.id]
|
||||||
|
user_ids = []
|
||||||
|
assert (r1,) == tuple(
|
||||||
|
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
movie_ids = [m2.id]
|
||||||
|
user_ids = [u2.id]
|
||||||
|
assert tuple() == tuple(
|
||||||
|
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
movie_ids = [m2.id]
|
||||||
|
user_ids = [u1.id]
|
||||||
|
assert (r1,) == tuple(
|
||||||
|
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
movie_ids = [m1.id, m2.id]
|
||||||
|
user_ids = [u1.id, u2.id]
|
||||||
|
assert (r1,) == tuple(
|
||||||
|
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_find_movies(conn: db.Connection):
|
||||||
|
m1 = a_movie(title="movie one")
|
||||||
|
await db.add(conn, m1)
|
||||||
|
|
||||||
|
m2 = a_movie(title="movie two", imdb_score=33, release_year=m1.release_year + 1)
|
||||||
|
await db.add(conn, m2)
|
||||||
|
|
||||||
|
u1 = models.User(
|
||||||
|
imdb_id="u00001",
|
||||||
|
name="User1",
|
||||||
|
secret="secret1",
|
||||||
|
)
|
||||||
|
await db.add(conn, u1)
|
||||||
|
|
||||||
|
u2 = models.User(
|
||||||
|
imdb_id="u00002",
|
||||||
|
name="User2",
|
||||||
|
secret="secret2",
|
||||||
|
)
|
||||||
|
await db.add(conn, u2)
|
||||||
|
|
||||||
|
r1 = models.Rating(
|
||||||
|
movie_id=m2.id,
|
||||||
|
movie=m2,
|
||||||
|
user_id=u1.id,
|
||||||
|
user=u1,
|
||||||
|
score=66,
|
||||||
|
rating_date=datetime.now(),
|
||||||
|
)
|
||||||
|
await db.add(conn, r1)
|
||||||
|
|
||||||
|
# ---
|
||||||
|
|
||||||
|
assert () == tuple(
|
||||||
|
await db.find_movies(conn, title=m1.title, include_unrated=False)
|
||||||
|
)
|
||||||
|
assert ((m1, []),) == tuple(
|
||||||
|
await db.find_movies(conn, title=m1.title, include_unrated=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ((m1, []),) == tuple(
|
||||||
|
await db.find_movies(conn, title="mo on", exact=False, include_unrated=True)
|
||||||
|
)
|
||||||
|
assert ((m1, []),) == tuple(
|
||||||
|
await db.find_movies(conn, title="movie one", exact=True, include_unrated=True)
|
||||||
|
)
|
||||||
|
assert () == tuple(
|
||||||
|
await db.find_movies(conn, title="mo on", exact=True, include_unrated=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ((m2, []),) == tuple(
|
||||||
|
await db.find_movies(conn, title="movie", exact=False, include_unrated=False)
|
||||||
|
)
|
||||||
|
assert ((m2, []), (m1, [])) == tuple(
|
||||||
|
await db.find_movies(conn, title="movie", exact=False, include_unrated=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ((m1, []),) == tuple(
|
||||||
|
await db.find_movies(
|
||||||
|
conn, include_unrated=True, yearcomp=("=", m1.release_year)
|
||||||
)
|
)
|
||||||
await db.add(m2)
|
)
|
||||||
|
assert ((m2, []),) == tuple(
|
||||||
m3 = models.Movie(
|
await db.find_movies(
|
||||||
title="movie it's, Part 3",
|
conn, include_unrated=True, yearcomp=("=", m2.release_year)
|
||||||
release_year=2015,
|
|
||||||
media_type="Movie",
|
|
||||||
imdb_id="tt0000002",
|
|
||||||
genres={"genre-2"},
|
|
||||||
)
|
)
|
||||||
await db.add(m3)
|
)
|
||||||
|
assert ((m1, []),) == tuple(
|
||||||
u1 = models.User(
|
await db.find_movies(
|
||||||
imdb_id="u00001",
|
conn, include_unrated=True, yearcomp=("<", m2.release_year)
|
||||||
name="User1",
|
|
||||||
secret="secret1",
|
|
||||||
)
|
)
|
||||||
await db.add(u1)
|
)
|
||||||
|
assert ((m2, []),) == tuple(
|
||||||
u2 = models.User(
|
await db.find_movies(
|
||||||
imdb_id="u00002",
|
conn, include_unrated=True, yearcomp=(">", m1.release_year)
|
||||||
name="User2",
|
|
||||||
secret="secret2",
|
|
||||||
)
|
)
|
||||||
await db.add(u2)
|
)
|
||||||
|
|
||||||
r1 = models.Rating(
|
assert ((m2, []), (m1, [])) == tuple(
|
||||||
movie_id=m2.id,
|
await db.find_movies(conn, include_unrated=True)
|
||||||
movie=m2,
|
)
|
||||||
user_id=u1.id,
|
assert ((m2, []),) == tuple(
|
||||||
user=u1,
|
await db.find_movies(conn, include_unrated=True, limit_rows=1)
|
||||||
score=66,
|
)
|
||||||
rating_date=datetime.now(),
|
assert ((m1, []),) == tuple(
|
||||||
)
|
await db.find_movies(conn, include_unrated=True, skip_rows=1)
|
||||||
await db.add(r1)
|
)
|
||||||
|
|
||||||
r2 = models.Rating(
|
assert ((m2, [r1]), (m1, [])) == tuple(
|
||||||
movie_id=m2.id,
|
await db.find_movies(conn, include_unrated=True, user_ids=[u1.id, u2.id])
|
||||||
movie=m2,
|
)
|
||||||
user_id=u2.id,
|
|
||||||
user=u2,
|
|
||||||
score=77,
|
|
||||||
rating_date=datetime.now(),
|
|
||||||
)
|
|
||||||
await db.add(r2)
|
|
||||||
|
|
||||||
# ---
|
|
||||||
|
|
||||||
rows = await db.find_ratings(
|
|
||||||
title=m1.title,
|
|
||||||
media_type=m1.media_type,
|
|
||||||
exact=True,
|
|
||||||
ignore_tv_episodes=True,
|
|
||||||
include_unrated=True,
|
|
||||||
yearcomp=("=", m1.release_year),
|
|
||||||
limit_rows=3,
|
|
||||||
user_ids=[],
|
|
||||||
)
|
|
||||||
ratings = (web_models.Rating(**r) for r in rows)
|
|
||||||
assert (web_models.RatingAggregate.from_movie(m1),) == tuple(
|
|
||||||
web_models.aggregate_ratings(ratings, user_ids=[])
|
|
||||||
)
|
|
||||||
|
|
||||||
rows = await db.find_ratings(title="movie", include_unrated=False)
|
|
||||||
ratings = tuple(web_models.Rating(**r) for r in rows)
|
|
||||||
assert (
|
|
||||||
web_models.Rating.from_movie(m2, rating=r1),
|
|
||||||
web_models.Rating.from_movie(m2, rating=r2),
|
|
||||||
) == ratings
|
|
||||||
|
|
||||||
rows = await db.find_ratings(title="movie", include_unrated=True)
|
|
||||||
ratings = tuple(web_models.Rating(**r) for r in rows)
|
|
||||||
assert (
|
|
||||||
web_models.Rating.from_movie(m1),
|
|
||||||
web_models.Rating.from_movie(m2, rating=r1),
|
|
||||||
web_models.Rating.from_movie(m2, rating=r2),
|
|
||||||
web_models.Rating.from_movie(m3),
|
|
||||||
) == ratings
|
|
||||||
|
|
||||||
aggr = web_models.aggregate_ratings(ratings, user_ids=[])
|
|
||||||
assert tuple(
|
|
||||||
web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3]
|
|
||||||
) == tuple(aggr)
|
|
||||||
|
|
||||||
aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id)])
|
|
||||||
assert (
|
|
||||||
web_models.RatingAggregate.from_movie(m1),
|
|
||||||
web_models.RatingAggregate.from_movie(m2, ratings=[r1]),
|
|
||||||
web_models.RatingAggregate.from_movie(m3),
|
|
||||||
) == tuple(aggr)
|
|
||||||
|
|
||||||
aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id), str(u2.id)])
|
|
||||||
assert (
|
|
||||||
web_models.RatingAggregate.from_movie(m1),
|
|
||||||
web_models.RatingAggregate.from_movie(m2, ratings=[r1, r2]),
|
|
||||||
web_models.RatingAggregate.from_movie(m3),
|
|
||||||
) == tuple(aggr)
|
|
||||||
|
|
||||||
rows = await db.find_ratings(title="movie", include_unrated=True)
|
|
||||||
ratings = (web_models.Rating(**r) for r in rows)
|
|
||||||
aggr = web_models.aggregate_ratings(ratings, user_ids=[])
|
|
||||||
assert tuple(
|
|
||||||
web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3]
|
|
||||||
) == tuple(aggr)
|
|
||||||
|
|
||||||
rows = await db.find_ratings(title="test", include_unrated=True)
|
|
||||||
ratings = tuple(web_models.Rating(**r) for r in rows)
|
|
||||||
assert (web_models.Rating.from_movie(m1),) == ratings
|
|
||||||
|
|
|
||||||
11
tests/test_models.py
Normal file
11
tests/test_models.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from unwind import models
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("mapper", models.mapper_registry.mappers)
|
||||||
|
def test_fields(mapper):
|
||||||
|
"""Test that models.fields() matches exactly all table columns."""
|
||||||
|
dcfields = {f.name for f in models.fields(mapper.class_)}
|
||||||
|
mfields = {c.name for c in mapper.columns}
|
||||||
|
assert dcfields == mfields
|
||||||
|
|
@ -1,53 +1,243 @@
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from starlette.testclient import TestClient
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
from unwind import create_app, db, imdb, models
|
from unwind import config, create_app, db, imdb, models
|
||||||
|
|
||||||
app = create_app()
|
app = create_app()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def unauthorized_client() -> TestClient:
|
||||||
|
# https://www.starlette.io/testclient/
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def authorized_client() -> TestClient:
|
||||||
|
client = TestClient(app)
|
||||||
|
client.auth = "user1", "secret1"
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def admin_client() -> TestClient:
|
||||||
|
client = TestClient(app)
|
||||||
|
for token in config.api_credentials.values():
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise RuntimeError("No bearer tokens configured.")
|
||||||
|
client.headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_app(shared_conn: db.Database):
|
async def test_get_ratings_for_group(
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
conn: db.Connection, unauthorized_client: TestClient
|
||||||
# https://www.starlette.io/testclient/
|
):
|
||||||
client = TestClient(app)
|
user = models.User(
|
||||||
response = client.get("/api/v1/movies")
|
imdb_id="ur12345678",
|
||||||
assert response.status_code == 403
|
name="user-1",
|
||||||
|
secret="secret-1",
|
||||||
|
groups=[],
|
||||||
|
)
|
||||||
|
group = models.Group(
|
||||||
|
name="group-1",
|
||||||
|
users=[models.GroupUser(id=str(user.id), name=user.name)],
|
||||||
|
)
|
||||||
|
user.groups = [models.UserGroup(id=str(group.id), access="r")]
|
||||||
|
path = app.url_path_for("get_ratings_for_group", group_id=str(group.id))
|
||||||
|
|
||||||
client.auth = "user1", "secret1"
|
resp = unauthorized_client.get(path)
|
||||||
|
assert resp.status_code == 404, "Group does not exist (yet)"
|
||||||
|
|
||||||
response = client.get("/api/v1/movies")
|
await db.add(conn, user)
|
||||||
assert response.status_code == 200
|
await db.add(conn, group)
|
||||||
assert response.json() == []
|
|
||||||
|
|
||||||
m = models.Movie(
|
resp = unauthorized_client.get(path)
|
||||||
title="test movie",
|
assert resp.status_code == 200
|
||||||
release_year=2013,
|
assert resp.json() == []
|
||||||
media_type="Movie",
|
|
||||||
imdb_id="tt12345678",
|
|
||||||
genres={"genre-1"},
|
|
||||||
)
|
|
||||||
await db.add(m)
|
|
||||||
|
|
||||||
response = client.get("/api/v1/movies", params={"include_unrated": 1})
|
movie = models.Movie(
|
||||||
assert response.status_code == 200
|
title="test movie",
|
||||||
assert response.json() == [{**models.asplain(m), "user_scores": []}]
|
release_year=2013,
|
||||||
|
media_type="Movie",
|
||||||
|
imdb_id="tt12345678",
|
||||||
|
genres={"genre-1"},
|
||||||
|
)
|
||||||
|
await db.add(conn, movie)
|
||||||
|
|
||||||
m_plain = {
|
rating = models.Rating(
|
||||||
"canonical_title": m.title,
|
movie_id=movie.id, user_id=user.id, score=66, rating_date=datetime.now()
|
||||||
"imdb_score": m.imdb_score,
|
)
|
||||||
"imdb_votes": m.imdb_votes,
|
await db.add(conn, rating)
|
||||||
"link": imdb.movie_url(m.imdb_id),
|
|
||||||
"media_type": m.media_type,
|
|
||||||
"original_title": m.original_title,
|
|
||||||
"user_scores": [],
|
|
||||||
"year": m.release_year,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = client.get("/api/v1/movies", params={"imdb_id": m.imdb_id})
|
rating_aggregate = {
|
||||||
assert response.status_code == 200
|
"canonical_title": movie.title,
|
||||||
assert response.json() == [m_plain]
|
"imdb_score": movie.imdb_score,
|
||||||
|
"imdb_votes": movie.imdb_votes,
|
||||||
|
"link": imdb.movie_url(movie.imdb_id),
|
||||||
|
"media_type": movie.media_type,
|
||||||
|
"original_title": movie.original_title,
|
||||||
|
"user_scores": [rating.score],
|
||||||
|
"year": movie.release_year,
|
||||||
|
}
|
||||||
|
|
||||||
response = client.get("/api/v1/movies", params={"unwind_id": str(m.id)})
|
resp = unauthorized_client.get(path)
|
||||||
assert response.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert response.json() == [m_plain]
|
assert resp.json() == [rating_aggregate]
|
||||||
|
|
||||||
|
filters = {
|
||||||
|
"imdb_id": movie.imdb_id,
|
||||||
|
"unwind_id": str(movie.id),
|
||||||
|
"title": movie.title,
|
||||||
|
"media_type": movie.media_type,
|
||||||
|
"year": movie.release_year,
|
||||||
|
}
|
||||||
|
for k, v in filters.items():
|
||||||
|
resp = unauthorized_client.get(path, params={k: v})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == [rating_aggregate]
|
||||||
|
|
||||||
|
resp = unauthorized_client.get(path, params={"title": "no such thing"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
# Test "exact" query param.
|
||||||
|
resp = unauthorized_client.get(
|
||||||
|
path, params={"title": "test movie", "exact": "true"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == [rating_aggregate]
|
||||||
|
resp = unauthorized_client.get(path, params={"title": "te mo", "exact": "false"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == [rating_aggregate]
|
||||||
|
resp = unauthorized_client.get(path, params={"title": "te mo", "exact": "true"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
# XXX Test "ignore_tv_episodes" query param.
|
||||||
|
# XXX Test "include_unrated" query param.
|
||||||
|
# XXX Test "per_page" query param.
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_movies(
|
||||||
|
conn: db.Connection,
|
||||||
|
unauthorized_client: TestClient,
|
||||||
|
authorized_client: TestClient,
|
||||||
|
):
|
||||||
|
path = app.url_path_for("list_movies")
|
||||||
|
response = unauthorized_client.get(path)
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
response = authorized_client.get(path)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == []
|
||||||
|
|
||||||
|
m = models.Movie(
|
||||||
|
title="test movie",
|
||||||
|
release_year=2013,
|
||||||
|
media_type="Movie",
|
||||||
|
imdb_id="tt12345678",
|
||||||
|
genres={"genre-1"},
|
||||||
|
)
|
||||||
|
await db.add(conn, m)
|
||||||
|
|
||||||
|
response = authorized_client.get(path, params={"include_unrated": 1})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == [{**models.asplain(m), "user_scores": []}]
|
||||||
|
|
||||||
|
m_plain = {
|
||||||
|
"canonical_title": m.title,
|
||||||
|
"imdb_score": m.imdb_score,
|
||||||
|
"imdb_votes": m.imdb_votes,
|
||||||
|
"link": imdb.movie_url(m.imdb_id),
|
||||||
|
"media_type": m.media_type,
|
||||||
|
"original_title": m.original_title,
|
||||||
|
"user_scores": [],
|
||||||
|
"year": m.release_year,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = authorized_client.get(path, params={"imdb_id": m.imdb_id})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == [m_plain]
|
||||||
|
|
||||||
|
response = authorized_client.get(path, params={"unwind_id": str(m.id)})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == [m_plain]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_users(
|
||||||
|
conn: db.Connection,
|
||||||
|
unauthorized_client: TestClient,
|
||||||
|
authorized_client: TestClient,
|
||||||
|
admin_client: TestClient,
|
||||||
|
):
|
||||||
|
path = app.url_path_for("list_users")
|
||||||
|
response = unauthorized_client.get(path)
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
response = authorized_client.get(path)
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
response = admin_client.get(path)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == []
|
||||||
|
|
||||||
|
m = models.User(
|
||||||
|
imdb_id="ur12345678",
|
||||||
|
name="user-1",
|
||||||
|
secret="secret-1",
|
||||||
|
groups=[],
|
||||||
|
)
|
||||||
|
await db.add(conn, m)
|
||||||
|
|
||||||
|
m_plain = {
|
||||||
|
"groups": m.groups,
|
||||||
|
"id": m.id,
|
||||||
|
"imdb_id": m.imdb_id,
|
||||||
|
"name": m.name,
|
||||||
|
"secret": m.secret,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = admin_client.get(path)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == [m_plain]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_groups(
|
||||||
|
conn: db.Connection,
|
||||||
|
unauthorized_client: TestClient,
|
||||||
|
authorized_client: TestClient,
|
||||||
|
admin_client: TestClient,
|
||||||
|
):
|
||||||
|
path = app.url_path_for("list_groups")
|
||||||
|
response = unauthorized_client.get(path)
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
response = authorized_client.get(path)
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
response = admin_client.get(path)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == []
|
||||||
|
|
||||||
|
m = models.Group(
|
||||||
|
name="group-1",
|
||||||
|
users=[models.GroupUser(id="123", name="itsa-me")],
|
||||||
|
)
|
||||||
|
await db.add(conn, m)
|
||||||
|
|
||||||
|
m_plain = {
|
||||||
|
"users": m.users,
|
||||||
|
"id": m.id,
|
||||||
|
"name": m.name,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = admin_client.get(path)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == [m_plain]
|
||||||
|
|
|
||||||
1332
unwind-ui/package-lock.json
generated
1332
unwind-ui/package-lock.json
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,13 +1,29 @@
|
||||||
import { defineConfig } from "vite"
|
import { defineConfig } from "vite"
|
||||||
import vue from "@vitejs/plugin-vue"
|
import vue from "@vitejs/plugin-vue"
|
||||||
|
|
||||||
|
// Vite defaults.
|
||||||
|
const vite_host = "localhost"
|
||||||
|
const vite_port = 3000
|
||||||
|
|
||||||
|
const base = process.env.BASE_URL || "/"
|
||||||
|
const proxied_api_url = `http://${vite_host}:${vite_port}/api/`
|
||||||
|
const real_api_url = `http://${process.env.API_HOST}:${process.env.API_PORT}/api/`
|
||||||
|
|
||||||
// https://vitejs.dev/config/
|
// https://vitejs.dev/config/
|
||||||
export default defineConfig({
|
export default defineConfig({
|
||||||
base: process.env.BASE_URL || "/",
|
base,
|
||||||
define: {
|
define: {
|
||||||
"process.env.API_URL": JSON.stringify(
|
"process.env.API_URL": JSON.stringify(process.env.API_URL || proxied_api_url),
|
||||||
process.env.API_URL || "http://localhost:8000/api/",
|
},
|
||||||
),
|
server: {
|
||||||
|
host: vite_host,
|
||||||
|
port: vite_port,
|
||||||
|
proxy: {
|
||||||
|
[`${base}api`]: {
|
||||||
|
target: real_api_url,
|
||||||
|
prependPath: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
plugins: [vue()],
|
plugins: [vue()],
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -2,20 +2,20 @@ import os
|
||||||
import tomllib
|
import tomllib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
datadir = Path(os.getenv("UNWIND_DATA") or "./data")
|
datadir: Path = Path(os.getenv("UNWIND_DATA") or "./data")
|
||||||
cachedir = (
|
cachedir: Path = Path(p) if (p := os.getenv("UNWIND_CACHEDIR")) else datadir / ".cache"
|
||||||
Path(cachedir)
|
debug: bool = os.getenv("DEBUG") == "1"
|
||||||
if (cachedir := os.getenv("UNWIND_CACHEDIR", datadir / ".cache"))
|
loglevel: str = os.getenv("UNWIND_LOGLEVEL") or ("DEBUG" if debug else "INFO")
|
||||||
else None
|
storage_path: Path = (
|
||||||
|
Path(p) if (p := os.getenv("UNWIND_STORAGE")) else datadir / "db.sqlite"
|
||||||
|
)
|
||||||
|
config_path: Path = (
|
||||||
|
Path(p) if (p := os.getenv("UNWIND_CONFIG")) else datadir / "config.toml"
|
||||||
)
|
)
|
||||||
debug = os.getenv("DEBUG") == "1"
|
|
||||||
loglevel = os.getenv("UNWIND_LOGLEVEL") or ("DEBUG" if debug else "INFO")
|
|
||||||
storage_path = os.getenv("UNWIND_STORAGE", datadir / "db.sqlite")
|
|
||||||
config_path = os.getenv("UNWIND_CONFIG", datadir / "config.toml")
|
|
||||||
|
|
||||||
with open(config_path, "rb") as fd:
|
with open(config_path, "rb") as fd:
|
||||||
_config = tomllib.load(fd)
|
_config = tomllib.load(fd)
|
||||||
|
|
||||||
api_base = _config["api"].get("base", "/api/")
|
api_base: str = _config["api"].get("base", "/api/")
|
||||||
api_cors = _config["api"].get("cors", "*")
|
api_cors: str = _config["api"].get("cors", "*")
|
||||||
api_credentials = _config["api"].get("credentials", {})
|
api_credentials: dict[str, str] = _config["api"].get("credentials", {})
|
||||||
|
|
|
||||||
664
unwind/db.py
664
unwind/db.py
|
|
@ -1,24 +1,27 @@
|
||||||
import asyncio
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
import threading
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Iterable, Literal, Type, TypeVar
|
from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type, TypeVar
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy as sa
|
||||||
from databases import Database
|
from sqlalchemy.dialects.sqlite import insert
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
|
||||||
|
|
||||||
from . import config
|
from . import config
|
||||||
from .models import (
|
from .models import (
|
||||||
|
Model,
|
||||||
Movie,
|
Movie,
|
||||||
Progress,
|
Progress,
|
||||||
Rating,
|
Rating,
|
||||||
User,
|
User,
|
||||||
asplain,
|
asplain,
|
||||||
fields,
|
db_patches,
|
||||||
fromplain,
|
fromplain,
|
||||||
|
metadata,
|
||||||
|
movies,
|
||||||
optional_fields,
|
optional_fields,
|
||||||
|
progress,
|
||||||
|
ratings,
|
||||||
utcnow,
|
utcnow,
|
||||||
)
|
)
|
||||||
from .types import ULID
|
from .types import ULID
|
||||||
|
|
@ -26,7 +29,9 @@ from .types import ULID
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
_shared_connection: Database | None = None
|
_engine: AsyncEngine | None = None
|
||||||
|
|
||||||
|
type Connection = AsyncConnection
|
||||||
|
|
||||||
|
|
||||||
async def open_connection_pool() -> None:
|
async def open_connection_pool() -> None:
|
||||||
|
|
@ -34,10 +39,13 @@ async def open_connection_pool() -> None:
|
||||||
|
|
||||||
This function needs to be called before any access to the database can happen.
|
This function needs to be called before any access to the database can happen.
|
||||||
"""
|
"""
|
||||||
db = shared_connection()
|
async with transaction() as conn:
|
||||||
await db.connect()
|
await conn.execute(sa.text("PRAGMA journal_mode=WAL"))
|
||||||
|
|
||||||
await apply_db_patches(db)
|
await conn.run_sync(metadata.create_all, tables=[db_patches])
|
||||||
|
|
||||||
|
async with new_connection() as conn:
|
||||||
|
await apply_db_patches(conn)
|
||||||
|
|
||||||
|
|
||||||
async def close_connection_pool() -> None:
|
async def close_connection_pool() -> None:
|
||||||
|
|
@ -46,48 +54,33 @@ async def close_connection_pool() -> None:
|
||||||
This function should be called before the app shuts down to ensure all data
|
This function should be called before the app shuts down to ensure all data
|
||||||
has been flushed to the database.
|
has been flushed to the database.
|
||||||
"""
|
"""
|
||||||
db = shared_connection()
|
engine = _shared_engine()
|
||||||
|
|
||||||
# Run automatic ANALYZE prior to closing the db,
|
async with engine.begin() as conn:
|
||||||
# see https://sqlite.com/lang_analyze.html.
|
# Run automatic ANALYZE prior to closing the db,
|
||||||
await db.execute("PRAGMA analysis_limit=400")
|
# see https://sqlite.com/lang_analyze.html.
|
||||||
await db.execute("PRAGMA optimize")
|
await conn.execute(sa.text("PRAGMA analysis_limit=400"))
|
||||||
|
await conn.execute(sa.text("PRAGMA optimize"))
|
||||||
|
|
||||||
await db.disconnect()
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
async def _create_patch_db(db):
|
async def current_patch_level(conn: Connection, /) -> str:
|
||||||
query = """
|
query = sa.select(db_patches.c.current)
|
||||||
CREATE TABLE IF NOT EXISTS db_patches (
|
current = await conn.scalar(query)
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
current TEXT
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
await db.execute(query)
|
|
||||||
|
|
||||||
|
|
||||||
async def current_patch_level(db) -> str:
|
|
||||||
await _create_patch_db(db)
|
|
||||||
|
|
||||||
query = "SELECT current FROM db_patches"
|
|
||||||
current = await db.fetch_val(query)
|
|
||||||
return current or ""
|
return current or ""
|
||||||
|
|
||||||
|
|
||||||
async def set_current_patch_level(db, current: str):
|
async def set_current_patch_level(conn: Connection, /, current: str) -> None:
|
||||||
await _create_patch_db(db)
|
stmt = insert(db_patches).values(id=1, current=current)
|
||||||
|
stmt = stmt.on_conflict_do_update(set_={"current": stmt.excluded.current})
|
||||||
query = """
|
await conn.execute(stmt)
|
||||||
INSERT INTO db_patches VALUES (1, :current)
|
|
||||||
ON CONFLICT DO UPDATE SET current=excluded.current
|
|
||||||
"""
|
|
||||||
await db.execute(query, values={"current": current})
|
|
||||||
|
|
||||||
|
|
||||||
db_patches_dir = Path(__file__).parent / "sql"
|
db_patches_dir = Path(__file__).parent / "sql"
|
||||||
|
|
||||||
|
|
||||||
async def apply_db_patches(db: Database):
|
async def apply_db_patches(conn: Connection, /) -> None:
|
||||||
"""Apply all remaining patches to the database.
|
"""Apply all remaining patches to the database.
|
||||||
|
|
||||||
Beware that patches will be applied in lexicographical order,
|
Beware that patches will be applied in lexicographical order,
|
||||||
|
|
@ -99,7 +92,7 @@ async def apply_db_patches(db: Database):
|
||||||
using two consecutive semi-colons (;).
|
using two consecutive semi-colons (;).
|
||||||
Failing to do so will result in an error.
|
Failing to do so will result in an error.
|
||||||
"""
|
"""
|
||||||
applied_lvl = await current_patch_level(db)
|
applied_lvl = await current_patch_level(conn)
|
||||||
|
|
||||||
did_patch = False
|
did_patch = False
|
||||||
|
|
||||||
|
|
@ -118,29 +111,52 @@ async def apply_db_patches(db: Database):
|
||||||
)
|
)
|
||||||
raise RuntimeError("No statement found.")
|
raise RuntimeError("No statement found.")
|
||||||
|
|
||||||
async with db.transaction():
|
async with transacted(conn):
|
||||||
for query in queries:
|
for query in queries:
|
||||||
await db.execute(query)
|
await conn.execute(sa.text(query))
|
||||||
|
|
||||||
await set_current_patch_level(db, patch_lvl)
|
await set_current_patch_level(conn, patch_lvl)
|
||||||
|
|
||||||
did_patch = True
|
did_patch = True
|
||||||
|
|
||||||
if did_patch:
|
if did_patch:
|
||||||
await db.execute("vacuum")
|
await _vacuum(conn)
|
||||||
|
|
||||||
|
|
||||||
async def get_import_progress() -> Progress | None:
|
async def _vacuum(conn: Connection, /) -> None:
|
||||||
|
"""Vacuum the database.
|
||||||
|
|
||||||
|
This function cannot be run on a connection with an open transaction.
|
||||||
|
"""
|
||||||
|
# With SQLAlchemy's "autobegin" behavior we need to switch the connection
|
||||||
|
# to "autocommit" first to keep it from automatically starting a transaction,
|
||||||
|
# as VACUUM cannot be run inside a transaction for most databases.
|
||||||
|
await conn.commit()
|
||||||
|
isolation_level = await conn.get_isolation_level()
|
||||||
|
log.debug("Previous isolation_level: %a", isolation_level)
|
||||||
|
await conn.execution_options(isolation_level="AUTOCOMMIT")
|
||||||
|
try:
|
||||||
|
await conn.execute(sa.text("vacuum"))
|
||||||
|
await conn.commit()
|
||||||
|
finally:
|
||||||
|
await conn.execution_options(isolation_level=isolation_level)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_import_progress(conn: Connection, /) -> Progress | None:
|
||||||
"""Return the latest import progress."""
|
"""Return the latest import progress."""
|
||||||
return await get(Progress, type="import-imdb-movies", order_by="started DESC")
|
return await get(
|
||||||
|
conn, Progress, type="import-imdb-movies", order_by=(progress.c.started, "desc")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def stop_import_progress(*, error: BaseException | None = None):
|
async def stop_import_progress(
|
||||||
|
conn: Connection, /, *, error: BaseException | None = None
|
||||||
|
) -> None:
|
||||||
"""Stop the current import.
|
"""Stop the current import.
|
||||||
|
|
||||||
If an error is given, it will be logged to the progress state.
|
If an error is given, it will be logged to the progress state.
|
||||||
"""
|
"""
|
||||||
current = await get_import_progress()
|
current = await get_import_progress(conn)
|
||||||
is_running = current and current.stopped is None
|
is_running = current and current.stopped is None
|
||||||
|
|
||||||
if not is_running:
|
if not is_running:
|
||||||
|
|
@ -151,17 +167,17 @@ async def stop_import_progress(*, error: BaseException | None = None):
|
||||||
current.error = repr(error)
|
current.error = repr(error)
|
||||||
current.stopped = utcnow().isoformat()
|
current.stopped = utcnow().isoformat()
|
||||||
|
|
||||||
await update(current)
|
await update(conn, current)
|
||||||
|
|
||||||
|
|
||||||
async def set_import_progress(progress: float) -> Progress:
|
async def set_import_progress(conn: Connection, /, progress: float) -> Progress:
|
||||||
"""Set the current import progress percentage.
|
"""Set the current import progress percentage.
|
||||||
|
|
||||||
If no import is currently running, this will create a new one.
|
If no import is currently running, this will create a new one.
|
||||||
"""
|
"""
|
||||||
progress = min(max(0.0, progress), 100.0) # clamp to 0 <= progress <= 100
|
progress = min(max(0.0, progress), 100.0) # clamp to 0 <= progress <= 100
|
||||||
|
|
||||||
current = await get_import_progress()
|
current = await get_import_progress(conn)
|
||||||
is_running = current and current.stopped is None
|
is_running = current and current.stopped is None
|
||||||
|
|
||||||
if not is_running:
|
if not is_running:
|
||||||
|
|
@ -171,163 +187,211 @@ async def set_import_progress(progress: float) -> Progress:
|
||||||
current.percent = progress
|
current.percent = progress
|
||||||
|
|
||||||
if is_running:
|
if is_running:
|
||||||
await update(current)
|
await update(conn, current)
|
||||||
else:
|
else:
|
||||||
await add(current)
|
await add(conn, current)
|
||||||
|
|
||||||
return current
|
return current
|
||||||
|
|
||||||
|
|
||||||
_lock = threading.Lock()
|
def _new_engine() -> AsyncEngine:
|
||||||
_prelock = threading.Lock()
|
uri = f"sqlite+aiosqlite:///{config.storage_path}"
|
||||||
|
|
||||||
|
return create_async_engine(
|
||||||
|
uri,
|
||||||
|
isolation_level="SERIALIZABLE",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _shared_engine() -> AsyncEngine:
|
||||||
|
global _engine
|
||||||
|
|
||||||
|
if _engine is None:
|
||||||
|
_engine = _new_engine()
|
||||||
|
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
|
||||||
|
def _new_connection() -> Connection:
|
||||||
|
return _shared_engine().connect()
|
||||||
|
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def single_threaded():
|
async def transaction(
|
||||||
"""Ensure the nested code is run only by a single thread at a time."""
|
*, force_rollback: bool = False
|
||||||
wait = 1e-5 # XXX not sure if there's a better magic value here
|
) -> AsyncGenerator[Connection, None]:
|
||||||
|
async with new_connection() as conn:
|
||||||
|
yield conn
|
||||||
|
|
||||||
# The pre-lock (a lock for the lock) allows for multiple threads to hand of
|
if not force_rollback:
|
||||||
# the main lock.
|
await conn.commit()
|
||||||
# With only a single lock the contending thread will spend most of its time
|
|
||||||
# in the asyncio.sleep and the reigning thread will have time to finish
|
|
||||||
# whatever it's doing and simply acquire the lock again before the other
|
|
||||||
# thread has had a change to try.
|
|
||||||
# By having another lock (and the same sleep time!) the contending thread
|
|
||||||
# will always have a chance to acquire the main lock.
|
|
||||||
while not _prelock.acquire(blocking=False):
|
|
||||||
await asyncio.sleep(wait)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while not _lock.acquire(blocking=False):
|
|
||||||
await asyncio.sleep(wait)
|
|
||||||
finally:
|
|
||||||
_prelock.release()
|
|
||||||
|
|
||||||
try:
|
# The _test_connection allows pinning a connection that will be shared across the app.
|
||||||
yield
|
# This can (and should only) be used when running tests, NOT IN PRODUCTION!
|
||||||
|
_test_connection: Connection | None = None
|
||||||
finally:
|
|
||||||
_lock.release()
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def locked_connection():
|
async def new_connection() -> AsyncGenerator[Connection, None]:
|
||||||
async with single_threaded():
|
"""Return a new connection.
|
||||||
yield shared_connection()
|
|
||||||
|
Any changes will be rolled back, unless `.commit()` is called on the
|
||||||
|
connection.
|
||||||
|
|
||||||
|
If you want to commit changes, consider using `transaction()` instead.
|
||||||
|
"""
|
||||||
|
conn = _test_connection or _new_connection()
|
||||||
|
|
||||||
|
# Support reusing the same connection for _test_connection.
|
||||||
|
is_started = conn.sync_connection is not None
|
||||||
|
if is_started:
|
||||||
|
yield conn
|
||||||
|
return
|
||||||
|
|
||||||
|
async with conn:
|
||||||
|
yield conn
|
||||||
|
|
||||||
|
|
||||||
def shared_connection() -> Database:
|
@contextlib.asynccontextmanager
|
||||||
global _shared_connection
|
async def transacted(
|
||||||
|
conn: Connection, /, *, force_rollback: bool = False
|
||||||
|
) -> AsyncGenerator[None, None]:
|
||||||
|
transaction = contextlib.nullcontext() if conn.in_transaction() else conn.begin()
|
||||||
|
|
||||||
if _shared_connection is None:
|
async with transaction:
|
||||||
uri = f"sqlite:///{config.storage_path}"
|
try:
|
||||||
_shared_connection = Database(uri)
|
yield
|
||||||
|
|
||||||
return _shared_connection
|
finally:
|
||||||
|
if force_rollback:
|
||||||
|
await conn.rollback()
|
||||||
|
|
||||||
|
|
||||||
async def add(item):
|
async def add(conn: Connection, /, item: Model) -> None:
|
||||||
# Support late initializing - used for optimization.
|
# Support late initializing - used for optimization.
|
||||||
if getattr(item, "_is_lazy", False):
|
if getattr(item, "_is_lazy", False):
|
||||||
item._lazy_init()
|
assert hasattr(item, "_lazy_init")
|
||||||
|
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
|
||||||
|
|
||||||
|
table: sa.Table = item.__table__
|
||||||
values = asplain(item, serialize=True)
|
values = asplain(item, serialize=True)
|
||||||
keys = ", ".join(f"{k}" for k in values)
|
stmt = table.insert().values(values)
|
||||||
placeholders = ", ".join(f":{k}" for k in values)
|
await conn.execute(stmt)
|
||||||
query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})"
|
|
||||||
async with locked_connection() as conn:
|
|
||||||
await conn.execute(query=query, values=values)
|
|
||||||
|
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType")
|
async def fetch_all(
|
||||||
|
conn: Connection, /, query: sa.Executable, values: "dict | None" = None
|
||||||
|
) -> Sequence[sa.Row]:
|
||||||
|
result = await conn.execute(query, values)
|
||||||
|
return result.all()
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_one(
|
||||||
|
conn: Connection, /, query: sa.Executable, values: "dict | None" = None
|
||||||
|
) -> sa.Row | None:
|
||||||
|
result = await conn.execute(query, values)
|
||||||
|
return result.first()
|
||||||
|
|
||||||
|
|
||||||
|
ModelType = TypeVar("ModelType", bound=Model)
|
||||||
|
|
||||||
|
|
||||||
async def get(
|
async def get(
|
||||||
model: Type[ModelType], *, order_by: str | None = None, **kwds
|
conn: Connection,
|
||||||
|
/,
|
||||||
|
model: Type[ModelType],
|
||||||
|
*,
|
||||||
|
order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None,
|
||||||
|
**field_values,
|
||||||
) -> ModelType | None:
|
) -> ModelType | None:
|
||||||
"""Load a model instance from the database.
|
"""Load a model instance from the database.
|
||||||
|
|
||||||
Passing `kwds` allows to filter the instance to load. You have to encode the
|
Passing `field_values` allows to filter the item to load. You have to encode the
|
||||||
values as the appropriate data type for the database prior to passing them
|
values as the appropriate data type for the database prior to passing them
|
||||||
to this function.
|
to this function.
|
||||||
"""
|
"""
|
||||||
values = {k: v for k, v in kwds.items() if v is not None}
|
if not field_values:
|
||||||
if not values:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
fields_ = ", ".join(f.name for f in fields(model))
|
table: sa.Table = model.__table__
|
||||||
cond = " AND ".join(f"{k}=:{k}" for k in values)
|
query = sa.select(model).where(
|
||||||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
*(table.c[k] == v for k, v in field_values.items() if v is not None)
|
||||||
|
)
|
||||||
if order_by:
|
if order_by:
|
||||||
query += f" ORDER BY {order_by}"
|
order_col, order_dir = order_by
|
||||||
async with locked_connection() as conn:
|
query = query.order_by(
|
||||||
row = await conn.fetch_one(query=query, values=values)
|
order_col.asc() if order_dir == "asc" else order_col.desc()
|
||||||
|
)
|
||||||
|
row = await fetch_one(conn, query)
|
||||||
return fromplain(model, row._mapping, serialized=True) if row else None
|
return fromplain(model, row._mapping, serialized=True) if row else None
|
||||||
|
|
||||||
|
|
||||||
async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
async def get_many(
|
||||||
keys = {
|
conn: Connection, /, model: Type[ModelType], **field_sets: set | list
|
||||||
k: [f"{k}_{i}" for i, _ in enumerate(vs, start=1)] for k, vs in kwds.items()
|
) -> Iterable[ModelType]:
|
||||||
}
|
"""Return the items with any values matching all given field sets.
|
||||||
|
|
||||||
if not keys:
|
This is similar to `get_all`, but instead of a scalar value a list of values
|
||||||
|
must be given. If any of the given values is set for that field on an item,
|
||||||
|
the item is considered a match.
|
||||||
|
If no field values are given, no items will be returned.
|
||||||
|
"""
|
||||||
|
if not field_sets:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
values = {n: v for k, vs in kwds.items() for n, v in zip(keys[k], vs)}
|
table: sa.Table = model.__table__
|
||||||
|
query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items()))
|
||||||
|
rows = await fetch_all(conn, query)
|
||||||
|
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
||||||
fields_ = ", ".join(f.name for f in fields(model))
|
|
||||||
cond = " AND ".join(
|
async def get_all(
|
||||||
f"{k} IN ({','.join(':'+n for n in ns)})" for k, ns in keys.items()
|
conn: Connection, /, model: Type[ModelType], **field_values
|
||||||
|
) -> Iterable[ModelType]:
|
||||||
|
"""Filter all items by comparing all given field values.
|
||||||
|
|
||||||
|
If no filters are given, all items will be returned.
|
||||||
|
"""
|
||||||
|
table: sa.Table = model.__table__
|
||||||
|
query = sa.select(model).where(
|
||||||
|
*(table.c[k] == v for k, v in field_values.items() if v is not None)
|
||||||
)
|
)
|
||||||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
rows = await fetch_all(conn, query)
|
||||||
async with locked_connection() as conn:
|
|
||||||
rows = await conn.fetch_all(query=query, values=values)
|
|
||||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
||||||
async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
async def update(conn: Connection, /, item: Model) -> None:
|
||||||
values = {k: v for k, v in kwds.items() if v is not None}
|
|
||||||
|
|
||||||
fields_ = ", ".join(f.name for f in fields(model))
|
|
||||||
cond = " AND ".join(f"{k}=:{k}" for k in values) or "1=1"
|
|
||||||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
|
||||||
async with locked_connection() as conn:
|
|
||||||
rows = await conn.fetch_all(query=query, values=values)
|
|
||||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
|
||||||
|
|
||||||
|
|
||||||
async def update(item):
|
|
||||||
# Support late initializing - used for optimization.
|
# Support late initializing - used for optimization.
|
||||||
if getattr(item, "_is_lazy", False):
|
if getattr(item, "_is_lazy", False):
|
||||||
item._lazy_init()
|
assert hasattr(item, "_lazy_init")
|
||||||
|
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues]
|
||||||
|
|
||||||
|
table: sa.Table = item.__table__
|
||||||
values = asplain(item, serialize=True)
|
values = asplain(item, serialize=True)
|
||||||
keys = ", ".join(f"{k}=:{k}" for k in values if k != "id")
|
stmt = table.update().where(table.c.id == values["id"]).values(values)
|
||||||
query = f"UPDATE {item._table} SET {keys} WHERE id=:id"
|
await conn.execute(stmt)
|
||||||
async with locked_connection() as conn:
|
|
||||||
await conn.execute(query=query, values=values)
|
|
||||||
|
|
||||||
|
|
||||||
async def remove(item):
|
async def remove(conn: Connection, /, item: Model) -> None:
|
||||||
|
table: sa.Table = item.__table__
|
||||||
values = asplain(item, filter_fields={"id"}, serialize=True)
|
values = asplain(item, filter_fields={"id"}, serialize=True)
|
||||||
query = f"DELETE FROM {item._table} WHERE id=:id"
|
stmt = table.delete().where(table.c.id == values["id"])
|
||||||
async with locked_connection() as conn:
|
await conn.execute(stmt)
|
||||||
await conn.execute(query=query, values=values)
|
|
||||||
|
|
||||||
|
|
||||||
async def add_or_update_user(user: User):
|
async def add_or_update_user(conn: Connection, /, user: User) -> None:
|
||||||
db_user = await get(User, imdb_id=user.imdb_id)
|
db_user = await get(conn, User, imdb_id=user.imdb_id)
|
||||||
if not db_user:
|
if not db_user:
|
||||||
await add(user)
|
await add(conn, user)
|
||||||
else:
|
else:
|
||||||
user.id = db_user.id
|
user.id = db_user.id
|
||||||
|
|
||||||
if user != db_user:
|
if user != db_user:
|
||||||
await update(user)
|
await update(conn, user)
|
||||||
|
|
||||||
|
|
||||||
async def add_or_update_many_movies(movies: list[Movie]):
|
async def add_or_update_many_movies(conn: Connection, /, movies: list[Movie]) -> None:
|
||||||
"""Add or update Movies in the database.
|
"""Add or update Movies in the database.
|
||||||
|
|
||||||
This is an optimized version of `add_or_update_movie` for the purpose
|
This is an optimized version of `add_or_update_movie` for the purpose
|
||||||
|
|
@ -336,12 +400,13 @@ async def add_or_update_many_movies(movies: list[Movie]):
|
||||||
# for movie in movies:
|
# for movie in movies:
|
||||||
# await add_or_update_movie(movie)
|
# await add_or_update_movie(movie)
|
||||||
db_movies = {
|
db_movies = {
|
||||||
m.imdb_id: m for m in await get_many(Movie, imdb_id=[m.imdb_id for m in movies])
|
m.imdb_id: m
|
||||||
|
for m in await get_many(conn, Movie, imdb_id=[m.imdb_id for m in movies])
|
||||||
}
|
}
|
||||||
for movie in movies:
|
for movie in movies:
|
||||||
# XXX optimize bulk add & update as well
|
# XXX optimize bulk add & update as well
|
||||||
if movie.imdb_id not in db_movies:
|
if movie.imdb_id not in db_movies:
|
||||||
await add(movie)
|
await add(conn, movie)
|
||||||
else:
|
else:
|
||||||
db_movie = db_movies[movie.imdb_id]
|
db_movie = db_movies[movie.imdb_id]
|
||||||
movie.id = db_movie.id
|
movie.id = db_movie.id
|
||||||
|
|
@ -354,10 +419,10 @@ async def add_or_update_many_movies(movies: list[Movie]):
|
||||||
if movie.updated <= db_movie.updated:
|
if movie.updated <= db_movie.updated:
|
||||||
return
|
return
|
||||||
|
|
||||||
await update(movie)
|
await update(conn, movie)
|
||||||
|
|
||||||
|
|
||||||
async def add_or_update_movie(movie: Movie):
|
async def add_or_update_movie(conn: Connection, /, movie: Movie) -> None:
|
||||||
"""Add or update a Movie in the database.
|
"""Add or update a Movie in the database.
|
||||||
|
|
||||||
This is an upsert operation, but it will also update the Movie you pass
|
This is an upsert operation, but it will also update the Movie you pass
|
||||||
|
|
@ -365,9 +430,9 @@ async def add_or_update_movie(movie: Movie):
|
||||||
set all optional values on your Movie that might be unset but exist in the
|
set all optional values on your Movie that might be unset but exist in the
|
||||||
database. It's a bidirectional sync.
|
database. It's a bidirectional sync.
|
||||||
"""
|
"""
|
||||||
db_movie = await get(Movie, imdb_id=movie.imdb_id)
|
db_movie = await get(conn, Movie, imdb_id=movie.imdb_id)
|
||||||
if not db_movie:
|
if not db_movie:
|
||||||
await add(movie)
|
await add(conn, movie)
|
||||||
else:
|
else:
|
||||||
movie.id = db_movie.id
|
movie.id = db_movie.id
|
||||||
|
|
||||||
|
|
@ -379,33 +444,35 @@ async def add_or_update_movie(movie: Movie):
|
||||||
if movie.updated <= db_movie.updated:
|
if movie.updated <= db_movie.updated:
|
||||||
return
|
return
|
||||||
|
|
||||||
await update(movie)
|
await update(conn, movie)
|
||||||
|
|
||||||
|
|
||||||
async def add_or_update_rating(rating: Rating) -> bool:
|
async def add_or_update_rating(conn: Connection, /, rating: Rating) -> bool:
|
||||||
db_rating = await get(
|
db_rating = await get(
|
||||||
Rating, movie_id=str(rating.movie_id), user_id=str(rating.user_id)
|
conn, Rating, movie_id=str(rating.movie_id), user_id=str(rating.user_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not db_rating:
|
if not db_rating:
|
||||||
await add(rating)
|
await add(conn, rating)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
rating.id = db_rating.id
|
rating.id = db_rating.id
|
||||||
|
|
||||||
if rating != db_rating:
|
if rating != db_rating:
|
||||||
await update(rating)
|
await update(conn, rating)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def sql_escape(s: str, char="#"):
|
def sql_escape(s: str, char: str = "#") -> str:
|
||||||
return s.replace(char, 2 * char).replace("%", f"{char}%").replace("_", f"{char}_")
|
return s.replace(char, 2 * char).replace("%", f"{char}%").replace("_", f"{char}_")
|
||||||
|
|
||||||
|
|
||||||
async def find_ratings(
|
async def find_ratings(
|
||||||
|
conn: Connection,
|
||||||
|
/,
|
||||||
*,
|
*,
|
||||||
title: str | None = None,
|
title: str | None = None,
|
||||||
media_type: str | None = None,
|
media_type: str | None = None,
|
||||||
|
|
@ -415,163 +482,129 @@ async def find_ratings(
|
||||||
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
|
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
|
||||||
limit_rows: int = 10,
|
limit_rows: int = 10,
|
||||||
user_ids: Iterable[str] = [],
|
user_ids: Iterable[str] = [],
|
||||||
):
|
) -> Iterable[dict[str, Any]]:
|
||||||
values: dict[str, int | str] = {
|
|
||||||
"limit_rows": limit_rows,
|
|
||||||
}
|
|
||||||
|
|
||||||
conditions = []
|
conditions = []
|
||||||
|
|
||||||
if title:
|
if title:
|
||||||
values["escape"] = "#"
|
escape_char = "#"
|
||||||
escaped_title = sql_escape(title, char=values["escape"])
|
escaped_title = sql_escape(title, char=escape_char)
|
||||||
values["pattern"] = (
|
pattern = (
|
||||||
"_".join(escaped_title.split())
|
"_".join(escaped_title.split())
|
||||||
if exact
|
if exact
|
||||||
else "%" + "%".join(escaped_title.split()) + "%"
|
else "%" + "%".join(escaped_title.split()) + "%"
|
||||||
)
|
)
|
||||||
conditions.append(
|
conditions.append(
|
||||||
f"""
|
sa.or_(
|
||||||
(
|
movies.c.title.like(pattern, escape=escape_char),
|
||||||
{Movie._table}.title LIKE :pattern ESCAPE :escape
|
movies.c.original_title.like(pattern, escape=escape_char),
|
||||||
OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape
|
|
||||||
)
|
)
|
||||||
"""
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if yearcomp:
|
match yearcomp:
|
||||||
op, year = yearcomp
|
case ("<", year):
|
||||||
assert op in "<=>"
|
conditions.append(movies.c.release_year < year)
|
||||||
values["year"] = year
|
case ("=", year):
|
||||||
conditions.append(f"{Movie._table}.release_year{op}:year")
|
conditions.append(movies.c.release_year == year)
|
||||||
|
case (">", year):
|
||||||
|
conditions.append(movies.c.release_year > year)
|
||||||
|
|
||||||
if media_type:
|
if media_type is not None:
|
||||||
values["media_type"] = media_type
|
conditions.append(movies.c.media_type == media_type)
|
||||||
conditions.append(f"{Movie._table}.media_type=:media_type")
|
|
||||||
|
|
||||||
if ignore_tv_episodes:
|
if ignore_tv_episodes:
|
||||||
conditions.append(f"{Movie._table}.media_type!='TV Episode'")
|
conditions.append(movies.c.media_type != "TV Episode")
|
||||||
|
|
||||||
user_condition = "1=1"
|
user_condition = []
|
||||||
if user_ids:
|
if user_ids:
|
||||||
uvs = {f"user_id_{i}": v for i, v in enumerate(user_ids, start=1)}
|
user_condition.append(ratings.c.user_id.in_(user_ids))
|
||||||
values.update(uvs)
|
|
||||||
user_condition = f"{Rating._table}.user_id IN ({','.join(':'+n for n in uvs)})"
|
|
||||||
|
|
||||||
query = f"""
|
query = (
|
||||||
SELECT DISTINCT {Rating._table}.movie_id
|
sa.select(ratings.c.movie_id)
|
||||||
FROM {Rating._table}
|
.distinct()
|
||||||
LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id
|
.outerjoin_from(ratings, movies, movies.c.id == ratings.c.movie_id)
|
||||||
WHERE {user_condition}{(' AND ' + ' AND '.join(conditions)) if conditions else ''}
|
.where(*conditions, *user_condition)
|
||||||
ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC, {Movie._table}.imdb_score DESC
|
.order_by(
|
||||||
LIMIT :limit_rows
|
sa.func.length(movies.c.title).asc(),
|
||||||
"""
|
ratings.c.rating_date.desc(),
|
||||||
async with locked_connection() as conn:
|
movies.c.imdb_score.desc(),
|
||||||
rows = await conn.fetch_all(bindparams(query, values))
|
)
|
||||||
movie_ids = tuple(r._mapping["movie_id"] for r in rows)
|
.limit(limit_rows)
|
||||||
|
)
|
||||||
|
rating_rows: sa.CursorResult[Rating] = await conn.execute(query)
|
||||||
|
movie_ids = [r.movie_id for r in rating_rows]
|
||||||
|
|
||||||
if include_unrated and len(movie_ids) < limit_rows:
|
if include_unrated and len(movie_ids) < limit_rows:
|
||||||
sqlin, sqlin_vals = sql_in("id", movie_ids, not_=True)
|
query = (
|
||||||
query = f"""
|
sa.select(movies.c.id)
|
||||||
SELECT DISTINCT id AS movie_id
|
.distinct()
|
||||||
FROM {Movie._table}
|
.where(movies.c.id.not_in(movie_ids), *conditions)
|
||||||
WHERE {sqlin}
|
.order_by(
|
||||||
{('AND ' + ' AND '.join(conditions)) if conditions else ''}
|
sa.func.length(movies.c.title).asc(),
|
||||||
ORDER BY length(title) ASC, imdb_score DESC, release_year DESC
|
movies.c.imdb_score.desc(),
|
||||||
LIMIT :limit_rows
|
movies.c.release_year.desc(),
|
||||||
"""
|
|
||||||
async with locked_connection() as conn:
|
|
||||||
rows = await conn.fetch_all(
|
|
||||||
bindparams(
|
|
||||||
query,
|
|
||||||
{**values, **sqlin_vals, "limit_rows": limit_rows - len(movie_ids)},
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
movie_ids += tuple(r._mapping["movie_id"] for r in rows)
|
.limit(limit_rows - len(movie_ids))
|
||||||
|
)
|
||||||
|
movie_rows: sa.CursorResult[Movie] = await conn.execute(query)
|
||||||
|
movie_ids += [r.id for r in movie_rows]
|
||||||
|
|
||||||
return await ratings_for_movie_ids(ids=movie_ids)
|
return await ratings_for_movie_ids(conn, ids=movie_ids)
|
||||||
|
|
||||||
|
|
||||||
async def ratings_for_movie_ids(
|
async def ratings_for_movie_ids(
|
||||||
ids: Iterable[ULID | str] = [], imdb_ids: Iterable[str] = []
|
conn: Connection,
|
||||||
|
/,
|
||||||
|
ids: Iterable[ULID | str] = [],
|
||||||
|
imdb_ids: Iterable[str] = [],
|
||||||
) -> Iterable[dict[str, Any]]:
|
) -> Iterable[dict[str, Any]]:
|
||||||
conds: list[str] = []
|
conds = []
|
||||||
vals: dict[str, str] = {}
|
|
||||||
|
|
||||||
if ids:
|
if ids:
|
||||||
sqlin, sqlin_vals = sql_in(f"{Movie._table}.id", (str(x) for x in ids))
|
conds.append(movies.c.id.in_([str(x) for x in ids]))
|
||||||
conds.append(sqlin)
|
|
||||||
vals.update(sqlin_vals)
|
|
||||||
|
|
||||||
if imdb_ids:
|
if imdb_ids:
|
||||||
sqlin, sqlin_vals = sql_in(f"{Movie._table}.imdb_id", imdb_ids)
|
conds.append(movies.c.imdb_id.in_(imdb_ids))
|
||||||
conds.append(sqlin)
|
|
||||||
vals.update(sqlin_vals)
|
|
||||||
|
|
||||||
if not conds:
|
if not conds:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
query = f"""
|
query = (
|
||||||
SELECT
|
sa.select(
|
||||||
{Rating._table}.score AS user_score,
|
ratings.c.score.label("user_score"),
|
||||||
{Rating._table}.user_id AS user_id,
|
ratings.c.user_id.label("user_id"),
|
||||||
{Movie._table}.imdb_score,
|
movies.c.imdb_score,
|
||||||
{Movie._table}.imdb_votes,
|
movies.c.imdb_votes,
|
||||||
{Movie._table}.imdb_id AS movie_imdb_id,
|
movies.c.imdb_id.label("movie_imdb_id"),
|
||||||
{Movie._table}.media_type AS media_type,
|
movies.c.media_type.label("media_type"),
|
||||||
{Movie._table}.title AS canonical_title,
|
movies.c.title.label("canonical_title"),
|
||||||
{Movie._table}.original_title AS original_title,
|
movies.c.original_title.label("original_title"),
|
||||||
{Movie._table}.release_year AS release_year
|
movies.c.release_year.label("release_year"),
|
||||||
FROM {Movie._table}
|
)
|
||||||
LEFT JOIN {Rating._table} ON {Movie._table}.id={Rating._table}.movie_id
|
.outerjoin_from(movies, ratings, movies.c.id == ratings.c.movie_id)
|
||||||
WHERE {(' OR '.join(conds))}
|
.where(sa.or_(*conds))
|
||||||
"""
|
)
|
||||||
|
rows = await fetch_all(conn, query)
|
||||||
async with locked_connection() as conn:
|
|
||||||
rows = await conn.fetch_all(bindparams(query, vals))
|
|
||||||
return tuple(dict(r._mapping) for r in rows)
|
return tuple(dict(r._mapping) for r in rows)
|
||||||
|
|
||||||
|
|
||||||
def sql_fields(tp: Type):
|
|
||||||
return (f"{tp._table}.{f.name}" for f in fields(tp))
|
|
||||||
|
|
||||||
|
|
||||||
def sql_in(column: str, values: Iterable[T], not_=False) -> tuple[str, dict[str, T]]:
|
|
||||||
c = column.replace(".", "___")
|
|
||||||
value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)}
|
|
||||||
placeholders = ",".join(":" + k for k in value_map)
|
|
||||||
if not_:
|
|
||||||
return f"{column} NOT IN ({placeholders})", value_map
|
|
||||||
return f"{column} IN ({placeholders})", value_map
|
|
||||||
|
|
||||||
|
|
||||||
async def ratings_for_movies(
|
async def ratings_for_movies(
|
||||||
movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = []
|
conn: Connection, /, movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = []
|
||||||
) -> Iterable[Rating]:
|
) -> Iterable[Rating]:
|
||||||
values: dict[str, str] = {}
|
conditions = [ratings.c.movie_id.in_(str(x) for x in movie_ids)]
|
||||||
conditions: list[str] = []
|
|
||||||
|
|
||||||
q, vm = sql_in("movie_id", [str(m) for m in movie_ids])
|
|
||||||
conditions.append(q)
|
|
||||||
values.update(vm)
|
|
||||||
|
|
||||||
if user_ids:
|
if user_ids:
|
||||||
q, vm = sql_in("user_id", [str(m) for m in user_ids])
|
conditions.append(ratings.c.user_id.in_(str(x) for x in user_ids))
|
||||||
conditions.append(q)
|
|
||||||
values.update(vm)
|
|
||||||
|
|
||||||
query = f"""
|
query = sa.select(ratings).where(*conditions)
|
||||||
SELECT {','.join(sql_fields(Rating))}
|
|
||||||
FROM {Rating._table}
|
|
||||||
WHERE {' AND '.join(f'({c})' for c in conditions) if conditions else '1=1'}
|
|
||||||
"""
|
|
||||||
|
|
||||||
async with locked_connection() as conn:
|
rows = await fetch_all(conn, query)
|
||||||
rows = await conn.fetch_all(query, values)
|
|
||||||
|
|
||||||
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
|
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
||||||
async def find_movies(
|
async def find_movies(
|
||||||
|
conn: Connection,
|
||||||
|
/,
|
||||||
*,
|
*,
|
||||||
title: str | None = None,
|
title: str | None = None,
|
||||||
media_type: str | None = None,
|
media_type: str | None = None,
|
||||||
|
|
@ -583,88 +616,63 @@ async def find_movies(
|
||||||
include_unrated: bool = False,
|
include_unrated: bool = False,
|
||||||
user_ids: list[ULID] = [],
|
user_ids: list[ULID] = [],
|
||||||
) -> Iterable[tuple[Movie, list[Rating]]]:
|
) -> Iterable[tuple[Movie, list[Rating]]]:
|
||||||
values: dict[str, int | str] = {
|
|
||||||
"limit_rows": limit_rows,
|
|
||||||
"skip_rows": skip_rows,
|
|
||||||
}
|
|
||||||
|
|
||||||
conditions = []
|
conditions = []
|
||||||
|
|
||||||
if title:
|
if title:
|
||||||
values["escape"] = "#"
|
escape_char = "#"
|
||||||
escaped_title = sql_escape(title, char=values["escape"])
|
escaped_title = sql_escape(title, char=escape_char)
|
||||||
values["pattern"] = (
|
pattern = (
|
||||||
"_".join(escaped_title.split())
|
"_".join(escaped_title.split())
|
||||||
if exact
|
if exact
|
||||||
else "%" + "%".join(escaped_title.split()) + "%"
|
else "%" + "%".join(escaped_title.split()) + "%"
|
||||||
)
|
)
|
||||||
conditions.append(
|
conditions.append(
|
||||||
f"""
|
sa.or_(
|
||||||
(
|
movies.c.title.like(pattern, escape=escape_char),
|
||||||
{Movie._table}.title LIKE :pattern ESCAPE :escape
|
movies.c.original_title.like(pattern, escape=escape_char),
|
||||||
OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape
|
|
||||||
)
|
)
|
||||||
"""
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if yearcomp:
|
match yearcomp:
|
||||||
op, year = yearcomp
|
case ("<", year):
|
||||||
assert op in "<=>"
|
conditions.append(movies.c.release_year < year)
|
||||||
values["year"] = year
|
case ("=", year):
|
||||||
conditions.append(f"{Movie._table}.release_year{op}:year")
|
conditions.append(movies.c.release_year == year)
|
||||||
|
case (">", year):
|
||||||
|
conditions.append(movies.c.release_year > year)
|
||||||
|
|
||||||
if media_type:
|
if media_type is not None:
|
||||||
values["media_type"] = media_type
|
conditions.append(movies.c.media_type == media_type)
|
||||||
conditions.append(f"{Movie._table}.media_type=:media_type")
|
|
||||||
|
|
||||||
if ignore_tv_episodes:
|
if ignore_tv_episodes:
|
||||||
conditions.append(f"{Movie._table}.media_type!='TV Episode'")
|
conditions.append(movies.c.media_type != "TV Episode")
|
||||||
|
|
||||||
if not include_unrated:
|
if not include_unrated:
|
||||||
conditions.append(f"{Movie._table}.imdb_score NOTNULL")
|
conditions.append(movies.c.imdb_score.is_not(None))
|
||||||
|
|
||||||
query = f"""
|
query = (
|
||||||
SELECT {','.join(sql_fields(Movie))}
|
sa.select(movies)
|
||||||
FROM {Movie._table}
|
.where(*conditions)
|
||||||
WHERE {(' AND '.join(conditions)) if conditions else '1=1'}
|
.order_by(
|
||||||
ORDER BY
|
sa.func.length(movies.c.title).asc(),
|
||||||
length({Movie._table}.title) ASC,
|
movies.c.imdb_score.desc(),
|
||||||
{Movie._table}.imdb_score DESC,
|
movies.c.release_year.desc(),
|
||||||
{Movie._table}.release_year DESC
|
)
|
||||||
LIMIT :skip_rows, :limit_rows
|
.limit(limit_rows)
|
||||||
"""
|
.offset(skip_rows)
|
||||||
async with locked_connection() as conn:
|
)
|
||||||
rows = await conn.fetch_all(bindparams(query, values))
|
|
||||||
|
|
||||||
movies = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
|
rows = await fetch_all(conn, query)
|
||||||
|
|
||||||
|
movies_ = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
|
||||||
|
|
||||||
if not user_ids:
|
if not user_ids:
|
||||||
return ((m, []) for m in movies)
|
return ((m, []) for m in movies_)
|
||||||
|
|
||||||
ratings = await ratings_for_movies((m.id for m in movies), user_ids)
|
ratings = await ratings_for_movies(conn, (m.id for m in movies_), user_ids)
|
||||||
|
|
||||||
aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies}
|
aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies_}
|
||||||
for rating in ratings:
|
for rating in ratings:
|
||||||
aggreg[rating.movie_id][1].append(rating)
|
aggreg[rating.movie_id][1].append(rating)
|
||||||
|
|
||||||
return aggreg.values()
|
return aggreg.values()
|
||||||
|
|
||||||
|
|
||||||
def bindparams(query: str, values: dict):
|
|
||||||
"""Bind values to a query.
|
|
||||||
|
|
||||||
This is similar to what SQLAlchemy and Databases do, but it allows to
|
|
||||||
easily use the same placeholder in multiple places.
|
|
||||||
"""
|
|
||||||
pump_vals = {}
|
|
||||||
pump_keys = {}
|
|
||||||
|
|
||||||
def pump(match):
|
|
||||||
key = match[1]
|
|
||||||
val = values[key]
|
|
||||||
pump_keys[key] = 1 + pump_keys.setdefault(key, 0)
|
|
||||||
pump_key = f"{key}_{pump_keys[key]}"
|
|
||||||
pump_vals[pump_key] = val
|
|
||||||
return f":{pump_key}"
|
|
||||||
|
|
||||||
pump_query = re.sub(r":(\w+)\b", pump, query)
|
|
||||||
return sqlalchemy.text(pump_query).bindparams(**pump_vals)
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ from collections import namedtuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
import bs4
|
||||||
|
|
||||||
from . import db
|
from . import db
|
||||||
from .models import Movie, Rating, User
|
from .models import Movie, Rating, User
|
||||||
from .request import asession, asoup_from_url, cache_path
|
from .request import asession, asoup_from_url, cache_path
|
||||||
|
|
@ -38,12 +40,14 @@ async def refresh_user_ratings_from_imdb(stop_on_dupe: bool = True):
|
||||||
async with asession() as s:
|
async with asession() as s:
|
||||||
s.headers["Accept-Language"] = "en-US, en;q=0.5"
|
s.headers["Accept-Language"] = "en-US, en;q=0.5"
|
||||||
|
|
||||||
for user in await db.get_all(User):
|
async with db.new_connection() as conn:
|
||||||
|
users = list(await db.get_all(conn, User))
|
||||||
|
for user in users:
|
||||||
log.info("⚡️ Loading data for %s ...", user.name)
|
log.info("⚡️ Loading data for %s ...", user.name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for rating, is_updated in load_ratings(user.imdb_id):
|
async for rating, is_updated in load_ratings(user.imdb_id):
|
||||||
assert rating.user.id == user.id
|
assert rating.user is not None and rating.user.id == user.id
|
||||||
|
|
||||||
if stop_on_dupe and not is_updated:
|
if stop_on_dupe and not is_updated:
|
||||||
break
|
break
|
||||||
|
|
@ -94,7 +98,7 @@ find_year = re.compile(
|
||||||
find_movie_id = re.compile(r"/title/(?P<id>tt\d+)/").search
|
find_movie_id = re.compile(r"/title/(?P<id>tt\d+)/").search
|
||||||
|
|
||||||
|
|
||||||
def movie_and_rating_from_item(item) -> tuple[Movie, Rating]:
|
def movie_and_rating_from_item(item: bs4.Tag) -> tuple[Movie, Rating]:
|
||||||
genres = (genre := item.find("span", "genre")) and genre.string or ""
|
genres = (genre := item.find("span", "genre")) and genre.string or ""
|
||||||
movie = Movie(
|
movie = Movie(
|
||||||
title=item.h3.a.string.strip(),
|
title=item.h3.a.string.strip(),
|
||||||
|
|
@ -154,13 +158,19 @@ async def parse_page(url: str) -> tuple[list[Rating], str | None]:
|
||||||
|
|
||||||
soup = await asoup_from_url(url)
|
soup = await asoup_from_url(url)
|
||||||
|
|
||||||
meta = soup.find("meta", property="pageId")
|
if (meta := soup.find("meta", property="pageId")) is None:
|
||||||
headline = soup.h1
|
raise RuntimeError("No pageId found.")
|
||||||
assert meta is not None and headline is not None
|
assert isinstance(meta, bs4.Tag)
|
||||||
imdb_id = meta["content"]
|
imdb_id = meta["content"]
|
||||||
user = await db.get(User, imdb_id=imdb_id) or User(
|
assert isinstance(imdb_id, str)
|
||||||
imdb_id=imdb_id, name="", secret=""
|
async with db.new_connection() as conn:
|
||||||
)
|
user = await db.get(conn, User, imdb_id=imdb_id) or User(
|
||||||
|
imdb_id=imdb_id, name="", secret=""
|
||||||
|
)
|
||||||
|
|
||||||
|
if (headline := soup.h1) is None:
|
||||||
|
raise RuntimeError("No headline found.")
|
||||||
|
assert isinstance(headline.string, str)
|
||||||
if match := find_name(headline.string):
|
if match := find_name(headline.string):
|
||||||
user.name = match["name"]
|
user.name = match["name"]
|
||||||
|
|
||||||
|
|
@ -184,9 +194,15 @@ async def parse_page(url: str) -> tuple[list[Rating], str | None]:
|
||||||
|
|
||||||
ratings.append(rating)
|
ratings.append(rating)
|
||||||
|
|
||||||
footer = soup.find("div", "footer")
|
next_url = None
|
||||||
assert footer is not None
|
if (footer := soup.find("div", "footer")) is None:
|
||||||
next_url = urljoin(url, footer.find(string=re.compile(r"Next")).parent["href"])
|
raise RuntimeError("No footer found.")
|
||||||
|
assert isinstance(footer, bs4.Tag)
|
||||||
|
if (next_link := footer.find("a", string="Next")) is not None:
|
||||||
|
assert isinstance(next_link, bs4.Tag)
|
||||||
|
next_href = next_link["href"]
|
||||||
|
assert isinstance(next_href, str)
|
||||||
|
next_url = urljoin(url, next_href)
|
||||||
|
|
||||||
return (ratings, next_url if url != next_url else None)
|
return (ratings, next_url if url != next_url else None)
|
||||||
|
|
||||||
|
|
@ -200,14 +216,15 @@ async def load_ratings(user_id: str):
|
||||||
for i, rating in enumerate(ratings):
|
for i, rating in enumerate(ratings):
|
||||||
assert rating.user and rating.movie
|
assert rating.user and rating.movie
|
||||||
|
|
||||||
if i == 0:
|
async with db.transaction() as conn:
|
||||||
# All rating objects share the same user.
|
if i == 0:
|
||||||
await db.add_or_update_user(rating.user)
|
# All rating objects share the same user.
|
||||||
rating.user_id = rating.user.id
|
await db.add_or_update_user(conn, rating.user)
|
||||||
|
rating.user_id = rating.user.id
|
||||||
|
|
||||||
await db.add_or_update_movie(rating.movie)
|
await db.add_or_update_movie(conn, rating.movie)
|
||||||
rating.movie_id = rating.movie.id
|
rating.movie_id = rating.movie.id
|
||||||
|
|
||||||
is_updated = await db.add_or_update_rating(rating)
|
is_updated = await db.add_or_update_rating(conn, rating)
|
||||||
|
|
||||||
yield rating, is_updated
|
yield rating, is_updated
|
||||||
|
|
|
||||||
|
|
@ -209,7 +209,8 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path):
|
||||||
for i, m in enumerate(read_basics(basics_path)):
|
for i, m in enumerate(read_basics(basics_path)):
|
||||||
perc = 100 * i / total
|
perc = 100 * i / total
|
||||||
if perc >= perc_next_report:
|
if perc >= perc_next_report:
|
||||||
await db.set_import_progress(perc)
|
async with db.transaction() as conn:
|
||||||
|
await db.set_import_progress(conn, perc)
|
||||||
log.info("⏳ Imported %s%%", round(perc, 1))
|
log.info("⏳ Imported %s%%", round(perc, 1))
|
||||||
perc_next_report += perc_step
|
perc_next_report += perc_step
|
||||||
|
|
||||||
|
|
@ -233,15 +234,18 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path):
|
||||||
chunk.append(m)
|
chunk.append(m)
|
||||||
|
|
||||||
if len(chunk) > 1000:
|
if len(chunk) > 1000:
|
||||||
await add_or_update_many_movies(chunk)
|
async with db.transaction() as conn:
|
||||||
|
await add_or_update_many_movies(conn, chunk)
|
||||||
chunk = []
|
chunk = []
|
||||||
|
|
||||||
if chunk:
|
if chunk:
|
||||||
await add_or_update_many_movies(chunk)
|
async with db.transaction() as conn:
|
||||||
|
await add_or_update_many_movies(conn, chunk)
|
||||||
chunk = []
|
chunk = []
|
||||||
|
|
||||||
log.info("👍 Imported 100%")
|
log.info("👍 Imported 100%")
|
||||||
await db.set_import_progress(100)
|
async with db.transaction() as conn:
|
||||||
|
await db.set_import_progress(conn, 100)
|
||||||
|
|
||||||
|
|
||||||
async def download_datasets(*, basics_path: Path, ratings_path: Path) -> None:
|
async def download_datasets(*, basics_path: Path, ratings_path: Path) -> None:
|
||||||
|
|
@ -270,7 +274,8 @@ async def load_from_web(*, force: bool = False) -> None:
|
||||||
See https://www.imdb.com/interfaces/ and https://datasets.imdbws.com/ for
|
See https://www.imdb.com/interfaces/ and https://datasets.imdbws.com/ for
|
||||||
more information on the IMDb database dumps.
|
more information on the IMDb database dumps.
|
||||||
"""
|
"""
|
||||||
await db.set_import_progress(0)
|
async with db.transaction() as conn:
|
||||||
|
await db.set_import_progress(conn, 0)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ratings_file = config.datadir / "imdb/title.ratings.tsv.gz"
|
ratings_file = config.datadir / "imdb/title.ratings.tsv.gz"
|
||||||
|
|
@ -290,8 +295,10 @@ async def load_from_web(*, force: bool = False) -> None:
|
||||||
await import_from_file(basics_path=basics_file, ratings_path=ratings_file)
|
await import_from_file(basics_path=basics_file, ratings_path=ratings_file)
|
||||||
|
|
||||||
except BaseException as err:
|
except BaseException as err:
|
||||||
await db.stop_import_progress(error=err)
|
async with db.transaction() as conn:
|
||||||
|
await db.stop_import_progress(conn, error=err)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await db.stop_import_progress()
|
async with db.transaction() as conn:
|
||||||
|
await db.stop_import_progress(conn)
|
||||||
|
|
|
||||||
174
unwind/models.py
174
unwind/models.py
|
|
@ -11,13 +11,18 @@ from typing import (
|
||||||
Container,
|
Container,
|
||||||
Literal,
|
Literal,
|
||||||
Mapping,
|
Mapping,
|
||||||
|
Protocol,
|
||||||
Type,
|
Type,
|
||||||
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
get_args,
|
get_args,
|
||||||
get_origin,
|
get_origin,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from sqlalchemy import Column, ForeignKey, Integer, String, Table
|
||||||
|
from sqlalchemy.orm import registry
|
||||||
|
|
||||||
from .types import ULID
|
from .types import ULID
|
||||||
|
|
||||||
JSON = int | float | str | None | list["JSON"] | dict[str, "JSON"]
|
JSON = int | float | str | None | list["JSON"] | dict[str, "JSON"]
|
||||||
|
|
@ -26,8 +31,16 @@ JSONObject = dict[str, JSON]
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class Model(Protocol):
|
||||||
|
__table__: ClassVar[Table]
|
||||||
|
|
||||||
|
|
||||||
|
mapper_registry = registry()
|
||||||
|
metadata = mapper_registry.metadata
|
||||||
|
|
||||||
|
|
||||||
def annotations(tp: Type) -> tuple | None:
|
def annotations(tp: Type) -> tuple | None:
|
||||||
return tp.__metadata__ if hasattr(tp, "__metadata__") else None
|
return tp.__metadata__ if hasattr(tp, "__metadata__") else None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def fields(class_or_instance):
|
def fields(class_or_instance):
|
||||||
|
|
@ -112,7 +125,7 @@ def asplain(
|
||||||
if filter_fields is not None and f.name not in filter_fields:
|
if filter_fields is not None and f.name not in filter_fields:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
target = f.type
|
target: Any = f.type
|
||||||
# XXX this doesn't properly support any kind of nested types
|
# XXX this doesn't properly support any kind of nested types
|
||||||
if (otype := optional_type(f.type)) is not None:
|
if (otype := optional_type(f.type)) is not None:
|
||||||
target = otype
|
target = otype
|
||||||
|
|
@ -156,7 +169,7 @@ def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
|
||||||
|
|
||||||
dd: JSONObject = {}
|
dd: JSONObject = {}
|
||||||
for f in fields(cls):
|
for f in fields(cls):
|
||||||
target = f.type
|
target: Any = f.type
|
||||||
otype = optional_type(f.type)
|
otype = optional_type(f.type)
|
||||||
is_opt = otype is not None
|
is_opt = otype is not None
|
||||||
if is_opt:
|
if is_opt:
|
||||||
|
|
@ -194,12 +207,38 @@ def validate(o: object) -> None:
|
||||||
|
|
||||||
|
|
||||||
def utcnow():
|
def utcnow():
|
||||||
return datetime.utcnow().replace(tzinfo=timezone.utc)
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
@mapper_registry.mapped
|
||||||
|
@dataclass
|
||||||
|
class DbPatch:
|
||||||
|
__table__: ClassVar[Table] = Table(
|
||||||
|
"db_patches",
|
||||||
|
metadata,
|
||||||
|
Column("id", Integer, primary_key=True),
|
||||||
|
Column("current", String),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: int
|
||||||
|
current: str
|
||||||
|
|
||||||
|
|
||||||
|
db_patches = DbPatch.__table__
|
||||||
|
|
||||||
|
|
||||||
|
@mapper_registry.mapped
|
||||||
@dataclass
|
@dataclass
|
||||||
class Progress:
|
class Progress:
|
||||||
_table: ClassVar[str] = "progress"
|
__table__: ClassVar[Table] = Table(
|
||||||
|
"progress",
|
||||||
|
metadata,
|
||||||
|
Column("id", String, primary_key=True), # ULID
|
||||||
|
Column("type", String, nullable=False),
|
||||||
|
Column("state", String, nullable=False), # JSON {"percent": ..., "error": ...}
|
||||||
|
Column("started", String, nullable=False), # datetime
|
||||||
|
Column("stopped", String),
|
||||||
|
)
|
||||||
|
|
||||||
id: ULID = field(default_factory=ULID)
|
id: ULID = field(default_factory=ULID)
|
||||||
type: str = None
|
type: str = None
|
||||||
|
|
@ -236,9 +275,28 @@ class Progress:
|
||||||
self._state = state
|
self._state = state
|
||||||
|
|
||||||
|
|
||||||
|
progress = Progress.__table__
|
||||||
|
|
||||||
|
|
||||||
|
@mapper_registry.mapped
|
||||||
@dataclass
|
@dataclass
|
||||||
class Movie:
|
class Movie:
|
||||||
_table: ClassVar[str] = "movies"
|
__table__: ClassVar[Table] = Table(
|
||||||
|
"movies",
|
||||||
|
metadata,
|
||||||
|
Column("id", String, primary_key=True), # ULID
|
||||||
|
Column("title", String, nullable=False),
|
||||||
|
Column("original_title", String),
|
||||||
|
Column("release_year", Integer, nullable=False),
|
||||||
|
Column("media_type", String, nullable=False),
|
||||||
|
Column("imdb_id", String, nullable=False, unique=True),
|
||||||
|
Column("imdb_score", Integer),
|
||||||
|
Column("imdb_votes", Integer),
|
||||||
|
Column("runtime", Integer),
|
||||||
|
Column("genres", String, nullable=False),
|
||||||
|
Column("created", String, nullable=False), # datetime
|
||||||
|
Column("updated", String, nullable=False), # datetime
|
||||||
|
)
|
||||||
|
|
||||||
id: ULID = field(default_factory=ULID)
|
id: ULID = field(default_factory=ULID)
|
||||||
title: str = None # canonical title (usually English)
|
title: str = None # canonical title (usually English)
|
||||||
|
|
@ -283,6 +341,8 @@ class Movie:
|
||||||
self._is_lazy = False
|
self._is_lazy = False
|
||||||
|
|
||||||
|
|
||||||
|
movies = Movie.__table__
|
||||||
|
|
||||||
_RelationSentinel = object()
|
_RelationSentinel = object()
|
||||||
"""Mark a model field as containing external data.
|
"""Mark a model field as containing external data.
|
||||||
|
|
||||||
|
|
@ -294,9 +354,65 @@ The contents of the Relation are ignored or discarded when using
|
||||||
Relation = Annotated[T | None, _RelationSentinel]
|
Relation = Annotated[T | None, _RelationSentinel]
|
||||||
|
|
||||||
|
|
||||||
|
Access = Literal[
|
||||||
|
"r", # read
|
||||||
|
"i", # index
|
||||||
|
"w", # write
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class UserGroup(TypedDict):
|
||||||
|
id: str
|
||||||
|
access: Access
|
||||||
|
|
||||||
|
|
||||||
|
@mapper_registry.mapped
|
||||||
|
@dataclass
|
||||||
|
class User:
|
||||||
|
__table__: ClassVar[Table] = Table(
|
||||||
|
"users",
|
||||||
|
metadata,
|
||||||
|
Column("id", String, primary_key=True), # ULID
|
||||||
|
Column("imdb_id", String, nullable=False, unique=True),
|
||||||
|
Column("name", String, nullable=False),
|
||||||
|
Column("secret", String, nullable=False),
|
||||||
|
Column("groups", String, nullable=False), # JSON array
|
||||||
|
)
|
||||||
|
|
||||||
|
id: ULID = field(default_factory=ULID)
|
||||||
|
imdb_id: str = None
|
||||||
|
name: str = None # canonical user name
|
||||||
|
secret: str = None
|
||||||
|
groups: list[UserGroup] = field(default_factory=list)
|
||||||
|
|
||||||
|
def has_access(self, group_id: ULID | str, access: Access = "r"):
|
||||||
|
group_id = group_id if isinstance(group_id, str) else str(group_id)
|
||||||
|
return any(g["id"] == group_id and access == g["access"] for g in self.groups)
|
||||||
|
|
||||||
|
def set_access(self, group_id: ULID | str, access: Access):
|
||||||
|
group_id = group_id if isinstance(group_id, str) else str(group_id)
|
||||||
|
for g in self.groups:
|
||||||
|
if g["id"] == group_id:
|
||||||
|
g["access"] = access
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self.groups.append({"id": group_id, "access": access})
|
||||||
|
|
||||||
|
|
||||||
|
@mapper_registry.mapped
|
||||||
@dataclass
|
@dataclass
|
||||||
class Rating:
|
class Rating:
|
||||||
_table: ClassVar[str] = "ratings"
|
__table__: ClassVar[Table] = Table(
|
||||||
|
"ratings",
|
||||||
|
metadata,
|
||||||
|
Column("id", String, primary_key=True), # ULID
|
||||||
|
Column("movie_id", ForeignKey("movies.id"), nullable=False), # ULID
|
||||||
|
Column("user_id", ForeignKey("users.id"), nullable=False), # ULID
|
||||||
|
Column("score", Integer, nullable=False),
|
||||||
|
Column("rating_date", String, nullable=False), # datetime
|
||||||
|
Column("favorite", Integer), # bool
|
||||||
|
Column("finished", Integer), # bool
|
||||||
|
)
|
||||||
|
|
||||||
id: ULID = field(default_factory=ULID)
|
id: ULID = field(default_factory=ULID)
|
||||||
|
|
||||||
|
|
@ -304,7 +420,7 @@ class Rating:
|
||||||
movie: Relation[Movie] = None
|
movie: Relation[Movie] = None
|
||||||
|
|
||||||
user_id: ULID = None
|
user_id: ULID = None
|
||||||
user: Relation["User"] = None
|
user: Relation[User] = None
|
||||||
|
|
||||||
score: int = None # range: [0,100]
|
score: int = None # range: [0,100]
|
||||||
rating_date: datetime = None
|
rating_date: datetime = None
|
||||||
|
|
@ -324,41 +440,25 @@ class Rating:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
Access = Literal[
|
ratings = Rating.__table__
|
||||||
"r", # read
|
|
||||||
"i", # index
|
|
||||||
"w", # write
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class GroupUser(TypedDict):
|
||||||
class User:
|
id: str
|
||||||
_table: ClassVar[str] = "users"
|
name: str
|
||||||
|
|
||||||
id: ULID = field(default_factory=ULID)
|
|
||||||
imdb_id: str = None
|
|
||||||
name: str = None # canonical user name
|
|
||||||
secret: str = None
|
|
||||||
groups: list[dict[str, str]] = field(default_factory=list)
|
|
||||||
|
|
||||||
def has_access(self, group_id: ULID | str, access: Access = "r"):
|
|
||||||
group_id = group_id if isinstance(group_id, str) else str(group_id)
|
|
||||||
return any(g["id"] == group_id and access == g["access"] for g in self.groups)
|
|
||||||
|
|
||||||
def set_access(self, group_id: ULID | str, access: Access):
|
|
||||||
group_id = group_id if isinstance(group_id, str) else str(group_id)
|
|
||||||
for g in self.groups:
|
|
||||||
if g["id"] == group_id:
|
|
||||||
g["access"] = access
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
self.groups.append({"id": group_id, "access": access})
|
|
||||||
|
|
||||||
|
|
||||||
|
@mapper_registry.mapped
|
||||||
@dataclass
|
@dataclass
|
||||||
class Group:
|
class Group:
|
||||||
_table: ClassVar[str] = "groups"
|
__table__: ClassVar[Table] = Table(
|
||||||
|
"groups",
|
||||||
|
metadata,
|
||||||
|
Column("id", String, primary_key=True), # ULID
|
||||||
|
Column("name", String, nullable=False),
|
||||||
|
Column("users", String, nullable=False), # JSON array
|
||||||
|
)
|
||||||
|
|
||||||
id: ULID = field(default_factory=ULID)
|
id: ULID = field(default_factory=ULID)
|
||||||
name: str = None
|
name: str = None
|
||||||
users: list[dict[str, str]] = field(default_factory=list)
|
users: list[GroupUser] = field(default_factory=list)
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from hashlib import md5
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from random import random
|
from random import random
|
||||||
from time import sleep, time
|
from time import sleep, time
|
||||||
from typing import Callable, ParamSpec, TypeVar, cast
|
from typing import Any, Callable, ParamSpec, TypeVar, cast
|
||||||
|
|
||||||
import bs4
|
import bs4
|
||||||
import httpx
|
import httpx
|
||||||
|
|
@ -190,9 +190,11 @@ async def asoup_from_url(url):
|
||||||
def _last_modified_from_response(resp: _Response_T) -> float | None:
|
def _last_modified_from_response(resp: _Response_T) -> float | None:
|
||||||
if last_mod := resp.headers.get("last-modified"):
|
if last_mod := resp.headers.get("last-modified"):
|
||||||
try:
|
try:
|
||||||
return email.utils.parsedate_to_datetime(last_mod).timestamp()
|
dt = email.utils.parsedate_to_datetime(last_mod)
|
||||||
except:
|
except ValueError:
|
||||||
log.exception("🐛 Received invalid value for Last-Modified: %s", last_mod)
|
log.exception("🐛 Received invalid value for Last-Modified: %s", last_mod)
|
||||||
|
else:
|
||||||
|
return dt.timestamp()
|
||||||
|
|
||||||
|
|
||||||
def _last_modified_from_file(path: Path) -> float:
|
def _last_modified_from_file(path: Path) -> float:
|
||||||
|
|
@ -206,8 +208,8 @@ async def adownload(
|
||||||
replace_existing: bool | None = None,
|
replace_existing: bool | None = None,
|
||||||
only_if_newer: bool = False,
|
only_if_newer: bool = False,
|
||||||
timeout: float | None = None,
|
timeout: float | None = None,
|
||||||
chunk_callback=None,
|
chunk_callback: Callable[[bytes], Any] | None = None,
|
||||||
response_callback=None,
|
response_callback: Callable[[_Response_T], Any] | None = None,
|
||||||
) -> bytes | None:
|
) -> bytes | None:
|
||||||
"""Download a file.
|
"""Download a file.
|
||||||
|
|
||||||
|
|
@ -246,7 +248,7 @@ async def adownload(
|
||||||
if response_callback is not None:
|
if response_callback is not None:
|
||||||
try:
|
try:
|
||||||
response_callback(resp)
|
response_callback(resp)
|
||||||
except:
|
except BaseException:
|
||||||
log.exception("🐛 Error in response callback.")
|
log.exception("🐛 Error in response callback.")
|
||||||
|
|
||||||
log.debug(
|
log.debug(
|
||||||
|
|
@ -267,7 +269,9 @@ async def adownload(
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
|
|
||||||
if to_path is None:
|
if to_path is None:
|
||||||
await resp.aread() # Download the response stream to allow `resp.content` access.
|
await (
|
||||||
|
resp.aread()
|
||||||
|
) # Download the response stream to allow `resp.content` access.
|
||||||
return resp.content
|
return resp.content
|
||||||
|
|
||||||
resp_lastmod = _last_modified_from_response(resp)
|
resp_lastmod = _last_modified_from_response(resp)
|
||||||
|
|
@ -275,7 +279,7 @@ async def adownload(
|
||||||
# Check Last-Modified in case the server ignored If-Modified-Since.
|
# Check Last-Modified in case the server ignored If-Modified-Since.
|
||||||
# XXX also check Content-Length?
|
# XXX also check Content-Length?
|
||||||
if file_exists and only_if_newer and resp_lastmod is not None:
|
if file_exists and only_if_newer and resp_lastmod is not None:
|
||||||
assert file_lastmod
|
assert file_lastmod # pyright: ignore [reportUnboundVariable]
|
||||||
|
|
||||||
if resp_lastmod <= file_lastmod:
|
if resp_lastmod <= file_lastmod:
|
||||||
log.debug("✋ Local file is newer, skipping download: %a", req.url)
|
log.debug("✋ Local file is newer, skipping download: %a", req.url)
|
||||||
|
|
@ -299,7 +303,7 @@ async def adownload(
|
||||||
if chunk_callback:
|
if chunk_callback:
|
||||||
try:
|
try:
|
||||||
chunk_callback(chunk)
|
chunk_callback(chunk)
|
||||||
except:
|
except BaseException:
|
||||||
log.exception("🐛 Error in chunk callback.")
|
log.exception("🐛 Error in chunk callback.")
|
||||||
finally:
|
finally:
|
||||||
os.close(tempfd)
|
os.close(tempfd)
|
||||||
|
|
|
||||||
132
unwind/web.py
132
unwind/web.py
|
|
@ -168,7 +168,8 @@ async def auth_user(request) -> User | None:
|
||||||
if not isinstance(request.user, AuthedUser):
|
if not isinstance(request.user, AuthedUser):
|
||||||
return
|
return
|
||||||
|
|
||||||
user = await db.get(User, id=request.user.user_id)
|
async with db.new_connection() as conn:
|
||||||
|
user = await db.get(conn, User, id=request.user.user_id)
|
||||||
if not user:
|
if not user:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -179,7 +180,7 @@ async def auth_user(request) -> User | None:
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
_routes = []
|
_routes: list[Route] = []
|
||||||
|
|
||||||
|
|
||||||
def route(path: str, *, methods: list[str] | None = None, **kwds):
|
def route(path: str, *, methods: list[str] | None = None, **kwds):
|
||||||
|
|
@ -191,16 +192,13 @@ def route(path: str, *, methods: list[str] | None = None, **kwds):
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
route.registered = _routes
|
|
||||||
|
|
||||||
|
|
||||||
@route("/groups/{group_id}/ratings")
|
@route("/groups/{group_id}/ratings")
|
||||||
async def get_ratings_for_group(request):
|
async def get_ratings_for_group(request):
|
||||||
group_id = as_ulid(request.path_params["group_id"])
|
group_id = as_ulid(request.path_params["group_id"])
|
||||||
group = await db.get(Group, id=str(group_id))
|
|
||||||
|
|
||||||
if not group:
|
async with db.new_connection() as conn:
|
||||||
return not_found()
|
if (group := await db.get(conn, Group, id=str(group_id))) is None:
|
||||||
|
return not_found()
|
||||||
|
|
||||||
user_ids = {u["id"] for u in group.users}
|
user_ids = {u["id"] for u in group.users}
|
||||||
|
|
||||||
|
|
@ -211,22 +209,26 @@ async def get_ratings_for_group(request):
|
||||||
|
|
||||||
# if (imdb_id or unwind_id) and (movie := await db.get(Movie, id=unwind_id, imdb_id=imdb_id)):
|
# if (imdb_id or unwind_id) and (movie := await db.get(Movie, id=unwind_id, imdb_id=imdb_id)):
|
||||||
if unwind_id:
|
if unwind_id:
|
||||||
rows = await db.ratings_for_movie_ids(ids=[unwind_id])
|
async with db.new_connection() as conn:
|
||||||
|
rows = await db.ratings_for_movie_ids(conn, ids=[unwind_id])
|
||||||
|
|
||||||
elif imdb_id:
|
elif imdb_id:
|
||||||
rows = await db.ratings_for_movie_ids(imdb_ids=[imdb_id])
|
async with db.new_connection() as conn:
|
||||||
|
rows = await db.ratings_for_movie_ids(conn, imdb_ids=[imdb_id])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
rows = await find_ratings(
|
async with db.new_connection() as conn:
|
||||||
title=params.get("title"),
|
rows = await find_ratings(
|
||||||
media_type=params.get("media_type"),
|
conn,
|
||||||
exact=truthy(params.get("exact")),
|
title=params.get("title"),
|
||||||
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
|
media_type=params.get("media_type"),
|
||||||
include_unrated=truthy(params.get("include_unrated")),
|
exact=truthy(params.get("exact")),
|
||||||
yearcomp=yearcomp(params["year"]) if "year" in params else None,
|
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
|
||||||
limit_rows=as_int(params.get("per_page"), max=10, default=5),
|
include_unrated=truthy(params.get("include_unrated")),
|
||||||
user_ids=user_ids,
|
yearcomp=yearcomp(params["year"]) if "year" in params else None,
|
||||||
)
|
limit_rows=as_int(params.get("per_page"), max=10, default=5),
|
||||||
|
user_ids=user_ids,
|
||||||
|
)
|
||||||
|
|
||||||
ratings = (web_models.Rating(**r) for r in rows)
|
ratings = (web_models.Rating(**r) for r in rows)
|
||||||
|
|
||||||
|
|
@ -265,7 +267,8 @@ async def list_movies(request):
|
||||||
if group_id := params.get("group_id"):
|
if group_id := params.get("group_id"):
|
||||||
group_id = as_ulid(group_id)
|
group_id = as_ulid(group_id)
|
||||||
|
|
||||||
group = await db.get(Group, id=str(group_id))
|
async with db.new_connection() as conn:
|
||||||
|
group = await db.get(conn, Group, id=str(group_id))
|
||||||
if not group:
|
if not group:
|
||||||
return not_found("Group not found.")
|
return not_found("Group not found.")
|
||||||
|
|
||||||
|
|
@ -290,26 +293,31 @@ async def list_movies(request):
|
||||||
|
|
||||||
if imdb_id or unwind_id:
|
if imdb_id or unwind_id:
|
||||||
# XXX missing support for user_ids and user_scores
|
# XXX missing support for user_ids and user_scores
|
||||||
movies = (
|
async with db.new_connection() as conn:
|
||||||
[m] if (m := await db.get(Movie, id=unwind_id, imdb_id=imdb_id)) else []
|
movies = (
|
||||||
)
|
[m]
|
||||||
|
if (m := await db.get(conn, Movie, id=unwind_id, imdb_id=imdb_id))
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
resp = [asplain(web_models.RatingAggregate.from_movie(m)) for m in movies]
|
resp = [asplain(web_models.RatingAggregate.from_movie(m)) for m in movies]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
per_page = as_int(params.get("per_page"), max=1000, default=5)
|
per_page = as_int(params.get("per_page"), max=1000, default=5)
|
||||||
page = as_int(params.get("page"), min=1, default=1)
|
page = as_int(params.get("page"), min=1, default=1)
|
||||||
movieratings = await find_movies(
|
async with db.new_connection() as conn:
|
||||||
title=params.get("title"),
|
movieratings = await find_movies(
|
||||||
media_type=params.get("media_type"),
|
conn,
|
||||||
exact=truthy(params.get("exact")),
|
title=params.get("title"),
|
||||||
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
|
media_type=params.get("media_type"),
|
||||||
include_unrated=truthy(params.get("include_unrated")),
|
exact=truthy(params.get("exact")),
|
||||||
yearcomp=yearcomp(params["year"]) if "year" in params else None,
|
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
|
||||||
limit_rows=per_page,
|
include_unrated=truthy(params.get("include_unrated")),
|
||||||
skip_rows=(page - 1) * per_page,
|
yearcomp=yearcomp(params["year"]) if "year" in params else None,
|
||||||
user_ids=list(user_ids),
|
limit_rows=per_page,
|
||||||
)
|
skip_rows=(page - 1) * per_page,
|
||||||
|
user_ids=list(user_ids),
|
||||||
|
)
|
||||||
|
|
||||||
resp = []
|
resp = []
|
||||||
for movie, ratings in movieratings:
|
for movie, ratings in movieratings:
|
||||||
|
|
@ -329,7 +337,8 @@ async def add_movie(request):
|
||||||
@route("/movies/_reload_imdb", methods=["GET"])
|
@route("/movies/_reload_imdb", methods=["GET"])
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def progress_for_load_imdb_movies(request):
|
async def progress_for_load_imdb_movies(request):
|
||||||
progress = await db.get_import_progress()
|
async with db.new_connection() as conn:
|
||||||
|
progress = await db.get_import_progress(conn)
|
||||||
if not progress:
|
if not progress:
|
||||||
return JSONResponse({"status": "No import exists."}, status_code=404)
|
return JSONResponse({"status": "No import exists."}, status_code=404)
|
||||||
|
|
||||||
|
|
@ -368,14 +377,16 @@ async def load_imdb_movies(request):
|
||||||
force = truthy(params.get("force"))
|
force = truthy(params.get("force"))
|
||||||
|
|
||||||
async with _import_lock:
|
async with _import_lock:
|
||||||
progress = await db.get_import_progress()
|
async with db.new_connection() as conn:
|
||||||
|
progress = await db.get_import_progress(conn)
|
||||||
if progress and not progress.stopped:
|
if progress and not progress.stopped:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{"status": "Import is running.", "progress": progress.percent},
|
{"status": "Import is running.", "progress": progress.percent},
|
||||||
status_code=409,
|
status_code=409,
|
||||||
)
|
)
|
||||||
|
|
||||||
await db.set_import_progress(0)
|
async with db.transaction() as conn:
|
||||||
|
await db.set_import_progress(conn, 0)
|
||||||
|
|
||||||
task = BackgroundTask(imdb_import.load_from_web, force=force)
|
task = BackgroundTask(imdb_import.load_from_web, force=force)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
|
@ -386,7 +397,8 @@ async def load_imdb_movies(request):
|
||||||
@route("/users")
|
@route("/users")
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def list_users(request):
|
async def list_users(request):
|
||||||
users = await db.get_all(User)
|
async with db.new_connection() as conn:
|
||||||
|
users = await db.get_all(conn, User)
|
||||||
|
|
||||||
return JSONResponse([asplain(u) for u in users])
|
return JSONResponse([asplain(u) for u in users])
|
||||||
|
|
||||||
|
|
@ -402,7 +414,8 @@ async def add_user(request):
|
||||||
secret = secrets.token_bytes()
|
secret = secrets.token_bytes()
|
||||||
|
|
||||||
user = User(name=name, imdb_id=imdb_id, secret=phc_scrypt(secret))
|
user = User(name=name, imdb_id=imdb_id, secret=phc_scrypt(secret))
|
||||||
await db.add(user)
|
async with db.transaction() as conn:
|
||||||
|
await db.add(conn, user)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{
|
{
|
||||||
|
|
@ -418,7 +431,8 @@ async def show_user(request):
|
||||||
user_id = as_ulid(request.path_params["user_id"])
|
user_id = as_ulid(request.path_params["user_id"])
|
||||||
|
|
||||||
if is_admin(request):
|
if is_admin(request):
|
||||||
user = await db.get(User, id=str(user_id))
|
async with db.new_connection() as conn:
|
||||||
|
user = await db.get(conn, User, id=str(user_id))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
user = await auth_user(request)
|
user = await auth_user(request)
|
||||||
|
|
@ -445,14 +459,15 @@ async def show_user(request):
|
||||||
async def remove_user(request):
|
async def remove_user(request):
|
||||||
user_id = as_ulid(request.path_params["user_id"])
|
user_id = as_ulid(request.path_params["user_id"])
|
||||||
|
|
||||||
user = await db.get(User, id=str(user_id))
|
async with db.new_connection() as conn:
|
||||||
|
user = await db.get(conn, User, id=str(user_id))
|
||||||
if not user:
|
if not user:
|
||||||
return not_found()
|
return not_found()
|
||||||
|
|
||||||
async with db.shared_connection().transaction():
|
async with db.transaction() as conn:
|
||||||
# XXX remove user refs from groups and ratings
|
# XXX remove user refs from groups and ratings
|
||||||
|
|
||||||
await db.remove(user)
|
await db.remove(conn, user)
|
||||||
|
|
||||||
return JSONResponse(asplain(user))
|
return JSONResponse(asplain(user))
|
||||||
|
|
||||||
|
|
@ -463,7 +478,8 @@ async def modify_user(request):
|
||||||
user_id = as_ulid(request.path_params["user_id"])
|
user_id = as_ulid(request.path_params["user_id"])
|
||||||
|
|
||||||
if is_admin(request):
|
if is_admin(request):
|
||||||
user = await db.get(User, id=str(user_id))
|
async with db.new_connection() as conn:
|
||||||
|
user = await db.get(conn, User, id=str(user_id))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
user = await auth_user(request)
|
user = await auth_user(request)
|
||||||
|
|
@ -499,7 +515,8 @@ async def modify_user(request):
|
||||||
|
|
||||||
user.secret = phc_scrypt(secret)
|
user.secret = phc_scrypt(secret)
|
||||||
|
|
||||||
await db.update(user)
|
async with db.transaction() as conn:
|
||||||
|
await db.update(conn, user)
|
||||||
|
|
||||||
return JSONResponse(asplain(user))
|
return JSONResponse(asplain(user))
|
||||||
|
|
||||||
|
|
@ -509,13 +526,15 @@ async def modify_user(request):
|
||||||
async def add_group_to_user(request):
|
async def add_group_to_user(request):
|
||||||
user_id = as_ulid(request.path_params["user_id"])
|
user_id = as_ulid(request.path_params["user_id"])
|
||||||
|
|
||||||
user = await db.get(User, id=str(user_id))
|
async with db.new_connection() as conn:
|
||||||
|
user = await db.get(conn, User, id=str(user_id))
|
||||||
if not user:
|
if not user:
|
||||||
return not_found("User not found")
|
return not_found("User not found")
|
||||||
|
|
||||||
(group_id, access) = await json_from_body(request, ["group", "access"])
|
(group_id, access) = await json_from_body(request, ["group", "access"])
|
||||||
|
|
||||||
group = await db.get(Group, id=str(group_id))
|
async with db.new_connection() as conn:
|
||||||
|
group = await db.get(conn, Group, id=str(group_id))
|
||||||
if not group:
|
if not group:
|
||||||
return not_found("Group not found")
|
return not_found("Group not found")
|
||||||
|
|
||||||
|
|
@ -523,7 +542,8 @@ async def add_group_to_user(request):
|
||||||
raise HTTPException(422, f"Invalid access level.")
|
raise HTTPException(422, f"Invalid access level.")
|
||||||
|
|
||||||
user.set_access(group_id, access)
|
user.set_access(group_id, access)
|
||||||
await db.update(user)
|
async with db.transaction() as conn:
|
||||||
|
await db.update(conn, user)
|
||||||
|
|
||||||
return JSONResponse(asplain(user))
|
return JSONResponse(asplain(user))
|
||||||
|
|
||||||
|
|
@ -551,7 +571,8 @@ async def load_imdb_user_ratings(request):
|
||||||
@route("/groups")
|
@route("/groups")
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def list_groups(request):
|
async def list_groups(request):
|
||||||
groups = await db.get_all(Group)
|
async with db.new_connection() as conn:
|
||||||
|
groups = await db.get_all(conn, Group)
|
||||||
|
|
||||||
return JSONResponse([asplain(g) for g in groups])
|
return JSONResponse([asplain(g) for g in groups])
|
||||||
|
|
||||||
|
|
@ -564,7 +585,8 @@ async def add_group(request):
|
||||||
# XXX restrict name
|
# XXX restrict name
|
||||||
|
|
||||||
group = Group(name=name)
|
group = Group(name=name)
|
||||||
await db.add(group)
|
async with db.transaction() as conn:
|
||||||
|
await db.add(conn, group)
|
||||||
|
|
||||||
return JSONResponse(asplain(group))
|
return JSONResponse(asplain(group))
|
||||||
|
|
||||||
|
|
@ -573,7 +595,8 @@ async def add_group(request):
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def add_user_to_group(request):
|
async def add_user_to_group(request):
|
||||||
group_id = as_ulid(request.path_params["group_id"])
|
group_id = as_ulid(request.path_params["group_id"])
|
||||||
group = await db.get(Group, id=str(group_id))
|
async with db.new_connection() as conn:
|
||||||
|
group = await db.get(conn, Group, id=str(group_id))
|
||||||
|
|
||||||
if not group:
|
if not group:
|
||||||
return not_found()
|
return not_found()
|
||||||
|
|
@ -600,7 +623,8 @@ async def add_user_to_group(request):
|
||||||
else:
|
else:
|
||||||
group.users.append({"name": name, "id": user_id})
|
group.users.append({"name": name, "id": user_id})
|
||||||
|
|
||||||
await db.update(group)
|
async with db.transaction() as conn:
|
||||||
|
await db.update(conn, group)
|
||||||
|
|
||||||
return JSONResponse(asplain(group))
|
return JSONResponse(asplain(group))
|
||||||
|
|
||||||
|
|
@ -632,7 +656,7 @@ def create_app():
|
||||||
return Starlette(
|
return Starlette(
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
routes=[
|
routes=[
|
||||||
Mount(f"{config.api_base}v1", routes=route.registered),
|
Mount(f"{config.api_base}v1", routes=_routes),
|
||||||
],
|
],
|
||||||
middleware=[
|
middleware=[
|
||||||
Middleware(ResponseTimeMiddleware, header_name="Unwind-Elapsed"),
|
Middleware(ResponseTimeMiddleware, header_name="Unwind-Elapsed"),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue