diff --git a/tests/test_web.py b/tests/test_web.py index 358c2a2..364edd5 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -1,22 +1,140 @@ +from datetime import datetime import pytest from starlette.testclient import TestClient -from unwind import create_app, db, imdb, models +from unwind import config, create_app, db, imdb, models app = create_app() +@pytest.fixture(scope="module") +def unauthorized_client() -> TestClient: + # https://www.starlette.io/testclient/ + return TestClient(app) + + +@pytest.fixture(scope="module") +def authorized_client() -> TestClient: + client = TestClient(app) + client.auth = "user1", "secret1" + return client + + +@pytest.fixture(scope="module") +def admin_client() -> TestClient: + client = TestClient(app) + for token in config.api_credentials.values(): + break + else: + raise RuntimeError("No bearer tokens configured.") + client.headers = {"Authorization": f"Bearer {token}"} + return client + + @pytest.mark.asyncio -async def test_app(shared_conn: db.Database): +async def test_get_ratings_for_group( + shared_conn: db.Database, unauthorized_client: TestClient +): + user = models.User( + imdb_id="ur12345678", + name="user-1", + secret="secret-1", + groups=[], + ) + group = models.Group( + name="group-1", + users=[models.GroupUser(id=str(user.id), name=user.name)], + ) + user.groups = [models.UserGroup(id=str(group.id), access="r")] + path = app.url_path_for("get_ratings_for_group", group_id=str(group.id)) async with shared_conn.transaction(force_rollback=True): - # https://www.starlette.io/testclient/ - client = TestClient(app) - response = client.get("/api/v1/movies") + resp = unauthorized_client.get(path) + assert resp.status_code == 404, "Group does not exist (yet)" + + await db.add(user) + await db.add(group) + + resp = unauthorized_client.get(path) + assert resp.status_code == 200 + assert resp.json() == [] + + movie = models.Movie( + title="test movie", + release_year=2013, + media_type="Movie", + imdb_id="tt12345678", + genres={"genre-1"}, + ) + await db.add(movie) + + rating = models.Rating( + movie_id=movie.id, user_id=user.id, score=66, rating_date=datetime.now() + ) + await db.add(rating) + + rating_aggregate = { + "canonical_title": movie.title, + "imdb_score": movie.imdb_score, + "imdb_votes": movie.imdb_votes, + "link": imdb.movie_url(movie.imdb_id), + "media_type": movie.media_type, + "original_title": movie.original_title, + "user_scores": [rating.score], + "year": movie.release_year, + } + + resp = unauthorized_client.get(path) + assert resp.status_code == 200 + assert resp.json() == [rating_aggregate] + + filters = { + "imdb_id": movie.imdb_id, + "unwind_id": str(movie.id), + "title": movie.title, + "media_type": movie.media_type, + "year": movie.release_year, + } + for k, v in filters.items(): + resp = unauthorized_client.get(path, params={k: v}) + assert resp.status_code == 200 + assert resp.json() == [rating_aggregate] + + resp = unauthorized_client.get(path, params={"title": "no such thing"}) + assert resp.status_code == 200 + assert resp.json() == [] + + # Test "exact" query param. + resp = unauthorized_client.get( + path, params={"title": "test movie", "exact": "true"} + ) + assert resp.status_code == 200 + assert resp.json() == [rating_aggregate] + resp = unauthorized_client.get( + path, params={"title": "te mo", "exact": "false"} + ) + assert resp.status_code == 200 + assert resp.json() == [rating_aggregate] + resp = unauthorized_client.get(path, params={"title": "te mo", "exact": "true"}) + assert resp.status_code == 200 + assert resp.json() == [] + + # XXX Test "ignore_tv_episodes" query param. + # XXX Test "include_unrated" query param. + # XXX Test "per_page" query param. + + +@pytest.mark.asyncio +async def test_list_movies( + shared_conn: db.Database, + unauthorized_client: TestClient, + authorized_client: TestClient, +): + path = app.url_path_for("list_movies") + async with shared_conn.transaction(force_rollback=True): + response = unauthorized_client.get(path) assert response.status_code == 403 - client.auth = "user1", "secret1" - - response = client.get("/api/v1/movies") + response = authorized_client.get(path) assert response.status_code == 200 assert response.json() == [] @@ -29,7 +147,7 @@ async def test_app(shared_conn: db.Database): ) await db.add(m) - response = client.get("/api/v1/movies", params={"include_unrated": 1}) + response = authorized_client.get(path, params={"include_unrated": 1}) assert response.status_code == 200 assert response.json() == [{**models.asplain(m), "user_scores": []}] @@ -44,10 +162,86 @@ async def test_app(shared_conn: db.Database): "year": m.release_year, } - response = client.get("/api/v1/movies", params={"imdb_id": m.imdb_id}) + response = authorized_client.get(path, params={"imdb_id": m.imdb_id}) assert response.status_code == 200 assert response.json() == [m_plain] - response = client.get("/api/v1/movies", params={"unwind_id": str(m.id)}) + response = authorized_client.get(path, params={"unwind_id": str(m.id)}) + assert response.status_code == 200 + assert response.json() == [m_plain] + + +@pytest.mark.asyncio +async def test_list_users( + shared_conn: db.Database, + unauthorized_client: TestClient, + authorized_client: TestClient, + admin_client: TestClient, +): + path = app.url_path_for("list_users") + async with shared_conn.transaction(force_rollback=True): + response = unauthorized_client.get(path) + assert response.status_code == 403 + + response = authorized_client.get(path) + assert response.status_code == 403 + + response = admin_client.get(path) + assert response.status_code == 200 + assert response.json() == [] + + m = models.User( + imdb_id="ur12345678", + name="user-1", + secret="secret-1", + groups=[], + ) + await db.add(m) + + m_plain = { + "groups": m.groups, + "id": m.id, + "imdb_id": m.imdb_id, + "name": m.name, + "secret": m.secret, + } + + response = admin_client.get(path) + assert response.status_code == 200 + assert response.json() == [m_plain] + + +@pytest.mark.asyncio +async def test_list_groups( + shared_conn: db.Database, + unauthorized_client: TestClient, + authorized_client: TestClient, + admin_client: TestClient, +): + path = app.url_path_for("list_groups") + async with shared_conn.transaction(force_rollback=True): + response = unauthorized_client.get(path) + assert response.status_code == 403 + + response = authorized_client.get(path) + assert response.status_code == 403 + + response = admin_client.get(path) + assert response.status_code == 200 + assert response.json() == [] + + m = models.Group( + name="group-1", + users=[models.GroupUser(id="123", name="itsa-me")], + ) + await db.add(m) + + m_plain = { + "users": m.users, + "id": m.id, + "name": m.name, + } + + response = admin_client.get(path) assert response.status_code == 200 assert response.json() == [m_plain] diff --git a/unwind/models.py b/unwind/models.py index 4480307..0bf489d 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -13,6 +13,7 @@ from typing import ( Mapping, Type, TypeVar, + TypedDict, Union, get_args, get_origin, @@ -331,6 +332,11 @@ Access = Literal[ ] +class UserGroup(TypedDict): + id: str + access: Access + + @dataclass class User: _table: ClassVar[str] = "users" @@ -339,7 +345,7 @@ class User: imdb_id: str = None name: str = None # canonical user name secret: str = None - groups: list[dict[str, str]] = field(default_factory=list) + groups: list[UserGroup] = field(default_factory=list) def has_access(self, group_id: ULID | str, access: Access = "r"): group_id = group_id if isinstance(group_id, str) else str(group_id) @@ -355,10 +361,15 @@ class User: self.groups.append({"id": group_id, "access": access}) +class GroupUser(TypedDict): + id: str + name: str + + @dataclass class Group: _table: ClassVar[str] = "groups" id: ULID = field(default_factory=ULID) name: str = None - users: list[dict[str, str]] = field(default_factory=list) + users: list[GroupUser] = field(default_factory=list) diff --git a/unwind/web.py b/unwind/web.py index eb08e9c..3ebbcdc 100644 --- a/unwind/web.py +++ b/unwind/web.py @@ -179,7 +179,7 @@ async def auth_user(request) -> User | None: return user -_routes = [] +_routes: list[Route] = [] def route(path: str, *, methods: list[str] | None = None, **kwds): @@ -191,15 +191,11 @@ def route(path: str, *, methods: list[str] | None = None, **kwds): return decorator -route.registered = _routes - - @route("/groups/{group_id}/ratings") async def get_ratings_for_group(request): group_id = as_ulid(request.path_params["group_id"]) - group = await db.get(Group, id=str(group_id)) - if not group: + if (group := await db.get(Group, id=str(group_id))) is None: return not_found() user_ids = {u["id"] for u in group.users} @@ -632,7 +628,7 @@ def create_app(): return Starlette( lifespan=lifespan, routes=[ - Mount(f"{config.api_base}v1", routes=route.registered), + Mount(f"{config.api_base}v1", routes=_routes), ], middleware=[ Middleware(ResponseTimeMiddleware, header_name="Unwind-Elapsed"),