unwind/unwind/db.py

673 lines
20 KiB
Python

import asyncio
import contextlib
import logging
import re
import threading
from pathlib import Path
from typing import Any, Iterable, Literal, Type, TypeVar
import sqlalchemy as sa
from databases import Database
from . import config
from .models import (
Movie,
Progress,
Rating,
User,
asplain,
fields,
fromplain,
optional_fields,
utcnow,
)
from .types import ULID
log = logging.getLogger(__name__)
T = TypeVar("T")
_shared_connection: Database | None = None
async def open_connection_pool() -> None:
"""Open the DB connection pool.
This function needs to be called before any access to the database can happen.
"""
db = shared_connection()
await db.connect()
await apply_db_patches(db)
async def close_connection_pool() -> None:
"""Close the DB connection pool.
This function should be called before the app shuts down to ensure all data
has been flushed to the database.
"""
db = shared_connection()
# Run automatic ANALYZE prior to closing the db,
# see https://sqlite.com/lang_analyze.html.
await db.execute("PRAGMA analysis_limit=400")
await db.execute("PRAGMA optimize")
await db.disconnect()
async def _create_patch_db(db):
query = """
CREATE TABLE IF NOT EXISTS db_patches (
id INTEGER PRIMARY KEY,
current TEXT
)
"""
await db.execute(query)
async def current_patch_level(db) -> str:
await _create_patch_db(db)
query = "SELECT current FROM db_patches"
current = await db.fetch_val(query)
return current or ""
async def set_current_patch_level(db, current: str):
await _create_patch_db(db)
query = """
INSERT INTO db_patches VALUES (1, :current)
ON CONFLICT DO UPDATE SET current=excluded.current
"""
await db.execute(query, values={"current": current})
db_patches_dir = Path(__file__).parent / "sql"
async def apply_db_patches(db: Database):
"""Apply all remaining patches to the database.
Beware that patches will be applied in lexicographical order,
i.e. "10" comes before "9".
The current patch state is recorded in the DB itself.
Please note that every SQL statement in a patch file MUST be terminated
using two consecutive semi-colons (;).
Failing to do so will result in an error.
"""
applied_lvl = await current_patch_level(db)
did_patch = False
for patchfile in sorted(db_patches_dir.glob("*.sql"), key=lambda p: p.stem):
patch_lvl = patchfile.stem
if patch_lvl <= applied_lvl:
continue
log.info("Applying patch: %s", patch_lvl)
sql = patchfile.read_text()
queries = sql.split(";;")
if len(queries) < 2:
log.error(
"Patch file is missing statement terminator (`;;'): %s", patchfile
)
raise RuntimeError("No statement found.")
async with db.transaction():
for query in queries:
await db.execute(query)
await set_current_patch_level(db, patch_lvl)
did_patch = True
if did_patch:
await db.execute("vacuum")
async def get_import_progress() -> Progress | None:
"""Return the latest import progress."""
return await get(Progress, type="import-imdb-movies", order_by="started DESC")
async def stop_import_progress(*, error: BaseException | None = None):
"""Stop the current import.
If an error is given, it will be logged to the progress state.
"""
current = await get_import_progress()
is_running = current and current.stopped is None
if not is_running:
return
assert current
if error:
current.error = repr(error)
current.stopped = utcnow().isoformat()
await update(current)
async def set_import_progress(progress: float) -> Progress:
"""Set the current import progress percentage.
If no import is currently running, this will create a new one.
"""
progress = min(max(0.0, progress), 100.0) # clamp to 0 <= progress <= 100
current = await get_import_progress()
is_running = current and current.stopped is None
if not is_running:
current = Progress(type="import-imdb-movies")
assert current
current.percent = progress
if is_running:
await update(current)
else:
await add(current)
return current
_lock = threading.Lock()
_prelock = threading.Lock()
@contextlib.asynccontextmanager
async def single_threaded():
"""Ensure the nested code is run only by a single thread at a time."""
wait = 1e-5 # XXX not sure if there's a better magic value here
# The pre-lock (a lock for the lock) allows for multiple threads to hand of
# the main lock.
# With only a single lock the contending thread will spend most of its time
# in the asyncio.sleep and the reigning thread will have time to finish
# whatever it's doing and simply acquire the lock again before the other
# thread has had a change to try.
# By having another lock (and the same sleep time!) the contending thread
# will always have a chance to acquire the main lock.
while not _prelock.acquire(blocking=False):
await asyncio.sleep(wait)
try:
while not _lock.acquire(blocking=False):
await asyncio.sleep(wait)
finally:
_prelock.release()
try:
yield
finally:
_lock.release()
@contextlib.asynccontextmanager
async def locked_connection():
async with single_threaded():
yield shared_connection()
def shared_connection() -> Database:
global _shared_connection
if _shared_connection is None:
uri = f"sqlite:///{config.storage_path}"
_shared_connection = Database(uri)
return _shared_connection
async def add(item):
# Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False):
item._lazy_init()
values = asplain(item, serialize=True)
keys = ", ".join(f"{k}" for k in values)
placeholders = ", ".join(f":{k}" for k in values)
query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})"
async with locked_connection() as conn:
await conn.execute(query=query, values=values)
ModelType = TypeVar("ModelType")
async def get(
model: Type[ModelType], *, order_by: str | None = None, **kwds
) -> ModelType | None:
"""Load a model instance from the database.
Passing `kwds` allows to filter the instance to load. You have to encode the
values as the appropriate data type for the database prior to passing them
to this function.
"""
values = {k: v for k, v in kwds.items() if v is not None}
if not values:
return
fields_ = ", ".join(f.name for f in fields(model))
cond = " AND ".join(f"{k}=:{k}" for k in values)
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
if order_by:
query += f" ORDER BY {order_by}"
async with locked_connection() as conn:
row = await conn.fetch_one(query=query, values=values)
return fromplain(model, row._mapping, serialized=True) if row else None
async def get_many(
model: Type[ModelType], **field_sets: set | list
) -> Iterable[ModelType]:
"""Return the items with any values matching all given field sets.
This is similar to `get_all`, but instead of a scalar value a list of values
must be given. If any of the given values is set for that field on an item,
the item is considered a match.
If no field values are given, no items will be returned.
"""
if not field_sets:
return []
table: sa.Table = model.__table__
query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items()))
async with locked_connection() as conn:
rows = await conn.fetch_all(query)
return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]:
"""Filter all items by comparing all given field values.
If no filters are given, all items will be returned.
"""
table: sa.Table = model.__table__
query = sa.select(model).where(
*(table.c[k] == v for k, v in field_values.items() if v is not None)
)
async with locked_connection() as conn:
rows = await conn.fetch_all(query)
return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def update(item):
# Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False):
item._lazy_init()
values = asplain(item, serialize=True)
keys = ", ".join(f"{k}=:{k}" for k in values if k != "id")
query = f"UPDATE {item._table} SET {keys} WHERE id=:id"
async with locked_connection() as conn:
await conn.execute(query=query, values=values)
async def remove(item):
values = asplain(item, filter_fields={"id"}, serialize=True)
query = f"DELETE FROM {item._table} WHERE id=:id"
async with locked_connection() as conn:
await conn.execute(query=query, values=values)
async def add_or_update_user(user: User):
db_user = await get(User, imdb_id=user.imdb_id)
if not db_user:
await add(user)
else:
user.id = db_user.id
if user != db_user:
await update(user)
async def add_or_update_many_movies(movies: list[Movie]):
"""Add or update Movies in the database.
This is an optimized version of `add_or_update_movie` for the purpose
of bulk operations.
"""
# for movie in movies:
# await add_or_update_movie(movie)
db_movies = {
m.imdb_id: m for m in await get_many(Movie, imdb_id=[m.imdb_id for m in movies])
}
for movie in movies:
# XXX optimize bulk add & update as well
if movie.imdb_id not in db_movies:
await add(movie)
else:
db_movie = db_movies[movie.imdb_id]
movie.id = db_movie.id
# We want to keep any existing value in the DB for all optional fields.
for f in optional_fields(movie):
if getattr(movie, f.name) is None:
setattr(movie, f.name, getattr(db_movie, f.name))
if movie.updated <= db_movie.updated:
return
await update(movie)
async def add_or_update_movie(movie: Movie):
"""Add or update a Movie in the database.
This is an upsert operation, but it will also update the Movie you pass
into the function to make its `id` match the DB's movie's `id`, and also
set all optional values on your Movie that might be unset but exist in the
database. It's a bidirectional sync.
"""
db_movie = await get(Movie, imdb_id=movie.imdb_id)
if not db_movie:
await add(movie)
else:
movie.id = db_movie.id
# We want to keep any existing value in the DB for all optional fields.
for f in optional_fields(movie):
if getattr(movie, f.name) is None:
setattr(movie, f.name, getattr(db_movie, f.name))
if movie.updated <= db_movie.updated:
return
await update(movie)
async def add_or_update_rating(rating: Rating) -> bool:
db_rating = await get(
Rating, movie_id=str(rating.movie_id), user_id=str(rating.user_id)
)
if not db_rating:
await add(rating)
return True
else:
rating.id = db_rating.id
if rating != db_rating:
await update(rating)
return True
return False
def sql_escape(s: str, char="#"):
return s.replace(char, 2 * char).replace("%", f"{char}%").replace("_", f"{char}_")
async def find_ratings(
*,
title: str | None = None,
media_type: str | None = None,
exact: bool = False,
ignore_tv_episodes: bool = False,
include_unrated: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10,
user_ids: Iterable[str] = [],
):
values: dict[str, int | str] = {
"limit_rows": limit_rows,
}
conditions = []
if title:
values["escape"] = "#"
escaped_title = sql_escape(title, char=values["escape"])
values["pattern"] = (
"_".join(escaped_title.split())
if exact
else "%" + "%".join(escaped_title.split()) + "%"
)
conditions.append(
f"""
(
{Movie._table}.title LIKE :pattern ESCAPE :escape
OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape
)
"""
)
if yearcomp:
op, year = yearcomp
assert op in "<=>"
values["year"] = year
conditions.append(f"{Movie._table}.release_year{op}:year")
if media_type:
values["media_type"] = media_type
conditions.append(f"{Movie._table}.media_type=:media_type")
if ignore_tv_episodes:
conditions.append(f"{Movie._table}.media_type!='TV Episode'")
user_condition = "1=1"
if user_ids:
uvs = {f"user_id_{i}": v for i, v in enumerate(user_ids, start=1)}
values.update(uvs)
user_condition = f"{Rating._table}.user_id IN ({','.join(':'+n for n in uvs)})"
query = f"""
SELECT DISTINCT {Rating._table}.movie_id
FROM {Rating._table}
LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id
WHERE {user_condition}{(' AND ' + ' AND '.join(conditions)) if conditions else ''}
ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC, {Movie._table}.imdb_score DESC
LIMIT :limit_rows
"""
async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values))
movie_ids = tuple(r._mapping["movie_id"] for r in rows)
if include_unrated and len(movie_ids) < limit_rows:
sqlin, sqlin_vals = sql_in("id", movie_ids, not_=True)
query = f"""
SELECT DISTINCT id AS movie_id
FROM {Movie._table}
WHERE {sqlin}
{('AND ' + ' AND '.join(conditions)) if conditions else ''}
ORDER BY length(title) ASC, imdb_score DESC, release_year DESC
LIMIT :limit_rows
"""
async with locked_connection() as conn:
rows = await conn.fetch_all(
bindparams(
query,
{**values, **sqlin_vals, "limit_rows": limit_rows - len(movie_ids)},
)
)
movie_ids += tuple(r._mapping["movie_id"] for r in rows)
return await ratings_for_movie_ids(ids=movie_ids)
async def ratings_for_movie_ids(
ids: Iterable[ULID | str] = [], imdb_ids: Iterable[str] = []
) -> Iterable[dict[str, Any]]:
conds: list[str] = []
vals: dict[str, str] = {}
if ids:
sqlin, sqlin_vals = sql_in(f"{Movie._table}.id", (str(x) for x in ids))
conds.append(sqlin)
vals.update(sqlin_vals)
if imdb_ids:
sqlin, sqlin_vals = sql_in(f"{Movie._table}.imdb_id", imdb_ids)
conds.append(sqlin)
vals.update(sqlin_vals)
if not conds:
return []
query = f"""
SELECT
{Rating._table}.score AS user_score,
{Rating._table}.user_id AS user_id,
{Movie._table}.imdb_score,
{Movie._table}.imdb_votes,
{Movie._table}.imdb_id AS movie_imdb_id,
{Movie._table}.media_type AS media_type,
{Movie._table}.title AS canonical_title,
{Movie._table}.original_title AS original_title,
{Movie._table}.release_year AS release_year
FROM {Movie._table}
LEFT JOIN {Rating._table} ON {Movie._table}.id={Rating._table}.movie_id
WHERE {(' OR '.join(conds))}
"""
async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, vals))
return tuple(dict(r._mapping) for r in rows)
def sql_fields(tp: Type):
return (f"{tp._table}.{f.name}" for f in fields(tp))
def sql_in(column: str, values: Iterable[T], not_=False) -> tuple[str, dict[str, T]]:
c = column.replace(".", "___")
value_map = {f"{c}_{i}": v for i, v in enumerate(values, start=1)}
placeholders = ",".join(":" + k for k in value_map)
if not_:
return f"{column} NOT IN ({placeholders})", value_map
return f"{column} IN ({placeholders})", value_map
async def ratings_for_movies(
movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = []
) -> Iterable[Rating]:
values: dict[str, str] = {}
conditions: list[str] = []
q, vm = sql_in("movie_id", [str(m) for m in movie_ids])
conditions.append(q)
values.update(vm)
if user_ids:
q, vm = sql_in("user_id", [str(m) for m in user_ids])
conditions.append(q)
values.update(vm)
query = f"""
SELECT {','.join(sql_fields(Rating))}
FROM {Rating._table}
WHERE {' AND '.join(f'({c})' for c in conditions) if conditions else '1=1'}
"""
async with locked_connection() as conn:
rows = await conn.fetch_all(query, values)
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
async def find_movies(
*,
title: str | None = None,
media_type: str | None = None,
exact: bool = False,
ignore_tv_episodes: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10,
skip_rows: int = 0,
include_unrated: bool = False,
user_ids: list[ULID] = [],
) -> Iterable[tuple[Movie, list[Rating]]]:
values: dict[str, int | str] = {
"limit_rows": limit_rows,
"skip_rows": skip_rows,
}
conditions = []
if title:
values["escape"] = "#"
escaped_title = sql_escape(title, char=values["escape"])
values["pattern"] = (
"_".join(escaped_title.split())
if exact
else "%" + "%".join(escaped_title.split()) + "%"
)
conditions.append(
f"""
(
{Movie._table}.title LIKE :pattern ESCAPE :escape
OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape
)
"""
)
if yearcomp:
op, year = yearcomp
assert op in "<=>"
values["year"] = year
conditions.append(f"{Movie._table}.release_year{op}:year")
if media_type:
values["media_type"] = media_type
conditions.append(f"{Movie._table}.media_type=:media_type")
if ignore_tv_episodes:
conditions.append(f"{Movie._table}.media_type!='TV Episode'")
if not include_unrated:
conditions.append(f"{Movie._table}.imdb_score NOTNULL")
query = f"""
SELECT {','.join(sql_fields(Movie))}
FROM {Movie._table}
WHERE {(' AND '.join(conditions)) if conditions else '1=1'}
ORDER BY
length({Movie._table}.title) ASC,
{Movie._table}.imdb_score DESC,
{Movie._table}.release_year DESC
LIMIT :skip_rows, :limit_rows
"""
async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values))
movies = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
if not user_ids:
return ((m, []) for m in movies)
ratings = await ratings_for_movies((m.id for m in movies), user_ids)
aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies}
for rating in ratings:
aggreg[rating.movie_id][1].append(rating)
return aggreg.values()
def bindparams(query: str, values: dict):
"""Bind values to a query.
This is similar to what SQLAlchemy and Databases do, but it allows to
easily use the same placeholder in multiple places.
"""
pump_vals = {}
pump_keys = {}
def pump(match):
key = match[1]
val = values[key]
pump_keys[key] = 1 + pump_keys.setdefault(key, 0)
pump_key = f"{key}_{pump_keys[key]}"
pump_vals[pump_key] = val
return f":{pump_key}"
pump_query = re.sub(r":(\w+)\b", pump, query)
return sa.text(pump_query).bindparams(**pump_vals)