diff --git a/unwind/db.py b/unwind/db.py index 22ea0b0..fa035e7 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -157,7 +157,7 @@ async def get(model: Type[ModelType], **kwds) -> Optional[ModelType]: return fromplain(model, row) if row else None -async def get_many(model: Type[ModelType], **kwds) -> list[ModelType]: +async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]: keys = { k: [f"{k}_{i}" for i, _ in enumerate(vs, start=1)] for k, vs in kwds.items() } @@ -173,7 +173,17 @@ async def get_many(model: Type[ModelType], **kwds) -> list[ModelType]: ) query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" rows = await shared_connection().fetch_all(query=query, values=values) - return [fromplain(model, row) for row in rows] + return (fromplain(model, row) for row in rows) + + +async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]: + values = {k: v for k, v in kwds.items() if v is not None} + + fields_ = ", ".join(f.name for f in fields(model)) + cond = " AND ".join(f"{k}=:{k}" for k in values) or "1=1" + query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" + rows = await shared_connection().fetch_all(query=query, values=values) + return (fromplain(model, row) for row in rows) async def update(item): diff --git a/unwind/web.py b/unwind/web.py index 0974f13..d50df10 100644 --- a/unwind/web.py +++ b/unwind/web.py @@ -21,7 +21,7 @@ from starlette.routing import Mount, Route from . import config, db from .db import close_connection_pool, find_ratings, open_connection_pool from .middleware.responsetime import ResponseTimeMiddleware -from .models import Group, Movie, asplain +from .models import Group, Movie, User, asplain from .types import ULID from .utils import b64encode, phc_compare, phc_scrypt @@ -160,6 +160,12 @@ async def add_movie(request): pass +@requires(["authenticated", "admin"]) +async def list_users(request): + users = await db.get_all(User) + return JSONResponse([asplain(u) for u in users]) + + @requires(["authenticated", "admin"]) async def add_user(request): pass @@ -263,6 +269,7 @@ def create_app(): routes=[ Route("/movies", get_movies), Route("/movies", add_movie, methods=["POST"]), + Route("/users", list_users), Route("/users", add_user, methods=["POST"]), Route("/users/{user_id}/ratings", ratings_for_user), Route(