add some route tests
This commit is contained in:
parent
00486778db
commit
f97c5c8472
3 changed files with 221 additions and 20 deletions
|
|
@ -1,22 +1,140 @@
|
|||
from datetime import datetime
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from unwind import create_app, db, imdb, models
|
||||
from unwind import config, create_app, db, imdb, models
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def unauthorized_client() -> TestClient:
|
||||
# https://www.starlette.io/testclient/
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def authorized_client() -> TestClient:
|
||||
client = TestClient(app)
|
||||
client.auth = "user1", "secret1"
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def admin_client() -> TestClient:
|
||||
client = TestClient(app)
|
||||
for token in config.api_credentials.values():
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("No bearer tokens configured.")
|
||||
client.headers = {"Authorization": f"Bearer {token}"}
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_app(shared_conn: db.Database):
|
||||
async def test_get_ratings_for_group(
|
||||
shared_conn: db.Database, unauthorized_client: TestClient
|
||||
):
|
||||
user = models.User(
|
||||
imdb_id="ur12345678",
|
||||
name="user-1",
|
||||
secret="secret-1",
|
||||
groups=[],
|
||||
)
|
||||
group = models.Group(
|
||||
name="group-1",
|
||||
users=[models.GroupUser(id=str(user.id), name=user.name)],
|
||||
)
|
||||
user.groups = [models.UserGroup(id=str(group.id), access="r")]
|
||||
path = app.url_path_for("get_ratings_for_group", group_id=str(group.id))
|
||||
async with shared_conn.transaction(force_rollback=True):
|
||||
# https://www.starlette.io/testclient/
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/movies")
|
||||
resp = unauthorized_client.get(path)
|
||||
assert resp.status_code == 404, "Group does not exist (yet)"
|
||||
|
||||
await db.add(user)
|
||||
await db.add(group)
|
||||
|
||||
resp = unauthorized_client.get(path)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
movie = models.Movie(
|
||||
title="test movie",
|
||||
release_year=2013,
|
||||
media_type="Movie",
|
||||
imdb_id="tt12345678",
|
||||
genres={"genre-1"},
|
||||
)
|
||||
await db.add(movie)
|
||||
|
||||
rating = models.Rating(
|
||||
movie_id=movie.id, user_id=user.id, score=66, rating_date=datetime.now()
|
||||
)
|
||||
await db.add(rating)
|
||||
|
||||
rating_aggregate = {
|
||||
"canonical_title": movie.title,
|
||||
"imdb_score": movie.imdb_score,
|
||||
"imdb_votes": movie.imdb_votes,
|
||||
"link": imdb.movie_url(movie.imdb_id),
|
||||
"media_type": movie.media_type,
|
||||
"original_title": movie.original_title,
|
||||
"user_scores": [rating.score],
|
||||
"year": movie.release_year,
|
||||
}
|
||||
|
||||
resp = unauthorized_client.get(path)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == [rating_aggregate]
|
||||
|
||||
filters = {
|
||||
"imdb_id": movie.imdb_id,
|
||||
"unwind_id": str(movie.id),
|
||||
"title": movie.title,
|
||||
"media_type": movie.media_type,
|
||||
"year": movie.release_year,
|
||||
}
|
||||
for k, v in filters.items():
|
||||
resp = unauthorized_client.get(path, params={k: v})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == [rating_aggregate]
|
||||
|
||||
resp = unauthorized_client.get(path, params={"title": "no such thing"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
# Test "exact" query param.
|
||||
resp = unauthorized_client.get(
|
||||
path, params={"title": "test movie", "exact": "true"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == [rating_aggregate]
|
||||
resp = unauthorized_client.get(
|
||||
path, params={"title": "te mo", "exact": "false"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == [rating_aggregate]
|
||||
resp = unauthorized_client.get(path, params={"title": "te mo", "exact": "true"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
# XXX Test "ignore_tv_episodes" query param.
|
||||
# XXX Test "include_unrated" query param.
|
||||
# XXX Test "per_page" query param.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_movies(
|
||||
shared_conn: db.Database,
|
||||
unauthorized_client: TestClient,
|
||||
authorized_client: TestClient,
|
||||
):
|
||||
path = app.url_path_for("list_movies")
|
||||
async with shared_conn.transaction(force_rollback=True):
|
||||
response = unauthorized_client.get(path)
|
||||
assert response.status_code == 403
|
||||
|
||||
client.auth = "user1", "secret1"
|
||||
|
||||
response = client.get("/api/v1/movies")
|
||||
response = authorized_client.get(path)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
|
|
@ -29,7 +147,7 @@ async def test_app(shared_conn: db.Database):
|
|||
)
|
||||
await db.add(m)
|
||||
|
||||
response = client.get("/api/v1/movies", params={"include_unrated": 1})
|
||||
response = authorized_client.get(path, params={"include_unrated": 1})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [{**models.asplain(m), "user_scores": []}]
|
||||
|
||||
|
|
@ -44,10 +162,86 @@ async def test_app(shared_conn: db.Database):
|
|||
"year": m.release_year,
|
||||
}
|
||||
|
||||
response = client.get("/api/v1/movies", params={"imdb_id": m.imdb_id})
|
||||
response = authorized_client.get(path, params={"imdb_id": m.imdb_id})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [m_plain]
|
||||
|
||||
response = client.get("/api/v1/movies", params={"unwind_id": str(m.id)})
|
||||
response = authorized_client.get(path, params={"unwind_id": str(m.id)})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [m_plain]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users(
|
||||
shared_conn: db.Database,
|
||||
unauthorized_client: TestClient,
|
||||
authorized_client: TestClient,
|
||||
admin_client: TestClient,
|
||||
):
|
||||
path = app.url_path_for("list_users")
|
||||
async with shared_conn.transaction(force_rollback=True):
|
||||
response = unauthorized_client.get(path)
|
||||
assert response.status_code == 403
|
||||
|
||||
response = authorized_client.get(path)
|
||||
assert response.status_code == 403
|
||||
|
||||
response = admin_client.get(path)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
m = models.User(
|
||||
imdb_id="ur12345678",
|
||||
name="user-1",
|
||||
secret="secret-1",
|
||||
groups=[],
|
||||
)
|
||||
await db.add(m)
|
||||
|
||||
m_plain = {
|
||||
"groups": m.groups,
|
||||
"id": m.id,
|
||||
"imdb_id": m.imdb_id,
|
||||
"name": m.name,
|
||||
"secret": m.secret,
|
||||
}
|
||||
|
||||
response = admin_client.get(path)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [m_plain]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_groups(
|
||||
shared_conn: db.Database,
|
||||
unauthorized_client: TestClient,
|
||||
authorized_client: TestClient,
|
||||
admin_client: TestClient,
|
||||
):
|
||||
path = app.url_path_for("list_groups")
|
||||
async with shared_conn.transaction(force_rollback=True):
|
||||
response = unauthorized_client.get(path)
|
||||
assert response.status_code == 403
|
||||
|
||||
response = authorized_client.get(path)
|
||||
assert response.status_code == 403
|
||||
|
||||
response = admin_client.get(path)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
m = models.Group(
|
||||
name="group-1",
|
||||
users=[models.GroupUser(id="123", name="itsa-me")],
|
||||
)
|
||||
await db.add(m)
|
||||
|
||||
m_plain = {
|
||||
"users": m.users,
|
||||
"id": m.id,
|
||||
"name": m.name,
|
||||
}
|
||||
|
||||
response = admin_client.get(path)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [m_plain]
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from typing import (
|
|||
Mapping,
|
||||
Type,
|
||||
TypeVar,
|
||||
TypedDict,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
|
|
@ -331,6 +332,11 @@ Access = Literal[
|
|||
]
|
||||
|
||||
|
||||
class UserGroup(TypedDict):
|
||||
id: str
|
||||
access: Access
|
||||
|
||||
|
||||
@dataclass
|
||||
class User:
|
||||
_table: ClassVar[str] = "users"
|
||||
|
|
@ -339,7 +345,7 @@ class User:
|
|||
imdb_id: str = None
|
||||
name: str = None # canonical user name
|
||||
secret: str = None
|
||||
groups: list[dict[str, str]] = field(default_factory=list)
|
||||
groups: list[UserGroup] = field(default_factory=list)
|
||||
|
||||
def has_access(self, group_id: ULID | str, access: Access = "r"):
|
||||
group_id = group_id if isinstance(group_id, str) else str(group_id)
|
||||
|
|
@ -355,10 +361,15 @@ class User:
|
|||
self.groups.append({"id": group_id, "access": access})
|
||||
|
||||
|
||||
class GroupUser(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Group:
|
||||
_table: ClassVar[str] = "groups"
|
||||
|
||||
id: ULID = field(default_factory=ULID)
|
||||
name: str = None
|
||||
users: list[dict[str, str]] = field(default_factory=list)
|
||||
users: list[GroupUser] = field(default_factory=list)
|
||||
|
|
|
|||
|
|
@ -179,7 +179,7 @@ async def auth_user(request) -> User | None:
|
|||
return user
|
||||
|
||||
|
||||
_routes = []
|
||||
_routes: list[Route] = []
|
||||
|
||||
|
||||
def route(path: str, *, methods: list[str] | None = None, **kwds):
|
||||
|
|
@ -191,15 +191,11 @@ def route(path: str, *, methods: list[str] | None = None, **kwds):
|
|||
return decorator
|
||||
|
||||
|
||||
route.registered = _routes
|
||||
|
||||
|
||||
@route("/groups/{group_id}/ratings")
|
||||
async def get_ratings_for_group(request):
|
||||
group_id = as_ulid(request.path_params["group_id"])
|
||||
group = await db.get(Group, id=str(group_id))
|
||||
|
||||
if not group:
|
||||
if (group := await db.get(Group, id=str(group_id))) is None:
|
||||
return not_found()
|
||||
|
||||
user_ids = {u["id"] for u in group.users}
|
||||
|
|
@ -632,7 +628,7 @@ def create_app():
|
|||
return Starlette(
|
||||
lifespan=lifespan,
|
||||
routes=[
|
||||
Mount(f"{config.api_base}v1", routes=route.registered),
|
||||
Mount(f"{config.api_base}v1", routes=_routes),
|
||||
],
|
||||
middleware=[
|
||||
Middleware(ResponseTimeMiddleware, header_name="Unwind-Elapsed"),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue