diff --git a/tests/conftest.py b/tests/conftest.py index 0fd79ea..e57d3e1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ import asyncio import pytest +import pytest_asyncio + from unwind import db @@ -13,7 +15,7 @@ def event_loop(): loop.close() -@pytest.fixture(scope="session") +@pytest_asyncio.fixture(scope="session") async def shared_conn(): c = db.shared_connection() await c.connect() @@ -24,7 +26,7 @@ async def shared_conn(): await c.disconnect() -@pytest.fixture +@pytest_asyncio.fixture async def conn(shared_conn): async with shared_conn.transaction(force_rollback=True): yield shared_conn diff --git a/tests/test_db.py b/tests/test_db.py index caeaf69..13a7de4 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -3,10 +3,9 @@ import pytest from unwind import db, models, web_models -pytestmark = pytest.mark.asyncio - -async def test_add_and_get(shared_conn): +@pytest.mark.asyncio +async def test_add_and_get(shared_conn: db.Database): async with shared_conn.transaction(force_rollback=True): m1 = models.Movie( @@ -31,7 +30,8 @@ async def test_add_and_get(shared_conn): assert m2 == await db.get(models.Movie, id=str(m2.id)) -async def test_find_ratings(shared_conn): +@pytest.mark.asyncio +async def test_find_ratings(shared_conn: db.Database): async with shared_conn.transaction(force_rollback=True): m1 = models.Movie( diff --git a/tests/test_imdb.py b/tests/test_imdb.py index 13a03fd..48308a7 100644 --- a/tests/test_imdb.py +++ b/tests/test_imdb.py @@ -3,12 +3,12 @@ from unwind.imdb import imdb_rating_from_score, score_from_imdb_rating @pytest.mark.parametrize("rating", (x / 10 for x in range(10, 101))) -def test_rating_conversion(rating): +def test_rating_conversion(rating: float): assert rating == imdb_rating_from_score(score_from_imdb_rating(rating)) @pytest.mark.parametrize("score", range(0, 101)) -def test_score_conversion(score): +def test_score_conversion(score: int): # Because our score covers 101 discrete values and IMDb's rating only 91 # discrete values, the mapping is non-injective, i.e. 10 values can't be # mapped uniquely. diff --git a/tests/test_web.py b/tests/test_web.py index 55c2d23..250447d 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -4,13 +4,11 @@ import pytest from unwind import create_app from unwind import db, models, imdb -# https://pypi.org/project/pytest-asyncio/ -pytestmark = pytest.mark.asyncio - app = create_app() -async def test_app(shared_conn): +@pytest.mark.asyncio +async def test_app(shared_conn: db.Database): async with shared_conn.transaction(force_rollback=True): # https://www.starlette.io/testclient/