unwind/unwind/db.py

266 lines
7.9 KiB
Python

import logging
import re
from dataclasses import fields
from pathlib import Path
from typing import Literal, Optional, Type, TypeVar, Union
import sqlalchemy
from databases import Database
from . import config
from .models import Movie, Rating, User, asplain, fromplain, optional_fields
log = logging.getLogger(__name__)
_shared_connection: Optional[Database] = None
async def open_connection_pool() -> None:
"""Open the DB connection pool.
This function needs to be called before any access to the database can happen.
"""
db = shared_connection()
await db.connect()
await init_db(db)
async def close_connection_pool() -> None:
"""Close the DB connection pool.
This function should be called before the app shuts down to ensure all data
has been flushed to the database.
"""
db = shared_connection()
# Run automatic ANALYZE prior to closing the db,
# see https://sqlite.com/lang_analyze.html.
await db.execute("PRAGMA analysis_limit=400")
await db.execute("PRAGMA optimize")
await db.disconnect()
def shared_connection() -> Database:
global _shared_connection
if _shared_connection is None:
uri = f"sqlite:///{config.storage_path}"
_shared_connection = Database(uri)
return _shared_connection
async def init_db(db):
sql = Path(__file__).with_name("init.sql").read_text()
async with db.transaction():
for stmt in sql.split(";;"):
await db.execute(query=stmt)
async def add(item):
values = asplain(item)
keys = ", ".join(f"{k}" for k in values)
placeholders = ", ".join(f":{k}" for k in values)
query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})"
await shared_connection().execute(query=query, values=values)
ModelType = TypeVar("ModelType")
async def get(model: Type[ModelType], **kwds) -> Optional[ModelType]:
values = {k: v for k, v in kwds.items() if v is not None}
if not values:
return
fields_ = ", ".join(f.name for f in fields(model))
cond = " AND ".join(f"{k}=:{k}" for k, v in values.items())
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
row = await shared_connection().fetch_one(query=query, values=values)
return fromplain(model, row) if row else None
async def update(item):
values = asplain(item)
keys = ", ".join(f"{k}=:{k}" for k in values if k != "id")
query = f"UPDATE {item._table} SET {keys} WHERE id=:id"
await shared_connection().execute(query=query, values=values)
async def add_or_update_user(user: User):
db_user = await get(User, imdb_id=user.imdb_id)
if not db_user:
await add(user)
else:
user.id = db_user.id
if user != db_user:
await update(user)
async def add_or_update_movie(movie: Movie):
"""Add or update a Movie in the database.
This is an upsert operation, but it will also update the Movie you pass
into the function to make its `id` match the DB's movie's `id`, and also
set all optional values on your Movie that might be unset but exist in the
database. It's a bidirectional sync.
"""
db_movie = await get(Movie, imdb_id=movie.imdb_id)
if not db_movie:
await add(movie)
else:
movie.id = db_movie.id
# We want to keep any existing value in the DB for all optional fields.
for f in optional_fields(movie):
if getattr(movie, f.name) is None:
setattr(movie, f.name, getattr(db_movie, f.name))
if movie.updated <= db_movie.updated:
return
await update(movie)
async def add_or_update_rating(rating: Rating) -> bool:
db_rating = await get(
Rating, movie_id=str(rating.movie_id), user_id=str(rating.user_id)
)
if not db_rating:
await add(rating)
return True
else:
rating.id = db_rating.id
if rating != db_rating:
await update(rating)
return True
return False
def sql_escape(s: str, char="#"):
return s.replace(char, 2 * char).replace("%", f"{char}%").replace("_", f"{char}_")
async def find_ratings(
*,
imdb_movie_id: str = None,
title: str = None,
media_type: str = None,
ignore_tv_episodes: bool = False,
include_unrated: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] = None,
limit_rows: int = 10,
):
values: dict[str, Union[int, str]] = {
"limit_rows": limit_rows,
}
conditions = []
if title:
values["escape"] = "#"
escaped_title = sql_escape(title, char=values["escape"])
values["pattern"] = "%" + "%".join(escaped_title.split()) + "%"
conditions.append(
f"""
(
{Movie._table}.title LIKE :pattern ESCAPE :escape
OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape
)
"""
)
if yearcomp:
op, year = yearcomp
assert op in "<=>"
values["year"] = year
conditions.append(f"{Movie._table}.release_year{op}:year")
if media_type:
values["media_type"] = media_type
conditions.append(f"{Movie._table}.media_type=:media_type")
if ignore_tv_episodes:
conditions.append(f"{Movie._table}.media_type!='TV Episode'")
source_table = "newest_movies"
ctes = [
f"""{source_table} AS (
SELECT DISTINCT {Rating._table}.movie_id
FROM {Rating._table}
LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id
{('WHERE ' + ' AND '.join(conditions)) if conditions else ''}
ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC
LIMIT :limit_rows
)"""
]
if include_unrated:
source_table = "target_movies"
ctes.extend(
[
f"""unrated_movies AS (
SELECT DISTINCT id AS movie_id
FROM {Movie._table}
WHERE id NOT IN newest_movies
{('AND ' + ' AND '.join(conditions)) if conditions else ''}
ORDER BY length(title) ASC, release_year DESC
LIMIT :limit_rows
)""",
f"""{source_table} AS (
SELECT * FROM newest_movies
UNION ALL -- using ALL here avoids the reordering of IDs
SELECT * FROM unrated_movies
LIMIT :limit_rows
)""",
]
)
query = f"""
WITH
{','.join(ctes)}
SELECT
-- {User._table}.name AS user_name,
{Rating._table}.score AS user_score,
{Movie._table}.score AS imdb_score,
{Movie._table}.imdb_id AS movie_imdb_id,
{Movie._table}.media_type AS media_type,
{Movie._table}.title AS canonical_title,
{Movie._table}.original_title AS original_title,
{Movie._table}.release_year AS release_year
FROM {source_table}
LEFT JOIN {Rating._table} ON {Rating._table}.movie_id={source_table}.movie_id
-- LEFT JOIN {User._table} ON {User._table}.id={Rating._table}.user_id
LEFT JOIN {Movie._table} ON {Movie._table}.id={source_table}.movie_id
"""
rows = await shared_connection().fetch_all(bindparams(query, values))
return tuple(dict(r) for r in rows)
def bindparams(query: str, values: dict):
"""Bind values to a query.
This is similar to what SQLAlchemy and Databases do, but it allows to
easily use the same placeholder in multiple places.
"""
pump_vals = {}
pump_keys = {}
def pump(match):
key = match[1]
val = values[key]
pump_keys[key] = 1 + pump_keys.setdefault(key, 0)
pump_key = f"{key}_{pump_keys[key]}"
pump_vals[pump_key] = val
return f":{pump_key}"
pump_query = re.sub(r":(\w+)\b", pump, query)
return sqlalchemy.text(pump_query).bindparams(**pump_vals)