diff --git a/unwind/db.py b/unwind/db.py index 21e9728..c8ef3a0 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -1,5 +1,8 @@ +import asyncio +import contextlib import logging import re +import threading from pathlib import Path from typing import Iterable, Literal, Optional, Type, TypeVar, Union @@ -173,7 +176,44 @@ async def set_import_progress(progress: float) -> Progress: else: await add(current) - return current + +_lock = threading.Lock() +_prelock = threading.Lock() + + +@contextlib.asynccontextmanager +async def single_threaded(): + """Ensure the nested code is run only by a single thread at a time.""" + wait = 1e-5 # XXX not sure if there's a better magic value here + + # The pre-lock (a lock for the lock) allows for multiple threads to hand of + # the main lock. + # With only a single lock the contending thread will spend most of its time + # in the asyncio.sleep and the reigning thread will have time to finish + # whatever it's doing and simply acquire the lock again before the other + # thread has had a change to try. + # By having another lock (and the same sleep time!) the contending thread + # will always have a chance to acquire the main lock. + while not _prelock.acquire(blocking=False): + await asyncio.sleep(wait) + + try: + while not _lock.acquire(blocking=False): + await asyncio.sleep(wait) + finally: + _prelock.release() + + try: + yield + + finally: + _lock.release() + + +@contextlib.asynccontextmanager +async def locked_connection(): + async with single_threaded(): + yield shared_connection() def shared_connection() -> Database: @@ -195,7 +235,8 @@ async def add(item): keys = ", ".join(f"{k}" for k in values) placeholders = ", ".join(f":{k}" for k in values) query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})" - await shared_connection().execute(query=query, values=values) + async with locked_connection() as conn: + await conn.execute(query=query, values=values) ModelType = TypeVar("ModelType") @@ -213,7 +254,8 @@ async def get( query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" if order_by: query += f" ORDER BY {order_by}" - row = await shared_connection().fetch_one(query=query, values=values) + async with locked_connection() as conn: + row = await conn.fetch_one(query=query, values=values) return fromplain(model, row) if row else None @@ -232,7 +274,8 @@ async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]: f"{k} IN ({','.join(':'+n for n in ns)})" for k, ns in keys.items() ) query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" - rows = await shared_connection().fetch_all(query=query, values=values) + async with locked_connection() as conn: + rows = await conn.fetch_all(query=query, values=values) return (fromplain(model, row) for row in rows) @@ -242,7 +285,8 @@ async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]: fields_ = ", ".join(f.name for f in fields(model)) cond = " AND ".join(f"{k}=:{k}" for k in values) or "1=1" query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" - rows = await shared_connection().fetch_all(query=query, values=values) + async with locked_connection() as conn: + rows = await conn.fetch_all(query=query, values=values) return (fromplain(model, row) for row in rows) @@ -254,13 +298,15 @@ async def update(item): values = asplain(item) keys = ", ".join(f"{k}=:{k}" for k in values if k != "id") query = f"UPDATE {item._table} SET {keys} WHERE id=:id" - await shared_connection().execute(query=query, values=values) + async with locked_connection() as conn: + await conn.execute(query=query, values=values) async def remove(item): values = asplain(item, fields_={"id"}) query = f"DELETE FROM {item._table} WHERE id=:id" - await shared_connection().execute(query=query, values=values) + async with locked_connection() as conn: + await conn.execute(query=query, values=values) async def add_or_update_user(user: User): @@ -456,7 +502,8 @@ async def find_ratings( LEFT JOIN {Movie._table} ON {Movie._table}.id={source_table}.movie_id """ - rows = await shared_connection().fetch_all(bindparams(query, values)) + async with locked_connection() as conn: + rows = await conn.fetch_all(bindparams(query, values)) return tuple(dict(r) for r in rows) @@ -507,7 +554,8 @@ async def ratings_for_movies( WHERE {' AND '.join(f'({c})' for c in conditions) if conditions else '1=1'} """ - rows = await shared_connection().fetch_all(query, values) + async with locked_connection() as conn: + rows = await conn.fetch_all(query, values) return (fromplain(Rating, row) for row in rows) @@ -573,7 +621,8 @@ async def find_movies( {Movie._table}.release_year DESC LIMIT :skip_rows, :limit_rows """ - rows = await shared_connection().fetch_all(bindparams(query, values)) + async with locked_connection() as conn: + rows = await conn.fetch_all(bindparams(query, values)) movies = [fromplain(Movie, row) for row in rows]