feat: add CLI command to load IMDb charts

This introduces a generalized module interface for CLI commands.
This commit is contained in:
ducklet 2024-05-18 23:46:56 +02:00
parent 1789b2ce45
commit f7fc84c050
4 changed files with 160 additions and 7 deletions

View file

@ -2,10 +2,11 @@ import argparse
import asyncio
import logging
import secrets
import sys
from base64 import b64encode
from pathlib import Path
from . import config, db, models, utils
from . import cli, config, db, models, utils
from .db import close_connection_pool, open_connection_pool
from .imdb import refresh_user_ratings_from_imdb
from .imdb_import import download_datasets, import_from_file
@ -70,8 +71,8 @@ async def run_download_imdb_dataset(basics_path: Path, ratings_path: Path):
def getargs():
parser = argparse.ArgumentParser()
commands = parser.add_subparsers(required=True)
parser = argparse.ArgumentParser(prog="unwind", allow_abbrev=False)
commands = parser.add_subparsers(title="commands", metavar="COMMAND", dest="mode")
parser_import_imdb_dataset = commands.add_parser(
"import-imdb-dataset",
@ -145,12 +146,20 @@ def getargs():
help="Allow overwriting an existing user. WARNING: This will reset the user's password!",
)
for module in cli.modules:
cmd = commands.add_parser(module.name, help=module.help, allow_abbrev=False)
module.add_args(cmd)
try:
args = parser.parse_args()
except TypeError:
parser.print_usage()
raise
if args.mode is None:
parser.print_help()
sys.exit(1)
return args
@ -162,10 +171,7 @@ def main():
)
log.debug(f"Log level: {config.loglevel}")
try:
args = getargs()
except Exception:
return
args = getargs()
if args.mode == "load-user-ratings-from-imdb":
asyncio.run(run_load_user_ratings_from_imdb())
@ -176,5 +182,9 @@ def main():
elif args.mode == "download-imdb-dataset":
asyncio.run(run_download_imdb_dataset(args.basics, args.ratings))
modes = {m.name: m.main for m in cli.modules}
if handler := modes.get(args.mode):
asyncio.run(handler(args))
main()

39
unwind/cli/__init__.py Normal file
View file

@ -0,0 +1,39 @@
import argparse
import importlib
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, Coroutine, Iterable, Protocol, TypeGuard
type CommandHandler = Callable[[argparse.Namespace], Coroutine[Any, Any, None]]
class CliModule(Protocol):
name: str
help: str
add_args: Callable[[argparse.ArgumentParser], None]
main: CommandHandler
def _is_cli_module(m: ModuleType) -> TypeGuard[CliModule]:
return (
hasattr(m, "name")
and hasattr(m, "help")
and hasattr(m, "add_args")
and hasattr(m, "main")
)
_clidir = Path(__file__).parent
def _load_cmds() -> Iterable[CliModule]:
"""Return all CLI command modules."""
for f in _clidir.iterdir():
if f.suffix == ".py" and not f.name.startswith("__"):
m = importlib.import_module(f"{__package__}.{f.stem}")
if not _is_cli_module(m):
raise ValueError(f"Invalid CLI module: {m!a}")
yield m
modules = sorted(_load_cmds(), key=lambda m: m.name)

View file

@ -0,0 +1,97 @@
import argparse
import logging
from typing import Callable
import sqlalchemy as sa
from unwind import db, imdb, models, types, utils
log = logging.getLogger(__name__)
name = "load-imdb-charts"
help = "Load and import charts from imdb.com."
def add_args(cmd: argparse.ArgumentParser) -> None:
cmd.add_argument(
"--select",
action="append",
dest="charts",
default=[],
choices={"top250", "bottom100", "pop100"},
help="Select which charts to refresh.",
)
async def get_movie_ids(
conn: db.Connection, imdb_ids: list[imdb.MovieId]
) -> dict[imdb.MovieId, types.ULID]:
c = models.movies.c
query = sa.select(c.imdb_id, c.id).where(c.imdb_id.in_(imdb_ids))
rows = await db.fetch_all(conn, query)
return {row.imdb_id: types.ULID(row.id) for row in rows}
async def remove_all_awards(
conn: db.Connection, category: models.AwardCategory
) -> None:
stmt = models.awards.delete().where(models.awards.c.category == category)
await conn.execute(stmt)
_award_handlers: dict[models.AwardCategory, Callable] = {
"imdb-pop-100": imdb.load_most_popular_100,
"imdb-top-250": imdb.load_top_250,
"imdb-bottom-100": imdb.load_bottom_100,
}
async def update_awards(conn: db.Connection, category: models.AwardCategory) -> None:
load_imdb_ids = _award_handlers[category]
imdb_ids = await load_imdb_ids()
available = await get_movie_ids(conn, imdb_ids)
if missing := set(imdb_ids).difference(available):
log.warning(
"⚠️ Charts for category (%a) contained %i unknown movies: %a",
category,
len(missing),
missing,
)
await remove_all_awards(conn, category=category)
for pos, imdb_id in enumerate(imdb_ids, 1):
if (movie_id := available.get(imdb_id)) is None:
continue
award = models.Award(
movie_id=movie_id,
category=category,
details=utils.json_dump({"position": pos}),
)
await db.add(conn, award)
async def main(args: argparse.Namespace) -> None:
await db.open_connection_pool()
if not args.charts:
args.charts = {"top250", "bottom100", "pop100"}
if "pop100" in args.charts:
async with db.transaction() as conn:
await update_awards(conn, "imdb-pop-100")
log.info("✨ Updated most popular 100 movies.")
if "bottom100" in args.charts:
async with db.transaction() as conn:
await update_awards(conn, "imdb-bottom-100")
log.info("✨ Updated bottom 100 movies.")
if "top250" in args.charts:
async with db.transaction() as conn:
await update_awards(conn, "imdb-top-250")
log.info("✨ Updated top 250 rated movies.")
await db.close_connection_pool()

View file

@ -259,6 +259,13 @@ async def new_connection() -> AsyncGenerator[Connection, None]:
async def transacted(
conn: Connection, /, *, force_rollback: bool = False
) -> AsyncGenerator[None, None]:
"""Start a transaction for the given connection.
If `force_rollback` is `True` any changes will be rolled back at the end of the
transaction, unless they are explicitly committed.
Nesting transactions is allowed, but mixing values for `force_rollback` will likely
yield unexpected results.
"""
transaction = contextlib.nullcontext() if conn.in_transaction() else conn.begin()
async with transaction: