unwind/unwind/db.py

185 lines
5.1 KiB
Python

import logging
from dataclasses import fields
from pathlib import Path
from typing import Optional, Type, TypeVar
from databases import Database
from . import config
from .models import Movie, Rating, User, asplain, fromplain, utcnow
log = logging.getLogger(__name__)
_shared_connection: Optional[Database] = None
async def open_connection_pool() -> None:
"""Open the DB connection pool.
This function needs to be called before any access to the database can happen.
"""
db = shared_connection()
await db.connect()
await init_db(db)
async def close_connection_pool() -> None:
"""Close the DB connection pool.
This function should be called before the app shuts down to ensure all data
has been flushed to the database.
"""
db = shared_connection()
# Run automatic ANALYZE prior to closing the db,
# see https://sqlite.com/lang_analyze.html.
await db.execute("PRAGMA analysis_limit=400")
await db.execute("PRAGMA optimize")
await db.disconnect()
def shared_connection() -> Database:
global _shared_connection
if _shared_connection is None:
uri = f"sqlite:///{config.storage_path}"
_shared_connection = Database(uri)
return _shared_connection
async def init_db(db):
sql = Path(__file__).with_name("init.sql").read_text()
async with db.transaction():
for stmt in sql.split(";;"):
await db.execute(query=stmt)
async def add(item):
values = asplain(item)
keys = ", ".join(f"{k}" for k in values)
placeholders = ", ".join(f":{k}" for k in values)
query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})"
await shared_connection().execute(query=query, values=values)
ModelType = TypeVar("ModelType")
async def get(model: Type[ModelType], **kwds) -> Optional[ModelType]:
fields_ = ", ".join(f.name for f in fields(model))
cond = " AND ".join(f"{k}=:{k}" for k in kwds)
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
row = await shared_connection().fetch_one(query=query, values=kwds)
return fromplain(model, row) if row else None
async def update(item):
values = asplain(item)
keys = ", ".join(f"{k}=:{k}" for k in values if k != "id")
query = f"UPDATE {item._table} SET {keys} WHERE id=:id"
await shared_connection().execute(query=query, values=values)
async def add_or_update_user(user: User):
db_user = await get(User, imdb_id=user.imdb_id)
if not db_user:
await add(user)
else:
user.id = db_user.id
if user != db_user:
await update(user)
async def add_or_update_movie(movie: Movie):
db_movie = await get(Movie, imdb_id=movie.imdb_id)
if not db_movie:
await add(movie)
else:
movie.id = db_movie.id
movie.updated = db_movie.updated
if movie != db_movie:
movie.updated = utcnow()
await update(movie)
async def add_or_update_rating(rating: Rating) -> bool:
db_rating = await get(
Rating, movie_id=str(rating.movie_id), user_id=str(rating.user_id)
)
if not db_rating:
await add(rating)
return True
else:
rating.id = db_rating.id
if rating != db_rating:
await update(rating)
return True
return False
def sql_escape(s: str, char="#"):
return s.replace(char, 2 * char).replace("%", f"{char}%").replace("_", f"{char}_")
async def find_ratings(
*,
imdb_movie_id: str = None,
title: str = None,
media_type: str = None,
ignore_tv_episodes: bool = False,
limit_rows=10,
):
values = {
"limit_rows": limit_rows,
}
conditions = []
if title:
values["escape"] = "#"
escaped_title = sql_escape(title, char=values["escape"])
values["pattern"] = "%" + "%".join(escaped_title.split()) + "%"
conditions.append("movies.title LIKE :pattern ESCAPE :escape")
if media_type:
values["media_type"] = media_type
conditions.append("movies.media_type=:media_type")
if ignore_tv_episodes:
conditions.append("movies.media_type!='TV Episode'")
query = f"""
WITH newest_movies
AS (
SELECT DISTINCT ratings.movie_id
FROM ratings
LEFT JOIN movies ON movies.id=ratings.movie_id
{('WHERE ' + ' AND '.join(conditions)) if conditions else ''}
ORDER BY length(movies.title) ASC, ratings.rating_date DESC
LIMIT :limit_rows
)
SELECT
users.name AS user_name,
ratings.score AS user_score,
movies.score AS imdb_score,
movies.imdb_id AS movie_imdb_id,
movies.media_type AS media_type,
movies.title AS movie_title,
movies.release_year AS release_year
FROM newest_movies
LEFT JOIN ratings ON ratings.movie_id=newest_movies.movie_id
LEFT JOIN users ON users.id=ratings.user_id
LEFT JOIN movies ON movies.id=ratings.movie_id
"""
rows = await shared_connection().fetch_all(query=query, values=values)
return tuple(dict(r) for r in rows)