266 lines
7.9 KiB
Python
266 lines
7.9 KiB
Python
import logging
|
|
import re
|
|
from dataclasses import fields
|
|
from pathlib import Path
|
|
from typing import Literal, Optional, Type, TypeVar, Union
|
|
|
|
import sqlalchemy
|
|
from databases import Database
|
|
|
|
from . import config
|
|
from .models import Movie, Rating, User, asplain, fromplain, optional_fields
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_shared_connection: Optional[Database] = None
|
|
|
|
|
|
async def open_connection_pool() -> None:
|
|
"""Open the DB connection pool.
|
|
|
|
This function needs to be called before any access to the database can happen.
|
|
"""
|
|
db = shared_connection()
|
|
await db.connect()
|
|
|
|
await init_db(db)
|
|
|
|
|
|
async def close_connection_pool() -> None:
|
|
"""Close the DB connection pool.
|
|
|
|
This function should be called before the app shuts down to ensure all data
|
|
has been flushed to the database.
|
|
"""
|
|
db = shared_connection()
|
|
|
|
# Run automatic ANALYZE prior to closing the db,
|
|
# see https://sqlite.com/lang_analyze.html.
|
|
await db.execute("PRAGMA analysis_limit=400")
|
|
await db.execute("PRAGMA optimize")
|
|
|
|
await db.disconnect()
|
|
|
|
|
|
def shared_connection() -> Database:
|
|
global _shared_connection
|
|
|
|
if _shared_connection is None:
|
|
uri = f"sqlite:///{config.storage_path}"
|
|
_shared_connection = Database(uri)
|
|
|
|
return _shared_connection
|
|
|
|
|
|
async def init_db(db):
|
|
sql = Path(__file__).with_name("init.sql").read_text()
|
|
async with db.transaction():
|
|
for stmt in sql.split(";;"):
|
|
await db.execute(query=stmt)
|
|
|
|
|
|
async def add(item):
|
|
values = asplain(item)
|
|
keys = ", ".join(f"{k}" for k in values)
|
|
placeholders = ", ".join(f":{k}" for k in values)
|
|
query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})"
|
|
await shared_connection().execute(query=query, values=values)
|
|
|
|
|
|
ModelType = TypeVar("ModelType")
|
|
|
|
|
|
async def get(model: Type[ModelType], **kwds) -> Optional[ModelType]:
|
|
values = {k: v for k, v in kwds.items() if v is not None}
|
|
if not values:
|
|
return
|
|
|
|
fields_ = ", ".join(f.name for f in fields(model))
|
|
cond = " AND ".join(f"{k}=:{k}" for k, v in values.items())
|
|
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
|
row = await shared_connection().fetch_one(query=query, values=values)
|
|
return fromplain(model, row) if row else None
|
|
|
|
|
|
async def update(item):
|
|
values = asplain(item)
|
|
keys = ", ".join(f"{k}=:{k}" for k in values if k != "id")
|
|
query = f"UPDATE {item._table} SET {keys} WHERE id=:id"
|
|
await shared_connection().execute(query=query, values=values)
|
|
|
|
|
|
async def add_or_update_user(user: User):
|
|
db_user = await get(User, imdb_id=user.imdb_id)
|
|
if not db_user:
|
|
await add(user)
|
|
else:
|
|
user.id = db_user.id
|
|
|
|
if user != db_user:
|
|
await update(user)
|
|
|
|
|
|
async def add_or_update_movie(movie: Movie):
|
|
"""Add or update a Movie in the database.
|
|
|
|
This is an upsert operation, but it will also update the Movie you pass
|
|
into the function to make its `id` match the DB's movie's `id`, and also
|
|
set all optional values on your Movie that might be unset but exist in the
|
|
database. It's a bidirectional sync.
|
|
"""
|
|
db_movie = await get(Movie, imdb_id=movie.imdb_id)
|
|
if not db_movie:
|
|
await add(movie)
|
|
else:
|
|
movie.id = db_movie.id
|
|
|
|
# We want to keep any existing value in the DB for all optional fields.
|
|
for f in optional_fields(movie):
|
|
if getattr(movie, f.name) is None:
|
|
setattr(movie, f.name, getattr(db_movie, f.name))
|
|
|
|
if movie.updated <= db_movie.updated:
|
|
return
|
|
|
|
await update(movie)
|
|
|
|
|
|
async def add_or_update_rating(rating: Rating) -> bool:
|
|
db_rating = await get(
|
|
Rating, movie_id=str(rating.movie_id), user_id=str(rating.user_id)
|
|
)
|
|
|
|
if not db_rating:
|
|
await add(rating)
|
|
return True
|
|
|
|
else:
|
|
rating.id = db_rating.id
|
|
|
|
if rating != db_rating:
|
|
await update(rating)
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def sql_escape(s: str, char="#"):
|
|
return s.replace(char, 2 * char).replace("%", f"{char}%").replace("_", f"{char}_")
|
|
|
|
|
|
async def find_ratings(
|
|
*,
|
|
imdb_movie_id: str = None,
|
|
title: str = None,
|
|
media_type: str = None,
|
|
ignore_tv_episodes: bool = False,
|
|
include_unrated: bool = False,
|
|
yearcomp: tuple[Literal["<", "=", ">"], int] = None,
|
|
limit_rows: int = 10,
|
|
):
|
|
values: dict[str, Union[int, str]] = {
|
|
"limit_rows": limit_rows,
|
|
}
|
|
|
|
conditions = []
|
|
if title:
|
|
values["escape"] = "#"
|
|
escaped_title = sql_escape(title, char=values["escape"])
|
|
values["pattern"] = "%" + "%".join(escaped_title.split()) + "%"
|
|
conditions.append(
|
|
f"""
|
|
(
|
|
{Movie._table}.title LIKE :pattern ESCAPE :escape
|
|
OR {Movie._table}.original_title LIKE :pattern ESCAPE :escape
|
|
)
|
|
"""
|
|
)
|
|
|
|
if yearcomp:
|
|
op, year = yearcomp
|
|
assert op in "<=>"
|
|
values["year"] = year
|
|
conditions.append(f"{Movie._table}.release_year{op}:year")
|
|
|
|
if media_type:
|
|
values["media_type"] = media_type
|
|
conditions.append(f"{Movie._table}.media_type=:media_type")
|
|
|
|
if ignore_tv_episodes:
|
|
conditions.append(f"{Movie._table}.media_type!='TV Episode'")
|
|
|
|
source_table = "newest_movies"
|
|
ctes = [
|
|
f"""{source_table} AS (
|
|
SELECT DISTINCT {Rating._table}.movie_id
|
|
FROM {Rating._table}
|
|
LEFT JOIN {Movie._table} ON {Movie._table}.id={Rating._table}.movie_id
|
|
{('WHERE ' + ' AND '.join(conditions)) if conditions else ''}
|
|
ORDER BY length({Movie._table}.title) ASC, {Rating._table}.rating_date DESC
|
|
LIMIT :limit_rows
|
|
)"""
|
|
]
|
|
|
|
if include_unrated:
|
|
source_table = "target_movies"
|
|
ctes.extend(
|
|
[
|
|
f"""unrated_movies AS (
|
|
SELECT DISTINCT id AS movie_id
|
|
FROM {Movie._table}
|
|
WHERE id NOT IN newest_movies
|
|
{('AND ' + ' AND '.join(conditions)) if conditions else ''}
|
|
ORDER BY length(title) ASC, release_year DESC
|
|
LIMIT :limit_rows
|
|
)""",
|
|
f"""{source_table} AS (
|
|
SELECT * FROM newest_movies
|
|
UNION ALL -- using ALL here avoids the reordering of IDs
|
|
SELECT * FROM unrated_movies
|
|
LIMIT :limit_rows
|
|
)""",
|
|
]
|
|
)
|
|
|
|
query = f"""
|
|
WITH
|
|
{','.join(ctes)}
|
|
|
|
SELECT
|
|
-- {User._table}.name AS user_name,
|
|
{Rating._table}.score AS user_score,
|
|
{Movie._table}.score AS imdb_score,
|
|
{Movie._table}.imdb_id AS movie_imdb_id,
|
|
{Movie._table}.media_type AS media_type,
|
|
{Movie._table}.title AS canonical_title,
|
|
{Movie._table}.original_title AS original_title,
|
|
{Movie._table}.release_year AS release_year
|
|
FROM {source_table}
|
|
LEFT JOIN {Rating._table} ON {Rating._table}.movie_id={source_table}.movie_id
|
|
-- LEFT JOIN {User._table} ON {User._table}.id={Rating._table}.user_id
|
|
LEFT JOIN {Movie._table} ON {Movie._table}.id={source_table}.movie_id
|
|
"""
|
|
|
|
rows = await shared_connection().fetch_all(bindparams(query, values))
|
|
return tuple(dict(r) for r in rows)
|
|
|
|
|
|
def bindparams(query: str, values: dict):
|
|
"""Bind values to a query.
|
|
|
|
This is similar to what SQLAlchemy and Databases do, but it allows to
|
|
easily use the same placeholder in multiple places.
|
|
"""
|
|
pump_vals = {}
|
|
pump_keys = {}
|
|
|
|
def pump(match):
|
|
key = match[1]
|
|
val = values[key]
|
|
pump_keys[key] = 1 + pump_keys.setdefault(key, 0)
|
|
pump_key = f"{key}_{pump_keys[key]}"
|
|
pump_vals[pump_key] = val
|
|
return f":{pump_key}"
|
|
|
|
pump_query = re.sub(r":(\w+)\b", pump, query)
|
|
return sqlalchemy.text(pump_query).bindparams(**pump_vals)
|