From e1f35143df045b494567e779071eaf1da56c6b77 Mon Sep 17 00:00:00 2001 From: ducklet Date: Sun, 19 Dec 2021 19:30:08 +0100 Subject: [PATCH] add some tests for `db.find_ratings` --- scripts/tests | 4 +- tests/conftest.py | 30 +++++++++ tests/test_db.py | 160 ++++++++++++++++++++++++++++++++++++++++++++++ tests/test_web.py | 9 +-- 4 files changed, 195 insertions(+), 8 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_db.py diff --git a/scripts/tests b/scripts/tests index 56ff576..4237558 100755 --- a/scripts/tests +++ b/scripts/tests @@ -2,8 +2,10 @@ cd "$RUN_DIR" -dbfile="$RUN_DIR/tests.sqlite.local" +dbfile="${UNWIND_DATA:-./data}/tests.sqlite" +# Rollback in Databases is currently broken, so we have to rebuild the database +# each time; see https://github.com/encode/databases/issues/403 trap 'rm "$dbfile"' EXIT TERM INT QUIT [ -z "${DEBUG:-}" ] || set -x diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0fd79ea --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,30 @@ +import asyncio + +import pytest +from unwind import db + + +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.new_event_loop() + yield loop + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.run_until_complete(loop.shutdown_default_executor()) + loop.close() + + +@pytest.fixture(scope="session") +async def shared_conn(): + c = db.shared_connection() + await c.connect() + + await db.apply_db_patches(c) + yield c + + await c.disconnect() + + +@pytest.fixture +async def conn(shared_conn): + async with shared_conn.transaction(force_rollback=True): + yield shared_conn diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 0000000..caeaf69 --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,160 @@ +from datetime import datetime +import pytest + +from unwind import db, models, web_models + +pytestmark = pytest.mark.asyncio + + +async def test_add_and_get(shared_conn): + async with shared_conn.transaction(force_rollback=True): + + m1 = models.Movie( + title="test movie", + release_year=2013, + media_type="Movie", + imdb_id="tt0000000", + genres={"genre-1"}, + ) + await db.add(m1) + + m2 = models.Movie( + title="test movie", + release_year=2013, + media_type="Movie", + imdb_id="tt0000001", + genres={"genre-1"}, + ) + await db.add(m2) + + assert m1 == await db.get(models.Movie, id=str(m1.id)) + assert m2 == await db.get(models.Movie, id=str(m2.id)) + + +async def test_find_ratings(shared_conn): + async with shared_conn.transaction(force_rollback=True): + + m1 = models.Movie( + title="test movie", + release_year=2013, + media_type="Movie", + imdb_id="tt0000000", + genres={"genre-1"}, + ) + await db.add(m1) + + m2 = models.Movie( + title="it's anöther Movie, Part 2", + release_year=2015, + media_type="Movie", + imdb_id="tt0000001", + genres={"genre-2"}, + ) + await db.add(m2) + + m3 = models.Movie( + title="movie it's, Part 3", + release_year=2015, + media_type="Movie", + imdb_id="tt0000002", + genres={"genre-2"}, + ) + await db.add(m3) + + u1 = models.User( + imdb_id="u00001", + name="User1", + secret="secret1", + ) + await db.add(u1) + + u2 = models.User( + imdb_id="u00002", + name="User2", + secret="secret2", + ) + await db.add(u2) + + r1 = models.Rating( + movie_id=m2.id, + movie=m2, + user_id=u1.id, + user=u1, + score=66, + rating_date=datetime.now(), + ) + await db.add(r1) + + r2 = models.Rating( + movie_id=m2.id, + movie=m2, + user_id=u2.id, + user=u2, + score=77, + rating_date=datetime.now(), + ) + await db.add(r2) + + # --- + + rows = await db.find_ratings( + title=m1.title, + media_type=m1.media_type, + exact=True, + ignore_tv_episodes=True, + include_unrated=True, + yearcomp=("=", m1.release_year), + limit_rows=3, + user_ids=[], + ) + ratings = (web_models.Rating(**r) for r in rows) + assert (web_models.RatingAggregate.from_movie(m1),) == tuple( + web_models.aggregate_ratings(ratings, user_ids=[]) + ) + + rows = await db.find_ratings(title="movie", include_unrated=False) + ratings = tuple(web_models.Rating(**r) for r in rows) + assert ( + web_models.Rating.from_movie(m2, rating=r1), + web_models.Rating.from_movie(m2, rating=r2), + ) == ratings + + rows = await db.find_ratings(title="movie", include_unrated=True) + ratings = tuple(web_models.Rating(**r) for r in rows) + assert ( + web_models.Rating.from_movie(m1), + web_models.Rating.from_movie(m2, rating=r1), + web_models.Rating.from_movie(m2, rating=r2), + web_models.Rating.from_movie(m3), + ) == ratings + + aggr = web_models.aggregate_ratings(ratings, user_ids=[]) + assert tuple( + web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3] + ) == tuple(aggr) + + aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id)]) + assert ( + web_models.RatingAggregate.from_movie(m1), + web_models.RatingAggregate.from_movie(m2, ratings=[r1]), + web_models.RatingAggregate.from_movie(m3), + ) == tuple(aggr) + + aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id), str(u2.id)]) + assert ( + web_models.RatingAggregate.from_movie(m1), + web_models.RatingAggregate.from_movie(m2, ratings=[r1, r2]), + web_models.RatingAggregate.from_movie(m3), + ) == tuple(aggr) + + rows = await db.find_ratings(title="movie", include_unrated=True) + ratings = (web_models.Rating(**r) for r in rows) + aggr = web_models.aggregate_ratings(ratings, user_ids=[]) + assert tuple( + web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3] + ) == tuple(aggr) + + rows = await db.find_ratings(title="test", include_unrated=True) + ratings = tuple(web_models.Rating(**r) for r in rows) + assert (web_models.Rating.from_movie(m1),) == ratings + diff --git a/tests/test_web.py b/tests/test_web.py index a7b0852..55c2d23 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -10,11 +10,8 @@ pytestmark = pytest.mark.asyncio app = create_app() -async def test_app(): - await db.open_connection_pool() - conn = db.shared_connection() - - async with conn.transaction(force_rollback=True): +async def test_app(shared_conn): + async with shared_conn.transaction(force_rollback=True): # https://www.starlette.io/testclient/ client = TestClient(app) @@ -58,5 +55,3 @@ async def test_app(): response = client.get("/api/v1/movies", params={"unwind_id": str(m.id)}) assert response.status_code == 200 assert response.json() == [m_plain] - - await db.close_connection_pool()