From d09880438d12791711b9c516611baa78bd8d929d Mon Sep 17 00:00:00 2001 From: ducklet Date: Mon, 21 Jun 2021 23:48:36 +0200 Subject: [PATCH] add more filtering options --- scripts/dev | 2 +- scripts/server | 2 +- unwind/__init__.py | 2 +- unwind/db.py | 77 ++++++++++++++++++++++++++++++++++++++-------- unwind/web.py | 71 ++++++++++++++++++++++++++---------------- 5 files changed, 113 insertions(+), 41 deletions(-) diff --git a/scripts/dev b/scripts/dev index 4c2319f..7d3e2ef 100755 --- a/scripts/dev +++ b/scripts/dev @@ -4,4 +4,4 @@ cd "$RUN_DIR" [ -z "${DEBUG:-}" ] || set -x -exec uvicorn unwind:web_app --reload +exec uvicorn unwind:create_app --factory --reload diff --git a/scripts/server b/scripts/server index a184989..5440717 100755 --- a/scripts/server +++ b/scripts/server @@ -4,4 +4,4 @@ cd "$RUN_DIR" [ -z "${DEBUG:-}" ] || set -x -exec uvicorn --host 0.0.0.0 unwind:web_app +exec uvicorn --host 0.0.0.0 --factory unwind:create_app diff --git a/unwind/__init__.py b/unwind/__init__.py index 400fbe3..70cddc2 100644 --- a/unwind/__init__.py +++ b/unwind/__init__.py @@ -1 +1 @@ -from .web import app as web_app +from .web import create_app diff --git a/unwind/db.py b/unwind/db.py index e5eb29c..ae660c4 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -1,8 +1,10 @@ import logging +import re from dataclasses import fields from pathlib import Path from typing import Optional, Type, TypeVar +import sqlalchemy from databases import Database from . import config @@ -152,6 +154,8 @@ async def find_ratings( title: str = None, media_type: str = None, ignore_tv_episodes: bool = False, + include_unrated: bool = False, + year: int = None, limit_rows=10, ): values = { @@ -163,17 +167,19 @@ async def find_ratings( values["escape"] = "#" escaped_title = sql_escape(title, char=values["escape"]) values["pattern"] = "%" + "%".join(escaped_title.split()) + "%" - values["opattern"] = values["pattern"] - values["oescape"] = values["escape"] conditions.append( f""" ( {Movie._table}.title LIKE :pattern ESCAPE :escape - OR {Movie._table}.original_title LIKE :opattern ESCAPE :oescape + OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape ) """ ) + if year: + values["year"] = year + conditions.append(f"{Movie._table}.release_year=:year") + if media_type: values["media_type"] = media_type conditions.append(f"{Movie._table}.media_type=:media_type") @@ -181,19 +187,44 @@ async def find_ratings( if ignore_tv_episodes: conditions.append(f"{Movie._table}.media_type!='TV Episode'") - query = f""" - WITH newest_movies - AS ( + source_table = "newest_movies" + ctes = [ + f"""{source_table} AS ( SELECT DISTINCT {Rating._table}.movie_id FROM {Rating._table} LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id {('WHERE ' + ' AND '.join(conditions)) if conditions else ''} ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC LIMIT :limit_rows + )""" + ] + + if include_unrated: + source_table = "target_movies" + ctes.extend( + [ + f"""unrated_movies AS ( + SELECT DISTINCT id AS movie_id + FROM {Movie._table} + WHERE id NOT IN newest_movies + {('AND ' + ' AND '.join(conditions)) if conditions else ''} + ORDER BY length(title) ASC, release_year DESC + LIMIT :limit_rows + )""", + f"""{source_table} AS ( + SELECT * FROM newest_movies + UNION ALL -- using ALL here avoids the reordering of IDs + SELECT * FROM unrated_movies + )""", + ] ) + query = f""" + WITH + {','.join(ctes)} + SELECT - {User._table}.name AS user_name, + -- {User._table}.name AS user_name, {Rating._table}.score AS user_score, {Movie._table}.score AS imdb_score, {Movie._table}.imdb_id AS movie_imdb_id, @@ -201,11 +232,33 @@ async def find_ratings( {Movie._table}.title AS canonical_title, {Movie._table}.original_title AS original_title, {Movie._table}.release_year AS release_year - FROM newest_movies - LEFT JOIN {Rating._table} ON {Rating._table}.movie_id=newest_movies.movie_id - LEFT JOIN {User._table} ON {User._table}.id={Rating._table}.user_id - LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id + FROM {source_table} + LEFT JOIN {Rating._table} ON {Rating._table}.movie_id={source_table}.movie_id + -- LEFT JOIN {User._table} ON {User._table}.id={Rating._table}.user_id + LEFT JOIN {Movie._table} ON {Movie._table}.id={source_table}.movie_id + LIMIT :limit_rows """ - rows = await shared_connection().fetch_all(query=query, values=values) + rows = await shared_connection().fetch_all(bindparams(query, values)) return tuple(dict(r) for r in rows) + + +def bindparams(query: str, values: dict): + """Bind values to a query. + + This is similar to what SQLAlchemy and Databases do, but it allows to + easily use the same placeholder in multiple places. + """ + pump_vals = {} + pump_keys = {} + + def pump(match): + key = match[1] + val = values[key] + pump_keys[key] = 1 + pump_keys.setdefault(key, 0) + pump_key = f"{key}_{pump_keys[key]}" + pump_vals[pump_key] = val + return f":{pump_key}" + + pump_query = re.sub(r":(\w+)\b", pump, query) + return sqlalchemy.text(pump_query).bindparams(**pump_vals) diff --git a/unwind/web.py b/unwind/web.py index 4cbe6db..8b18f0d 100644 --- a/unwind/web.py +++ b/unwind/web.py @@ -1,5 +1,6 @@ import base64 import binascii +import logging from starlette.applications import Starlette from starlette.authentication import ( @@ -19,6 +20,8 @@ from . import config, db from .db import close_connection_pool, find_ratings, open_connection_pool from .models import Movie, asplain +log = logging.getLogger(__name__) + class BasicAuthBackend(AuthenticationBackend): async def authenticate(self, request): @@ -48,11 +51,13 @@ def truthy(s: str): async def ratings(request): - title = request.query_params.get("title") - media_type = request.query_params.get("media_type") - ignore_tv_episodes = truthy(request.query_params.get("ignore_tv_episodes")) + params = request.query_params rows = await find_ratings( - title=title, media_type=media_type, ignore_tv_episodes=ignore_tv_episodes + title=params.get("title"), + media_type=params.get("media_type"), + ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")), + include_unrated=truthy(params.get("include_unrated")), + year=int(params["year"]) if "year" in params else None, ) aggr = {} @@ -69,7 +74,8 @@ async def ratings(request): "media_type": r["media_type"], }, ) - mov["user_scores"].append(r["user_score"]) + if r["user_score"] is not None: + mov["user_scores"].append(r["user_score"]) resp = tuple(aggr.values()) @@ -121,24 +127,37 @@ async def get_ratings_for_group(request): request.path_params["group_id"] -app = Starlette( - on_startup=[open_connection_pool], - on_shutdown=[close_connection_pool], - routes=[ - Mount( - "/api/v1", - routes=[ - Route("/ratings", ratings), # XXX legacy, remove. - Route("/movies", get_movies), - Route("/movies", add_movie, methods=["POST"]), - Route("/users", add_user, methods=["POST"]), - Route("/users/{user_id}/ratings", ratings_for_user), - Route("/users/{user_id}/ratings", set_rating_for_user, methods=["PUT"]), - Route("/groups", add_group, methods=["POST"]), - Route("/groups/{group_id}/users", add_user_to_group, methods=["POST"]), - Route("/groups/{group_id}/ratings", get_ratings_for_group), - ], - ), - ], - middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuthBackend())], -) +def create_app(): + if config.loglevel == "DEBUG": + logging.basicConfig( + format="%(asctime)s.%(msecs)03d [%(name)s:%(process)d] %(levelname)s: %(message)s", + datefmt="%H:%M:%S", + level=config.loglevel, + ) + log.debug(f"Log level: {config.loglevel}") + + return Starlette( + on_startup=[open_connection_pool], + on_shutdown=[close_connection_pool], + routes=[ + Mount( + "/api/v1", + routes=[ + Route("/ratings", ratings), # XXX legacy, remove. + Route("/movies", get_movies), + Route("/movies", add_movie, methods=["POST"]), + Route("/users", add_user, methods=["POST"]), + Route("/users/{user_id}/ratings", ratings_for_user), + Route( + "/users/{user_id}/ratings", set_rating_for_user, methods=["PUT"] + ), + Route("/groups", add_group, methods=["POST"]), + Route( + "/groups/{group_id}/users", add_user_to_group, methods=["POST"] + ), + Route("/groups/{group_id}/ratings", get_ratings_for_group), + ], + ), + ], + middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuthBackend())], + )