diff --git a/unwind/__main__.py b/unwind/__main__.py index 82abfee..8ca995f 100644 --- a/unwind/__main__.py +++ b/unwind/__main__.py @@ -2,10 +2,11 @@ import argparse import asyncio import logging import secrets +import sys from base64 import b64encode from pathlib import Path -from . import config, db, models, utils +from . import cli, config, db, models, utils from .db import close_connection_pool, open_connection_pool from .imdb import refresh_user_ratings_from_imdb from .imdb_import import download_datasets, import_from_file @@ -70,8 +71,8 @@ async def run_download_imdb_dataset(basics_path: Path, ratings_path: Path): def getargs(): - parser = argparse.ArgumentParser() - commands = parser.add_subparsers(required=True) + parser = argparse.ArgumentParser(prog="unwind", allow_abbrev=False) + commands = parser.add_subparsers(title="commands", metavar="COMMAND", dest="mode") parser_import_imdb_dataset = commands.add_parser( "import-imdb-dataset", @@ -145,12 +146,20 @@ def getargs(): help="Allow overwriting an existing user. WARNING: This will reset the user's password!", ) + for module in cli.modules: + cmd = commands.add_parser(module.name, help=module.help, allow_abbrev=False) + module.add_args(cmd) + try: args = parser.parse_args() except TypeError: parser.print_usage() raise + if args.mode is None: + parser.print_help() + sys.exit(1) + return args @@ -162,10 +171,7 @@ def main(): ) log.debug(f"Log level: {config.loglevel}") - try: - args = getargs() - except Exception: - return + args = getargs() if args.mode == "load-user-ratings-from-imdb": asyncio.run(run_load_user_ratings_from_imdb()) @@ -176,5 +182,9 @@ def main(): elif args.mode == "download-imdb-dataset": asyncio.run(run_download_imdb_dataset(args.basics, args.ratings)) + modes = {m.name: m.main for m in cli.modules} + if handler := modes.get(args.mode): + asyncio.run(handler(args)) + main() diff --git a/unwind/cli/__init__.py b/unwind/cli/__init__.py new file mode 100644 index 0000000..dd6f8cc --- /dev/null +++ b/unwind/cli/__init__.py @@ -0,0 +1,39 @@ +import argparse +import importlib +from pathlib import Path +from types import ModuleType +from typing import Any, Callable, Coroutine, Iterable, Protocol, TypeGuard + +type CommandHandler = Callable[[argparse.Namespace], Coroutine[Any, Any, None]] + + +class CliModule(Protocol): + name: str + help: str + add_args: Callable[[argparse.ArgumentParser], None] + main: CommandHandler + + +def _is_cli_module(m: ModuleType) -> TypeGuard[CliModule]: + return ( + hasattr(m, "name") + and hasattr(m, "help") + and hasattr(m, "add_args") + and hasattr(m, "main") + ) + + +_clidir = Path(__file__).parent + + +def _load_cmds() -> Iterable[CliModule]: + """Return all CLI command modules.""" + for f in _clidir.iterdir(): + if f.suffix == ".py" and not f.name.startswith("__"): + m = importlib.import_module(f"{__package__}.{f.stem}") + if not _is_cli_module(m): + raise ValueError(f"Invalid CLI module: {m!a}") + yield m + + +modules = sorted(_load_cmds(), key=lambda m: m.name) diff --git a/unwind/cli/load_imdb_charts.py b/unwind/cli/load_imdb_charts.py new file mode 100644 index 0000000..2b82774 --- /dev/null +++ b/unwind/cli/load_imdb_charts.py @@ -0,0 +1,97 @@ +import argparse +import logging +from typing import Callable + +import sqlalchemy as sa + +from unwind import db, imdb, models, types, utils + +log = logging.getLogger(__name__) + +name = "load-imdb-charts" +help = "Load and import charts from imdb.com." + + +def add_args(cmd: argparse.ArgumentParser) -> None: + cmd.add_argument( + "--select", + action="append", + dest="charts", + default=[], + choices={"top250", "bottom100", "pop100"}, + help="Select which charts to refresh.", + ) + + +async def get_movie_ids( + conn: db.Connection, imdb_ids: list[imdb.MovieId] +) -> dict[imdb.MovieId, types.ULID]: + c = models.movies.c + query = sa.select(c.imdb_id, c.id).where(c.imdb_id.in_(imdb_ids)) + rows = await db.fetch_all(conn, query) + return {row.imdb_id: types.ULID(row.id) for row in rows} + + +async def remove_all_awards( + conn: db.Connection, category: models.AwardCategory +) -> None: + stmt = models.awards.delete().where(models.awards.c.category == category) + await conn.execute(stmt) + + +_award_handlers: dict[models.AwardCategory, Callable] = { + "imdb-pop-100": imdb.load_most_popular_100, + "imdb-top-250": imdb.load_top_250, + "imdb-bottom-100": imdb.load_bottom_100, +} + + +async def update_awards(conn: db.Connection, category: models.AwardCategory) -> None: + load_imdb_ids = _award_handlers[category] + imdb_ids = await load_imdb_ids() + + available = await get_movie_ids(conn, imdb_ids) + if missing := set(imdb_ids).difference(available): + log.warning( + "⚠️ Charts for category (%a) contained %i unknown movies: %a", + category, + len(missing), + missing, + ) + + await remove_all_awards(conn, category=category) + + for pos, imdb_id in enumerate(imdb_ids, 1): + if (movie_id := available.get(imdb_id)) is None: + continue + + award = models.Award( + movie_id=movie_id, + category=category, + details=utils.json_dump({"position": pos}), + ) + await db.add(conn, award) + + +async def main(args: argparse.Namespace) -> None: + await db.open_connection_pool() + + if not args.charts: + args.charts = {"top250", "bottom100", "pop100"} + + if "pop100" in args.charts: + async with db.transaction() as conn: + await update_awards(conn, "imdb-pop-100") + log.info("✨ Updated most popular 100 movies.") + + if "bottom100" in args.charts: + async with db.transaction() as conn: + await update_awards(conn, "imdb-bottom-100") + log.info("✨ Updated bottom 100 movies.") + + if "top250" in args.charts: + async with db.transaction() as conn: + await update_awards(conn, "imdb-top-250") + log.info("✨ Updated top 250 rated movies.") + + await db.close_connection_pool() diff --git a/unwind/db.py b/unwind/db.py index d94e335..b609c59 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -259,6 +259,13 @@ async def new_connection() -> AsyncGenerator[Connection, None]: async def transacted( conn: Connection, /, *, force_rollback: bool = False ) -> AsyncGenerator[None, None]: + """Start a transaction for the given connection. + + If `force_rollback` is `True` any changes will be rolled back at the end of the + transaction, unless they are explicitly committed. + Nesting transactions is allowed, but mixing values for `force_rollback` will likely + yield unexpected results. + """ transaction = contextlib.nullcontext() if conn.in_transaction() else conn.begin() async with transaction: