feat: add CLI command to load IMDb charts

This introduces a generalized module interface for CLI commands.
This commit is contained in:
ducklet 2024-05-18 23:46:56 +02:00
parent 1789b2ce45
commit f7fc84c050
4 changed files with 160 additions and 7 deletions

View file

@ -2,10 +2,11 @@ import argparse
import asyncio import asyncio
import logging import logging
import secrets import secrets
import sys
from base64 import b64encode from base64 import b64encode
from pathlib import Path from pathlib import Path
from . import config, db, models, utils from . import cli, config, db, models, utils
from .db import close_connection_pool, open_connection_pool from .db import close_connection_pool, open_connection_pool
from .imdb import refresh_user_ratings_from_imdb from .imdb import refresh_user_ratings_from_imdb
from .imdb_import import download_datasets, import_from_file from .imdb_import import download_datasets, import_from_file
@ -70,8 +71,8 @@ async def run_download_imdb_dataset(basics_path: Path, ratings_path: Path):
def getargs(): def getargs():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(prog="unwind", allow_abbrev=False)
commands = parser.add_subparsers(required=True) commands = parser.add_subparsers(title="commands", metavar="COMMAND", dest="mode")
parser_import_imdb_dataset = commands.add_parser( parser_import_imdb_dataset = commands.add_parser(
"import-imdb-dataset", "import-imdb-dataset",
@ -145,12 +146,20 @@ def getargs():
help="Allow overwriting an existing user. WARNING: This will reset the user's password!", help="Allow overwriting an existing user. WARNING: This will reset the user's password!",
) )
for module in cli.modules:
cmd = commands.add_parser(module.name, help=module.help, allow_abbrev=False)
module.add_args(cmd)
try: try:
args = parser.parse_args() args = parser.parse_args()
except TypeError: except TypeError:
parser.print_usage() parser.print_usage()
raise raise
if args.mode is None:
parser.print_help()
sys.exit(1)
return args return args
@ -162,10 +171,7 @@ def main():
) )
log.debug(f"Log level: {config.loglevel}") log.debug(f"Log level: {config.loglevel}")
try: args = getargs()
args = getargs()
except Exception:
return
if args.mode == "load-user-ratings-from-imdb": if args.mode == "load-user-ratings-from-imdb":
asyncio.run(run_load_user_ratings_from_imdb()) asyncio.run(run_load_user_ratings_from_imdb())
@ -176,5 +182,9 @@ def main():
elif args.mode == "download-imdb-dataset": elif args.mode == "download-imdb-dataset":
asyncio.run(run_download_imdb_dataset(args.basics, args.ratings)) asyncio.run(run_download_imdb_dataset(args.basics, args.ratings))
modes = {m.name: m.main for m in cli.modules}
if handler := modes.get(args.mode):
asyncio.run(handler(args))
main() main()

39
unwind/cli/__init__.py Normal file
View file

@ -0,0 +1,39 @@
import argparse
import importlib
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, Coroutine, Iterable, Protocol, TypeGuard
type CommandHandler = Callable[[argparse.Namespace], Coroutine[Any, Any, None]]
class CliModule(Protocol):
name: str
help: str
add_args: Callable[[argparse.ArgumentParser], None]
main: CommandHandler
def _is_cli_module(m: ModuleType) -> TypeGuard[CliModule]:
return (
hasattr(m, "name")
and hasattr(m, "help")
and hasattr(m, "add_args")
and hasattr(m, "main")
)
_clidir = Path(__file__).parent
def _load_cmds() -> Iterable[CliModule]:
"""Return all CLI command modules."""
for f in _clidir.iterdir():
if f.suffix == ".py" and not f.name.startswith("__"):
m = importlib.import_module(f"{__package__}.{f.stem}")
if not _is_cli_module(m):
raise ValueError(f"Invalid CLI module: {m!a}")
yield m
modules = sorted(_load_cmds(), key=lambda m: m.name)

View file

@ -0,0 +1,97 @@
import argparse
import logging
from typing import Callable
import sqlalchemy as sa
from unwind import db, imdb, models, types, utils
log = logging.getLogger(__name__)
name = "load-imdb-charts"
help = "Load and import charts from imdb.com."
def add_args(cmd: argparse.ArgumentParser) -> None:
cmd.add_argument(
"--select",
action="append",
dest="charts",
default=[],
choices={"top250", "bottom100", "pop100"},
help="Select which charts to refresh.",
)
async def get_movie_ids(
conn: db.Connection, imdb_ids: list[imdb.MovieId]
) -> dict[imdb.MovieId, types.ULID]:
c = models.movies.c
query = sa.select(c.imdb_id, c.id).where(c.imdb_id.in_(imdb_ids))
rows = await db.fetch_all(conn, query)
return {row.imdb_id: types.ULID(row.id) for row in rows}
async def remove_all_awards(
conn: db.Connection, category: models.AwardCategory
) -> None:
stmt = models.awards.delete().where(models.awards.c.category == category)
await conn.execute(stmt)
_award_handlers: dict[models.AwardCategory, Callable] = {
"imdb-pop-100": imdb.load_most_popular_100,
"imdb-top-250": imdb.load_top_250,
"imdb-bottom-100": imdb.load_bottom_100,
}
async def update_awards(conn: db.Connection, category: models.AwardCategory) -> None:
load_imdb_ids = _award_handlers[category]
imdb_ids = await load_imdb_ids()
available = await get_movie_ids(conn, imdb_ids)
if missing := set(imdb_ids).difference(available):
log.warning(
"⚠️ Charts for category (%a) contained %i unknown movies: %a",
category,
len(missing),
missing,
)
await remove_all_awards(conn, category=category)
for pos, imdb_id in enumerate(imdb_ids, 1):
if (movie_id := available.get(imdb_id)) is None:
continue
award = models.Award(
movie_id=movie_id,
category=category,
details=utils.json_dump({"position": pos}),
)
await db.add(conn, award)
async def main(args: argparse.Namespace) -> None:
await db.open_connection_pool()
if not args.charts:
args.charts = {"top250", "bottom100", "pop100"}
if "pop100" in args.charts:
async with db.transaction() as conn:
await update_awards(conn, "imdb-pop-100")
log.info("✨ Updated most popular 100 movies.")
if "bottom100" in args.charts:
async with db.transaction() as conn:
await update_awards(conn, "imdb-bottom-100")
log.info("✨ Updated bottom 100 movies.")
if "top250" in args.charts:
async with db.transaction() as conn:
await update_awards(conn, "imdb-top-250")
log.info("✨ Updated top 250 rated movies.")
await db.close_connection_pool()

View file

@ -259,6 +259,13 @@ async def new_connection() -> AsyncGenerator[Connection, None]:
async def transacted( async def transacted(
conn: Connection, /, *, force_rollback: bool = False conn: Connection, /, *, force_rollback: bool = False
) -> AsyncGenerator[None, None]: ) -> AsyncGenerator[None, None]:
"""Start a transaction for the given connection.
If `force_rollback` is `True` any changes will be rolled back at the end of the
transaction, unless they are explicitly committed.
Nesting transactions is allowed, but mixing values for `force_rollback` will likely
yield unexpected results.
"""
transaction = contextlib.nullcontext() if conn.in_transaction() else conn.begin() transaction = contextlib.nullcontext() if conn.in_transaction() else conn.begin()
async with transaction: async with transaction: