unwind/unwind/db.py
2024-05-19 22:49:46 +02:00

665 lines
20 KiB
Python

import contextlib
import logging
from pathlib import Path
from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
import alembic.command
import alembic.config
import alembic.migration
from . import config
from .models import (
Award,
Model,
Movie,
Progress,
Rating,
User,
asplain,
awards,
fromplain,
metadata,
movies,
optional_fields,
progress,
ratings,
utcnow,
)
from .types import ULID, ImdbMovieId, UserIdStr
log = logging.getLogger(__name__)
_engine: AsyncEngine | None = None
type Connection = AsyncConnection
_project_dir = Path(__file__).parent.parent
_alembic_ini = _project_dir / "alembic.ini"
def _init(conn: sa.Connection) -> None:
# See https://alembic.sqlalchemy.org/en/latest/cookbook.html#building-an-up-to-date-database-from-scratch
context = alembic.migration.MigrationContext.configure(conn)
heads = context.get_current_heads()
is_empty_db = not heads # We consider a DB empty if Alembic hasn't touched it yet.
if is_empty_db:
log.info("⚡️ Initializing empty database.")
metadata.create_all(conn)
# We pass our existing connection to Alembic's env.py, to avoid running another asyncio loop there.
alembic_cfg = alembic.config.Config(_alembic_ini)
alembic_cfg.attributes["connection"] = conn
alembic.command.stamp(alembic_cfg, "head")
async def open_connection_pool() -> None:
"""Open the DB connection pool.
This function needs to be called before any access to the database can happen.
"""
async with transaction() as conn:
await conn.execute(sa.text("PRAGMA journal_mode=WAL"))
await conn.run_sync(_init)
async def close_connection_pool() -> None:
"""Close the DB connection pool.
This function should be called before the app shuts down to ensure all data
has been flushed to the database.
"""
engine = _shared_engine()
async with engine.begin() as conn:
# Run automatic ANALYZE prior to closing the db,
# see https://sqlite.com/lang_analyze.html.
await conn.execute(sa.text("PRAGMA analysis_limit=400"))
await conn.execute(sa.text("PRAGMA optimize"))
await engine.dispose()
async def vacuum(conn: Connection, /) -> None:
"""Vacuum the database.
This function cannot be run on a connection with an open transaction.
"""
# With SQLAlchemy's "autobegin" behavior we need to switch the connection
# to "autocommit" first to keep it from automatically starting a transaction,
# as VACUUM cannot be run inside a transaction for most databases.
await conn.commit()
isolation_level = await conn.get_isolation_level()
log.debug("Previous isolation_level: %a", isolation_level)
await conn.execution_options(isolation_level="AUTOCOMMIT")
try:
await conn.execute(sa.text("vacuum"))
await conn.commit()
finally:
await conn.execution_options(isolation_level=isolation_level)
async def get_import_progress(conn: Connection, /) -> Progress | None:
"""Return the latest import progress."""
return await get(
conn, Progress, type="import-imdb-movies", order_by=(progress.c.started, "desc")
)
async def stop_import_progress(
conn: Connection, /, *, error: BaseException | None = None
) -> None:
"""Stop the current import.
If an error is given, it will be logged to the progress state.
"""
current = await get_import_progress(conn)
is_running = current and current.stopped is None
if not is_running:
return
assert current
if error:
current.error = repr(error)
current.stopped = utcnow().isoformat()
await update(conn, current)
async def set_import_progress(conn: Connection, /, progress: float) -> Progress:
"""Set the current import progress percentage.
If no import is currently running, this will create a new one.
"""
progress = min(max(0.0, progress), 100.0) # clamp to 0 <= progress <= 100
current = await get_import_progress(conn)
is_running = current and current.stopped is None
if not is_running:
current = Progress(type="import-imdb-movies")
assert current
current.percent = progress
if is_running:
await update(conn, current)
else:
await add(conn, current)
return current
def _connection_uri() -> str:
return f"sqlite+aiosqlite:///{config.storage_path}"
def _new_engine() -> AsyncEngine:
return create_async_engine(
_connection_uri(),
isolation_level="SERIALIZABLE",
)
def _shared_engine() -> AsyncEngine:
global _engine
if _engine is None:
_engine = _new_engine()
return _engine
def _new_connection() -> Connection:
return _shared_engine().connect()
@contextlib.asynccontextmanager
async def transaction(
*, force_rollback: bool = False
) -> AsyncGenerator[Connection, None]:
async with new_connection() as conn:
yield conn
if not force_rollback:
await conn.commit()
# The _test_connection allows pinning a connection that will be shared across the app.
# This can (and should only) be used when running tests, NOT IN PRODUCTION!
_test_connection: Connection | None = None
@contextlib.asynccontextmanager
async def new_connection() -> AsyncGenerator[Connection, None]:
"""Return a new connection.
Any changes will be rolled back, unless `.commit()` is called on the
connection.
If you want to commit changes, consider using `transaction()` instead.
"""
conn = _test_connection or _new_connection()
# Support reusing the same connection for _test_connection.
is_started = conn.sync_connection is not None
if is_started:
yield conn
return
async with conn:
yield conn
@contextlib.asynccontextmanager
async def transacted(
conn: Connection, /, *, force_rollback: bool = False
) -> AsyncGenerator[None, None]:
"""Start a transaction for the given connection.
If `force_rollback` is `True` any changes will be rolled back at the end of the
transaction, unless they are explicitly committed.
Nesting transactions is allowed, but mixing values for `force_rollback` will likely
yield unexpected results.
"""
transaction = contextlib.nullcontext() if conn.in_transaction() else conn.begin()
async with transaction:
try:
yield
finally:
if force_rollback:
await conn.rollback()
async def add(conn: Connection, /, item: Model) -> None:
# Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False):
assert hasattr(item, "_lazy_init")
item._lazy_init() # pyright: ignore[reportAttributeAccessIssue]
table: sa.Table = item.__table__
values = asplain(item, serialize=True)
stmt = table.insert().values(values)
await conn.execute(stmt)
async def fetch_all(
conn: Connection, /, query: sa.Executable, values: "dict | None" = None
) -> Sequence[sa.Row]:
result = await conn.execute(query, values)
return result.all()
async def fetch_one(
conn: Connection, /, query: sa.Executable, values: "dict | None" = None
) -> sa.Row | None:
result = await conn.execute(query, values)
return result.first()
async def get[T: Model](
conn: Connection,
/,
model: Type[T],
*,
order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None,
**field_values,
) -> T | None:
"""Load a model instance from the database.
Passing `field_values` allows to filter the item to load. You have to encode the
values as the appropriate data type for the database prior to passing them
to this function.
"""
if not field_values:
return
table: sa.Table = model.__table__
query = sa.select(model).where(
*(table.c[k] == v for k, v in field_values.items() if v is not None)
)
if order_by:
order_col, order_dir = order_by
query = query.order_by(
order_col.asc() if order_dir == "asc" else order_col.desc()
)
row = await fetch_one(conn, query)
return fromplain(model, row._mapping, serialized=True) if row else None
async def get_many[T: Model](
conn: Connection, /, model: Type[T], **field_sets: set | list
) -> Iterable[T]:
"""Return the items with any values matching all given field sets.
This is similar to `get_all`, but instead of a scalar value a list of values
must be given. If any of the given values is set for that field on an item,
the item is considered a match.
If no field values are given, no items will be returned.
"""
if not field_sets:
return []
table: sa.Table = model.__table__
query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items()))
rows = await fetch_all(conn, query)
return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def get_all[T: Model](
conn: Connection, /, model: Type[T], **field_values
) -> Iterable[T]:
"""Filter all items by comparing all given field values.
If no filters are given, all items will be returned.
"""
table: sa.Table = model.__table__
query = sa.select(model).where(
*(table.c[k] == v for k, v in field_values.items() if v is not None)
)
rows = await fetch_all(conn, query)
return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def update(conn: Connection, /, item: Model) -> None:
# Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False):
assert hasattr(item, "_lazy_init")
item._lazy_init() # pyright: ignore[reportAttributeAccessIssue]
table: sa.Table = item.__table__
values = asplain(item, serialize=True)
stmt = table.update().where(table.c.id == values["id"]).values(values)
await conn.execute(stmt)
async def remove(conn: Connection, /, item: Model) -> None:
table: sa.Table = item.__table__
values = asplain(item, filter_fields={"id"}, serialize=True)
stmt = table.delete().where(table.c.id == values["id"])
await conn.execute(stmt)
async def add_or_update_user(conn: Connection, /, user: User) -> None:
db_user = await get(conn, User, imdb_id=user.imdb_id)
if not db_user:
await add(conn, user)
else:
user.id = db_user.id
if user != db_user:
await update(conn, user)
async def add_or_update_many_movies(conn: Connection, /, movies: list[Movie]) -> None:
"""Add or update Movies in the database.
This is an optimized version of `add_or_update_movie` for the purpose
of bulk operations.
"""
# for movie in movies:
# await add_or_update_movie(movie)
db_movies = {
m.imdb_id: m
for m in await get_many(conn, Movie, imdb_id=[m.imdb_id for m in movies])
}
for movie in movies:
# XXX optimize bulk add & update as well
if movie.imdb_id not in db_movies:
await add(conn, movie)
else:
db_movie = db_movies[movie.imdb_id]
movie.id = db_movie.id
# We want to keep any existing value in the DB for all optional fields.
for f in optional_fields(movie):
if getattr(movie, f.name) is None:
setattr(movie, f.name, getattr(db_movie, f.name))
if movie.updated <= db_movie.updated:
return
await update(conn, movie)
async def add_or_update_movie(conn: Connection, /, movie: Movie) -> None:
"""Add or update a Movie in the database.
This is an upsert operation, but it will also update the Movie you pass
into the function to make its `id` match the DB's movie's `id`, and also
set all optional values on your Movie that might be unset but exist in the
database. It's a bidirectional sync.
"""
db_movie = await get(conn, Movie, imdb_id=movie.imdb_id)
if not db_movie:
await add(conn, movie)
else:
movie.id = db_movie.id
# We want to keep any existing value in the DB for all optional fields.
for f in optional_fields(movie):
if getattr(movie, f.name) is None:
setattr(movie, f.name, getattr(db_movie, f.name))
if movie.updated <= db_movie.updated:
return
await update(conn, movie)
async def add_or_update_rating(conn: Connection, /, rating: Rating) -> bool:
db_rating = await get(
conn, Rating, movie_id=str(rating.movie_id), user_id=str(rating.user_id)
)
if not db_rating:
await add(conn, rating)
return True
else:
rating.id = db_rating.id
if rating != db_rating:
await update(conn, rating)
return True
return False
async def get_awards(
conn: Connection, /, imdb_ids: list[ImdbMovieId]
) -> dict[ImdbMovieId, list[Award]]:
query = (
sa.select(Award, movies.c.imdb_id)
.join(movies, awards.c.movie_id == movies.c.id)
.where(movies.c.imdb_id.in_(imdb_ids))
)
rows = await fetch_all(conn, query)
awards_dict: dict[ImdbMovieId, list[Award]] = {}
for row in rows:
awards_dict.setdefault(row.imdb_id, []).append(
fromplain(Award, row._mapping, serialized=True)
)
return awards_dict
def sql_escape(s: str, char: str = "#") -> str:
return s.replace(char, 2 * char).replace("%", f"{char}%").replace("_", f"{char}_")
async def find_ratings(
conn: Connection,
/,
*,
title: str | None = None,
media_type: str | None = None,
exact: bool = False,
ignore_tv_episodes: bool = False,
include_unrated: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10,
user_ids: Iterable[UserIdStr] = [],
) -> Iterable[dict[str, Any]]:
conditions = []
if title:
escape_char = "#"
escaped_title = sql_escape(title, char=escape_char)
pattern = (
"_".join(escaped_title.split())
if exact
else "%" + "%".join(escaped_title.split()) + "%"
)
conditions.append(
sa.or_(
movies.c.title.like(pattern, escape=escape_char),
movies.c.original_title.like(pattern, escape=escape_char),
)
)
match yearcomp:
case ("<", year):
conditions.append(movies.c.release_year < year)
case ("=", year):
conditions.append(movies.c.release_year == year)
case (">", year):
conditions.append(movies.c.release_year > year)
if media_type is not None:
conditions.append(movies.c.media_type == media_type)
if ignore_tv_episodes:
conditions.append(movies.c.media_type != "TV Episode")
user_condition = []
if user_ids:
user_condition.append(ratings.c.user_id.in_(user_ids))
query = (
sa.select(ratings.c.movie_id)
.distinct()
.outerjoin_from(ratings, movies, movies.c.id == ratings.c.movie_id)
.where(*conditions, *user_condition)
.order_by(
sa.func.length(movies.c.title).asc(),
ratings.c.rating_date.desc(),
movies.c.imdb_score.desc(),
)
.limit(limit_rows)
)
rating_rows: sa.CursorResult[Rating] = await conn.execute(query)
movie_ids = [r.movie_id for r in rating_rows]
if include_unrated and len(movie_ids) < limit_rows:
query = (
sa.select(movies.c.id)
.distinct()
.where(movies.c.id.not_in(movie_ids), *conditions)
.order_by(
sa.func.length(movies.c.title).asc(),
movies.c.imdb_score.desc(),
movies.c.release_year.desc(),
)
.limit(limit_rows - len(movie_ids))
)
movie_rows: sa.CursorResult[Movie] = await conn.execute(query)
movie_ids += [r.id for r in movie_rows]
return await ratings_for_movie_ids(conn, ids=movie_ids)
async def ratings_for_movie_ids(
conn: Connection,
/,
ids: Iterable[ULID | str] = [],
imdb_ids: Iterable[str] = [],
) -> Iterable[dict[str, Any]]:
conds = []
if ids:
conds.append(movies.c.id.in_([str(x) for x in ids]))
if imdb_ids:
conds.append(movies.c.imdb_id.in_(imdb_ids))
if not conds:
return []
query = (
sa.select(
ratings.c.score.label("user_score"),
ratings.c.user_id.label("user_id"),
movies.c.imdb_score,
movies.c.imdb_votes,
movies.c.imdb_id.label("movie_imdb_id"),
movies.c.media_type.label("media_type"),
movies.c.title.label("canonical_title"),
movies.c.original_title.label("original_title"),
movies.c.release_year.label("release_year"),
)
.outerjoin_from(movies, ratings, movies.c.id == ratings.c.movie_id)
.where(sa.or_(*conds))
.order_by(
ratings.c.rating_date.asc(),
movies.c.title.asc(),
)
)
rows = await fetch_all(conn, query)
return tuple(dict(r._mapping) for r in rows)
async def ratings_for_movies(
conn: Connection, /, movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = []
) -> Iterable[Rating]:
conditions = [ratings.c.movie_id.in_(str(x) for x in movie_ids)]
if user_ids:
conditions.append(ratings.c.user_id.in_(str(x) for x in user_ids))
query = sa.select(ratings).where(*conditions)
rows = await fetch_all(conn, query)
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
async def find_movies(
conn: Connection,
/,
*,
title: str | None = None,
media_type: str | None = None,
exact: bool = False,
ignore_tv_episodes: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10,
skip_rows: int = 0,
include_unrated: bool = False,
user_ids: list[ULID] | None = None,
) -> Iterable[tuple[Movie, list[Rating]]]:
conditions = []
if title:
escape_char = "#"
escaped_title = sql_escape(title, char=escape_char)
pattern = (
"_".join(escaped_title.split())
if exact
else "%" + "%".join(escaped_title.split()) + "%"
)
conditions.append(
sa.or_(
movies.c.title.like(pattern, escape=escape_char),
movies.c.original_title.like(pattern, escape=escape_char),
)
)
match yearcomp:
case ("<", year):
conditions.append(movies.c.release_year < year)
case ("=", year):
conditions.append(movies.c.release_year == year)
case (">", year):
conditions.append(movies.c.release_year > year)
if media_type is not None:
conditions.append(movies.c.media_type == media_type)
if ignore_tv_episodes:
conditions.append(movies.c.media_type != "TV Episode")
if not include_unrated:
conditions.append(movies.c.imdb_score.is_not(None))
query = (
sa.select(movies)
.where(*conditions)
.order_by(
sa.func.length(movies.c.title).asc(),
movies.c.imdb_score.desc(),
movies.c.release_year.desc(),
)
.limit(limit_rows)
.offset(skip_rows)
)
rows = await fetch_all(conn, query)
movies_ = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
if not user_ids:
return ((m, []) for m in movies_)
ratings = await ratings_for_movies(conn, (m.id for m in movies_), user_ids)
aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies_}
for rating in ratings:
aggreg[rating.movie_id][1].append(rating)
return aggreg.values()