fix SQLite locking errors
This commit is contained in:
parent
f964c0ceb9
commit
b676c9ddde
1 changed files with 59 additions and 10 deletions
69
unwind/db.py
69
unwind/db.py
|
|
@ -1,5 +1,8 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
|
||||
|
|
@ -173,7 +176,44 @@ async def set_import_progress(progress: float) -> Progress:
|
|||
else:
|
||||
await add(current)
|
||||
|
||||
return current
|
||||
|
||||
_lock = threading.Lock()
|
||||
_prelock = threading.Lock()
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def single_threaded():
|
||||
"""Ensure the nested code is run only by a single thread at a time."""
|
||||
wait = 1e-5 # XXX not sure if there's a better magic value here
|
||||
|
||||
# The pre-lock (a lock for the lock) allows for multiple threads to hand of
|
||||
# the main lock.
|
||||
# With only a single lock the contending thread will spend most of its time
|
||||
# in the asyncio.sleep and the reigning thread will have time to finish
|
||||
# whatever it's doing and simply acquire the lock again before the other
|
||||
# thread has had a change to try.
|
||||
# By having another lock (and the same sleep time!) the contending thread
|
||||
# will always have a chance to acquire the main lock.
|
||||
while not _prelock.acquire(blocking=False):
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
try:
|
||||
while not _lock.acquire(blocking=False):
|
||||
await asyncio.sleep(wait)
|
||||
finally:
|
||||
_prelock.release()
|
||||
|
||||
try:
|
||||
yield
|
||||
|
||||
finally:
|
||||
_lock.release()
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def locked_connection():
|
||||
async with single_threaded():
|
||||
yield shared_connection()
|
||||
|
||||
|
||||
def shared_connection() -> Database:
|
||||
|
|
@ -195,7 +235,8 @@ async def add(item):
|
|||
keys = ", ".join(f"{k}" for k in values)
|
||||
placeholders = ", ".join(f":{k}" for k in values)
|
||||
query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})"
|
||||
await shared_connection().execute(query=query, values=values)
|
||||
async with locked_connection() as conn:
|
||||
await conn.execute(query=query, values=values)
|
||||
|
||||
|
||||
ModelType = TypeVar("ModelType")
|
||||
|
|
@ -213,7 +254,8 @@ async def get(
|
|||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
||||
if order_by:
|
||||
query += f" ORDER BY {order_by}"
|
||||
row = await shared_connection().fetch_one(query=query, values=values)
|
||||
async with locked_connection() as conn:
|
||||
row = await conn.fetch_one(query=query, values=values)
|
||||
return fromplain(model, row) if row else None
|
||||
|
||||
|
||||
|
|
@ -232,7 +274,8 @@ async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
|||
f"{k} IN ({','.join(':'+n for n in ns)})" for k, ns in keys.items()
|
||||
)
|
||||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
||||
rows = await shared_connection().fetch_all(query=query, values=values)
|
||||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(query=query, values=values)
|
||||
return (fromplain(model, row) for row in rows)
|
||||
|
||||
|
||||
|
|
@ -242,7 +285,8 @@ async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
|||
fields_ = ", ".join(f.name for f in fields(model))
|
||||
cond = " AND ".join(f"{k}=:{k}" for k in values) or "1=1"
|
||||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
||||
rows = await shared_connection().fetch_all(query=query, values=values)
|
||||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(query=query, values=values)
|
||||
return (fromplain(model, row) for row in rows)
|
||||
|
||||
|
||||
|
|
@ -254,13 +298,15 @@ async def update(item):
|
|||
values = asplain(item)
|
||||
keys = ", ".join(f"{k}=:{k}" for k in values if k != "id")
|
||||
query = f"UPDATE {item._table} SET {keys} WHERE id=:id"
|
||||
await shared_connection().execute(query=query, values=values)
|
||||
async with locked_connection() as conn:
|
||||
await conn.execute(query=query, values=values)
|
||||
|
||||
|
||||
async def remove(item):
|
||||
values = asplain(item, fields_={"id"})
|
||||
query = f"DELETE FROM {item._table} WHERE id=:id"
|
||||
await shared_connection().execute(query=query, values=values)
|
||||
async with locked_connection() as conn:
|
||||
await conn.execute(query=query, values=values)
|
||||
|
||||
|
||||
async def add_or_update_user(user: User):
|
||||
|
|
@ -456,7 +502,8 @@ async def find_ratings(
|
|||
LEFT JOIN {Movie._table} ON {Movie._table}.id={source_table}.movie_id
|
||||
"""
|
||||
|
||||
rows = await shared_connection().fetch_all(bindparams(query, values))
|
||||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(bindparams(query, values))
|
||||
return tuple(dict(r) for r in rows)
|
||||
|
||||
|
||||
|
|
@ -507,7 +554,8 @@ async def ratings_for_movies(
|
|||
WHERE {' AND '.join(f'({c})' for c in conditions) if conditions else '1=1'}
|
||||
"""
|
||||
|
||||
rows = await shared_connection().fetch_all(query, values)
|
||||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(query, values)
|
||||
|
||||
return (fromplain(Rating, row) for row in rows)
|
||||
|
||||
|
|
@ -573,7 +621,8 @@ async def find_movies(
|
|||
{Movie._table}.release_year DESC
|
||||
LIMIT :skip_rows, :limit_rows
|
||||
"""
|
||||
rows = await shared_connection().fetch_all(bindparams(query, values))
|
||||
async with locked_connection() as conn:
|
||||
rows = await conn.fetch_all(bindparams(query, values))
|
||||
|
||||
movies = [fromplain(Movie, row) for row in rows]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue