feat: add CLI command to load IMDb charts
This introduces a generalized module interface for CLI commands.
This commit is contained in:
parent
1789b2ce45
commit
f7fc84c050
4 changed files with 160 additions and 7 deletions
|
|
@ -2,10 +2,11 @@ import argparse
|
|||
import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
import sys
|
||||
from base64 import b64encode
|
||||
from pathlib import Path
|
||||
|
||||
from . import config, db, models, utils
|
||||
from . import cli, config, db, models, utils
|
||||
from .db import close_connection_pool, open_connection_pool
|
||||
from .imdb import refresh_user_ratings_from_imdb
|
||||
from .imdb_import import download_datasets, import_from_file
|
||||
|
|
@ -70,8 +71,8 @@ async def run_download_imdb_dataset(basics_path: Path, ratings_path: Path):
|
|||
|
||||
|
||||
def getargs():
|
||||
parser = argparse.ArgumentParser()
|
||||
commands = parser.add_subparsers(required=True)
|
||||
parser = argparse.ArgumentParser(prog="unwind", allow_abbrev=False)
|
||||
commands = parser.add_subparsers(title="commands", metavar="COMMAND", dest="mode")
|
||||
|
||||
parser_import_imdb_dataset = commands.add_parser(
|
||||
"import-imdb-dataset",
|
||||
|
|
@ -145,12 +146,20 @@ def getargs():
|
|||
help="Allow overwriting an existing user. WARNING: This will reset the user's password!",
|
||||
)
|
||||
|
||||
for module in cli.modules:
|
||||
cmd = commands.add_parser(module.name, help=module.help, allow_abbrev=False)
|
||||
module.add_args(cmd)
|
||||
|
||||
try:
|
||||
args = parser.parse_args()
|
||||
except TypeError:
|
||||
parser.print_usage()
|
||||
raise
|
||||
|
||||
if args.mode is None:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
|
|
@ -162,10 +171,7 @@ def main():
|
|||
)
|
||||
log.debug(f"Log level: {config.loglevel}")
|
||||
|
||||
try:
|
||||
args = getargs()
|
||||
except Exception:
|
||||
return
|
||||
args = getargs()
|
||||
|
||||
if args.mode == "load-user-ratings-from-imdb":
|
||||
asyncio.run(run_load_user_ratings_from_imdb())
|
||||
|
|
@ -176,5 +182,9 @@ def main():
|
|||
elif args.mode == "download-imdb-dataset":
|
||||
asyncio.run(run_download_imdb_dataset(args.basics, args.ratings))
|
||||
|
||||
modes = {m.name: m.main for m in cli.modules}
|
||||
if handler := modes.get(args.mode):
|
||||
asyncio.run(handler(args))
|
||||
|
||||
|
||||
main()
|
||||
|
|
|
|||
39
unwind/cli/__init__.py
Normal file
39
unwind/cli/__init__.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
import argparse
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Coroutine, Iterable, Protocol, TypeGuard
|
||||
|
||||
type CommandHandler = Callable[[argparse.Namespace], Coroutine[Any, Any, None]]
|
||||
|
||||
|
||||
class CliModule(Protocol):
|
||||
name: str
|
||||
help: str
|
||||
add_args: Callable[[argparse.ArgumentParser], None]
|
||||
main: CommandHandler
|
||||
|
||||
|
||||
def _is_cli_module(m: ModuleType) -> TypeGuard[CliModule]:
|
||||
return (
|
||||
hasattr(m, "name")
|
||||
and hasattr(m, "help")
|
||||
and hasattr(m, "add_args")
|
||||
and hasattr(m, "main")
|
||||
)
|
||||
|
||||
|
||||
_clidir = Path(__file__).parent
|
||||
|
||||
|
||||
def _load_cmds() -> Iterable[CliModule]:
|
||||
"""Return all CLI command modules."""
|
||||
for f in _clidir.iterdir():
|
||||
if f.suffix == ".py" and not f.name.startswith("__"):
|
||||
m = importlib.import_module(f"{__package__}.{f.stem}")
|
||||
if not _is_cli_module(m):
|
||||
raise ValueError(f"Invalid CLI module: {m!a}")
|
||||
yield m
|
||||
|
||||
|
||||
modules = sorted(_load_cmds(), key=lambda m: m.name)
|
||||
97
unwind/cli/load_imdb_charts.py
Normal file
97
unwind/cli/load_imdb_charts.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
import argparse
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from unwind import db, imdb, models, types, utils
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
name = "load-imdb-charts"
|
||||
help = "Load and import charts from imdb.com."
|
||||
|
||||
|
||||
def add_args(cmd: argparse.ArgumentParser) -> None:
|
||||
cmd.add_argument(
|
||||
"--select",
|
||||
action="append",
|
||||
dest="charts",
|
||||
default=[],
|
||||
choices={"top250", "bottom100", "pop100"},
|
||||
help="Select which charts to refresh.",
|
||||
)
|
||||
|
||||
|
||||
async def get_movie_ids(
|
||||
conn: db.Connection, imdb_ids: list[imdb.MovieId]
|
||||
) -> dict[imdb.MovieId, types.ULID]:
|
||||
c = models.movies.c
|
||||
query = sa.select(c.imdb_id, c.id).where(c.imdb_id.in_(imdb_ids))
|
||||
rows = await db.fetch_all(conn, query)
|
||||
return {row.imdb_id: types.ULID(row.id) for row in rows}
|
||||
|
||||
|
||||
async def remove_all_awards(
|
||||
conn: db.Connection, category: models.AwardCategory
|
||||
) -> None:
|
||||
stmt = models.awards.delete().where(models.awards.c.category == category)
|
||||
await conn.execute(stmt)
|
||||
|
||||
|
||||
_award_handlers: dict[models.AwardCategory, Callable] = {
|
||||
"imdb-pop-100": imdb.load_most_popular_100,
|
||||
"imdb-top-250": imdb.load_top_250,
|
||||
"imdb-bottom-100": imdb.load_bottom_100,
|
||||
}
|
||||
|
||||
|
||||
async def update_awards(conn: db.Connection, category: models.AwardCategory) -> None:
|
||||
load_imdb_ids = _award_handlers[category]
|
||||
imdb_ids = await load_imdb_ids()
|
||||
|
||||
available = await get_movie_ids(conn, imdb_ids)
|
||||
if missing := set(imdb_ids).difference(available):
|
||||
log.warning(
|
||||
"⚠️ Charts for category (%a) contained %i unknown movies: %a",
|
||||
category,
|
||||
len(missing),
|
||||
missing,
|
||||
)
|
||||
|
||||
await remove_all_awards(conn, category=category)
|
||||
|
||||
for pos, imdb_id in enumerate(imdb_ids, 1):
|
||||
if (movie_id := available.get(imdb_id)) is None:
|
||||
continue
|
||||
|
||||
award = models.Award(
|
||||
movie_id=movie_id,
|
||||
category=category,
|
||||
details=utils.json_dump({"position": pos}),
|
||||
)
|
||||
await db.add(conn, award)
|
||||
|
||||
|
||||
async def main(args: argparse.Namespace) -> None:
|
||||
await db.open_connection_pool()
|
||||
|
||||
if not args.charts:
|
||||
args.charts = {"top250", "bottom100", "pop100"}
|
||||
|
||||
if "pop100" in args.charts:
|
||||
async with db.transaction() as conn:
|
||||
await update_awards(conn, "imdb-pop-100")
|
||||
log.info("✨ Updated most popular 100 movies.")
|
||||
|
||||
if "bottom100" in args.charts:
|
||||
async with db.transaction() as conn:
|
||||
await update_awards(conn, "imdb-bottom-100")
|
||||
log.info("✨ Updated bottom 100 movies.")
|
||||
|
||||
if "top250" in args.charts:
|
||||
async with db.transaction() as conn:
|
||||
await update_awards(conn, "imdb-top-250")
|
||||
log.info("✨ Updated top 250 rated movies.")
|
||||
|
||||
await db.close_connection_pool()
|
||||
|
|
@ -259,6 +259,13 @@ async def new_connection() -> AsyncGenerator[Connection, None]:
|
|||
async def transacted(
|
||||
conn: Connection, /, *, force_rollback: bool = False
|
||||
) -> AsyncGenerator[None, None]:
|
||||
"""Start a transaction for the given connection.
|
||||
|
||||
If `force_rollback` is `True` any changes will be rolled back at the end of the
|
||||
transaction, unless they are explicitly committed.
|
||||
Nesting transactions is allowed, but mixing values for `force_rollback` will likely
|
||||
yield unexpected results.
|
||||
"""
|
||||
transaction = contextlib.nullcontext() if conn.in_transaction() else conn.begin()
|
||||
|
||||
async with transaction:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue