fix SQLite locking errors

This commit is contained in:
ducklet 2021-08-18 20:02:10 +02:00
parent f964c0ceb9
commit b676c9ddde

View file

@ -1,5 +1,8 @@
import asyncio
import contextlib
import logging import logging
import re import re
import threading
from pathlib import Path from pathlib import Path
from typing import Iterable, Literal, Optional, Type, TypeVar, Union from typing import Iterable, Literal, Optional, Type, TypeVar, Union
@ -173,7 +176,44 @@ async def set_import_progress(progress: float) -> Progress:
else: else:
await add(current) await add(current)
return current
_lock = threading.Lock()
_prelock = threading.Lock()
@contextlib.asynccontextmanager
async def single_threaded():
"""Ensure the nested code is run only by a single thread at a time."""
wait = 1e-5 # XXX not sure if there's a better magic value here
# The pre-lock (a lock for the lock) allows for multiple threads to hand of
# the main lock.
# With only a single lock the contending thread will spend most of its time
# in the asyncio.sleep and the reigning thread will have time to finish
# whatever it's doing and simply acquire the lock again before the other
# thread has had a change to try.
# By having another lock (and the same sleep time!) the contending thread
# will always have a chance to acquire the main lock.
while not _prelock.acquire(blocking=False):
await asyncio.sleep(wait)
try:
while not _lock.acquire(blocking=False):
await asyncio.sleep(wait)
finally:
_prelock.release()
try:
yield
finally:
_lock.release()
@contextlib.asynccontextmanager
async def locked_connection():
async with single_threaded():
yield shared_connection()
def shared_connection() -> Database: def shared_connection() -> Database:
@ -195,7 +235,8 @@ async def add(item):
keys = ", ".join(f"{k}" for k in values) keys = ", ".join(f"{k}" for k in values)
placeholders = ", ".join(f":{k}" for k in values) placeholders = ", ".join(f":{k}" for k in values)
query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})" query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})"
await shared_connection().execute(query=query, values=values) async with locked_connection() as conn:
await conn.execute(query=query, values=values)
ModelType = TypeVar("ModelType") ModelType = TypeVar("ModelType")
@ -213,7 +254,8 @@ async def get(
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
if order_by: if order_by:
query += f" ORDER BY {order_by}" query += f" ORDER BY {order_by}"
row = await shared_connection().fetch_one(query=query, values=values) async with locked_connection() as conn:
row = await conn.fetch_one(query=query, values=values)
return fromplain(model, row) if row else None return fromplain(model, row) if row else None
@ -232,7 +274,8 @@ async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
f"{k} IN ({','.join(':'+n for n in ns)})" for k, ns in keys.items() f"{k} IN ({','.join(':'+n for n in ns)})" for k, ns in keys.items()
) )
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
rows = await shared_connection().fetch_all(query=query, values=values) async with locked_connection() as conn:
rows = await conn.fetch_all(query=query, values=values)
return (fromplain(model, row) for row in rows) return (fromplain(model, row) for row in rows)
@ -242,7 +285,8 @@ async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
fields_ = ", ".join(f.name for f in fields(model)) fields_ = ", ".join(f.name for f in fields(model))
cond = " AND ".join(f"{k}=:{k}" for k in values) or "1=1" cond = " AND ".join(f"{k}=:{k}" for k in values) or "1=1"
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
rows = await shared_connection().fetch_all(query=query, values=values) async with locked_connection() as conn:
rows = await conn.fetch_all(query=query, values=values)
return (fromplain(model, row) for row in rows) return (fromplain(model, row) for row in rows)
@ -254,13 +298,15 @@ async def update(item):
values = asplain(item) values = asplain(item)
keys = ", ".join(f"{k}=:{k}" for k in values if k != "id") keys = ", ".join(f"{k}=:{k}" for k in values if k != "id")
query = f"UPDATE {item._table} SET {keys} WHERE id=:id" query = f"UPDATE {item._table} SET {keys} WHERE id=:id"
await shared_connection().execute(query=query, values=values) async with locked_connection() as conn:
await conn.execute(query=query, values=values)
async def remove(item): async def remove(item):
values = asplain(item, fields_={"id"}) values = asplain(item, fields_={"id"})
query = f"DELETE FROM {item._table} WHERE id=:id" query = f"DELETE FROM {item._table} WHERE id=:id"
await shared_connection().execute(query=query, values=values) async with locked_connection() as conn:
await conn.execute(query=query, values=values)
async def add_or_update_user(user: User): async def add_or_update_user(user: User):
@ -456,7 +502,8 @@ async def find_ratings(
LEFT JOIN {Movie._table} ON {Movie._table}.id={source_table}.movie_id LEFT JOIN {Movie._table} ON {Movie._table}.id={source_table}.movie_id
""" """
rows = await shared_connection().fetch_all(bindparams(query, values)) async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values))
return tuple(dict(r) for r in rows) return tuple(dict(r) for r in rows)
@ -507,7 +554,8 @@ async def ratings_for_movies(
WHERE {' AND '.join(f'({c})' for c in conditions) if conditions else '1=1'} WHERE {' AND '.join(f'({c})' for c in conditions) if conditions else '1=1'}
""" """
rows = await shared_connection().fetch_all(query, values) async with locked_connection() as conn:
rows = await conn.fetch_all(query, values)
return (fromplain(Rating, row) for row in rows) return (fromplain(Rating, row) for row in rows)
@ -573,7 +621,8 @@ async def find_movies(
{Movie._table}.release_year DESC {Movie._table}.release_year DESC
LIMIT :skip_rows, :limit_rows LIMIT :skip_rows, :limit_rows
""" """
rows = await shared_connection().fetch_all(bindparams(query, values)) async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values))
movies = [fromplain(Movie, row) for row in rows] movies = [fromplain(Movie, row) for row in rows]