673 lines
20 KiB
Python
673 lines
20 KiB
Python
import asyncio
|
|
import contextlib
|
|
import logging
|
|
import re
|
|
import threading
|
|
from pathlib import Path
|
|
from typing import Any, Iterable, Literal, Type, TypeVar
|
|
|
|
import sqlalchemy as sa
|
|
from databases import Database
|
|
|
|
from . import config
|
|
from .models import (
|
|
Movie,
|
|
Progress,
|
|
Rating,
|
|
User,
|
|
asplain,
|
|
fields,
|
|
fromplain,
|
|
optional_fields,
|
|
utcnow,
|
|
)
|
|
from .types import ULID
|
|
|
|
log = logging.getLogger(__name__)
|
|
T = TypeVar("T")
|
|
|
|
_shared_connection: Database | None = None
|
|
|
|
|
|
async def open_connection_pool() -> None:
|
|
"""Open the DB connection pool.
|
|
|
|
This function needs to be called before any access to the database can happen.
|
|
"""
|
|
db = shared_connection()
|
|
await db.connect()
|
|
|
|
await apply_db_patches(db)
|
|
|
|
|
|
async def close_connection_pool() -> None:
|
|
"""Close the DB connection pool.
|
|
|
|
This function should be called before the app shuts down to ensure all data
|
|
has been flushed to the database.
|
|
"""
|
|
db = shared_connection()
|
|
|
|
# Run automatic ANALYZE prior to closing the db,
|
|
# see https://sqlite.com/lang_analyze.html.
|
|
await db.execute("PRAGMA analysis_limit=400")
|
|
await db.execute("PRAGMA optimize")
|
|
|
|
await db.disconnect()
|
|
|
|
|
|
async def _create_patch_db(db):
|
|
query = """
|
|
CREATE TABLE IF NOT EXISTS db_patches (
|
|
id INTEGER PRIMARY KEY,
|
|
current TEXT
|
|
)
|
|
"""
|
|
await db.execute(query)
|
|
|
|
|
|
async def current_patch_level(db) -> str:
|
|
await _create_patch_db(db)
|
|
|
|
query = "SELECT current FROM db_patches"
|
|
current = await db.fetch_val(query)
|
|
return current or ""
|
|
|
|
|
|
async def set_current_patch_level(db, current: str):
|
|
await _create_patch_db(db)
|
|
|
|
query = """
|
|
INSERT INTO db_patches VALUES (1, :current)
|
|
ON CONFLICT DO UPDATE SET current=excluded.current
|
|
"""
|
|
await db.execute(query, values={"current": current})
|
|
|
|
|
|
db_patches_dir = Path(__file__).parent / "sql"
|
|
|
|
|
|
async def apply_db_patches(db: Database):
|
|
"""Apply all remaining patches to the database.
|
|
|
|
Beware that patches will be applied in lexicographical order,
|
|
i.e. "10" comes before "9".
|
|
|
|
The current patch state is recorded in the DB itself.
|
|
|
|
Please note that every SQL statement in a patch file MUST be terminated
|
|
using two consecutive semi-colons (;).
|
|
Failing to do so will result in an error.
|
|
"""
|
|
applied_lvl = await current_patch_level(db)
|
|
|
|
did_patch = False
|
|
|
|
for patchfile in sorted(db_patches_dir.glob("*.sql"), key=lambda p: p.stem):
|
|
patch_lvl = patchfile.stem
|
|
if patch_lvl <= applied_lvl:
|
|
continue
|
|
|
|
log.info("Applying patch: %s", patch_lvl)
|
|
|
|
sql = patchfile.read_text()
|
|
queries = sql.split(";;")
|
|
if len(queries) < 2:
|
|
log.error(
|
|
"Patch file is missing statement terminator (`;;'): %s", patchfile
|
|
)
|
|
raise RuntimeError("No statement found.")
|
|
|
|
async with db.transaction():
|
|
for query in queries:
|
|
await db.execute(query)
|
|
|
|
await set_current_patch_level(db, patch_lvl)
|
|
|
|
did_patch = True
|
|
|
|
if did_patch:
|
|
await db.execute("vacuum")
|
|
|
|
|
|
async def get_import_progress() -> Progress | None:
|
|
"""Return the latest import progress."""
|
|
return await get(Progress, type="import-imdb-movies", order_by="started DESC")
|
|
|
|
|
|
async def stop_import_progress(*, error: BaseException | None = None):
|
|
"""Stop the current import.
|
|
|
|
If an error is given, it will be logged to the progress state.
|
|
"""
|
|
current = await get_import_progress()
|
|
is_running = current and current.stopped is None
|
|
|
|
if not is_running:
|
|
return
|
|
assert current
|
|
|
|
if error:
|
|
current.error = repr(error)
|
|
current.stopped = utcnow().isoformat()
|
|
|
|
await update(current)
|
|
|
|
|
|
async def set_import_progress(progress: float) -> Progress:
|
|
"""Set the current import progress percentage.
|
|
|
|
If no import is currently running, this will create a new one.
|
|
"""
|
|
progress = min(max(0.0, progress), 100.0) # clamp to 0 <= progress <= 100
|
|
|
|
current = await get_import_progress()
|
|
is_running = current and current.stopped is None
|
|
|
|
if not is_running:
|
|
current = Progress(type="import-imdb-movies")
|
|
assert current
|
|
|
|
current.percent = progress
|
|
|
|
if is_running:
|
|
await update(current)
|
|
else:
|
|
await add(current)
|
|
|
|
return current
|
|
|
|
|
|
_lock = threading.Lock()
|
|
_prelock = threading.Lock()
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def single_threaded():
|
|
"""Ensure the nested code is run only by a single thread at a time."""
|
|
wait = 1e-5 # XXX not sure if there's a better magic value here
|
|
|
|
# The pre-lock (a lock for the lock) allows for multiple threads to hand of
|
|
# the main lock.
|
|
# With only a single lock the contending thread will spend most of its time
|
|
# in the asyncio.sleep and the reigning thread will have time to finish
|
|
# whatever it's doing and simply acquire the lock again before the other
|
|
# thread has had a change to try.
|
|
# By having another lock (and the same sleep time!) the contending thread
|
|
# will always have a chance to acquire the main lock.
|
|
while not _prelock.acquire(blocking=False):
|
|
await asyncio.sleep(wait)
|
|
|
|
try:
|
|
while not _lock.acquire(blocking=False):
|
|
await asyncio.sleep(wait)
|
|
finally:
|
|
_prelock.release()
|
|
|
|
try:
|
|
yield
|
|
|
|
finally:
|
|
_lock.release()
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def locked_connection():
|
|
async with single_threaded():
|
|
yield shared_connection()
|
|
|
|
|
|
def shared_connection() -> Database:
|
|
global _shared_connection
|
|
|
|
if _shared_connection is None:
|
|
uri = f"sqlite:///{config.storage_path}"
|
|
_shared_connection = Database(uri)
|
|
|
|
return _shared_connection
|
|
|
|
|
|
async def add(item):
|
|
# Support late initializing - used for optimization.
|
|
if getattr(item, "_is_lazy", False):
|
|
item._lazy_init()
|
|
|
|
values = asplain(item, serialize=True)
|
|
keys = ", ".join(f"{k}" for k in values)
|
|
placeholders = ", ".join(f":{k}" for k in values)
|
|
query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})"
|
|
async with locked_connection() as conn:
|
|
await conn.execute(query=query, values=values)
|
|
|
|
|
|
ModelType = TypeVar("ModelType")
|
|
|
|
|
|
async def get(
|
|
model: Type[ModelType], *, order_by: str | None = None, **kwds
|
|
) -> ModelType | None:
|
|
"""Load a model instance from the database.
|
|
|
|
Passing `kwds` allows to filter the instance to load. You have to encode the
|
|
values as the appropriate data type for the database prior to passing them
|
|
to this function.
|
|
"""
|
|
values = {k: v for k, v in kwds.items() if v is not None}
|
|
if not values:
|
|
return
|
|
|
|
fields_ = ", ".join(f.name for f in fields(model))
|
|
cond = " AND ".join(f"{k}=:{k}" for k in values)
|
|
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
|
if order_by:
|
|
query += f" ORDER BY {order_by}"
|
|
async with locked_connection() as conn:
|
|
row = await conn.fetch_one(query=query, values=values)
|
|
return fromplain(model, row._mapping, serialized=True) if row else None
|
|
|
|
|
|
async def get_many(
|
|
model: Type[ModelType], **field_sets: set | list
|
|
) -> Iterable[ModelType]:
|
|
"""Return the items with any values matching all given field sets.
|
|
|
|
This is similar to `get_all`, but instead of a scalar value a list of values
|
|
must be given. If any of the given values is set for that field on an item,
|
|
the item is considered a match.
|
|
If no field values are given, no items will be returned.
|
|
"""
|
|
if not field_sets:
|
|
return []
|
|
|
|
table: sa.Table = model.__table__
|
|
query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items()))
|
|
async with locked_connection() as conn:
|
|
rows = await conn.fetch_all(query)
|
|
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
|
|
|
|
|
async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]:
|
|
"""Filter all items by comparing all given field values.
|
|
|
|
If no filters are given, all items will be returned.
|
|
"""
|
|
table: sa.Table = model.__table__
|
|
query = sa.select(model).where(
|
|
*(table.c[k] == v for k, v in field_values.items() if v is not None)
|
|
)
|
|
async with locked_connection() as conn:
|
|
rows = await conn.fetch_all(query)
|
|
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
|
|
|
|
|
async def update(item):
|
|
# Support late initializing - used for optimization.
|
|
if getattr(item, "_is_lazy", False):
|
|
item._lazy_init()
|
|
|
|
values = asplain(item, serialize=True)
|
|
keys = ", ".join(f"{k}=:{k}" for k in values if k != "id")
|
|
query = f"UPDATE {item._table} SET {keys} WHERE id=:id"
|
|
async with locked_connection() as conn:
|
|
await conn.execute(query=query, values=values)
|
|
|
|
|
|
async def remove(item):
|
|
values = asplain(item, filter_fields={"id"}, serialize=True)
|
|
query = f"DELETE FROM {item._table} WHERE id=:id"
|
|
async with locked_connection() as conn:
|
|
await conn.execute(query=query, values=values)
|
|
|
|
|
|
async def add_or_update_user(user: User):
|
|
db_user = await get(User, imdb_id=user.imdb_id)
|
|
if not db_user:
|
|
await add(user)
|
|
else:
|
|
user.id = db_user.id
|
|
|
|
if user != db_user:
|
|
await update(user)
|
|
|
|
|
|
async def add_or_update_many_movies(movies: list[Movie]):
|
|
"""Add or update Movies in the database.
|
|
|
|
This is an optimized version of `add_or_update_movie` for the purpose
|
|
of bulk operations.
|
|
"""
|
|
# for movie in movies:
|
|
# await add_or_update_movie(movie)
|
|
db_movies = {
|
|
m.imdb_id: m for m in await get_many(Movie, imdb_id=[m.imdb_id for m in movies])
|
|
}
|
|
for movie in movies:
|
|
# XXX optimize bulk add & update as well
|
|
if movie.imdb_id not in db_movies:
|
|
await add(movie)
|
|
else:
|
|
db_movie = db_movies[movie.imdb_id]
|
|
movie.id = db_movie.id
|
|
|
|
# We want to keep any existing value in the DB for all optional fields.
|
|
for f in optional_fields(movie):
|
|
if getattr(movie, f.name) is None:
|
|
setattr(movie, f.name, getattr(db_movie, f.name))
|
|
|
|
if movie.updated <= db_movie.updated:
|
|
return
|
|
|
|
await update(movie)
|
|
|
|
|
|
async def add_or_update_movie(movie: Movie):
|
|
"""Add or update a Movie in the database.
|
|
|
|
This is an upsert operation, but it will also update the Movie you pass
|
|
into the function to make its `id` match the DB's movie's `id`, and also
|
|
set all optional values on your Movie that might be unset but exist in the
|
|
database. It's a bidirectional sync.
|
|
"""
|
|
db_movie = await get(Movie, imdb_id=movie.imdb_id)
|
|
if not db_movie:
|
|
await add(movie)
|
|
else:
|
|
movie.id = db_movie.id
|
|
|
|
# We want to keep any existing value in the DB for all optional fields.
|
|
for f in optional_fields(movie):
|
|
if getattr(movie, f.name) is None:
|
|
setattr(movie, f.name, getattr(db_movie, f.name))
|
|
|
|
if movie.updated <= db_movie.updated:
|
|
return
|
|
|
|
await update(movie)
|
|
|
|
|
|
async def add_or_update_rating(rating: Rating) -> bool:
|
|
db_rating = await get(
|
|
Rating, movie_id=str(rating.movie_id), user_id=str(rating.user_id)
|
|
)
|
|
|
|
if not db_rating:
|
|
await add(rating)
|
|
return True
|
|
|
|
else:
|
|
rating.id = db_rating.id
|
|
|
|
if rating != db_rating:
|
|
await update(rating)
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def sql_escape(s: str, char="#"):
|
|
return s.replace(char, 2 * char).replace("%", f"{char}%").replace("_", f"{char}_")
|
|
|
|
|
|
async def find_ratings(
|
|
*,
|
|
title: str | None = None,
|
|
media_type: str | None = None,
|
|
exact: bool = False,
|
|
ignore_tv_episodes: bool = False,
|
|
include_unrated: bool = False,
|
|
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
|
|
limit_rows: int = 10,
|
|
user_ids: Iterable[str] = [],
|
|
):
|
|
values: dict[str, int | str] = {
|
|
"limit_rows": limit_rows,
|
|
}
|
|
|
|
conditions = []
|
|
if title:
|
|
values["escape"] = "#"
|
|
escaped_title = sql_escape(title, char=values["escape"])
|
|
values["pattern"] = (
|
|
"_".join(escaped_title.split())
|
|
if exact
|
|
else "%" + "%".join(escaped_title.split()) + "%"
|
|
)
|
|
conditions.append(
|
|
f"""
|
|
(
|
|
{Movie._table}.title LIKE :pattern ESCAPE :escape
|
|
OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape
|
|
)
|
|
"""
|
|
)
|
|
|
|
if yearcomp:
|
|
op, year = yearcomp
|
|
assert op in "<=>"
|
|
values["year"] = year
|
|
conditions.append(f"{Movie._table}.release_year{op}:year")
|
|
|
|
if media_type:
|
|
values["media_type"] = media_type
|
|
conditions.append(f"{Movie._table}.media_type=:media_type")
|
|
|
|
if ignore_tv_episodes:
|
|
conditions.append(f"{Movie._table}.media_type!='TV Episode'")
|
|
|
|
user_condition = "1=1"
|
|
if user_ids:
|
|
uvs = {f"user_id_{i}": v for i, v in enumerate(user_ids, start=1)}
|
|
values.update(uvs)
|
|
user_condition = f"{Rating._table}.user_id IN ({','.join(':'+n for n in uvs)})"
|
|
|
|
query = f"""
|
|
SELECT DISTINCT {Rating._table}.movie_id
|
|
FROM {Rating._table}
|
|
LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id
|
|
WHERE {user_condition}{(' AND ' + ' AND '.join(conditions)) if conditions else ''}
|
|
ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC, {Movie._table}.imdb_score DESC
|
|
LIMIT :limit_rows
|
|
"""
|
|
async with locked_connection() as conn:
|
|
rows = await conn.fetch_all(bindparams(query, values))
|
|
movie_ids = tuple(r._mapping["movie_id"] for r in rows)
|
|
|
|
if include_unrated and len(movie_ids) < limit_rows:
|
|
sqlin, sqlin_vals = sql_in("id", movie_ids, not_=True)
|
|
query = f"""
|
|
SELECT DISTINCT id AS movie_id
|
|
FROM {Movie._table}
|
|
WHERE {sqlin}
|
|
{('AND ' + ' AND '.join(conditions)) if conditions else ''}
|
|
ORDER BY length(title) ASC, imdb_score DESC, release_year DESC
|
|
LIMIT :limit_rows
|
|
"""
|
|
async with locked_connection() as conn:
|
|
rows = await conn.fetch_all(
|
|
bindparams(
|
|
query,
|
|
{**values, **sqlin_vals, "limit_rows": limit_rows - len(movie_ids)},
|
|
)
|
|
)
|
|
movie_ids += tuple(r._mapping["movie_id"] for r in rows)
|
|
|
|
return await ratings_for_movie_ids(ids=movie_ids)
|
|
|
|
|
|
async def ratings_for_movie_ids(
|
|
ids: Iterable[ULID | str] = [], imdb_ids: Iterable[str] = []
|
|
) -> Iterable[dict[str, Any]]:
|
|
conds: list[str] = []
|
|
vals: dict[str, str] = {}
|
|
|
|
if ids:
|
|
sqlin, sqlin_vals = sql_in(f"{Movie._table}.id", (str(x) for x in ids))
|
|
conds.append(sqlin)
|
|
vals.update(sqlin_vals)
|
|
|
|
if imdb_ids:
|
|
sqlin, sqlin_vals = sql_in(f"{Movie._table}.imdb_id", imdb_ids)
|
|
conds.append(sqlin)
|
|
vals.update(sqlin_vals)
|
|
|
|
if not conds:
|
|
return []
|
|
|
|
query = f"""
|
|
SELECT
|
|
{Rating._table}.score AS user_score,
|
|
{Rating._table}.user_id AS user_id,
|
|
{Movie._table}.imdb_score,
|
|
{Movie._table}.imdb_votes,
|
|
{Movie._table}.imdb_id AS movie_imdb_id,
|
|
{Movie._table}.media_type AS media_type,
|
|
{Movie._table}.title AS canonical_title,
|
|
{Movie._table}.original_title AS original_title,
|
|
{Movie._table}.release_year AS release_year
|
|
FROM {Movie._table}
|
|
LEFT JOIN {Rating._table} ON {Movie._table}.id={Rating._table}.movie_id
|
|
WHERE {(' OR '.join(conds))}
|
|
"""
|
|
|
|
async with locked_connection() as conn:
|
|
rows = await conn.fetch_all(bindparams(query, vals))
|
|
return tuple(dict(r._mapping) for r in rows)
|
|
|
|
|
|
def sql_fields(tp: Type):
|
|
return (f"{tp._table}.{f.name}" for f in fields(tp))
|
|
|
|
|
|
def sql_in(column: str, values: Iterable[T], not_=False) -> tuple[str, dict[str, T]]:
|
|
c = column.replace(".", "___")
|
|
value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)}
|
|
placeholders = ",".join(":" + k for k in value_map)
|
|
if not_:
|
|
return f"{column} NOT IN ({placeholders})", value_map
|
|
return f"{column} IN ({placeholders})", value_map
|
|
|
|
|
|
async def ratings_for_movies(
|
|
movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = []
|
|
) -> Iterable[Rating]:
|
|
values: dict[str, str] = {}
|
|
conditions: list[str] = []
|
|
|
|
q, vm = sql_in("movie_id", [str(m) for m in movie_ids])
|
|
conditions.append(q)
|
|
values.update(vm)
|
|
|
|
if user_ids:
|
|
q, vm = sql_in("user_id", [str(m) for m in user_ids])
|
|
conditions.append(q)
|
|
values.update(vm)
|
|
|
|
query = f"""
|
|
SELECT {','.join(sql_fields(Rating))}
|
|
FROM {Rating._table}
|
|
WHERE {' AND '.join(f'({c})' for c in conditions) if conditions else '1=1'}
|
|
"""
|
|
|
|
async with locked_connection() as conn:
|
|
rows = await conn.fetch_all(query, values)
|
|
|
|
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
|
|
|
|
|
|
async def find_movies(
|
|
*,
|
|
title: str | None = None,
|
|
media_type: str | None = None,
|
|
exact: bool = False,
|
|
ignore_tv_episodes: bool = False,
|
|
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
|
|
limit_rows: int = 10,
|
|
skip_rows: int = 0,
|
|
include_unrated: bool = False,
|
|
user_ids: list[ULID] = [],
|
|
) -> Iterable[tuple[Movie, list[Rating]]]:
|
|
values: dict[str, int | str] = {
|
|
"limit_rows": limit_rows,
|
|
"skip_rows": skip_rows,
|
|
}
|
|
|
|
conditions = []
|
|
if title:
|
|
values["escape"] = "#"
|
|
escaped_title = sql_escape(title, char=values["escape"])
|
|
values["pattern"] = (
|
|
"_".join(escaped_title.split())
|
|
if exact
|
|
else "%" + "%".join(escaped_title.split()) + "%"
|
|
)
|
|
conditions.append(
|
|
f"""
|
|
(
|
|
{Movie._table}.title LIKE :pattern ESCAPE :escape
|
|
OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape
|
|
)
|
|
"""
|
|
)
|
|
|
|
if yearcomp:
|
|
op, year = yearcomp
|
|
assert op in "<=>"
|
|
values["year"] = year
|
|
conditions.append(f"{Movie._table}.release_year{op}:year")
|
|
|
|
if media_type:
|
|
values["media_type"] = media_type
|
|
conditions.append(f"{Movie._table}.media_type=:media_type")
|
|
|
|
if ignore_tv_episodes:
|
|
conditions.append(f"{Movie._table}.media_type!='TV Episode'")
|
|
|
|
if not include_unrated:
|
|
conditions.append(f"{Movie._table}.imdb_score NOTNULL")
|
|
|
|
query = f"""
|
|
SELECT {','.join(sql_fields(Movie))}
|
|
FROM {Movie._table}
|
|
WHERE {(' AND '.join(conditions)) if conditions else '1=1'}
|
|
ORDER BY
|
|
length({Movie._table}.title) ASC,
|
|
{Movie._table}.imdb_score DESC,
|
|
{Movie._table}.release_year DESC
|
|
LIMIT :skip_rows, :limit_rows
|
|
"""
|
|
async with locked_connection() as conn:
|
|
rows = await conn.fetch_all(bindparams(query, values))
|
|
|
|
movies = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
|
|
|
|
if not user_ids:
|
|
return ((m, []) for m in movies)
|
|
|
|
ratings = await ratings_for_movies((m.id for m in movies), user_ids)
|
|
|
|
aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies}
|
|
for rating in ratings:
|
|
aggreg[rating.movie_id][1].append(rating)
|
|
|
|
return aggreg.values()
|
|
|
|
|
|
def bindparams(query: str, values: dict):
|
|
"""Bind values to a query.
|
|
|
|
This is similar to what SQLAlchemy and Databases do, but it allows to
|
|
easily use the same placeholder in multiple places.
|
|
"""
|
|
pump_vals = {}
|
|
pump_keys = {}
|
|
|
|
def pump(match):
|
|
key = match[1]
|
|
val = values[key]
|
|
pump_keys[key] = 1 + pump_keys.setdefault(key, 0)
|
|
pump_key = f"{key}_{pump_keys[key]}"
|
|
pump_vals[pump_key] = val
|
|
return f":{pump_key}"
|
|
|
|
pump_query = re.sub(r":(\w+)\b", pump, query)
|
|
return sa.text(pump_query).bindparams(**pump_vals)
|