feat: add CLI command to load IMDb charts
This introduces a generalized module interface for CLI commands.
This commit is contained in:
parent
1789b2ce45
commit
f7fc84c050
4 changed files with 160 additions and 7 deletions
|
|
@ -2,10 +2,11 @@ import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
|
import sys
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from . import config, db, models, utils
|
from . import cli, config, db, models, utils
|
||||||
from .db import close_connection_pool, open_connection_pool
|
from .db import close_connection_pool, open_connection_pool
|
||||||
from .imdb import refresh_user_ratings_from_imdb
|
from .imdb import refresh_user_ratings_from_imdb
|
||||||
from .imdb_import import download_datasets, import_from_file
|
from .imdb_import import download_datasets, import_from_file
|
||||||
|
|
@ -70,8 +71,8 @@ async def run_download_imdb_dataset(basics_path: Path, ratings_path: Path):
|
||||||
|
|
||||||
|
|
||||||
def getargs():
|
def getargs():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser(prog="unwind", allow_abbrev=False)
|
||||||
commands = parser.add_subparsers(required=True)
|
commands = parser.add_subparsers(title="commands", metavar="COMMAND", dest="mode")
|
||||||
|
|
||||||
parser_import_imdb_dataset = commands.add_parser(
|
parser_import_imdb_dataset = commands.add_parser(
|
||||||
"import-imdb-dataset",
|
"import-imdb-dataset",
|
||||||
|
|
@ -145,12 +146,20 @@ def getargs():
|
||||||
help="Allow overwriting an existing user. WARNING: This will reset the user's password!",
|
help="Allow overwriting an existing user. WARNING: This will reset the user's password!",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for module in cli.modules:
|
||||||
|
cmd = commands.add_parser(module.name, help=module.help, allow_abbrev=False)
|
||||||
|
module.add_args(cmd)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
except TypeError:
|
except TypeError:
|
||||||
parser.print_usage()
|
parser.print_usage()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
if args.mode is None:
|
||||||
|
parser.print_help()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -162,10 +171,7 @@ def main():
|
||||||
)
|
)
|
||||||
log.debug(f"Log level: {config.loglevel}")
|
log.debug(f"Log level: {config.loglevel}")
|
||||||
|
|
||||||
try:
|
args = getargs()
|
||||||
args = getargs()
|
|
||||||
except Exception:
|
|
||||||
return
|
|
||||||
|
|
||||||
if args.mode == "load-user-ratings-from-imdb":
|
if args.mode == "load-user-ratings-from-imdb":
|
||||||
asyncio.run(run_load_user_ratings_from_imdb())
|
asyncio.run(run_load_user_ratings_from_imdb())
|
||||||
|
|
@ -176,5 +182,9 @@ def main():
|
||||||
elif args.mode == "download-imdb-dataset":
|
elif args.mode == "download-imdb-dataset":
|
||||||
asyncio.run(run_download_imdb_dataset(args.basics, args.ratings))
|
asyncio.run(run_download_imdb_dataset(args.basics, args.ratings))
|
||||||
|
|
||||||
|
modes = {m.name: m.main for m in cli.modules}
|
||||||
|
if handler := modes.get(args.mode):
|
||||||
|
asyncio.run(handler(args))
|
||||||
|
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
39
unwind/cli/__init__.py
Normal file
39
unwind/cli/__init__.py
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
import argparse
|
||||||
|
import importlib
|
||||||
|
from pathlib import Path
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Any, Callable, Coroutine, Iterable, Protocol, TypeGuard
|
||||||
|
|
||||||
|
type CommandHandler = Callable[[argparse.Namespace], Coroutine[Any, Any, None]]
|
||||||
|
|
||||||
|
|
||||||
|
class CliModule(Protocol):
|
||||||
|
name: str
|
||||||
|
help: str
|
||||||
|
add_args: Callable[[argparse.ArgumentParser], None]
|
||||||
|
main: CommandHandler
|
||||||
|
|
||||||
|
|
||||||
|
def _is_cli_module(m: ModuleType) -> TypeGuard[CliModule]:
|
||||||
|
return (
|
||||||
|
hasattr(m, "name")
|
||||||
|
and hasattr(m, "help")
|
||||||
|
and hasattr(m, "add_args")
|
||||||
|
and hasattr(m, "main")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_clidir = Path(__file__).parent
|
||||||
|
|
||||||
|
|
||||||
|
def _load_cmds() -> Iterable[CliModule]:
|
||||||
|
"""Return all CLI command modules."""
|
||||||
|
for f in _clidir.iterdir():
|
||||||
|
if f.suffix == ".py" and not f.name.startswith("__"):
|
||||||
|
m = importlib.import_module(f"{__package__}.{f.stem}")
|
||||||
|
if not _is_cli_module(m):
|
||||||
|
raise ValueError(f"Invalid CLI module: {m!a}")
|
||||||
|
yield m
|
||||||
|
|
||||||
|
|
||||||
|
modules = sorted(_load_cmds(), key=lambda m: m.name)
|
||||||
97
unwind/cli/load_imdb_charts.py
Normal file
97
unwind/cli/load_imdb_charts.py
Normal file
|
|
@ -0,0 +1,97 @@
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from unwind import db, imdb, models, types, utils
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
name = "load-imdb-charts"
|
||||||
|
help = "Load and import charts from imdb.com."
|
||||||
|
|
||||||
|
|
||||||
|
def add_args(cmd: argparse.ArgumentParser) -> None:
|
||||||
|
cmd.add_argument(
|
||||||
|
"--select",
|
||||||
|
action="append",
|
||||||
|
dest="charts",
|
||||||
|
default=[],
|
||||||
|
choices={"top250", "bottom100", "pop100"},
|
||||||
|
help="Select which charts to refresh.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_movie_ids(
|
||||||
|
conn: db.Connection, imdb_ids: list[imdb.MovieId]
|
||||||
|
) -> dict[imdb.MovieId, types.ULID]:
|
||||||
|
c = models.movies.c
|
||||||
|
query = sa.select(c.imdb_id, c.id).where(c.imdb_id.in_(imdb_ids))
|
||||||
|
rows = await db.fetch_all(conn, query)
|
||||||
|
return {row.imdb_id: types.ULID(row.id) for row in rows}
|
||||||
|
|
||||||
|
|
||||||
|
async def remove_all_awards(
|
||||||
|
conn: db.Connection, category: models.AwardCategory
|
||||||
|
) -> None:
|
||||||
|
stmt = models.awards.delete().where(models.awards.c.category == category)
|
||||||
|
await conn.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
_award_handlers: dict[models.AwardCategory, Callable] = {
|
||||||
|
"imdb-pop-100": imdb.load_most_popular_100,
|
||||||
|
"imdb-top-250": imdb.load_top_250,
|
||||||
|
"imdb-bottom-100": imdb.load_bottom_100,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def update_awards(conn: db.Connection, category: models.AwardCategory) -> None:
|
||||||
|
load_imdb_ids = _award_handlers[category]
|
||||||
|
imdb_ids = await load_imdb_ids()
|
||||||
|
|
||||||
|
available = await get_movie_ids(conn, imdb_ids)
|
||||||
|
if missing := set(imdb_ids).difference(available):
|
||||||
|
log.warning(
|
||||||
|
"⚠️ Charts for category (%a) contained %i unknown movies: %a",
|
||||||
|
category,
|
||||||
|
len(missing),
|
||||||
|
missing,
|
||||||
|
)
|
||||||
|
|
||||||
|
await remove_all_awards(conn, category=category)
|
||||||
|
|
||||||
|
for pos, imdb_id in enumerate(imdb_ids, 1):
|
||||||
|
if (movie_id := available.get(imdb_id)) is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
award = models.Award(
|
||||||
|
movie_id=movie_id,
|
||||||
|
category=category,
|
||||||
|
details=utils.json_dump({"position": pos}),
|
||||||
|
)
|
||||||
|
await db.add(conn, award)
|
||||||
|
|
||||||
|
|
||||||
|
async def main(args: argparse.Namespace) -> None:
|
||||||
|
await db.open_connection_pool()
|
||||||
|
|
||||||
|
if not args.charts:
|
||||||
|
args.charts = {"top250", "bottom100", "pop100"}
|
||||||
|
|
||||||
|
if "pop100" in args.charts:
|
||||||
|
async with db.transaction() as conn:
|
||||||
|
await update_awards(conn, "imdb-pop-100")
|
||||||
|
log.info("✨ Updated most popular 100 movies.")
|
||||||
|
|
||||||
|
if "bottom100" in args.charts:
|
||||||
|
async with db.transaction() as conn:
|
||||||
|
await update_awards(conn, "imdb-bottom-100")
|
||||||
|
log.info("✨ Updated bottom 100 movies.")
|
||||||
|
|
||||||
|
if "top250" in args.charts:
|
||||||
|
async with db.transaction() as conn:
|
||||||
|
await update_awards(conn, "imdb-top-250")
|
||||||
|
log.info("✨ Updated top 250 rated movies.")
|
||||||
|
|
||||||
|
await db.close_connection_pool()
|
||||||
|
|
@ -259,6 +259,13 @@ async def new_connection() -> AsyncGenerator[Connection, None]:
|
||||||
async def transacted(
|
async def transacted(
|
||||||
conn: Connection, /, *, force_rollback: bool = False
|
conn: Connection, /, *, force_rollback: bool = False
|
||||||
) -> AsyncGenerator[None, None]:
|
) -> AsyncGenerator[None, None]:
|
||||||
|
"""Start a transaction for the given connection.
|
||||||
|
|
||||||
|
If `force_rollback` is `True` any changes will be rolled back at the end of the
|
||||||
|
transaction, unless they are explicitly committed.
|
||||||
|
Nesting transactions is allowed, but mixing values for `force_rollback` will likely
|
||||||
|
yield unexpected results.
|
||||||
|
"""
|
||||||
transaction = contextlib.nullcontext() if conn.in_transaction() else conn.begin()
|
transaction = contextlib.nullcontext() if conn.in_transaction() else conn.begin()
|
||||||
|
|
||||||
async with transaction:
|
async with transaction:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue