diff --git a/unwind/db.py b/unwind/db.py index f2b519b..5cb72e5 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -83,12 +83,31 @@ async def get(model: Type[ModelType], **kwds) -> Optional[ModelType]: return fields_ = ", ".join(f.name for f in fields(model)) - cond = " AND ".join(f"{k}=:{k}" for k, v in values.items()) + cond = " AND ".join(f"{k}=:{k}" for k in values) query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" row = await shared_connection().fetch_one(query=query, values=values) return fromplain(model, row) if row else None +async def get_many(model: Type[ModelType], **kwds) -> list[ModelType]: + keys = { + k: [f"{k}_{i}" for i, _ in enumerate(vs, start=1)] for k, vs in kwds.items() + } + + if not keys: + return [] + + values = {n: v for k, vs in kwds.items() for n, v in zip(keys[k], vs)} + + fields_ = ", ".join(f.name for f in fields(model)) + cond = " AND ".join( + f"{k} IN ({','.join(':'+n for n in ns)})" for k, ns in keys.items() + ) + query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" + rows = await shared_connection().fetch_all(query=query, values=values) + return [fromplain(model, row) for row in rows] + + async def update(item): values = asplain(item) keys = ", ".join(f"{k}=:{k}" for k in values if k != "id") @@ -107,6 +126,36 @@ async def add_or_update_user(user: User): await update(user) +async def add_or_update_many_movies(movies: list[Movie]): + """Add or update Movies in the database. + + This is an optimized version of `add_or_update_movie` for the purpose + of bulk operations. + """ + # for movie in movies: + # await add_or_update_movie(movie) + db_movies = { + m.imdb_id: m for m in await get_many(Movie, imdb_id=[m.imdb_id for m in movies]) + } + for movie in movies: + # XXX optimize bulk add & update as well + if movie.imdb_id not in db_movies: + await add(movie) + else: + db_movie = db_movies[movie.imdb_id] + movie.id = db_movie.id + + # We want to keep any existing value in the DB for all optional fields. + for f in optional_fields(movie): + if getattr(movie, f.name) is None: + setattr(movie, f.name, getattr(db_movie, f.name)) + + if movie.updated <= db_movie.updated: + return + + await update(movie) + + async def add_or_update_movie(movie: Movie): """Add or update a Movie in the database. diff --git a/unwind/imdb_import.py b/unwind/imdb_import.py index 3482774..e5888ef 100644 --- a/unwind/imdb_import.py +++ b/unwind/imdb_import.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Optional, cast from . import db -from .db import add_or_update_movie +from .db import add_or_update_many_movies from .imdb import score_from_imdb_rating from .models import Movie @@ -191,6 +191,8 @@ async def import_from_file(basics_path: Path, ratings_path: Path): async with db.shared_connection().transaction(): + chunk = [] + for i, m in enumerate(read_basics(basics_path)): if i / total > perc: @@ -211,4 +213,12 @@ async def import_from_file(basics_path: Path, ratings_path: Path): continue m.score = scores.get(m.imdb_id) - await add_or_update_movie(m) + chunk.append(m) + + if len(chunk) > 1000: + await add_or_update_many_movies(chunk) + chunk = [] + + if chunk: + await add_or_update_many_movies(chunk) + chunk = []