665 lines
20 KiB
Python
665 lines
20 KiB
Python
import contextlib
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type
|
|
|
|
import sqlalchemy as sa
|
|
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
|
|
|
|
import alembic.command
|
|
import alembic.config
|
|
import alembic.migration
|
|
|
|
from . import config
|
|
from .models import (
|
|
Award,
|
|
Model,
|
|
Movie,
|
|
Progress,
|
|
Rating,
|
|
User,
|
|
asplain,
|
|
awards,
|
|
fromplain,
|
|
metadata,
|
|
movies,
|
|
optional_fields,
|
|
progress,
|
|
ratings,
|
|
utcnow,
|
|
)
|
|
from .types import ULID, ImdbMovieId, UserIdStr
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_engine: AsyncEngine | None = None
|
|
|
|
type Connection = AsyncConnection
|
|
|
|
_project_dir = Path(__file__).parent.parent
|
|
_alembic_ini = _project_dir / "alembic.ini"
|
|
|
|
|
|
def _init(conn: sa.Connection) -> None:
|
|
# See https://alembic.sqlalchemy.org/en/latest/cookbook.html#building-an-up-to-date-database-from-scratch
|
|
context = alembic.migration.MigrationContext.configure(conn)
|
|
heads = context.get_current_heads()
|
|
|
|
is_empty_db = not heads # We consider a DB empty if Alembic hasn't touched it yet.
|
|
if is_empty_db:
|
|
log.info("⚡️ Initializing empty database.")
|
|
metadata.create_all(conn)
|
|
|
|
# We pass our existing connection to Alembic's env.py, to avoid running another asyncio loop there.
|
|
alembic_cfg = alembic.config.Config(_alembic_ini)
|
|
alembic_cfg.attributes["connection"] = conn
|
|
alembic.command.stamp(alembic_cfg, "head")
|
|
|
|
|
|
async def open_connection_pool() -> None:
|
|
"""Open the DB connection pool.
|
|
|
|
This function needs to be called before any access to the database can happen.
|
|
"""
|
|
async with transaction() as conn:
|
|
await conn.execute(sa.text("PRAGMA journal_mode=WAL"))
|
|
await conn.run_sync(_init)
|
|
|
|
|
|
async def close_connection_pool() -> None:
|
|
"""Close the DB connection pool.
|
|
|
|
This function should be called before the app shuts down to ensure all data
|
|
has been flushed to the database.
|
|
"""
|
|
engine = _shared_engine()
|
|
|
|
async with engine.begin() as conn:
|
|
# Run automatic ANALYZE prior to closing the db,
|
|
# see https://sqlite.com/lang_analyze.html.
|
|
await conn.execute(sa.text("PRAGMA analysis_limit=400"))
|
|
await conn.execute(sa.text("PRAGMA optimize"))
|
|
|
|
await engine.dispose()
|
|
|
|
|
|
async def vacuum(conn: Connection, /) -> None:
|
|
"""Vacuum the database.
|
|
|
|
This function cannot be run on a connection with an open transaction.
|
|
"""
|
|
# With SQLAlchemy's "autobegin" behavior we need to switch the connection
|
|
# to "autocommit" first to keep it from automatically starting a transaction,
|
|
# as VACUUM cannot be run inside a transaction for most databases.
|
|
await conn.commit()
|
|
isolation_level = await conn.get_isolation_level()
|
|
log.debug("Previous isolation_level: %a", isolation_level)
|
|
await conn.execution_options(isolation_level="AUTOCOMMIT")
|
|
try:
|
|
await conn.execute(sa.text("vacuum"))
|
|
await conn.commit()
|
|
finally:
|
|
await conn.execution_options(isolation_level=isolation_level)
|
|
|
|
|
|
async def get_import_progress(conn: Connection, /) -> Progress | None:
|
|
"""Return the latest import progress."""
|
|
return await get(
|
|
conn, Progress, type="import-imdb-movies", order_by=(progress.c.started, "desc")
|
|
)
|
|
|
|
|
|
async def stop_import_progress(
|
|
conn: Connection, /, *, error: BaseException | None = None
|
|
) -> None:
|
|
"""Stop the current import.
|
|
|
|
If an error is given, it will be logged to the progress state.
|
|
"""
|
|
current = await get_import_progress(conn)
|
|
is_running = current and current.stopped is None
|
|
|
|
if not is_running:
|
|
return
|
|
assert current
|
|
|
|
if error:
|
|
current.error = repr(error)
|
|
current.stopped = utcnow().isoformat()
|
|
|
|
await update(conn, current)
|
|
|
|
|
|
async def set_import_progress(conn: Connection, /, progress: float) -> Progress:
|
|
"""Set the current import progress percentage.
|
|
|
|
If no import is currently running, this will create a new one.
|
|
"""
|
|
progress = min(max(0.0, progress), 100.0) # clamp to 0 <= progress <= 100
|
|
|
|
current = await get_import_progress(conn)
|
|
is_running = current and current.stopped is None
|
|
|
|
if not is_running:
|
|
current = Progress(type="import-imdb-movies")
|
|
assert current
|
|
|
|
current.percent = progress
|
|
|
|
if is_running:
|
|
await update(conn, current)
|
|
else:
|
|
await add(conn, current)
|
|
|
|
return current
|
|
|
|
|
|
def _connection_uri() -> str:
|
|
return f"sqlite+aiosqlite:///{config.storage_path}"
|
|
|
|
|
|
def _new_engine() -> AsyncEngine:
|
|
return create_async_engine(
|
|
_connection_uri(),
|
|
isolation_level="SERIALIZABLE",
|
|
)
|
|
|
|
|
|
def _shared_engine() -> AsyncEngine:
|
|
global _engine
|
|
|
|
if _engine is None:
|
|
_engine = _new_engine()
|
|
|
|
return _engine
|
|
|
|
|
|
def _new_connection() -> Connection:
|
|
return _shared_engine().connect()
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def transaction(
|
|
*, force_rollback: bool = False
|
|
) -> AsyncGenerator[Connection, None]:
|
|
async with new_connection() as conn:
|
|
yield conn
|
|
|
|
if not force_rollback:
|
|
await conn.commit()
|
|
|
|
|
|
# The _test_connection allows pinning a connection that will be shared across the app.
|
|
# This can (and should only) be used when running tests, NOT IN PRODUCTION!
|
|
_test_connection: Connection | None = None
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def new_connection() -> AsyncGenerator[Connection, None]:
|
|
"""Return a new connection.
|
|
|
|
Any changes will be rolled back, unless `.commit()` is called on the
|
|
connection.
|
|
|
|
If you want to commit changes, consider using `transaction()` instead.
|
|
"""
|
|
conn = _test_connection or _new_connection()
|
|
|
|
# Support reusing the same connection for _test_connection.
|
|
is_started = conn.sync_connection is not None
|
|
if is_started:
|
|
yield conn
|
|
return
|
|
|
|
async with conn:
|
|
yield conn
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def transacted(
|
|
conn: Connection, /, *, force_rollback: bool = False
|
|
) -> AsyncGenerator[None, None]:
|
|
"""Start a transaction for the given connection.
|
|
|
|
If `force_rollback` is `True` any changes will be rolled back at the end of the
|
|
transaction, unless they are explicitly committed.
|
|
Nesting transactions is allowed, but mixing values for `force_rollback` will likely
|
|
yield unexpected results.
|
|
"""
|
|
transaction = contextlib.nullcontext() if conn.in_transaction() else conn.begin()
|
|
|
|
async with transaction:
|
|
try:
|
|
yield
|
|
|
|
finally:
|
|
if force_rollback:
|
|
await conn.rollback()
|
|
|
|
|
|
async def add(conn: Connection, /, item: Model) -> None:
|
|
# Support late initializing - used for optimization.
|
|
if getattr(item, "_is_lazy", False):
|
|
assert hasattr(item, "_lazy_init")
|
|
item._lazy_init() # pyright: ignore[reportAttributeAccessIssue]
|
|
|
|
table: sa.Table = item.__table__
|
|
values = asplain(item, serialize=True)
|
|
stmt = table.insert().values(values)
|
|
await conn.execute(stmt)
|
|
|
|
|
|
async def fetch_all(
|
|
conn: Connection, /, query: sa.Executable, values: "dict | None" = None
|
|
) -> Sequence[sa.Row]:
|
|
result = await conn.execute(query, values)
|
|
return result.all()
|
|
|
|
|
|
async def fetch_one(
|
|
conn: Connection, /, query: sa.Executable, values: "dict | None" = None
|
|
) -> sa.Row | None:
|
|
result = await conn.execute(query, values)
|
|
return result.first()
|
|
|
|
|
|
async def get[T: Model](
|
|
conn: Connection,
|
|
/,
|
|
model: Type[T],
|
|
*,
|
|
order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None,
|
|
**field_values,
|
|
) -> T | None:
|
|
"""Load a model instance from the database.
|
|
|
|
Passing `field_values` allows to filter the item to load. You have to encode the
|
|
values as the appropriate data type for the database prior to passing them
|
|
to this function.
|
|
"""
|
|
if not field_values:
|
|
return
|
|
|
|
table: sa.Table = model.__table__
|
|
query = sa.select(model).where(
|
|
*(table.c[k] == v for k, v in field_values.items() if v is not None)
|
|
)
|
|
if order_by:
|
|
order_col, order_dir = order_by
|
|
query = query.order_by(
|
|
order_col.asc() if order_dir == "asc" else order_col.desc()
|
|
)
|
|
row = await fetch_one(conn, query)
|
|
return fromplain(model, row._mapping, serialized=True) if row else None
|
|
|
|
|
|
async def get_many[T: Model](
|
|
conn: Connection, /, model: Type[T], **field_sets: set | list
|
|
) -> Iterable[T]:
|
|
"""Return the items with any values matching all given field sets.
|
|
|
|
This is similar to `get_all`, but instead of a scalar value a list of values
|
|
must be given. If any of the given values is set for that field on an item,
|
|
the item is considered a match.
|
|
If no field values are given, no items will be returned.
|
|
"""
|
|
if not field_sets:
|
|
return []
|
|
|
|
table: sa.Table = model.__table__
|
|
query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items()))
|
|
rows = await fetch_all(conn, query)
|
|
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
|
|
|
|
|
async def get_all[T: Model](
|
|
conn: Connection, /, model: Type[T], **field_values
|
|
) -> Iterable[T]:
|
|
"""Filter all items by comparing all given field values.
|
|
|
|
If no filters are given, all items will be returned.
|
|
"""
|
|
table: sa.Table = model.__table__
|
|
query = sa.select(model).where(
|
|
*(table.c[k] == v for k, v in field_values.items() if v is not None)
|
|
)
|
|
rows = await fetch_all(conn, query)
|
|
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
|
|
|
|
|
async def update(conn: Connection, /, item: Model) -> None:
|
|
# Support late initializing - used for optimization.
|
|
if getattr(item, "_is_lazy", False):
|
|
assert hasattr(item, "_lazy_init")
|
|
item._lazy_init() # pyright: ignore[reportAttributeAccessIssue]
|
|
|
|
table: sa.Table = item.__table__
|
|
values = asplain(item, serialize=True)
|
|
stmt = table.update().where(table.c.id == values["id"]).values(values)
|
|
await conn.execute(stmt)
|
|
|
|
|
|
async def remove(conn: Connection, /, item: Model) -> None:
|
|
table: sa.Table = item.__table__
|
|
values = asplain(item, filter_fields={"id"}, serialize=True)
|
|
stmt = table.delete().where(table.c.id == values["id"])
|
|
await conn.execute(stmt)
|
|
|
|
|
|
async def add_or_update_user(conn: Connection, /, user: User) -> None:
|
|
db_user = await get(conn, User, imdb_id=user.imdb_id)
|
|
if not db_user:
|
|
await add(conn, user)
|
|
else:
|
|
user.id = db_user.id
|
|
|
|
if user != db_user:
|
|
await update(conn, user)
|
|
|
|
|
|
async def add_or_update_many_movies(conn: Connection, /, movies: list[Movie]) -> None:
|
|
"""Add or update Movies in the database.
|
|
|
|
This is an optimized version of `add_or_update_movie` for the purpose
|
|
of bulk operations.
|
|
"""
|
|
# for movie in movies:
|
|
# await add_or_update_movie(movie)
|
|
db_movies = {
|
|
m.imdb_id: m
|
|
for m in await get_many(conn, Movie, imdb_id=[m.imdb_id for m in movies])
|
|
}
|
|
for movie in movies:
|
|
# XXX optimize bulk add & update as well
|
|
if movie.imdb_id not in db_movies:
|
|
await add(conn, movie)
|
|
else:
|
|
db_movie = db_movies[movie.imdb_id]
|
|
movie.id = db_movie.id
|
|
|
|
# We want to keep any existing value in the DB for all optional fields.
|
|
for f in optional_fields(movie):
|
|
if getattr(movie, f.name) is None:
|
|
setattr(movie, f.name, getattr(db_movie, f.name))
|
|
|
|
if movie.updated <= db_movie.updated:
|
|
return
|
|
|
|
await update(conn, movie)
|
|
|
|
|
|
async def add_or_update_movie(conn: Connection, /, movie: Movie) -> None:
|
|
"""Add or update a Movie in the database.
|
|
|
|
This is an upsert operation, but it will also update the Movie you pass
|
|
into the function to make its `id` match the DB's movie's `id`, and also
|
|
set all optional values on your Movie that might be unset but exist in the
|
|
database. It's a bidirectional sync.
|
|
"""
|
|
db_movie = await get(conn, Movie, imdb_id=movie.imdb_id)
|
|
if not db_movie:
|
|
await add(conn, movie)
|
|
else:
|
|
movie.id = db_movie.id
|
|
|
|
# We want to keep any existing value in the DB for all optional fields.
|
|
for f in optional_fields(movie):
|
|
if getattr(movie, f.name) is None:
|
|
setattr(movie, f.name, getattr(db_movie, f.name))
|
|
|
|
if movie.updated <= db_movie.updated:
|
|
return
|
|
|
|
await update(conn, movie)
|
|
|
|
|
|
async def add_or_update_rating(conn: Connection, /, rating: Rating) -> bool:
|
|
db_rating = await get(
|
|
conn, Rating, movie_id=str(rating.movie_id), user_id=str(rating.user_id)
|
|
)
|
|
|
|
if not db_rating:
|
|
await add(conn, rating)
|
|
return True
|
|
|
|
else:
|
|
rating.id = db_rating.id
|
|
|
|
if rating != db_rating:
|
|
await update(conn, rating)
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
async def get_awards(
|
|
conn: Connection, /, imdb_ids: list[ImdbMovieId]
|
|
) -> dict[ImdbMovieId, list[Award]]:
|
|
query = (
|
|
sa.select(Award, movies.c.imdb_id)
|
|
.join(movies, awards.c.movie_id == movies.c.id)
|
|
.where(movies.c.imdb_id.in_(imdb_ids))
|
|
)
|
|
rows = await fetch_all(conn, query)
|
|
awards_dict: dict[ImdbMovieId, list[Award]] = {}
|
|
for row in rows:
|
|
awards_dict.setdefault(row.imdb_id, []).append(
|
|
fromplain(Award, row._mapping, serialized=True)
|
|
)
|
|
return awards_dict
|
|
|
|
|
|
def sql_escape(s: str, char: str = "#") -> str:
|
|
return s.replace(char, 2 * char).replace("%", f"{char}%").replace("_", f"{char}_")
|
|
|
|
|
|
async def find_ratings(
|
|
conn: Connection,
|
|
/,
|
|
*,
|
|
title: str | None = None,
|
|
media_type: str | None = None,
|
|
exact: bool = False,
|
|
ignore_tv_episodes: bool = False,
|
|
include_unrated: bool = False,
|
|
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
|
|
limit_rows: int = 10,
|
|
user_ids: Iterable[UserIdStr] = [],
|
|
) -> Iterable[dict[str, Any]]:
|
|
conditions = []
|
|
|
|
if title:
|
|
escape_char = "#"
|
|
escaped_title = sql_escape(title, char=escape_char)
|
|
pattern = (
|
|
"_".join(escaped_title.split())
|
|
if exact
|
|
else "%" + "%".join(escaped_title.split()) + "%"
|
|
)
|
|
conditions.append(
|
|
sa.or_(
|
|
movies.c.title.like(pattern, escape=escape_char),
|
|
movies.c.original_title.like(pattern, escape=escape_char),
|
|
)
|
|
)
|
|
|
|
match yearcomp:
|
|
case ("<", year):
|
|
conditions.append(movies.c.release_year < year)
|
|
case ("=", year):
|
|
conditions.append(movies.c.release_year == year)
|
|
case (">", year):
|
|
conditions.append(movies.c.release_year > year)
|
|
|
|
if media_type is not None:
|
|
conditions.append(movies.c.media_type == media_type)
|
|
|
|
if ignore_tv_episodes:
|
|
conditions.append(movies.c.media_type != "TV Episode")
|
|
|
|
user_condition = []
|
|
if user_ids:
|
|
user_condition.append(ratings.c.user_id.in_(user_ids))
|
|
|
|
query = (
|
|
sa.select(ratings.c.movie_id)
|
|
.distinct()
|
|
.outerjoin_from(ratings, movies, movies.c.id == ratings.c.movie_id)
|
|
.where(*conditions, *user_condition)
|
|
.order_by(
|
|
sa.func.length(movies.c.title).asc(),
|
|
ratings.c.rating_date.desc(),
|
|
movies.c.imdb_score.desc(),
|
|
)
|
|
.limit(limit_rows)
|
|
)
|
|
rating_rows: sa.CursorResult[Rating] = await conn.execute(query)
|
|
movie_ids = [r.movie_id for r in rating_rows]
|
|
|
|
if include_unrated and len(movie_ids) < limit_rows:
|
|
query = (
|
|
sa.select(movies.c.id)
|
|
.distinct()
|
|
.where(movies.c.id.not_in(movie_ids), *conditions)
|
|
.order_by(
|
|
sa.func.length(movies.c.title).asc(),
|
|
movies.c.imdb_score.desc(),
|
|
movies.c.release_year.desc(),
|
|
)
|
|
.limit(limit_rows - len(movie_ids))
|
|
)
|
|
movie_rows: sa.CursorResult[Movie] = await conn.execute(query)
|
|
movie_ids += [r.id for r in movie_rows]
|
|
|
|
return await ratings_for_movie_ids(conn, ids=movie_ids)
|
|
|
|
|
|
async def ratings_for_movie_ids(
|
|
conn: Connection,
|
|
/,
|
|
ids: Iterable[ULID | str] = [],
|
|
imdb_ids: Iterable[str] = [],
|
|
) -> Iterable[dict[str, Any]]:
|
|
conds = []
|
|
|
|
if ids:
|
|
conds.append(movies.c.id.in_([str(x) for x in ids]))
|
|
|
|
if imdb_ids:
|
|
conds.append(movies.c.imdb_id.in_(imdb_ids))
|
|
|
|
if not conds:
|
|
return []
|
|
|
|
query = (
|
|
sa.select(
|
|
ratings.c.score.label("user_score"),
|
|
ratings.c.user_id.label("user_id"),
|
|
movies.c.imdb_score,
|
|
movies.c.imdb_votes,
|
|
movies.c.imdb_id.label("movie_imdb_id"),
|
|
movies.c.media_type.label("media_type"),
|
|
movies.c.title.label("canonical_title"),
|
|
movies.c.original_title.label("original_title"),
|
|
movies.c.release_year.label("release_year"),
|
|
)
|
|
.outerjoin_from(movies, ratings, movies.c.id == ratings.c.movie_id)
|
|
.where(sa.or_(*conds))
|
|
.order_by(
|
|
ratings.c.rating_date.asc(),
|
|
movies.c.title.asc(),
|
|
)
|
|
)
|
|
rows = await fetch_all(conn, query)
|
|
return tuple(dict(r._mapping) for r in rows)
|
|
|
|
|
|
async def ratings_for_movies(
|
|
conn: Connection, /, movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = []
|
|
) -> Iterable[Rating]:
|
|
conditions = [ratings.c.movie_id.in_(str(x) for x in movie_ids)]
|
|
|
|
if user_ids:
|
|
conditions.append(ratings.c.user_id.in_(str(x) for x in user_ids))
|
|
|
|
query = sa.select(ratings).where(*conditions)
|
|
|
|
rows = await fetch_all(conn, query)
|
|
|
|
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
|
|
|
|
|
|
async def find_movies(
|
|
conn: Connection,
|
|
/,
|
|
*,
|
|
title: str | None = None,
|
|
media_type: str | None = None,
|
|
exact: bool = False,
|
|
ignore_tv_episodes: bool = False,
|
|
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
|
|
limit_rows: int = 10,
|
|
skip_rows: int = 0,
|
|
include_unrated: bool = False,
|
|
user_ids: list[ULID] | None = None,
|
|
) -> Iterable[tuple[Movie, list[Rating]]]:
|
|
conditions = []
|
|
|
|
if title:
|
|
escape_char = "#"
|
|
escaped_title = sql_escape(title, char=escape_char)
|
|
pattern = (
|
|
"_".join(escaped_title.split())
|
|
if exact
|
|
else "%" + "%".join(escaped_title.split()) + "%"
|
|
)
|
|
conditions.append(
|
|
sa.or_(
|
|
movies.c.title.like(pattern, escape=escape_char),
|
|
movies.c.original_title.like(pattern, escape=escape_char),
|
|
)
|
|
)
|
|
|
|
match yearcomp:
|
|
case ("<", year):
|
|
conditions.append(movies.c.release_year < year)
|
|
case ("=", year):
|
|
conditions.append(movies.c.release_year == year)
|
|
case (">", year):
|
|
conditions.append(movies.c.release_year > year)
|
|
|
|
if media_type is not None:
|
|
conditions.append(movies.c.media_type == media_type)
|
|
|
|
if ignore_tv_episodes:
|
|
conditions.append(movies.c.media_type != "TV Episode")
|
|
|
|
if not include_unrated:
|
|
conditions.append(movies.c.imdb_score.is_not(None))
|
|
|
|
query = (
|
|
sa.select(movies)
|
|
.where(*conditions)
|
|
.order_by(
|
|
sa.func.length(movies.c.title).asc(),
|
|
movies.c.imdb_score.desc(),
|
|
movies.c.release_year.desc(),
|
|
)
|
|
.limit(limit_rows)
|
|
.offset(skip_rows)
|
|
)
|
|
|
|
rows = await fetch_all(conn, query)
|
|
|
|
movies_ = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
|
|
|
|
if not user_ids:
|
|
return ((m, []) for m in movies_)
|
|
|
|
ratings = await ratings_for_movies(conn, (m.id for m in movies_), user_ids)
|
|
|
|
aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies_}
|
|
for rating in ratings:
|
|
aggreg[rating.movie_id][1].append(rating)
|
|
|
|
return aggreg.values()
|