fix SQLite locking errors

This commit is contained in:
ducklet 2021-08-18 20:02:10 +02:00
parent f964c0ceb9
commit b676c9ddde

View file

@ -1,5 +1,8 @@
import asyncio
import contextlib
import logging
import re
import threading
from pathlib import Path
from typing import Iterable, Literal, Optional, Type, TypeVar, Union
@ -173,7 +176,44 @@ async def set_import_progress(progress: float) -> Progress:
else:
await add(current)
return current
_lock = threading.Lock()
_prelock = threading.Lock()
@contextlib.asynccontextmanager
async def single_threaded():
"""Ensure the nested code is run only by a single thread at a time."""
wait = 1e-5 # XXX not sure if there's a better magic value here
# The pre-lock (a lock for the lock) allows for multiple threads to hand of
# the main lock.
# With only a single lock the contending thread will spend most of its time
# in the asyncio.sleep and the reigning thread will have time to finish
# whatever it's doing and simply acquire the lock again before the other
# thread has had a change to try.
# By having another lock (and the same sleep time!) the contending thread
# will always have a chance to acquire the main lock.
while not _prelock.acquire(blocking=False):
await asyncio.sleep(wait)
try:
while not _lock.acquire(blocking=False):
await asyncio.sleep(wait)
finally:
_prelock.release()
try:
yield
finally:
_lock.release()
@contextlib.asynccontextmanager
async def locked_connection():
async with single_threaded():
yield shared_connection()
def shared_connection() -> Database:
@ -195,7 +235,8 @@ async def add(item):
keys = ", ".join(f"{k}" for k in values)
placeholders = ", ".join(f":{k}" for k in values)
query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})"
await shared_connection().execute(query=query, values=values)
async with locked_connection() as conn:
await conn.execute(query=query, values=values)
ModelType = TypeVar("ModelType")
@ -213,7 +254,8 @@ async def get(
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
if order_by:
query += f" ORDER BY {order_by}"
row = await shared_connection().fetch_one(query=query, values=values)
async with locked_connection() as conn:
row = await conn.fetch_one(query=query, values=values)
return fromplain(model, row) if row else None
@ -232,7 +274,8 @@ async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
f"{k} IN ({','.join(':'+n for n in ns)})" for k, ns in keys.items()
)
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
rows = await shared_connection().fetch_all(query=query, values=values)
async with locked_connection() as conn:
rows = await conn.fetch_all(query=query, values=values)
return (fromplain(model, row) for row in rows)
@ -242,7 +285,8 @@ async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
fields_ = ", ".join(f.name for f in fields(model))
cond = " AND ".join(f"{k}=:{k}" for k in values) or "1=1"
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
rows = await shared_connection().fetch_all(query=query, values=values)
async with locked_connection() as conn:
rows = await conn.fetch_all(query=query, values=values)
return (fromplain(model, row) for row in rows)
@ -254,13 +298,15 @@ async def update(item):
values = asplain(item)
keys = ", ".join(f"{k}=:{k}" for k in values if k != "id")
query = f"UPDATE {item._table} SET {keys} WHERE id=:id"
await shared_connection().execute(query=query, values=values)
async with locked_connection() as conn:
await conn.execute(query=query, values=values)
async def remove(item):
values = asplain(item, fields_={"id"})
query = f"DELETE FROM {item._table} WHERE id=:id"
await shared_connection().execute(query=query, values=values)
async with locked_connection() as conn:
await conn.execute(query=query, values=values)
async def add_or_update_user(user: User):
@ -456,7 +502,8 @@ async def find_ratings(
LEFT JOIN {Movie._table} ON {Movie._table}.id={source_table}.movie_id
"""
rows = await shared_connection().fetch_all(bindparams(query, values))
async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values))
return tuple(dict(r) for r in rows)
@ -507,7 +554,8 @@ async def ratings_for_movies(
WHERE {' AND '.join(f'({c})' for c in conditions) if conditions else '1=1'}
"""
rows = await shared_connection().fetch_all(query, values)
async with locked_connection() as conn:
rows = await conn.fetch_all(query, values)
return (fromplain(Rating, row) for row in rows)
@ -573,7 +621,8 @@ async def find_movies(
{Movie._table}.release_year DESC
LIMIT :skip_rows, :limit_rows
"""
rows = await shared_connection().fetch_all(bindparams(query, values))
async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values))
movies = [fromplain(Movie, row) for row in rows]