remove databases, use SQLAlechemy 2.0 instead
Among the many changes we switch to using SQLAlchemy's connection pool, which means we are no longer required to guard against multiple threads working on the database. All db funcs now receive a connection to use as their first argument, this allows the caller to control transaction & rollback behavior.
This commit is contained in:
parent
c63bee072f
commit
4981de4a04
12 changed files with 876 additions and 765 deletions
129
poetry.lock
generated
129
poetry.lock
generated
|
|
@ -153,31 +153,6 @@ files = [
|
||||||
[package.extras]
|
[package.extras]
|
||||||
toml = ["tomli"]
|
toml = ["tomli"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "databases"
|
|
||||||
version = "0.7.0"
|
|
||||||
description = "Async database support for Python."
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.7"
|
|
||||||
files = [
|
|
||||||
{file = "databases-0.7.0-py3-none-any.whl", hash = "sha256:cf5da4b8a3e3cd038c459529725ebb64931cbbb7a091102664f20ef8f6cefd0d"},
|
|
||||||
{file = "databases-0.7.0.tar.gz", hash = "sha256:ea2d419d3d2eb80595b7ceb8f282056f080af62efe2fb9bcd83562f93ec4b674"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
aiosqlite = {version = "*", optional = true, markers = "extra == \"sqlite\""}
|
|
||||||
sqlalchemy = ">=1.4.42,<1.5"
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
aiomysql = ["aiomysql"]
|
|
||||||
aiopg = ["aiopg"]
|
|
||||||
aiosqlite = ["aiosqlite"]
|
|
||||||
asyncmy = ["asyncmy"]
|
|
||||||
asyncpg = ["asyncpg"]
|
|
||||||
mysql = ["aiomysql"]
|
|
||||||
postgresql = ["asyncpg"]
|
|
||||||
sqlite = ["aiosqlite"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "greenlet"
|
name = "greenlet"
|
||||||
version = "3.0.1"
|
version = "3.0.1"
|
||||||
|
|
@ -554,62 +529,90 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sqlalchemy"
|
name = "sqlalchemy"
|
||||||
version = "1.4.50"
|
version = "2.0.23"
|
||||||
description = "Database Abstraction Library"
|
description = "Database Abstraction Library"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "SQLAlchemy-1.4.50-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d00665725063692c42badfd521d0c4392e83c6c826795d38eb88fb108e5660e5"},
|
{file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:638c2c0b6b4661a4fd264f6fb804eccd392745c5887f9317feb64bb7cb03b3ea"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85292ff52ddf85a39367057c3d7968a12ee1fb84565331a36a8fead346f08796"},
|
{file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e3b5036aa326dc2df50cba3c958e29b291a80f604b1afa4c8ce73e78e1c9f01d"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d0fed0f791d78e7767c2db28d34068649dfeea027b83ed18c45a423f741425cb"},
|
{file = "SQLAlchemy-2.0.23-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:787af80107fb691934a01889ca8f82a44adedbf5ef3d6ad7d0f0b9ac557e0c34"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db4db3c08ffbb18582f856545f058a7a5e4ab6f17f75795ca90b3c38ee0a8ba4"},
|
{file = "SQLAlchemy-2.0.23-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c14eba45983d2f48f7546bb32b47937ee2cafae353646295f0e99f35b14286ab"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:14b0cacdc8a4759a1e1bd47dc3ee3f5db997129eb091330beda1da5a0e9e5bd7"},
|
{file = "SQLAlchemy-2.0.23-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0666031df46b9badba9bed00092a1ffa3aa063a5e68fa244acd9f08070e936d3"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1fb9cb60e0f33040e4f4681e6658a7eb03b5cb4643284172f91410d8c493dace"},
|
{file = "SQLAlchemy-2.0.23-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:89a01238fcb9a8af118eaad3ffcc5dedaacbd429dc6fdc43fe430d3a941ff965"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c4cb501d585aa74a0f86d0ea6263b9c5e1d1463f8f9071392477fd401bd3c7cc"},
|
{file = "SQLAlchemy-2.0.23-cp310-cp310-win32.whl", hash = "sha256:cabafc7837b6cec61c0e1e5c6d14ef250b675fa9c3060ed8a7e38653bd732ff8"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a7a66297e46f85a04d68981917c75723e377d2e0599d15fbe7a56abed5e2d75"},
|
{file = "SQLAlchemy-2.0.23-cp310-cp310-win_amd64.whl", hash = "sha256:87a3d6b53c39cd173990de2f5f4b83431d534a74f0e2f88bd16eabb5667e65c6"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1db0221cb26d66294f4ca18c533e427211673ab86c1fbaca8d6d9ff78654293"},
|
{file = "SQLAlchemy-2.0.23-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d5578e6863eeb998980c212a39106ea139bdc0b3f73291b96e27c929c90cd8e1"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b7dbe6369677a2bea68fe9812c6e4bbca06ebfa4b5cde257b2b0bf208709131"},
|
{file = "SQLAlchemy-2.0.23-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:62d9e964870ea5ade4bc870ac4004c456efe75fb50404c03c5fd61f8bc669a72"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a9bddb60566dc45c57fd0a5e14dd2d9e5f106d2241e0a2dc0c1da144f9444516"},
|
{file = "SQLAlchemy-2.0.23-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c80c38bd2ea35b97cbf7c21aeb129dcbebbf344ee01a7141016ab7b851464f8e"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82dd4131d88395df7c318eeeef367ec768c2a6fe5bd69423f7720c4edb79473c"},
|
{file = "SQLAlchemy-2.0.23-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75eefe09e98043cff2fb8af9796e20747ae870c903dc61d41b0c2e55128f958d"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:273505fcad22e58cc67329cefab2e436006fc68e3c5423056ee0513e6523268a"},
|
{file = "SQLAlchemy-2.0.23-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bd45a5b6c68357578263d74daab6ff9439517f87da63442d244f9f23df56138d"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3257a6e09626d32b28a0c5b4f1a97bced585e319cfa90b417f9ab0f6145c33c"},
|
{file = "SQLAlchemy-2.0.23-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a86cb7063e2c9fb8e774f77fbf8475516d270a3e989da55fa05d08089d77f8c4"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d69738d582e3a24125f0c246ed8d712b03bd21e148268421e4a4d09c34f521a5"},
|
{file = "SQLAlchemy-2.0.23-cp311-cp311-win32.whl", hash = "sha256:b41f5d65b54cdf4934ecede2f41b9c60c9f785620416e8e6c48349ab18643855"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:34e1c5d9cd3e6bf3d1ce56971c62a40c06bfc02861728f368dcfec8aeedb2814"},
|
{file = "SQLAlchemy-2.0.23-cp311-cp311-win_amd64.whl", hash = "sha256:9ca922f305d67605668e93991aaf2c12239c78207bca3b891cd51a4515c72e22"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1fcee5a2c859eecb4ed179edac5ffbc7c84ab09a5420219078ccc6edda45436"},
|
{file = "SQLAlchemy-2.0.23-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0f7fb0c7527c41fa6fcae2be537ac137f636a41b4c5a4c58914541e2f436b45"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbaf6643a604aa17e7a7afd74f665f9db882df5c297bdd86c38368f2c471f37d"},
|
{file = "SQLAlchemy-2.0.23-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7c424983ab447dab126c39d3ce3be5bee95700783204a72549c3dceffe0fc8f4"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2e70e0673d7d12fa6cd363453a0d22dac0d9978500aa6b46aa96e22690a55eab"},
|
{file = "SQLAlchemy-2.0.23-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f508ba8f89e0a5ecdfd3761f82dda2a3d7b678a626967608f4273e0dba8f07ac"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b881ac07d15fb3e4f68c5a67aa5cdaf9eb8f09eb5545aaf4b0a5f5f4659be18"},
|
{file = "SQLAlchemy-2.0.23-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6463aa765cf02b9247e38b35853923edbf2f6fd1963df88706bc1d02410a5577"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f6997da81114daef9203d30aabfa6b218a577fc2bd797c795c9c88c9eb78d49"},
|
{file = "SQLAlchemy-2.0.23-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e599a51acf3cc4d31d1a0cf248d8f8d863b6386d2b6782c5074427ebb7803bda"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdb77e1789e7596b77fd48d99ec1d2108c3349abd20227eea0d48d3f8cf398d9"},
|
{file = "SQLAlchemy-2.0.23-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:fd54601ef9cc455a0c61e5245f690c8a3ad67ddb03d3b91c361d076def0b4c60"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:128a948bd40780667114b0297e2cc6d657b71effa942e0a368d8cc24293febb3"},
|
{file = "SQLAlchemy-2.0.23-cp312-cp312-win32.whl", hash = "sha256:42d0b0290a8fb0165ea2c2781ae66e95cca6e27a2fbe1016ff8db3112ac1e846"},
|
||||||
{file = "SQLAlchemy-1.4.50-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2d526aeea1bd6a442abc7c9b4b00386fd70253b80d54a0930c0a216230a35be"},
|
{file = "SQLAlchemy-2.0.23-cp312-cp312-win_amd64.whl", hash = "sha256:227135ef1e48165f37590b8bfc44ed7ff4c074bf04dc8d6f8e7f1c14a94aa6ca"},
|
||||||
{file = "SQLAlchemy-1.4.50.tar.gz", hash = "sha256:3b97ddf509fc21e10b09403b5219b06c5b558b27fc2453150274fa4e70707dbf"},
|
{file = "SQLAlchemy-2.0.23-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:14aebfe28b99f24f8a4c1346c48bc3d63705b1f919a24c27471136d2f219f02d"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e983fa42164577d073778d06d2cc5d020322425a509a08119bdcee70ad856bf"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e0dc9031baa46ad0dd5a269cb7a92a73284d1309228be1d5935dac8fb3cae24"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5f94aeb99f43729960638e7468d4688f6efccb837a858b34574e01143cf11f89"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:63bfc3acc970776036f6d1d0e65faa7473be9f3135d37a463c5eba5efcdb24c8"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp37-cp37m-win32.whl", hash = "sha256:f48ed89dd11c3c586f45e9eec1e437b355b3b6f6884ea4a4c3111a3358fd0c18"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp37-cp37m-win_amd64.whl", hash = "sha256:1e018aba8363adb0599e745af245306cb8c46b9ad0a6fc0a86745b6ff7d940fc"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:64ac935a90bc479fee77f9463f298943b0e60005fe5de2aa654d9cdef46c54df"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c4722f3bc3c1c2fcc3702dbe0016ba31148dd6efcd2a2fd33c1b4897c6a19693"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4af79c06825e2836de21439cb2a6ce22b2ca129bad74f359bddd173f39582bf5"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:683ef58ca8eea4747737a1c35c11372ffeb84578d3aab8f3e10b1d13d66f2bc4"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:d4041ad05b35f1f4da481f6b811b4af2f29e83af253bf37c3c4582b2c68934ab"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aeb397de65a0a62f14c257f36a726945a7f7bb60253462e8602d9b97b5cbe204"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp38-cp38-win32.whl", hash = "sha256:42ede90148b73fe4ab4a089f3126b2cfae8cfefc955c8174d697bb46210c8306"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp38-cp38-win_amd64.whl", hash = "sha256:964971b52daab357d2c0875825e36584d58f536e920f2968df8d581054eada4b"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:616fe7bcff0a05098f64b4478b78ec2dfa03225c23734d83d6c169eb41a93e55"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0e680527245895aba86afbd5bef6c316831c02aa988d1aad83c47ffe92655e74"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9585b646ffb048c0250acc7dad92536591ffe35dba624bb8fd9b471e25212a35"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4895a63e2c271ffc7a81ea424b94060f7b3b03b4ea0cd58ab5bb676ed02f4221"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:cc1d21576f958c42d9aec68eba5c1a7d715e5fc07825a629015fe8e3b0657fb0"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:967c0b71156f793e6662dd839da54f884631755275ed71f1539c95bbada9aaab"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp39-cp39-win32.whl", hash = "sha256:0a8c6aa506893e25a04233bc721c6b6cf844bafd7250535abb56cb6cc1368884"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-cp39-cp39-win_amd64.whl", hash = "sha256:f3420d00d2cb42432c1d0e44540ae83185ccbbc67a6054dcc8ab5387add6620b"},
|
||||||
|
{file = "SQLAlchemy-2.0.23-py3-none-any.whl", hash = "sha256:31952bbc527d633b9479f5f81e8b9dfada00b91d6baba021a869095f1a97006d"},
|
||||||
|
{file = "SQLAlchemy-2.0.23.tar.gz", hash = "sha256:c1bda93cbbe4aa2aa0aa8655c5aeda505cd219ff3e8da91d1d329e143e4aff69"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
aiosqlite = {version = "*", optional = true, markers = "python_version >= \"3\" and extra == \"aiosqlite\""}
|
aiosqlite = {version = "*", optional = true, markers = "extra == \"aiosqlite\""}
|
||||||
greenlet = {version = "!=0.4.17", optional = true, markers = "python_version >= \"3\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"aiosqlite\")"}
|
greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"aiosqlite\""}
|
||||||
typing-extensions = {version = "!=3.10.0.1", optional = true, markers = "extra == \"aiosqlite\""}
|
typing-extensions = {version = ">=4.2.0", optional = true, markers = "extra == \"aiosqlite\""}
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
aiomysql = ["aiomysql (>=0.2.0)", "greenlet (!=0.4.17)"]
|
aiomysql = ["aiomysql (>=0.2.0)", "greenlet (!=0.4.17)"]
|
||||||
|
aioodbc = ["aioodbc", "greenlet (!=0.4.17)"]
|
||||||
aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing-extensions (!=3.10.0.1)"]
|
aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing-extensions (!=3.10.0.1)"]
|
||||||
asyncio = ["greenlet (!=0.4.17)"]
|
asyncio = ["greenlet (!=0.4.17)"]
|
||||||
asyncmy = ["asyncmy (>=0.2.3,!=0.2.4)", "greenlet (!=0.4.17)"]
|
asyncmy = ["asyncmy (>=0.2.3,!=0.2.4,!=0.2.6)", "greenlet (!=0.4.17)"]
|
||||||
mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2)"]
|
mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5)"]
|
||||||
mssql = ["pyodbc"]
|
mssql = ["pyodbc"]
|
||||||
mssql-pymssql = ["pymssql"]
|
mssql-pymssql = ["pymssql"]
|
||||||
mssql-pyodbc = ["pyodbc"]
|
mssql-pyodbc = ["pyodbc"]
|
||||||
mypy = ["mypy (>=0.910)", "sqlalchemy2-stubs"]
|
mypy = ["mypy (>=0.910)"]
|
||||||
mysql = ["mysqlclient (>=1.4.0)", "mysqlclient (>=1.4.0,<2)"]
|
mysql = ["mysqlclient (>=1.4.0)"]
|
||||||
mysql-connector = ["mysql-connector-python"]
|
mysql-connector = ["mysql-connector-python"]
|
||||||
oracle = ["cx-oracle (>=7)", "cx-oracle (>=7,<8)"]
|
oracle = ["cx-oracle (>=8)"]
|
||||||
|
oracle-oracledb = ["oracledb (>=1.0.1)"]
|
||||||
postgresql = ["psycopg2 (>=2.7)"]
|
postgresql = ["psycopg2 (>=2.7)"]
|
||||||
postgresql-asyncpg = ["asyncpg", "greenlet (!=0.4.17)"]
|
postgresql-asyncpg = ["asyncpg", "greenlet (!=0.4.17)"]
|
||||||
postgresql-pg8000 = ["pg8000 (>=1.16.6,!=1.29.0)"]
|
postgresql-pg8000 = ["pg8000 (>=1.29.1)"]
|
||||||
|
postgresql-psycopg = ["psycopg (>=3.0.7)"]
|
||||||
postgresql-psycopg2binary = ["psycopg2-binary"]
|
postgresql-psycopg2binary = ["psycopg2-binary"]
|
||||||
postgresql-psycopg2cffi = ["psycopg2cffi"]
|
postgresql-psycopg2cffi = ["psycopg2cffi"]
|
||||||
pymysql = ["pymysql", "pymysql (<1)"]
|
postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
|
||||||
|
pymysql = ["pymysql"]
|
||||||
sqlcipher = ["sqlcipher3-binary"]
|
sqlcipher = ["sqlcipher3-binary"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -683,4 +686,4 @@ files = [
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.12"
|
python-versions = "^3.12"
|
||||||
content-hash = "8d0ddcdcd96f4736bb3608df11678d78776f5cf7c6883474b61b158c99ac4732"
|
content-hash = "fc07028820963701634eb55b42ea12962fd7c6fc25ef76ddadf30f2c74544b5f"
|
||||||
|
|
|
||||||
|
|
@ -11,10 +11,9 @@ beautifulsoup4 = "^4.9.3"
|
||||||
html5lib = "^1.1"
|
html5lib = "^1.1"
|
||||||
starlette = "^0.30"
|
starlette = "^0.30"
|
||||||
ulid-py = "^1.1.0"
|
ulid-py = "^1.1.0"
|
||||||
databases = {extras = ["sqlite"], version = "^0.7.0"}
|
|
||||||
uvicorn = "^0.23"
|
uvicorn = "^0.23"
|
||||||
httpx = "^0.24"
|
httpx = "^0.24"
|
||||||
sqlalchemy = {version = "^1.4", extras = ["aiosqlite"]}
|
sqlalchemy = {version = "^2.0", extras = ["aiosqlite"]}
|
||||||
|
|
||||||
[tool.poetry.group.dev]
|
[tool.poetry.group.dev]
|
||||||
optional = true
|
optional = true
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ dbfile="${UNWIND_DATA:-./data}/tests.sqlite"
|
||||||
|
|
||||||
# Rollback in Databases is currently broken, so we have to rebuild the database
|
# Rollback in Databases is currently broken, so we have to rebuild the database
|
||||||
# each time; see https://github.com/encode/databases/issues/403
|
# each time; see https://github.com/encode/databases/issues/403
|
||||||
trap 'rm "$dbfile"' EXIT TERM INT QUIT
|
trap 'rm "$dbfile" "${dbfile}-shm" "${dbfile}-wal"' EXIT TERM INT QUIT
|
||||||
|
|
||||||
[ -z "${DEBUG:-}" ] || set -x
|
[ -z "${DEBUG:-}" ] || set -x
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,16 +17,19 @@ def event_loop():
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def shared_conn():
|
async def shared_conn():
|
||||||
c = db._shared_connection()
|
"""A database connection, ready to use."""
|
||||||
await c.connect()
|
await db.open_connection_pool()
|
||||||
|
|
||||||
await db.apply_db_patches(c)
|
async with db.new_connection() as c:
|
||||||
yield c
|
db._test_connection = c
|
||||||
|
yield c
|
||||||
|
db._test_connection = None
|
||||||
|
|
||||||
await c.disconnect()
|
await db.close_connection_pool()
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
async def conn(shared_conn):
|
async def conn(shared_conn: db.Connection):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
"""A transacted database connection, will be rolled back after use."""
|
||||||
|
async with db.transacted(shared_conn, force_rollback=True):
|
||||||
yield shared_conn
|
yield shared_conn
|
||||||
|
|
|
||||||
627
tests/test_db.py
627
tests/test_db.py
|
|
@ -4,7 +4,7 @@ import pytest
|
||||||
|
|
||||||
from unwind import db, models, web_models
|
from unwind import db, models, web_models
|
||||||
|
|
||||||
_movie_imdb_id = 1234567
|
_movie_imdb_id = 1230000
|
||||||
|
|
||||||
|
|
||||||
def a_movie(**kwds) -> models.Movie:
|
def a_movie(**kwds) -> models.Movie:
|
||||||
|
|
@ -21,394 +21,399 @@ def a_movie(**kwds) -> models.Movie:
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_current_patch_level(shared_conn: db.Database):
|
async def test_current_patch_level(conn: db.Connection):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
patch_level = "some-patch-level"
|
||||||
patch_level = "some-patch-level"
|
assert patch_level != await db.current_patch_level(conn)
|
||||||
assert patch_level != await db.current_patch_level(shared_conn)
|
await db.set_current_patch_level(conn, patch_level)
|
||||||
await db.set_current_patch_level(shared_conn, patch_level)
|
assert patch_level == await db.current_patch_level(conn)
|
||||||
assert patch_level == await db.current_patch_level(shared_conn)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get(shared_conn: db.Database):
|
async def test_get(conn: db.Connection):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
m1 = a_movie()
|
||||||
m1 = a_movie()
|
await db.add(conn, m1)
|
||||||
await db.add(m1)
|
|
||||||
|
|
||||||
m2 = a_movie(release_year=m1.release_year + 1)
|
m2 = a_movie(release_year=m1.release_year + 1)
|
||||||
await db.add(m2)
|
await db.add(conn, m2)
|
||||||
|
|
||||||
assert None is await db.get(models.Movie)
|
assert None is await db.get(conn, models.Movie)
|
||||||
assert None is await db.get(models.Movie, id="blerp")
|
assert None is await db.get(conn, models.Movie, id="blerp")
|
||||||
assert m1 == await db.get(models.Movie, id=str(m1.id))
|
assert m1 == await db.get(conn, models.Movie, id=str(m1.id))
|
||||||
assert m2 == await db.get(models.Movie, release_year=m2.release_year)
|
assert m2 == await db.get(conn, models.Movie, release_year=m2.release_year)
|
||||||
assert None is await db.get(
|
assert None is await db.get(
|
||||||
models.Movie, id=str(m1.id), release_year=m2.release_year
|
conn, models.Movie, id=str(m1.id), release_year=m2.release_year
|
||||||
)
|
)
|
||||||
assert m2 == await db.get(
|
assert m2 == await db.get(
|
||||||
models.Movie, id=str(m2.id), release_year=m2.release_year
|
conn, models.Movie, id=str(m2.id), release_year=m2.release_year
|
||||||
)
|
)
|
||||||
assert m1 == await db.get(
|
assert m1 == await db.get(
|
||||||
models.Movie,
|
conn,
|
||||||
media_type=m1.media_type,
|
models.Movie,
|
||||||
order_by=(models.movies.c.release_year, "asc"),
|
media_type=m1.media_type,
|
||||||
)
|
order_by=(models.movies.c.release_year, "asc"),
|
||||||
assert m2 == await db.get(
|
)
|
||||||
models.Movie,
|
assert m2 == await db.get(
|
||||||
media_type=m1.media_type,
|
conn,
|
||||||
order_by=(models.movies.c.release_year, "desc"),
|
models.Movie,
|
||||||
)
|
media_type=m1.media_type,
|
||||||
|
order_by=(models.movies.c.release_year, "desc"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_all(shared_conn: db.Database):
|
async def test_get_all(conn: db.Connection):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
m1 = a_movie()
|
||||||
m1 = a_movie()
|
await db.add(conn, m1)
|
||||||
await db.add(m1)
|
|
||||||
|
|
||||||
m2 = a_movie(release_year=m1.release_year)
|
m2 = a_movie(release_year=m1.release_year)
|
||||||
await db.add(m2)
|
await db.add(conn, m2)
|
||||||
|
|
||||||
m3 = a_movie(release_year=m1.release_year + 1)
|
m3 = a_movie(release_year=m1.release_year + 1)
|
||||||
await db.add(m3)
|
await db.add(conn, m3)
|
||||||
|
|
||||||
assert [] == list(await db.get_all(models.Movie, id="blerp"))
|
assert [] == list(await db.get_all(conn, models.Movie, id="blerp"))
|
||||||
assert [m1] == list(await db.get_all(models.Movie, id=str(m1.id)))
|
assert [m1] == list(await db.get_all(conn, models.Movie, id=str(m1.id)))
|
||||||
assert [m1, m2] == list(
|
assert [m1, m2] == list(
|
||||||
await db.get_all(models.Movie, release_year=m1.release_year)
|
await db.get_all(conn, models.Movie, release_year=m1.release_year)
|
||||||
)
|
)
|
||||||
assert [m1, m2, m3] == list(await db.get_all(models.Movie))
|
assert [m1, m2, m3] == list(await db.get_all(conn, models.Movie))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_many(shared_conn: db.Database):
|
async def test_get_many(conn: db.Connection):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
m1 = a_movie()
|
||||||
m1 = a_movie()
|
await db.add(conn, m1)
|
||||||
await db.add(m1)
|
|
||||||
|
|
||||||
m2 = a_movie(release_year=m1.release_year)
|
m2 = a_movie(release_year=m1.release_year)
|
||||||
await db.add(m2)
|
await db.add(conn, m2)
|
||||||
|
|
||||||
m3 = a_movie(release_year=m1.release_year + 1)
|
m3 = a_movie(release_year=m1.release_year + 1)
|
||||||
await db.add(m3)
|
await db.add(conn, m3)
|
||||||
|
|
||||||
assert [] == list(await db.get_many(models.Movie)), "selected nothing"
|
assert [] == list(await db.get_many(conn, models.Movie)), "selected nothing"
|
||||||
assert [m1] == list(await db.get_many(models.Movie, id=[str(m1.id)]))
|
assert [m1] == list(await db.get_many(conn, models.Movie, id=[str(m1.id)]))
|
||||||
assert [m1] == list(await db.get_many(models.Movie, id={str(m1.id)}))
|
assert [m1] == list(await db.get_many(conn, models.Movie, id={str(m1.id)}))
|
||||||
assert [m1, m2] == list(
|
assert [m1, m2] == list(
|
||||||
await db.get_many(models.Movie, release_year=[m1.release_year])
|
await db.get_many(conn, models.Movie, release_year=[m1.release_year])
|
||||||
)
|
)
|
||||||
assert [m1, m2, m3] == list(
|
assert [m1, m2, m3] == list(
|
||||||
await db.get_many(
|
await db.get_many(
|
||||||
models.Movie, release_year=[m1.release_year, m3.release_year]
|
conn, models.Movie, release_year=[m1.release_year, m3.release_year]
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_and_get(shared_conn: db.Database):
|
async def test_add_and_get(conn: db.Connection):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
m1 = a_movie()
|
||||||
m1 = a_movie()
|
await db.add(conn, m1)
|
||||||
await db.add(m1)
|
|
||||||
|
|
||||||
m2 = a_movie()
|
m2 = a_movie()
|
||||||
await db.add(m2)
|
await db.add(conn, m2)
|
||||||
|
|
||||||
assert m1 == await db.get(models.Movie, id=str(m1.id))
|
assert m1 == await db.get(conn, models.Movie, id=str(m1.id))
|
||||||
assert m2 == await db.get(models.Movie, id=str(m2.id))
|
assert m2 == await db.get(conn, models.Movie, id=str(m2.id))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update(shared_conn: db.Database):
|
async def test_update(conn: db.Connection):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
m = a_movie()
|
||||||
m = a_movie()
|
await db.add(conn, m)
|
||||||
await db.add(m)
|
|
||||||
|
|
||||||
assert m == await db.get(models.Movie, id=str(m.id))
|
assert m == await db.get(conn, models.Movie, id=str(m.id))
|
||||||
m.title += "something else"
|
m.title += "something else"
|
||||||
assert m != await db.get(models.Movie, id=str(m.id))
|
assert m != await db.get(conn, models.Movie, id=str(m.id))
|
||||||
|
|
||||||
await db.update(m)
|
await db.update(conn, m)
|
||||||
assert m == await db.get(models.Movie, id=str(m.id))
|
assert m == await db.get(conn, models.Movie, id=str(m.id))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_remove(shared_conn: db.Database):
|
async def test_remove(conn: db.Connection):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
m1 = a_movie()
|
||||||
m1 = a_movie()
|
await db.add(conn, m1)
|
||||||
await db.add(m1)
|
assert m1 == await db.get(conn, models.Movie, id=str(m1.id))
|
||||||
assert m1 == await db.get(models.Movie, id=str(m1.id))
|
|
||||||
|
|
||||||
await db.remove(m1)
|
await db.remove(conn, m1)
|
||||||
assert None is await db.get(models.Movie, id=str(m1.id))
|
assert None is await db.get(conn, models.Movie, id=str(m1.id))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_find_ratings(shared_conn: db.Database):
|
async def test_find_ratings(conn: db.Connection):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
m1 = a_movie(
|
||||||
m1 = a_movie(
|
title="test movie",
|
||||||
title="test movie",
|
release_year=2013,
|
||||||
release_year=2013,
|
genres={"genre-1"},
|
||||||
genres={"genre-1"},
|
)
|
||||||
)
|
await db.add(conn, m1)
|
||||||
await db.add(m1)
|
|
||||||
|
|
||||||
m2 = a_movie(
|
m2 = a_movie(
|
||||||
title="it's anöther Movie, Part 2",
|
title="it's anöther Movie, Part 2",
|
||||||
release_year=2015,
|
release_year=2015,
|
||||||
genres={"genre-2"},
|
genres={"genre-2"},
|
||||||
)
|
)
|
||||||
await db.add(m2)
|
await db.add(conn, m2)
|
||||||
|
|
||||||
m3 = a_movie(
|
m3 = a_movie(
|
||||||
title="movie it's, Part 3",
|
title="movie it's, Part 3",
|
||||||
release_year=m2.release_year,
|
release_year=m2.release_year,
|
||||||
genres=m2.genres,
|
genres=m2.genres,
|
||||||
)
|
)
|
||||||
await db.add(m3)
|
await db.add(conn, m3)
|
||||||
|
|
||||||
u1 = models.User(
|
u1 = models.User(
|
||||||
imdb_id="u00001",
|
imdb_id="u00001",
|
||||||
name="User1",
|
name="User1",
|
||||||
secret="secret1",
|
secret="secret1",
|
||||||
)
|
)
|
||||||
await db.add(u1)
|
await db.add(conn, u1)
|
||||||
|
|
||||||
u2 = models.User(
|
u2 = models.User(
|
||||||
imdb_id="u00002",
|
imdb_id="u00002",
|
||||||
name="User2",
|
name="User2",
|
||||||
secret="secret2",
|
secret="secret2",
|
||||||
)
|
)
|
||||||
await db.add(u2)
|
await db.add(conn, u2)
|
||||||
|
|
||||||
r1 = models.Rating(
|
r1 = models.Rating(
|
||||||
movie_id=m2.id,
|
movie_id=m2.id,
|
||||||
movie=m2,
|
movie=m2,
|
||||||
user_id=u1.id,
|
user_id=u1.id,
|
||||||
user=u1,
|
user=u1,
|
||||||
score=66,
|
score=66,
|
||||||
rating_date=datetime.now(),
|
rating_date=datetime.now(),
|
||||||
)
|
)
|
||||||
await db.add(r1)
|
await db.add(conn, r1)
|
||||||
|
|
||||||
r2 = models.Rating(
|
r2 = models.Rating(
|
||||||
movie_id=m2.id,
|
movie_id=m2.id,
|
||||||
movie=m2,
|
movie=m2,
|
||||||
user_id=u2.id,
|
user_id=u2.id,
|
||||||
user=u2,
|
user=u2,
|
||||||
score=77,
|
score=77,
|
||||||
rating_date=datetime.now(),
|
rating_date=datetime.now(),
|
||||||
)
|
)
|
||||||
await db.add(r2)
|
await db.add(conn, r2)
|
||||||
|
|
||||||
# ---
|
# ---
|
||||||
|
|
||||||
rows = await db.find_ratings(
|
rows = await db.find_ratings(
|
||||||
title=m1.title,
|
conn,
|
||||||
media_type=m1.media_type,
|
title=m1.title,
|
||||||
exact=True,
|
media_type=m1.media_type,
|
||||||
ignore_tv_episodes=True,
|
exact=True,
|
||||||
include_unrated=True,
|
ignore_tv_episodes=True,
|
||||||
yearcomp=("=", m1.release_year),
|
include_unrated=True,
|
||||||
limit_rows=3,
|
yearcomp=("=", m1.release_year),
|
||||||
user_ids=[],
|
limit_rows=3,
|
||||||
)
|
user_ids=[],
|
||||||
ratings = (web_models.Rating(**r) for r in rows)
|
)
|
||||||
assert (web_models.RatingAggregate.from_movie(m1),) == tuple(
|
ratings = (web_models.Rating(**r) for r in rows)
|
||||||
web_models.aggregate_ratings(ratings, user_ids=[])
|
assert (web_models.RatingAggregate.from_movie(m1),) == tuple(
|
||||||
)
|
web_models.aggregate_ratings(ratings, user_ids=[])
|
||||||
|
)
|
||||||
|
|
||||||
rows = await db.find_ratings(title="movie", include_unrated=False)
|
rows = await db.find_ratings(conn, title="movie", include_unrated=False)
|
||||||
ratings = tuple(web_models.Rating(**r) for r in rows)
|
ratings = tuple(web_models.Rating(**r) for r in rows)
|
||||||
assert (
|
assert (
|
||||||
web_models.Rating.from_movie(m2, rating=r1),
|
web_models.Rating.from_movie(m2, rating=r1),
|
||||||
web_models.Rating.from_movie(m2, rating=r2),
|
web_models.Rating.from_movie(m2, rating=r2),
|
||||||
) == ratings
|
) == ratings
|
||||||
|
|
||||||
rows = await db.find_ratings(title="movie", include_unrated=True)
|
rows = await db.find_ratings(conn, title="movie", include_unrated=True)
|
||||||
ratings = tuple(web_models.Rating(**r) for r in rows)
|
ratings = tuple(web_models.Rating(**r) for r in rows)
|
||||||
assert (
|
assert (
|
||||||
web_models.Rating.from_movie(m1),
|
web_models.Rating.from_movie(m1),
|
||||||
web_models.Rating.from_movie(m2, rating=r1),
|
web_models.Rating.from_movie(m2, rating=r1),
|
||||||
web_models.Rating.from_movie(m2, rating=r2),
|
web_models.Rating.from_movie(m2, rating=r2),
|
||||||
web_models.Rating.from_movie(m3),
|
web_models.Rating.from_movie(m3),
|
||||||
) == ratings
|
) == ratings
|
||||||
|
|
||||||
aggr = web_models.aggregate_ratings(ratings, user_ids=[])
|
aggr = web_models.aggregate_ratings(ratings, user_ids=[])
|
||||||
assert tuple(
|
assert tuple(
|
||||||
web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3]
|
web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3]
|
||||||
) == tuple(aggr)
|
) == tuple(aggr)
|
||||||
|
|
||||||
aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id)])
|
aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id)])
|
||||||
assert (
|
assert (
|
||||||
web_models.RatingAggregate.from_movie(m1),
|
web_models.RatingAggregate.from_movie(m1),
|
||||||
web_models.RatingAggregate.from_movie(m2, ratings=[r1]),
|
web_models.RatingAggregate.from_movie(m2, ratings=[r1]),
|
||||||
web_models.RatingAggregate.from_movie(m3),
|
web_models.RatingAggregate.from_movie(m3),
|
||||||
) == tuple(aggr)
|
) == tuple(aggr)
|
||||||
|
|
||||||
aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id), str(u2.id)])
|
aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id), str(u2.id)])
|
||||||
assert (
|
assert (
|
||||||
web_models.RatingAggregate.from_movie(m1),
|
web_models.RatingAggregate.from_movie(m1),
|
||||||
web_models.RatingAggregate.from_movie(m2, ratings=[r1, r2]),
|
web_models.RatingAggregate.from_movie(m2, ratings=[r1, r2]),
|
||||||
web_models.RatingAggregate.from_movie(m3),
|
web_models.RatingAggregate.from_movie(m3),
|
||||||
) == tuple(aggr)
|
) == tuple(aggr)
|
||||||
|
|
||||||
rows = await db.find_ratings(title="movie", include_unrated=True)
|
rows = await db.find_ratings(conn, title="movie", include_unrated=True)
|
||||||
ratings = (web_models.Rating(**r) for r in rows)
|
ratings = (web_models.Rating(**r) for r in rows)
|
||||||
aggr = web_models.aggregate_ratings(ratings, user_ids=[])
|
aggr = web_models.aggregate_ratings(ratings, user_ids=[])
|
||||||
assert tuple(
|
assert tuple(
|
||||||
web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3]
|
web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3]
|
||||||
) == tuple(aggr)
|
) == tuple(aggr)
|
||||||
|
|
||||||
rows = await db.find_ratings(title="test", include_unrated=True)
|
rows = await db.find_ratings(conn, title="test", include_unrated=True)
|
||||||
ratings = tuple(web_models.Rating(**r) for r in rows)
|
ratings = tuple(web_models.Rating(**r) for r in rows)
|
||||||
assert (web_models.Rating.from_movie(m1),) == ratings
|
assert (web_models.Rating.from_movie(m1),) == ratings
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ratings_for_movies(shared_conn: db.Database):
|
async def test_ratings_for_movies(conn: db.Connection):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
m1 = a_movie()
|
||||||
m1 = a_movie()
|
await db.add(conn, m1)
|
||||||
await db.add(m1)
|
|
||||||
|
|
||||||
m2 = a_movie()
|
m2 = a_movie()
|
||||||
await db.add(m2)
|
await db.add(conn, m2)
|
||||||
|
|
||||||
u1 = models.User(
|
u1 = models.User(
|
||||||
imdb_id="u00001",
|
imdb_id="u00001",
|
||||||
name="User1",
|
name="User1",
|
||||||
secret="secret1",
|
secret="secret1",
|
||||||
)
|
)
|
||||||
await db.add(u1)
|
await db.add(conn, u1)
|
||||||
|
|
||||||
u2 = models.User(
|
u2 = models.User(
|
||||||
imdb_id="u00002",
|
imdb_id="u00002",
|
||||||
name="User2",
|
name="User2",
|
||||||
secret="secret2",
|
secret="secret2",
|
||||||
)
|
)
|
||||||
await db.add(u2)
|
await db.add(conn, u2)
|
||||||
|
|
||||||
r1 = models.Rating(
|
r1 = models.Rating(
|
||||||
movie_id=m2.id,
|
movie_id=m2.id,
|
||||||
movie=m2,
|
movie=m2,
|
||||||
user_id=u1.id,
|
user_id=u1.id,
|
||||||
user=u1,
|
user=u1,
|
||||||
score=66,
|
score=66,
|
||||||
rating_date=datetime.now(),
|
rating_date=datetime.now(),
|
||||||
)
|
)
|
||||||
await db.add(r1)
|
await db.add(conn, r1)
|
||||||
|
|
||||||
# ---
|
# ---
|
||||||
|
|
||||||
movie_ids = [m1.id]
|
movie_ids = [m1.id]
|
||||||
user_ids = []
|
user_ids = []
|
||||||
assert tuple() == tuple(
|
assert tuple() == tuple(
|
||||||
await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids)
|
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
||||||
)
|
)
|
||||||
|
|
||||||
movie_ids = [m2.id]
|
movie_ids = [m2.id]
|
||||||
user_ids = []
|
user_ids = []
|
||||||
assert (r1,) == tuple(
|
assert (r1,) == tuple(
|
||||||
await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids)
|
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
||||||
)
|
)
|
||||||
|
|
||||||
movie_ids = [m2.id]
|
movie_ids = [m2.id]
|
||||||
user_ids = [u2.id]
|
user_ids = [u2.id]
|
||||||
assert tuple() == tuple(
|
assert tuple() == tuple(
|
||||||
await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids)
|
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
||||||
)
|
)
|
||||||
|
|
||||||
movie_ids = [m2.id]
|
movie_ids = [m2.id]
|
||||||
user_ids = [u1.id]
|
user_ids = [u1.id]
|
||||||
assert (r1,) == tuple(
|
assert (r1,) == tuple(
|
||||||
await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids)
|
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
||||||
)
|
)
|
||||||
|
|
||||||
movie_ids = [m1.id, m2.id]
|
movie_ids = [m1.id, m2.id]
|
||||||
user_ids = [u1.id, u2.id]
|
user_ids = [u1.id, u2.id]
|
||||||
assert (r1,) == tuple(
|
assert (r1,) == tuple(
|
||||||
await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids)
|
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_find_movies(shared_conn: db.Database):
|
async def test_find_movies(conn: db.Connection):
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
m1 = a_movie(title="movie one")
|
||||||
m1 = a_movie(title="movie one")
|
await db.add(conn, m1)
|
||||||
await db.add(m1)
|
|
||||||
|
|
||||||
m2 = a_movie(title="movie two", imdb_score=33, release_year=m1.release_year + 1)
|
m2 = a_movie(title="movie two", imdb_score=33, release_year=m1.release_year + 1)
|
||||||
await db.add(m2)
|
await db.add(conn, m2)
|
||||||
|
|
||||||
u1 = models.User(
|
u1 = models.User(
|
||||||
imdb_id="u00001",
|
imdb_id="u00001",
|
||||||
name="User1",
|
name="User1",
|
||||||
secret="secret1",
|
secret="secret1",
|
||||||
)
|
)
|
||||||
await db.add(u1)
|
await db.add(conn, u1)
|
||||||
|
|
||||||
u2 = models.User(
|
u2 = models.User(
|
||||||
imdb_id="u00002",
|
imdb_id="u00002",
|
||||||
name="User2",
|
name="User2",
|
||||||
secret="secret2",
|
secret="secret2",
|
||||||
)
|
)
|
||||||
await db.add(u2)
|
await db.add(conn, u2)
|
||||||
|
|
||||||
r1 = models.Rating(
|
r1 = models.Rating(
|
||||||
movie_id=m2.id,
|
movie_id=m2.id,
|
||||||
movie=m2,
|
movie=m2,
|
||||||
user_id=u1.id,
|
user_id=u1.id,
|
||||||
user=u1,
|
user=u1,
|
||||||
score=66,
|
score=66,
|
||||||
rating_date=datetime.now(),
|
rating_date=datetime.now(),
|
||||||
)
|
)
|
||||||
await db.add(r1)
|
await db.add(conn, r1)
|
||||||
|
|
||||||
# ---
|
# ---
|
||||||
|
|
||||||
assert () == tuple(await db.find_movies(title=m1.title, include_unrated=False))
|
assert () == tuple(
|
||||||
assert ((m1, []),) == tuple(
|
await db.find_movies(conn, title=m1.title, include_unrated=False)
|
||||||
await db.find_movies(title=m1.title, include_unrated=True)
|
)
|
||||||
)
|
assert ((m1, []),) == tuple(
|
||||||
|
await db.find_movies(conn, title=m1.title, include_unrated=True)
|
||||||
|
)
|
||||||
|
|
||||||
assert ((m1, []),) == tuple(
|
assert ((m1, []),) == tuple(
|
||||||
await db.find_movies(title="mo on", exact=False, include_unrated=True)
|
await db.find_movies(conn, title="mo on", exact=False, include_unrated=True)
|
||||||
)
|
)
|
||||||
assert ((m1, []),) == tuple(
|
assert ((m1, []),) == tuple(
|
||||||
await db.find_movies(title="movie one", exact=True, include_unrated=True)
|
await db.find_movies(conn, title="movie one", exact=True, include_unrated=True)
|
||||||
)
|
)
|
||||||
assert () == tuple(
|
assert () == tuple(
|
||||||
await db.find_movies(title="mo on", exact=True, include_unrated=True)
|
await db.find_movies(conn, title="mo on", exact=True, include_unrated=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert ((m2, []),) == tuple(
|
assert ((m2, []),) == tuple(
|
||||||
await db.find_movies(title="movie", exact=False, include_unrated=False)
|
await db.find_movies(conn, title="movie", exact=False, include_unrated=False)
|
||||||
)
|
)
|
||||||
assert ((m2, []), (m1, [])) == tuple(
|
assert ((m2, []), (m1, [])) == tuple(
|
||||||
await db.find_movies(title="movie", exact=False, include_unrated=True)
|
await db.find_movies(conn, title="movie", exact=False, include_unrated=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert ((m1, []),) == tuple(
|
assert ((m1, []),) == tuple(
|
||||||
await db.find_movies(include_unrated=True, yearcomp=("=", m1.release_year))
|
await db.find_movies(
|
||||||
|
conn, include_unrated=True, yearcomp=("=", m1.release_year)
|
||||||
)
|
)
|
||||||
assert ((m2, []),) == tuple(
|
)
|
||||||
await db.find_movies(include_unrated=True, yearcomp=("=", m2.release_year))
|
assert ((m2, []),) == tuple(
|
||||||
|
await db.find_movies(
|
||||||
|
conn, include_unrated=True, yearcomp=("=", m2.release_year)
|
||||||
)
|
)
|
||||||
assert ((m1, []),) == tuple(
|
)
|
||||||
await db.find_movies(include_unrated=True, yearcomp=("<", m2.release_year))
|
assert ((m1, []),) == tuple(
|
||||||
|
await db.find_movies(
|
||||||
|
conn, include_unrated=True, yearcomp=("<", m2.release_year)
|
||||||
)
|
)
|
||||||
assert ((m2, []),) == tuple(
|
)
|
||||||
await db.find_movies(include_unrated=True, yearcomp=(">", m1.release_year))
|
assert ((m2, []),) == tuple(
|
||||||
|
await db.find_movies(
|
||||||
|
conn, include_unrated=True, yearcomp=(">", m1.release_year)
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
assert ((m2, []), (m1, [])) == tuple(await db.find_movies(include_unrated=True))
|
assert ((m2, []), (m1, [])) == tuple(
|
||||||
assert ((m2, []),) == tuple(
|
await db.find_movies(conn, include_unrated=True)
|
||||||
await db.find_movies(include_unrated=True, limit_rows=1)
|
)
|
||||||
)
|
assert ((m2, []),) == tuple(
|
||||||
assert ((m1, []),) == tuple(
|
await db.find_movies(conn, include_unrated=True, limit_rows=1)
|
||||||
await db.find_movies(include_unrated=True, skip_rows=1)
|
)
|
||||||
)
|
assert ((m1, []),) == tuple(
|
||||||
|
await db.find_movies(conn, include_unrated=True, skip_rows=1)
|
||||||
|
)
|
||||||
|
|
||||||
assert ((m2, [r1]), (m1, [])) == tuple(
|
assert ((m2, [r1]), (m1, [])) == tuple(
|
||||||
await db.find_movies(include_unrated=True, user_ids=[u1.id, u2.id])
|
await db.find_movies(conn, include_unrated=True, user_ids=[u1.id, u2.id])
|
||||||
)
|
)
|
||||||
|
|
|
||||||
11
tests/test_models.py
Normal file
11
tests/test_models.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from unwind import models
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("mapper", models.mapper_registry.mappers)
|
||||||
|
def test_fields(mapper):
|
||||||
|
"""Test that models.fields() matches exactly all table columns."""
|
||||||
|
dcfields = {f.name for f in models.fields(mapper.class_)}
|
||||||
|
mfields = {c.name for c in mapper.columns}
|
||||||
|
assert dcfields == mfields
|
||||||
|
|
@ -34,7 +34,7 @@ def admin_client() -> TestClient:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_ratings_for_group(
|
async def test_get_ratings_for_group(
|
||||||
shared_conn: db.Database, unauthorized_client: TestClient
|
conn: db.Connection, unauthorized_client: TestClient
|
||||||
):
|
):
|
||||||
user = models.User(
|
user = models.User(
|
||||||
imdb_id="ur12345678",
|
imdb_id="ur12345678",
|
||||||
|
|
@ -48,201 +48,196 @@ async def test_get_ratings_for_group(
|
||||||
)
|
)
|
||||||
user.groups = [models.UserGroup(id=str(group.id), access="r")]
|
user.groups = [models.UserGroup(id=str(group.id), access="r")]
|
||||||
path = app.url_path_for("get_ratings_for_group", group_id=str(group.id))
|
path = app.url_path_for("get_ratings_for_group", group_id=str(group.id))
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
|
||||||
resp = unauthorized_client.get(path)
|
|
||||||
assert resp.status_code == 404, "Group does not exist (yet)"
|
|
||||||
|
|
||||||
await db.add(user)
|
resp = unauthorized_client.get(path)
|
||||||
await db.add(group)
|
assert resp.status_code == 404, "Group does not exist (yet)"
|
||||||
|
|
||||||
resp = unauthorized_client.get(path)
|
await db.add(conn, user)
|
||||||
assert resp.status_code == 200
|
await db.add(conn, group)
|
||||||
assert resp.json() == []
|
|
||||||
|
|
||||||
movie = models.Movie(
|
resp = unauthorized_client.get(path)
|
||||||
title="test movie",
|
assert resp.status_code == 200
|
||||||
release_year=2013,
|
assert resp.json() == []
|
||||||
media_type="Movie",
|
|
||||||
imdb_id="tt12345678",
|
|
||||||
genres={"genre-1"},
|
|
||||||
)
|
|
||||||
await db.add(movie)
|
|
||||||
|
|
||||||
rating = models.Rating(
|
movie = models.Movie(
|
||||||
movie_id=movie.id, user_id=user.id, score=66, rating_date=datetime.now()
|
title="test movie",
|
||||||
)
|
release_year=2013,
|
||||||
await db.add(rating)
|
media_type="Movie",
|
||||||
|
imdb_id="tt12345678",
|
||||||
|
genres={"genre-1"},
|
||||||
|
)
|
||||||
|
await db.add(conn, movie)
|
||||||
|
|
||||||
rating_aggregate = {
|
rating = models.Rating(
|
||||||
"canonical_title": movie.title,
|
movie_id=movie.id, user_id=user.id, score=66, rating_date=datetime.now()
|
||||||
"imdb_score": movie.imdb_score,
|
)
|
||||||
"imdb_votes": movie.imdb_votes,
|
await db.add(conn, rating)
|
||||||
"link": imdb.movie_url(movie.imdb_id),
|
|
||||||
"media_type": movie.media_type,
|
|
||||||
"original_title": movie.original_title,
|
|
||||||
"user_scores": [rating.score],
|
|
||||||
"year": movie.release_year,
|
|
||||||
}
|
|
||||||
|
|
||||||
resp = unauthorized_client.get(path)
|
rating_aggregate = {
|
||||||
|
"canonical_title": movie.title,
|
||||||
|
"imdb_score": movie.imdb_score,
|
||||||
|
"imdb_votes": movie.imdb_votes,
|
||||||
|
"link": imdb.movie_url(movie.imdb_id),
|
||||||
|
"media_type": movie.media_type,
|
||||||
|
"original_title": movie.original_title,
|
||||||
|
"user_scores": [rating.score],
|
||||||
|
"year": movie.release_year,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = unauthorized_client.get(path)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == [rating_aggregate]
|
||||||
|
|
||||||
|
filters = {
|
||||||
|
"imdb_id": movie.imdb_id,
|
||||||
|
"unwind_id": str(movie.id),
|
||||||
|
"title": movie.title,
|
||||||
|
"media_type": movie.media_type,
|
||||||
|
"year": movie.release_year,
|
||||||
|
}
|
||||||
|
for k, v in filters.items():
|
||||||
|
resp = unauthorized_client.get(path, params={k: v})
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert resp.json() == [rating_aggregate]
|
assert resp.json() == [rating_aggregate]
|
||||||
|
|
||||||
filters = {
|
resp = unauthorized_client.get(path, params={"title": "no such thing"})
|
||||||
"imdb_id": movie.imdb_id,
|
assert resp.status_code == 200
|
||||||
"unwind_id": str(movie.id),
|
assert resp.json() == []
|
||||||
"title": movie.title,
|
|
||||||
"media_type": movie.media_type,
|
|
||||||
"year": movie.release_year,
|
|
||||||
}
|
|
||||||
for k, v in filters.items():
|
|
||||||
resp = unauthorized_client.get(path, params={k: v})
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert resp.json() == [rating_aggregate]
|
|
||||||
|
|
||||||
resp = unauthorized_client.get(path, params={"title": "no such thing"})
|
# Test "exact" query param.
|
||||||
assert resp.status_code == 200
|
resp = unauthorized_client.get(
|
||||||
assert resp.json() == []
|
path, params={"title": "test movie", "exact": "true"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == [rating_aggregate]
|
||||||
|
resp = unauthorized_client.get(path, params={"title": "te mo", "exact": "false"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == [rating_aggregate]
|
||||||
|
resp = unauthorized_client.get(path, params={"title": "te mo", "exact": "true"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
# Test "exact" query param.
|
# XXX Test "ignore_tv_episodes" query param.
|
||||||
resp = unauthorized_client.get(
|
# XXX Test "include_unrated" query param.
|
||||||
path, params={"title": "test movie", "exact": "true"}
|
# XXX Test "per_page" query param.
|
||||||
)
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert resp.json() == [rating_aggregate]
|
|
||||||
resp = unauthorized_client.get(
|
|
||||||
path, params={"title": "te mo", "exact": "false"}
|
|
||||||
)
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert resp.json() == [rating_aggregate]
|
|
||||||
resp = unauthorized_client.get(path, params={"title": "te mo", "exact": "true"})
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert resp.json() == []
|
|
||||||
|
|
||||||
# XXX Test "ignore_tv_episodes" query param.
|
|
||||||
# XXX Test "include_unrated" query param.
|
|
||||||
# XXX Test "per_page" query param.
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_movies(
|
async def test_list_movies(
|
||||||
shared_conn: db.Database,
|
conn: db.Connection,
|
||||||
unauthorized_client: TestClient,
|
unauthorized_client: TestClient,
|
||||||
authorized_client: TestClient,
|
authorized_client: TestClient,
|
||||||
):
|
):
|
||||||
path = app.url_path_for("list_movies")
|
path = app.url_path_for("list_movies")
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
response = unauthorized_client.get(path)
|
||||||
response = unauthorized_client.get(path)
|
assert response.status_code == 403
|
||||||
assert response.status_code == 403
|
|
||||||
|
|
||||||
response = authorized_client.get(path)
|
response = authorized_client.get(path)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == []
|
assert response.json() == []
|
||||||
|
|
||||||
m = models.Movie(
|
m = models.Movie(
|
||||||
title="test movie",
|
title="test movie",
|
||||||
release_year=2013,
|
release_year=2013,
|
||||||
media_type="Movie",
|
media_type="Movie",
|
||||||
imdb_id="tt12345678",
|
imdb_id="tt12345678",
|
||||||
genres={"genre-1"},
|
genres={"genre-1"},
|
||||||
)
|
)
|
||||||
await db.add(m)
|
await db.add(conn, m)
|
||||||
|
|
||||||
response = authorized_client.get(path, params={"include_unrated": 1})
|
response = authorized_client.get(path, params={"include_unrated": 1})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == [{**models.asplain(m), "user_scores": []}]
|
assert response.json() == [{**models.asplain(m), "user_scores": []}]
|
||||||
|
|
||||||
m_plain = {
|
m_plain = {
|
||||||
"canonical_title": m.title,
|
"canonical_title": m.title,
|
||||||
"imdb_score": m.imdb_score,
|
"imdb_score": m.imdb_score,
|
||||||
"imdb_votes": m.imdb_votes,
|
"imdb_votes": m.imdb_votes,
|
||||||
"link": imdb.movie_url(m.imdb_id),
|
"link": imdb.movie_url(m.imdb_id),
|
||||||
"media_type": m.media_type,
|
"media_type": m.media_type,
|
||||||
"original_title": m.original_title,
|
"original_title": m.original_title,
|
||||||
"user_scores": [],
|
"user_scores": [],
|
||||||
"year": m.release_year,
|
"year": m.release_year,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = authorized_client.get(path, params={"imdb_id": m.imdb_id})
|
response = authorized_client.get(path, params={"imdb_id": m.imdb_id})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == [m_plain]
|
assert response.json() == [m_plain]
|
||||||
|
|
||||||
response = authorized_client.get(path, params={"unwind_id": str(m.id)})
|
response = authorized_client.get(path, params={"unwind_id": str(m.id)})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == [m_plain]
|
assert response.json() == [m_plain]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_users(
|
async def test_list_users(
|
||||||
shared_conn: db.Database,
|
conn: db.Connection,
|
||||||
unauthorized_client: TestClient,
|
unauthorized_client: TestClient,
|
||||||
authorized_client: TestClient,
|
authorized_client: TestClient,
|
||||||
admin_client: TestClient,
|
admin_client: TestClient,
|
||||||
):
|
):
|
||||||
path = app.url_path_for("list_users")
|
path = app.url_path_for("list_users")
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
response = unauthorized_client.get(path)
|
||||||
response = unauthorized_client.get(path)
|
assert response.status_code == 403
|
||||||
assert response.status_code == 403
|
|
||||||
|
|
||||||
response = authorized_client.get(path)
|
response = authorized_client.get(path)
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
|
|
||||||
response = admin_client.get(path)
|
response = admin_client.get(path)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == []
|
assert response.json() == []
|
||||||
|
|
||||||
m = models.User(
|
m = models.User(
|
||||||
imdb_id="ur12345678",
|
imdb_id="ur12345678",
|
||||||
name="user-1",
|
name="user-1",
|
||||||
secret="secret-1",
|
secret="secret-1",
|
||||||
groups=[],
|
groups=[],
|
||||||
)
|
)
|
||||||
await db.add(m)
|
await db.add(conn, m)
|
||||||
|
|
||||||
m_plain = {
|
m_plain = {
|
||||||
"groups": m.groups,
|
"groups": m.groups,
|
||||||
"id": m.id,
|
"id": m.id,
|
||||||
"imdb_id": m.imdb_id,
|
"imdb_id": m.imdb_id,
|
||||||
"name": m.name,
|
"name": m.name,
|
||||||
"secret": m.secret,
|
"secret": m.secret,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = admin_client.get(path)
|
response = admin_client.get(path)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == [m_plain]
|
assert response.json() == [m_plain]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_groups(
|
async def test_list_groups(
|
||||||
shared_conn: db.Database,
|
conn: db.Connection,
|
||||||
unauthorized_client: TestClient,
|
unauthorized_client: TestClient,
|
||||||
authorized_client: TestClient,
|
authorized_client: TestClient,
|
||||||
admin_client: TestClient,
|
admin_client: TestClient,
|
||||||
):
|
):
|
||||||
path = app.url_path_for("list_groups")
|
path = app.url_path_for("list_groups")
|
||||||
async with shared_conn.transaction(force_rollback=True):
|
response = unauthorized_client.get(path)
|
||||||
response = unauthorized_client.get(path)
|
assert response.status_code == 403
|
||||||
assert response.status_code == 403
|
|
||||||
|
|
||||||
response = authorized_client.get(path)
|
response = authorized_client.get(path)
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
|
|
||||||
response = admin_client.get(path)
|
response = admin_client.get(path)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == []
|
assert response.json() == []
|
||||||
|
|
||||||
m = models.Group(
|
m = models.Group(
|
||||||
name="group-1",
|
name="group-1",
|
||||||
users=[models.GroupUser(id="123", name="itsa-me")],
|
users=[models.GroupUser(id="123", name="itsa-me")],
|
||||||
)
|
)
|
||||||
await db.add(m)
|
await db.add(conn, m)
|
||||||
|
|
||||||
m_plain = {
|
m_plain = {
|
||||||
"users": m.users,
|
"users": m.users,
|
||||||
"id": m.id,
|
"id": m.id,
|
||||||
"name": m.name,
|
"name": m.name,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = admin_client.get(path)
|
response = admin_client.get(path)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == [m_plain]
|
assert response.json() == [m_plain]
|
||||||
|
|
|
||||||
305
unwind/db.py
305
unwind/db.py
|
|
@ -1,13 +1,11 @@
|
||||||
import asyncio
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import threading
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, AsyncGenerator, Iterable, Literal, Type, TypeVar
|
from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type, TypeVar
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from databases import Database
|
|
||||||
from sqlalchemy.dialects.sqlite import insert
|
from sqlalchemy.dialects.sqlite import insert
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
|
||||||
|
|
||||||
from . import config
|
from . import config
|
||||||
from .models import (
|
from .models import (
|
||||||
|
|
@ -31,7 +29,9 @@ from .types import ULID
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
_database: Database | None = None
|
_engine: AsyncEngine | None = None
|
||||||
|
|
||||||
|
type Connection = AsyncConnection
|
||||||
|
|
||||||
|
|
||||||
async def open_connection_pool() -> None:
|
async def open_connection_pool() -> None:
|
||||||
|
|
@ -39,12 +39,13 @@ async def open_connection_pool() -> None:
|
||||||
|
|
||||||
This function needs to be called before any access to the database can happen.
|
This function needs to be called before any access to the database can happen.
|
||||||
"""
|
"""
|
||||||
db = _shared_connection()
|
async with transaction() as conn:
|
||||||
await db.connect()
|
await conn.execute(sa.text("PRAGMA journal_mode=WAL"))
|
||||||
|
|
||||||
await db.execute(sa.text("PRAGMA journal_mode=WAL"))
|
await conn.run_sync(metadata.create_all, tables=[db_patches])
|
||||||
|
|
||||||
await apply_db_patches(db)
|
async with new_connection() as conn:
|
||||||
|
await apply_db_patches(conn)
|
||||||
|
|
||||||
|
|
||||||
async def close_connection_pool() -> None:
|
async def close_connection_pool() -> None:
|
||||||
|
|
@ -53,32 +54,33 @@ async def close_connection_pool() -> None:
|
||||||
This function should be called before the app shuts down to ensure all data
|
This function should be called before the app shuts down to ensure all data
|
||||||
has been flushed to the database.
|
has been flushed to the database.
|
||||||
"""
|
"""
|
||||||
db = _shared_connection()
|
engine = _shared_engine()
|
||||||
|
|
||||||
# Run automatic ANALYZE prior to closing the db,
|
async with engine.begin() as conn:
|
||||||
# see https://sqlite.com/lang_analyze.html.
|
# Run automatic ANALYZE prior to closing the db,
|
||||||
await db.execute(sa.text("PRAGMA analysis_limit=400"))
|
# see https://sqlite.com/lang_analyze.html.
|
||||||
await db.execute(sa.text("PRAGMA optimize"))
|
await conn.execute(sa.text("PRAGMA analysis_limit=400"))
|
||||||
|
await conn.execute(sa.text("PRAGMA optimize"))
|
||||||
|
|
||||||
await db.disconnect()
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
async def current_patch_level(db: Database) -> str:
|
async def current_patch_level(conn: Connection, /) -> str:
|
||||||
query = sa.select(db_patches.c.current)
|
query = sa.select(db_patches.c.current)
|
||||||
current = await db.fetch_val(query)
|
current = await conn.scalar(query)
|
||||||
return current or ""
|
return current or ""
|
||||||
|
|
||||||
|
|
||||||
async def set_current_patch_level(db: Database, current: str) -> None:
|
async def set_current_patch_level(conn: Connection, /, current: str) -> None:
|
||||||
stmt = insert(db_patches).values(id=1, current=current)
|
stmt = insert(db_patches).values(id=1, current=current)
|
||||||
stmt = stmt.on_conflict_do_update(set_={"current": stmt.excluded.current})
|
stmt = stmt.on_conflict_do_update(set_={"current": stmt.excluded.current})
|
||||||
await db.execute(stmt)
|
await conn.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
db_patches_dir = Path(__file__).parent / "sql"
|
db_patches_dir = Path(__file__).parent / "sql"
|
||||||
|
|
||||||
|
|
||||||
async def apply_db_patches(db: Database) -> None:
|
async def apply_db_patches(conn: Connection, /) -> None:
|
||||||
"""Apply all remaining patches to the database.
|
"""Apply all remaining patches to the database.
|
||||||
|
|
||||||
Beware that patches will be applied in lexicographical order,
|
Beware that patches will be applied in lexicographical order,
|
||||||
|
|
@ -90,7 +92,7 @@ async def apply_db_patches(db: Database) -> None:
|
||||||
using two consecutive semi-colons (;).
|
using two consecutive semi-colons (;).
|
||||||
Failing to do so will result in an error.
|
Failing to do so will result in an error.
|
||||||
"""
|
"""
|
||||||
applied_lvl = await current_patch_level(db)
|
applied_lvl = await current_patch_level(conn)
|
||||||
|
|
||||||
did_patch = False
|
did_patch = False
|
||||||
|
|
||||||
|
|
@ -109,31 +111,52 @@ async def apply_db_patches(db: Database) -> None:
|
||||||
)
|
)
|
||||||
raise RuntimeError("No statement found.")
|
raise RuntimeError("No statement found.")
|
||||||
|
|
||||||
async with db.transaction():
|
async with transacted(conn):
|
||||||
for query in queries:
|
for query in queries:
|
||||||
await db.execute(sa.text(query))
|
await conn.execute(sa.text(query))
|
||||||
|
|
||||||
await set_current_patch_level(db, patch_lvl)
|
await set_current_patch_level(conn, patch_lvl)
|
||||||
|
|
||||||
did_patch = True
|
did_patch = True
|
||||||
|
|
||||||
if did_patch:
|
if did_patch:
|
||||||
await db.execute(sa.text("vacuum"))
|
await _vacuum(conn)
|
||||||
|
|
||||||
|
|
||||||
async def get_import_progress() -> Progress | None:
|
async def _vacuum(conn: Connection, /) -> None:
|
||||||
|
"""Vacuum the database.
|
||||||
|
|
||||||
|
This function cannot be run on a connection with an open transaction.
|
||||||
|
"""
|
||||||
|
# With SQLAlchemy's "autobegin" behavior we need to switch the connection
|
||||||
|
# to "autocommit" first to keep it from automatically starting a transaction,
|
||||||
|
# as VACUUM cannot be run inside a transaction for most databases.
|
||||||
|
await conn.commit()
|
||||||
|
isolation_level = await conn.get_isolation_level()
|
||||||
|
log.debug("Previous isolation_level: %a", isolation_level)
|
||||||
|
await conn.execution_options(isolation_level="AUTOCOMMIT")
|
||||||
|
try:
|
||||||
|
await conn.execute(sa.text("vacuum"))
|
||||||
|
await conn.commit()
|
||||||
|
finally:
|
||||||
|
await conn.execution_options(isolation_level=isolation_level)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_import_progress(conn: Connection, /) -> Progress | None:
|
||||||
"""Return the latest import progress."""
|
"""Return the latest import progress."""
|
||||||
return await get(
|
return await get(
|
||||||
Progress, type="import-imdb-movies", order_by=(progress.c.started, "desc")
|
conn, Progress, type="import-imdb-movies", order_by=(progress.c.started, "desc")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def stop_import_progress(*, error: BaseException | None = None) -> None:
|
async def stop_import_progress(
|
||||||
|
conn: Connection, /, *, error: BaseException | None = None
|
||||||
|
) -> None:
|
||||||
"""Stop the current import.
|
"""Stop the current import.
|
||||||
|
|
||||||
If an error is given, it will be logged to the progress state.
|
If an error is given, it will be logged to the progress state.
|
||||||
"""
|
"""
|
||||||
current = await get_import_progress()
|
current = await get_import_progress(conn)
|
||||||
is_running = current and current.stopped is None
|
is_running = current and current.stopped is None
|
||||||
|
|
||||||
if not is_running:
|
if not is_running:
|
||||||
|
|
@ -144,17 +167,17 @@ async def stop_import_progress(*, error: BaseException | None = None) -> None:
|
||||||
current.error = repr(error)
|
current.error = repr(error)
|
||||||
current.stopped = utcnow().isoformat()
|
current.stopped = utcnow().isoformat()
|
||||||
|
|
||||||
await update(current)
|
await update(conn, current)
|
||||||
|
|
||||||
|
|
||||||
async def set_import_progress(progress: float) -> Progress:
|
async def set_import_progress(conn: Connection, /, progress: float) -> Progress:
|
||||||
"""Set the current import progress percentage.
|
"""Set the current import progress percentage.
|
||||||
|
|
||||||
If no import is currently running, this will create a new one.
|
If no import is currently running, this will create a new one.
|
||||||
"""
|
"""
|
||||||
progress = min(max(0.0, progress), 100.0) # clamp to 0 <= progress <= 100
|
progress = min(max(0.0, progress), 100.0) # clamp to 0 <= progress <= 100
|
||||||
|
|
||||||
current = await get_import_progress()
|
current = await get_import_progress(conn)
|
||||||
is_running = current and current.stopped is None
|
is_running = current and current.stopped is None
|
||||||
|
|
||||||
if not is_running:
|
if not is_running:
|
||||||
|
|
@ -164,71 +187,88 @@ async def set_import_progress(progress: float) -> Progress:
|
||||||
current.percent = progress
|
current.percent = progress
|
||||||
|
|
||||||
if is_running:
|
if is_running:
|
||||||
await update(current)
|
await update(conn, current)
|
||||||
else:
|
else:
|
||||||
await add(current)
|
await add(conn, current)
|
||||||
|
|
||||||
return current
|
return current
|
||||||
|
|
||||||
|
|
||||||
_lock = threading.Lock()
|
def _new_engine() -> AsyncEngine:
|
||||||
_prelock = threading.Lock()
|
uri = f"sqlite+aiosqlite:///{config.storage_path}"
|
||||||
|
|
||||||
|
return create_async_engine(
|
||||||
|
uri,
|
||||||
|
isolation_level="SERIALIZABLE",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _shared_engine() -> AsyncEngine:
|
||||||
|
global _engine
|
||||||
|
|
||||||
|
if _engine is None:
|
||||||
|
_engine = _new_engine()
|
||||||
|
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
|
||||||
|
def _new_connection() -> Connection:
|
||||||
|
return _shared_engine().connect()
|
||||||
|
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def single_threaded():
|
async def transaction(
|
||||||
"""Ensure the nested code is run only by a single thread at a time."""
|
*, force_rollback: bool = False
|
||||||
wait = 1e-5 # XXX not sure if there's a better magic value here
|
) -> AsyncGenerator[Connection, None]:
|
||||||
|
async with new_connection() as conn:
|
||||||
|
yield conn
|
||||||
|
|
||||||
# The pre-lock (a lock for the lock) allows for multiple threads to hand of
|
if not force_rollback:
|
||||||
# the main lock.
|
await conn.commit()
|
||||||
# With only a single lock the contending thread will spend most of its time
|
|
||||||
# in the asyncio.sleep and the reigning thread will have time to finish
|
|
||||||
# whatever it's doing and simply acquire the lock again before the other
|
|
||||||
# thread has had a change to try.
|
|
||||||
# By having another lock (and the same sleep time!) the contending thread
|
|
||||||
# will always have a chance to acquire the main lock.
|
|
||||||
while not _prelock.acquire(blocking=False):
|
|
||||||
await asyncio.sleep(wait)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while not _lock.acquire(blocking=False):
|
|
||||||
await asyncio.sleep(wait)
|
|
||||||
finally:
|
|
||||||
_prelock.release()
|
|
||||||
|
|
||||||
try:
|
# The _test_connection allows pinning a connection that will be shared across the app.
|
||||||
yield
|
# This can (and should only) be used when running tests, NOT IN PRODUCTION!
|
||||||
|
_test_connection: Connection | None = None
|
||||||
finally:
|
|
||||||
_lock.release()
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def _locked_connection():
|
async def new_connection() -> AsyncGenerator[Connection, None]:
|
||||||
async with single_threaded():
|
"""Return a new connection.
|
||||||
yield _shared_connection()
|
|
||||||
|
Any changes will be rolled back, unless `.commit()` is called on the
|
||||||
|
connection.
|
||||||
|
|
||||||
|
If you want to commit changes, consider using `transaction()` instead.
|
||||||
|
"""
|
||||||
|
conn = _test_connection or _new_connection()
|
||||||
|
|
||||||
|
# Support reusing the same connection for _test_connection.
|
||||||
|
is_started = conn.sync_connection is not None
|
||||||
|
if is_started:
|
||||||
|
yield conn
|
||||||
|
return
|
||||||
|
|
||||||
|
async with conn:
|
||||||
|
yield conn
|
||||||
|
|
||||||
|
|
||||||
def _shared_connection() -> Database:
|
@contextlib.asynccontextmanager
|
||||||
global _database
|
async def transacted(
|
||||||
|
conn: Connection, /, *, force_rollback: bool = False
|
||||||
|
) -> AsyncGenerator[None, None]:
|
||||||
|
transaction = contextlib.nullcontext() if conn.in_transaction() else conn.begin()
|
||||||
|
|
||||||
if _database is None:
|
async with transaction:
|
||||||
uri = f"sqlite:///{config.storage_path}"
|
try:
|
||||||
# uri = f"sqlite+aiosqlite:///{config.storage_path}"
|
yield
|
||||||
_database = Database(uri)
|
|
||||||
|
|
||||||
engine = sa.create_engine(uri, future=True)
|
finally:
|
||||||
metadata.create_all(engine, tables=[db_patches])
|
if force_rollback:
|
||||||
|
await conn.rollback()
|
||||||
return _database
|
|
||||||
|
|
||||||
|
|
||||||
def transaction():
|
async def add(conn: Connection, /, item: Model) -> None:
|
||||||
return _shared_connection().transaction()
|
|
||||||
|
|
||||||
|
|
||||||
async def add(item: Model) -> None:
|
|
||||||
# Support late initializing - used for optimization.
|
# Support late initializing - used for optimization.
|
||||||
if getattr(item, "_is_lazy", False):
|
if getattr(item, "_is_lazy", False):
|
||||||
assert hasattr(item, "_lazy_init")
|
assert hasattr(item, "_lazy_init")
|
||||||
|
|
@ -237,14 +277,29 @@ async def add(item: Model) -> None:
|
||||||
table: sa.Table = item.__table__
|
table: sa.Table = item.__table__
|
||||||
values = asplain(item, serialize=True)
|
values = asplain(item, serialize=True)
|
||||||
stmt = table.insert().values(values)
|
stmt = table.insert().values(values)
|
||||||
async with _locked_connection() as conn:
|
await conn.execute(stmt)
|
||||||
await conn.execute(stmt)
|
|
||||||
|
|
||||||
|
async def fetch_all(
|
||||||
|
conn: Connection, /, query: sa.Executable, values: "dict | None" = None
|
||||||
|
) -> Sequence[sa.Row]:
|
||||||
|
result = await conn.execute(query, values)
|
||||||
|
return result.all()
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_one(
|
||||||
|
conn: Connection, /, query: sa.Executable, values: "dict | None" = None
|
||||||
|
) -> sa.Row | None:
|
||||||
|
result = await conn.execute(query, values)
|
||||||
|
return result.first()
|
||||||
|
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=Model)
|
ModelType = TypeVar("ModelType", bound=Model)
|
||||||
|
|
||||||
|
|
||||||
async def get(
|
async def get(
|
||||||
|
conn: Connection,
|
||||||
|
/,
|
||||||
model: Type[ModelType],
|
model: Type[ModelType],
|
||||||
*,
|
*,
|
||||||
order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None,
|
order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None,
|
||||||
|
|
@ -268,13 +323,12 @@ async def get(
|
||||||
query = query.order_by(
|
query = query.order_by(
|
||||||
order_col.asc() if order_dir == "asc" else order_col.desc()
|
order_col.asc() if order_dir == "asc" else order_col.desc()
|
||||||
)
|
)
|
||||||
async with _locked_connection() as conn:
|
row = await fetch_one(conn, query)
|
||||||
row = await conn.fetch_one(query)
|
|
||||||
return fromplain(model, row._mapping, serialized=True) if row else None
|
return fromplain(model, row._mapping, serialized=True) if row else None
|
||||||
|
|
||||||
|
|
||||||
async def get_many(
|
async def get_many(
|
||||||
model: Type[ModelType], **field_sets: set | list
|
conn: Connection, /, model: Type[ModelType], **field_sets: set | list
|
||||||
) -> Iterable[ModelType]:
|
) -> Iterable[ModelType]:
|
||||||
"""Return the items with any values matching all given field sets.
|
"""Return the items with any values matching all given field sets.
|
||||||
|
|
||||||
|
|
@ -288,12 +342,13 @@ async def get_many(
|
||||||
|
|
||||||
table: sa.Table = model.__table__
|
table: sa.Table = model.__table__
|
||||||
query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items()))
|
query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items()))
|
||||||
async with _locked_connection() as conn:
|
rows = await fetch_all(conn, query)
|
||||||
rows = await conn.fetch_all(query)
|
|
||||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
||||||
async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]:
|
async def get_all(
|
||||||
|
conn: Connection, /, model: Type[ModelType], **field_values
|
||||||
|
) -> Iterable[ModelType]:
|
||||||
"""Filter all items by comparing all given field values.
|
"""Filter all items by comparing all given field values.
|
||||||
|
|
||||||
If no filters are given, all items will be returned.
|
If no filters are given, all items will be returned.
|
||||||
|
|
@ -302,12 +357,11 @@ async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]
|
||||||
query = sa.select(model).where(
|
query = sa.select(model).where(
|
||||||
*(table.c[k] == v for k, v in field_values.items() if v is not None)
|
*(table.c[k] == v for k, v in field_values.items() if v is not None)
|
||||||
)
|
)
|
||||||
async with _locked_connection() as conn:
|
rows = await fetch_all(conn, query)
|
||||||
rows = await conn.fetch_all(query)
|
|
||||||
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
return (fromplain(model, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
||||||
async def update(item: Model) -> None:
|
async def update(conn: Connection, /, item: Model) -> None:
|
||||||
# Support late initializing - used for optimization.
|
# Support late initializing - used for optimization.
|
||||||
if getattr(item, "_is_lazy", False):
|
if getattr(item, "_is_lazy", False):
|
||||||
assert hasattr(item, "_lazy_init")
|
assert hasattr(item, "_lazy_init")
|
||||||
|
|
@ -316,30 +370,28 @@ async def update(item: Model) -> None:
|
||||||
table: sa.Table = item.__table__
|
table: sa.Table = item.__table__
|
||||||
values = asplain(item, serialize=True)
|
values = asplain(item, serialize=True)
|
||||||
stmt = table.update().where(table.c.id == values["id"]).values(values)
|
stmt = table.update().where(table.c.id == values["id"]).values(values)
|
||||||
async with _locked_connection() as conn:
|
await conn.execute(stmt)
|
||||||
await conn.execute(stmt)
|
|
||||||
|
|
||||||
|
|
||||||
async def remove(item: Model) -> None:
|
async def remove(conn: Connection, /, item: Model) -> None:
|
||||||
table: sa.Table = item.__table__
|
table: sa.Table = item.__table__
|
||||||
values = asplain(item, filter_fields={"id"}, serialize=True)
|
values = asplain(item, filter_fields={"id"}, serialize=True)
|
||||||
stmt = table.delete().where(table.c.id == values["id"])
|
stmt = table.delete().where(table.c.id == values["id"])
|
||||||
async with _locked_connection() as conn:
|
await conn.execute(stmt)
|
||||||
await conn.execute(stmt)
|
|
||||||
|
|
||||||
|
|
||||||
async def add_or_update_user(user: User) -> None:
|
async def add_or_update_user(conn: Connection, /, user: User) -> None:
|
||||||
db_user = await get(User, imdb_id=user.imdb_id)
|
db_user = await get(conn, User, imdb_id=user.imdb_id)
|
||||||
if not db_user:
|
if not db_user:
|
||||||
await add(user)
|
await add(conn, user)
|
||||||
else:
|
else:
|
||||||
user.id = db_user.id
|
user.id = db_user.id
|
||||||
|
|
||||||
if user != db_user:
|
if user != db_user:
|
||||||
await update(user)
|
await update(conn, user)
|
||||||
|
|
||||||
|
|
||||||
async def add_or_update_many_movies(movies: list[Movie]) -> None:
|
async def add_or_update_many_movies(conn: Connection, /, movies: list[Movie]) -> None:
|
||||||
"""Add or update Movies in the database.
|
"""Add or update Movies in the database.
|
||||||
|
|
||||||
This is an optimized version of `add_or_update_movie` for the purpose
|
This is an optimized version of `add_or_update_movie` for the purpose
|
||||||
|
|
@ -348,12 +400,13 @@ async def add_or_update_many_movies(movies: list[Movie]) -> None:
|
||||||
# for movie in movies:
|
# for movie in movies:
|
||||||
# await add_or_update_movie(movie)
|
# await add_or_update_movie(movie)
|
||||||
db_movies = {
|
db_movies = {
|
||||||
m.imdb_id: m for m in await get_many(Movie, imdb_id=[m.imdb_id for m in movies])
|
m.imdb_id: m
|
||||||
|
for m in await get_many(conn, Movie, imdb_id=[m.imdb_id for m in movies])
|
||||||
}
|
}
|
||||||
for movie in movies:
|
for movie in movies:
|
||||||
# XXX optimize bulk add & update as well
|
# XXX optimize bulk add & update as well
|
||||||
if movie.imdb_id not in db_movies:
|
if movie.imdb_id not in db_movies:
|
||||||
await add(movie)
|
await add(conn, movie)
|
||||||
else:
|
else:
|
||||||
db_movie = db_movies[movie.imdb_id]
|
db_movie = db_movies[movie.imdb_id]
|
||||||
movie.id = db_movie.id
|
movie.id = db_movie.id
|
||||||
|
|
@ -366,10 +419,10 @@ async def add_or_update_many_movies(movies: list[Movie]) -> None:
|
||||||
if movie.updated <= db_movie.updated:
|
if movie.updated <= db_movie.updated:
|
||||||
return
|
return
|
||||||
|
|
||||||
await update(movie)
|
await update(conn, movie)
|
||||||
|
|
||||||
|
|
||||||
async def add_or_update_movie(movie: Movie) -> None:
|
async def add_or_update_movie(conn: Connection, /, movie: Movie) -> None:
|
||||||
"""Add or update a Movie in the database.
|
"""Add or update a Movie in the database.
|
||||||
|
|
||||||
This is an upsert operation, but it will also update the Movie you pass
|
This is an upsert operation, but it will also update the Movie you pass
|
||||||
|
|
@ -377,9 +430,9 @@ async def add_or_update_movie(movie: Movie) -> None:
|
||||||
set all optional values on your Movie that might be unset but exist in the
|
set all optional values on your Movie that might be unset but exist in the
|
||||||
database. It's a bidirectional sync.
|
database. It's a bidirectional sync.
|
||||||
"""
|
"""
|
||||||
db_movie = await get(Movie, imdb_id=movie.imdb_id)
|
db_movie = await get(conn, Movie, imdb_id=movie.imdb_id)
|
||||||
if not db_movie:
|
if not db_movie:
|
||||||
await add(movie)
|
await add(conn, movie)
|
||||||
else:
|
else:
|
||||||
movie.id = db_movie.id
|
movie.id = db_movie.id
|
||||||
|
|
||||||
|
|
@ -391,23 +444,23 @@ async def add_or_update_movie(movie: Movie) -> None:
|
||||||
if movie.updated <= db_movie.updated:
|
if movie.updated <= db_movie.updated:
|
||||||
return
|
return
|
||||||
|
|
||||||
await update(movie)
|
await update(conn, movie)
|
||||||
|
|
||||||
|
|
||||||
async def add_or_update_rating(rating: Rating) -> bool:
|
async def add_or_update_rating(conn: Connection, /, rating: Rating) -> bool:
|
||||||
db_rating = await get(
|
db_rating = await get(
|
||||||
Rating, movie_id=str(rating.movie_id), user_id=str(rating.user_id)
|
conn, Rating, movie_id=str(rating.movie_id), user_id=str(rating.user_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not db_rating:
|
if not db_rating:
|
||||||
await add(rating)
|
await add(conn, rating)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
rating.id = db_rating.id
|
rating.id = db_rating.id
|
||||||
|
|
||||||
if rating != db_rating:
|
if rating != db_rating:
|
||||||
await update(rating)
|
await update(conn, rating)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
@ -418,6 +471,8 @@ def sql_escape(s: str, char: str = "#") -> str:
|
||||||
|
|
||||||
|
|
||||||
async def find_ratings(
|
async def find_ratings(
|
||||||
|
conn: Connection,
|
||||||
|
/,
|
||||||
*,
|
*,
|
||||||
title: str | None = None,
|
title: str | None = None,
|
||||||
media_type: str | None = None,
|
media_type: str | None = None,
|
||||||
|
|
@ -475,9 +530,8 @@ async def find_ratings(
|
||||||
)
|
)
|
||||||
.limit(limit_rows)
|
.limit(limit_rows)
|
||||||
)
|
)
|
||||||
async with _locked_connection() as conn:
|
rating_rows: sa.CursorResult[Rating] = await conn.execute(query)
|
||||||
rating_rows: AsyncGenerator[Rating, None] = conn.iterate(query) # type: ignore
|
movie_ids = [r.movie_id for r in rating_rows]
|
||||||
movie_ids = [r.movie_id async for r in rating_rows]
|
|
||||||
|
|
||||||
if include_unrated and len(movie_ids) < limit_rows:
|
if include_unrated and len(movie_ids) < limit_rows:
|
||||||
query = (
|
query = (
|
||||||
|
|
@ -491,15 +545,17 @@ async def find_ratings(
|
||||||
)
|
)
|
||||||
.limit(limit_rows - len(movie_ids))
|
.limit(limit_rows - len(movie_ids))
|
||||||
)
|
)
|
||||||
async with _locked_connection() as conn:
|
movie_rows: sa.CursorResult[Movie] = await conn.execute(query)
|
||||||
movie_rows: AsyncGenerator[Movie, None] = conn.iterate(query) # type: ignore
|
movie_ids += [r.id for r in movie_rows]
|
||||||
movie_ids += [r.id async for r in movie_rows]
|
|
||||||
|
|
||||||
return await ratings_for_movie_ids(ids=movie_ids)
|
return await ratings_for_movie_ids(conn, ids=movie_ids)
|
||||||
|
|
||||||
|
|
||||||
async def ratings_for_movie_ids(
|
async def ratings_for_movie_ids(
|
||||||
ids: Iterable[ULID | str] = [], imdb_ids: Iterable[str] = []
|
conn: Connection,
|
||||||
|
/,
|
||||||
|
ids: Iterable[ULID | str] = [],
|
||||||
|
imdb_ids: Iterable[str] = [],
|
||||||
) -> Iterable[dict[str, Any]]:
|
) -> Iterable[dict[str, Any]]:
|
||||||
conds = []
|
conds = []
|
||||||
|
|
||||||
|
|
@ -527,13 +583,12 @@ async def ratings_for_movie_ids(
|
||||||
.outerjoin_from(movies, ratings, movies.c.id == ratings.c.movie_id)
|
.outerjoin_from(movies, ratings, movies.c.id == ratings.c.movie_id)
|
||||||
.where(sa.or_(*conds))
|
.where(sa.or_(*conds))
|
||||||
)
|
)
|
||||||
async with _locked_connection() as conn:
|
rows = await fetch_all(conn, query)
|
||||||
rows = await conn.fetch_all(query)
|
|
||||||
return tuple(dict(r._mapping) for r in rows)
|
return tuple(dict(r._mapping) for r in rows)
|
||||||
|
|
||||||
|
|
||||||
async def ratings_for_movies(
|
async def ratings_for_movies(
|
||||||
movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = []
|
conn: Connection, /, movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = []
|
||||||
) -> Iterable[Rating]:
|
) -> Iterable[Rating]:
|
||||||
conditions = [ratings.c.movie_id.in_(str(x) for x in movie_ids)]
|
conditions = [ratings.c.movie_id.in_(str(x) for x in movie_ids)]
|
||||||
|
|
||||||
|
|
@ -542,13 +597,14 @@ async def ratings_for_movies(
|
||||||
|
|
||||||
query = sa.select(ratings).where(*conditions)
|
query = sa.select(ratings).where(*conditions)
|
||||||
|
|
||||||
async with _locked_connection() as conn:
|
rows = await fetch_all(conn, query)
|
||||||
rows = await conn.fetch_all(query)
|
|
||||||
|
|
||||||
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
|
return (fromplain(Rating, row._mapping, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
||||||
async def find_movies(
|
async def find_movies(
|
||||||
|
conn: Connection,
|
||||||
|
/,
|
||||||
*,
|
*,
|
||||||
title: str | None = None,
|
title: str | None = None,
|
||||||
media_type: str | None = None,
|
media_type: str | None = None,
|
||||||
|
|
@ -606,15 +662,14 @@ async def find_movies(
|
||||||
.offset(skip_rows)
|
.offset(skip_rows)
|
||||||
)
|
)
|
||||||
|
|
||||||
async with _locked_connection() as conn:
|
rows = await fetch_all(conn, query)
|
||||||
rows = await conn.fetch_all(query)
|
|
||||||
|
|
||||||
movies_ = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
|
movies_ = [fromplain(Movie, row._mapping, serialized=True) for row in rows]
|
||||||
|
|
||||||
if not user_ids:
|
if not user_ids:
|
||||||
return ((m, []) for m in movies_)
|
return ((m, []) for m in movies_)
|
||||||
|
|
||||||
ratings = await ratings_for_movies((m.id for m in movies_), user_ids)
|
ratings = await ratings_for_movies(conn, (m.id for m in movies_), user_ids)
|
||||||
|
|
||||||
aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies_}
|
aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies_}
|
||||||
for rating in ratings:
|
for rating in ratings:
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,9 @@ async def refresh_user_ratings_from_imdb(stop_on_dupe: bool = True):
|
||||||
async with asession() as s:
|
async with asession() as s:
|
||||||
s.headers["Accept-Language"] = "en-US, en;q=0.5"
|
s.headers["Accept-Language"] = "en-US, en;q=0.5"
|
||||||
|
|
||||||
for user in await db.get_all(User):
|
async with db.new_connection() as conn:
|
||||||
|
users = list(await db.get_all(conn, User))
|
||||||
|
for user in users:
|
||||||
log.info("⚡️ Loading data for %s ...", user.name)
|
log.info("⚡️ Loading data for %s ...", user.name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -96,7 +98,7 @@ find_year = re.compile(
|
||||||
find_movie_id = re.compile(r"/title/(?P<id>tt\d+)/").search
|
find_movie_id = re.compile(r"/title/(?P<id>tt\d+)/").search
|
||||||
|
|
||||||
|
|
||||||
def movie_and_rating_from_item(item) -> tuple[Movie, Rating]:
|
def movie_and_rating_from_item(item: bs4.Tag) -> tuple[Movie, Rating]:
|
||||||
genres = (genre := item.find("span", "genre")) and genre.string or ""
|
genres = (genre := item.find("span", "genre")) and genre.string or ""
|
||||||
movie = Movie(
|
movie = Movie(
|
||||||
title=item.h3.a.string.strip(),
|
title=item.h3.a.string.strip(),
|
||||||
|
|
@ -161,9 +163,10 @@ async def parse_page(url: str) -> tuple[list[Rating], str | None]:
|
||||||
assert isinstance(meta, bs4.Tag)
|
assert isinstance(meta, bs4.Tag)
|
||||||
imdb_id = meta["content"]
|
imdb_id = meta["content"]
|
||||||
assert isinstance(imdb_id, str)
|
assert isinstance(imdb_id, str)
|
||||||
user = await db.get(User, imdb_id=imdb_id) or User(
|
async with db.new_connection() as conn:
|
||||||
imdb_id=imdb_id, name="", secret=""
|
user = await db.get(conn, User, imdb_id=imdb_id) or User(
|
||||||
)
|
imdb_id=imdb_id, name="", secret=""
|
||||||
|
)
|
||||||
|
|
||||||
if (headline := soup.h1) is None:
|
if (headline := soup.h1) is None:
|
||||||
raise RuntimeError("No headline found.")
|
raise RuntimeError("No headline found.")
|
||||||
|
|
@ -213,14 +216,15 @@ async def load_ratings(user_id: str):
|
||||||
for i, rating in enumerate(ratings):
|
for i, rating in enumerate(ratings):
|
||||||
assert rating.user and rating.movie
|
assert rating.user and rating.movie
|
||||||
|
|
||||||
if i == 0:
|
async with db.transaction() as conn:
|
||||||
# All rating objects share the same user.
|
if i == 0:
|
||||||
await db.add_or_update_user(rating.user)
|
# All rating objects share the same user.
|
||||||
rating.user_id = rating.user.id
|
await db.add_or_update_user(conn, rating.user)
|
||||||
|
rating.user_id = rating.user.id
|
||||||
|
|
||||||
await db.add_or_update_movie(rating.movie)
|
await db.add_or_update_movie(conn, rating.movie)
|
||||||
rating.movie_id = rating.movie.id
|
rating.movie_id = rating.movie.id
|
||||||
|
|
||||||
is_updated = await db.add_or_update_rating(rating)
|
is_updated = await db.add_or_update_rating(conn, rating)
|
||||||
|
|
||||||
yield rating, is_updated
|
yield rating, is_updated
|
||||||
|
|
|
||||||
|
|
@ -209,7 +209,8 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path):
|
||||||
for i, m in enumerate(read_basics(basics_path)):
|
for i, m in enumerate(read_basics(basics_path)):
|
||||||
perc = 100 * i / total
|
perc = 100 * i / total
|
||||||
if perc >= perc_next_report:
|
if perc >= perc_next_report:
|
||||||
await db.set_import_progress(perc)
|
async with db.transaction() as conn:
|
||||||
|
await db.set_import_progress(conn, perc)
|
||||||
log.info("⏳ Imported %s%%", round(perc, 1))
|
log.info("⏳ Imported %s%%", round(perc, 1))
|
||||||
perc_next_report += perc_step
|
perc_next_report += perc_step
|
||||||
|
|
||||||
|
|
@ -233,15 +234,18 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path):
|
||||||
chunk.append(m)
|
chunk.append(m)
|
||||||
|
|
||||||
if len(chunk) > 1000:
|
if len(chunk) > 1000:
|
||||||
await add_or_update_many_movies(chunk)
|
async with db.transaction() as conn:
|
||||||
|
await add_or_update_many_movies(conn, chunk)
|
||||||
chunk = []
|
chunk = []
|
||||||
|
|
||||||
if chunk:
|
if chunk:
|
||||||
await add_or_update_many_movies(chunk)
|
async with db.transaction() as conn:
|
||||||
|
await add_or_update_many_movies(conn, chunk)
|
||||||
chunk = []
|
chunk = []
|
||||||
|
|
||||||
log.info("👍 Imported 100%")
|
log.info("👍 Imported 100%")
|
||||||
await db.set_import_progress(100)
|
async with db.transaction() as conn:
|
||||||
|
await db.set_import_progress(conn, 100)
|
||||||
|
|
||||||
|
|
||||||
async def download_datasets(*, basics_path: Path, ratings_path: Path) -> None:
|
async def download_datasets(*, basics_path: Path, ratings_path: Path) -> None:
|
||||||
|
|
@ -270,7 +274,8 @@ async def load_from_web(*, force: bool = False) -> None:
|
||||||
See https://www.imdb.com/interfaces/ and https://datasets.imdbws.com/ for
|
See https://www.imdb.com/interfaces/ and https://datasets.imdbws.com/ for
|
||||||
more information on the IMDb database dumps.
|
more information on the IMDb database dumps.
|
||||||
"""
|
"""
|
||||||
await db.set_import_progress(0)
|
async with db.transaction() as conn:
|
||||||
|
await db.set_import_progress(conn, 0)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ratings_file = config.datadir / "imdb/title.ratings.tsv.gz"
|
ratings_file = config.datadir / "imdb/title.ratings.tsv.gz"
|
||||||
|
|
@ -290,8 +295,10 @@ async def load_from_web(*, force: bool = False) -> None:
|
||||||
await import_from_file(basics_path=basics_file, ratings_path=ratings_file)
|
await import_from_file(basics_path=basics_file, ratings_path=ratings_file)
|
||||||
|
|
||||||
except BaseException as err:
|
except BaseException as err:
|
||||||
await db.stop_import_progress(error=err)
|
async with db.transaction() as conn:
|
||||||
|
await db.stop_import_progress(conn, error=err)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await db.stop_import_progress()
|
async with db.transaction() as conn:
|
||||||
|
await db.stop_import_progress(conn)
|
||||||
|
|
|
||||||
|
|
@ -354,49 +354,6 @@ The contents of the Relation are ignored or discarded when using
|
||||||
Relation = Annotated[T | None, _RelationSentinel]
|
Relation = Annotated[T | None, _RelationSentinel]
|
||||||
|
|
||||||
|
|
||||||
@mapper_registry.mapped
|
|
||||||
@dataclass
|
|
||||||
class Rating:
|
|
||||||
__table__: ClassVar[Table] = Table(
|
|
||||||
"ratings",
|
|
||||||
metadata,
|
|
||||||
Column("id", String, primary_key=True), # ULID
|
|
||||||
Column("movie_id", ForeignKey("movies.id"), nullable=False), # ULID
|
|
||||||
Column("user_id", ForeignKey("users.id"), nullable=False), # ULID
|
|
||||||
Column("score", Integer, nullable=False),
|
|
||||||
Column("rating_date", String, nullable=False), # datetime
|
|
||||||
Column("favorite", Integer), # bool
|
|
||||||
Column("finished", Integer), # bool
|
|
||||||
)
|
|
||||||
|
|
||||||
id: ULID = field(default_factory=ULID)
|
|
||||||
|
|
||||||
movie_id: ULID = None
|
|
||||||
movie: Relation[Movie] = None
|
|
||||||
|
|
||||||
user_id: ULID = None
|
|
||||||
user: Relation["User"] = None
|
|
||||||
|
|
||||||
score: int = None # range: [0,100]
|
|
||||||
rating_date: datetime = None
|
|
||||||
favorite: bool | None = None
|
|
||||||
finished: bool | None = None
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
"""Return wether two Ratings are equal.
|
|
||||||
|
|
||||||
This operation compares all fields as expected, except that it
|
|
||||||
ignores any field marked as Relation.
|
|
||||||
"""
|
|
||||||
if type(other) is not type(self):
|
|
||||||
return False
|
|
||||||
return all(
|
|
||||||
getattr(self, f.name) == getattr(other, f.name) for f in fields(self)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
ratings = Rating.__table__
|
|
||||||
|
|
||||||
Access = Literal[
|
Access = Literal[
|
||||||
"r", # read
|
"r", # read
|
||||||
"i", # index
|
"i", # index
|
||||||
|
|
@ -442,6 +399,50 @@ class User:
|
||||||
self.groups.append({"id": group_id, "access": access})
|
self.groups.append({"id": group_id, "access": access})
|
||||||
|
|
||||||
|
|
||||||
|
@mapper_registry.mapped
|
||||||
|
@dataclass
|
||||||
|
class Rating:
|
||||||
|
__table__: ClassVar[Table] = Table(
|
||||||
|
"ratings",
|
||||||
|
metadata,
|
||||||
|
Column("id", String, primary_key=True), # ULID
|
||||||
|
Column("movie_id", ForeignKey("movies.id"), nullable=False), # ULID
|
||||||
|
Column("user_id", ForeignKey("users.id"), nullable=False), # ULID
|
||||||
|
Column("score", Integer, nullable=False),
|
||||||
|
Column("rating_date", String, nullable=False), # datetime
|
||||||
|
Column("favorite", Integer), # bool
|
||||||
|
Column("finished", Integer), # bool
|
||||||
|
)
|
||||||
|
|
||||||
|
id: ULID = field(default_factory=ULID)
|
||||||
|
|
||||||
|
movie_id: ULID = None
|
||||||
|
movie: Relation[Movie] = None
|
||||||
|
|
||||||
|
user_id: ULID = None
|
||||||
|
user: Relation[User] = None
|
||||||
|
|
||||||
|
score: int = None # range: [0,100]
|
||||||
|
rating_date: datetime = None
|
||||||
|
favorite: bool | None = None
|
||||||
|
finished: bool | None = None
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
"""Return wether two Ratings are equal.
|
||||||
|
|
||||||
|
This operation compares all fields as expected, except that it
|
||||||
|
ignores any field marked as Relation.
|
||||||
|
"""
|
||||||
|
if type(other) is not type(self):
|
||||||
|
return False
|
||||||
|
return all(
|
||||||
|
getattr(self, f.name) == getattr(other, f.name) for f in fields(self)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ratings = Rating.__table__
|
||||||
|
|
||||||
|
|
||||||
class GroupUser(TypedDict):
|
class GroupUser(TypedDict):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
|
|
|
||||||
124
unwind/web.py
124
unwind/web.py
|
|
@ -168,7 +168,8 @@ async def auth_user(request) -> User | None:
|
||||||
if not isinstance(request.user, AuthedUser):
|
if not isinstance(request.user, AuthedUser):
|
||||||
return
|
return
|
||||||
|
|
||||||
user = await db.get(User, id=request.user.user_id)
|
async with db.new_connection() as conn:
|
||||||
|
user = await db.get(conn, User, id=request.user.user_id)
|
||||||
if not user:
|
if not user:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -195,8 +196,9 @@ def route(path: str, *, methods: list[str] | None = None, **kwds):
|
||||||
async def get_ratings_for_group(request):
|
async def get_ratings_for_group(request):
|
||||||
group_id = as_ulid(request.path_params["group_id"])
|
group_id = as_ulid(request.path_params["group_id"])
|
||||||
|
|
||||||
if (group := await db.get(Group, id=str(group_id))) is None:
|
async with db.new_connection() as conn:
|
||||||
return not_found()
|
if (group := await db.get(conn, Group, id=str(group_id))) is None:
|
||||||
|
return not_found()
|
||||||
|
|
||||||
user_ids = {u["id"] for u in group.users}
|
user_ids = {u["id"] for u in group.users}
|
||||||
|
|
||||||
|
|
@ -207,22 +209,26 @@ async def get_ratings_for_group(request):
|
||||||
|
|
||||||
# if (imdb_id or unwind_id) and (movie := await db.get(Movie, id=unwind_id, imdb_id=imdb_id)):
|
# if (imdb_id or unwind_id) and (movie := await db.get(Movie, id=unwind_id, imdb_id=imdb_id)):
|
||||||
if unwind_id:
|
if unwind_id:
|
||||||
rows = await db.ratings_for_movie_ids(ids=[unwind_id])
|
async with db.new_connection() as conn:
|
||||||
|
rows = await db.ratings_for_movie_ids(conn, ids=[unwind_id])
|
||||||
|
|
||||||
elif imdb_id:
|
elif imdb_id:
|
||||||
rows = await db.ratings_for_movie_ids(imdb_ids=[imdb_id])
|
async with db.new_connection() as conn:
|
||||||
|
rows = await db.ratings_for_movie_ids(conn, imdb_ids=[imdb_id])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
rows = await find_ratings(
|
async with db.new_connection() as conn:
|
||||||
title=params.get("title"),
|
rows = await find_ratings(
|
||||||
media_type=params.get("media_type"),
|
conn,
|
||||||
exact=truthy(params.get("exact")),
|
title=params.get("title"),
|
||||||
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
|
media_type=params.get("media_type"),
|
||||||
include_unrated=truthy(params.get("include_unrated")),
|
exact=truthy(params.get("exact")),
|
||||||
yearcomp=yearcomp(params["year"]) if "year" in params else None,
|
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
|
||||||
limit_rows=as_int(params.get("per_page"), max=10, default=5),
|
include_unrated=truthy(params.get("include_unrated")),
|
||||||
user_ids=user_ids,
|
yearcomp=yearcomp(params["year"]) if "year" in params else None,
|
||||||
)
|
limit_rows=as_int(params.get("per_page"), max=10, default=5),
|
||||||
|
user_ids=user_ids,
|
||||||
|
)
|
||||||
|
|
||||||
ratings = (web_models.Rating(**r) for r in rows)
|
ratings = (web_models.Rating(**r) for r in rows)
|
||||||
|
|
||||||
|
|
@ -261,7 +267,8 @@ async def list_movies(request):
|
||||||
if group_id := params.get("group_id"):
|
if group_id := params.get("group_id"):
|
||||||
group_id = as_ulid(group_id)
|
group_id = as_ulid(group_id)
|
||||||
|
|
||||||
group = await db.get(Group, id=str(group_id))
|
async with db.new_connection() as conn:
|
||||||
|
group = await db.get(conn, Group, id=str(group_id))
|
||||||
if not group:
|
if not group:
|
||||||
return not_found("Group not found.")
|
return not_found("Group not found.")
|
||||||
|
|
||||||
|
|
@ -286,26 +293,31 @@ async def list_movies(request):
|
||||||
|
|
||||||
if imdb_id or unwind_id:
|
if imdb_id or unwind_id:
|
||||||
# XXX missing support for user_ids and user_scores
|
# XXX missing support for user_ids and user_scores
|
||||||
movies = (
|
async with db.new_connection() as conn:
|
||||||
[m] if (m := await db.get(Movie, id=unwind_id, imdb_id=imdb_id)) else []
|
movies = (
|
||||||
)
|
[m]
|
||||||
|
if (m := await db.get(conn, Movie, id=unwind_id, imdb_id=imdb_id))
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
resp = [asplain(web_models.RatingAggregate.from_movie(m)) for m in movies]
|
resp = [asplain(web_models.RatingAggregate.from_movie(m)) for m in movies]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
per_page = as_int(params.get("per_page"), max=1000, default=5)
|
per_page = as_int(params.get("per_page"), max=1000, default=5)
|
||||||
page = as_int(params.get("page"), min=1, default=1)
|
page = as_int(params.get("page"), min=1, default=1)
|
||||||
movieratings = await find_movies(
|
async with db.new_connection() as conn:
|
||||||
title=params.get("title"),
|
movieratings = await find_movies(
|
||||||
media_type=params.get("media_type"),
|
conn,
|
||||||
exact=truthy(params.get("exact")),
|
title=params.get("title"),
|
||||||
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
|
media_type=params.get("media_type"),
|
||||||
include_unrated=truthy(params.get("include_unrated")),
|
exact=truthy(params.get("exact")),
|
||||||
yearcomp=yearcomp(params["year"]) if "year" in params else None,
|
ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")),
|
||||||
limit_rows=per_page,
|
include_unrated=truthy(params.get("include_unrated")),
|
||||||
skip_rows=(page - 1) * per_page,
|
yearcomp=yearcomp(params["year"]) if "year" in params else None,
|
||||||
user_ids=list(user_ids),
|
limit_rows=per_page,
|
||||||
)
|
skip_rows=(page - 1) * per_page,
|
||||||
|
user_ids=list(user_ids),
|
||||||
|
)
|
||||||
|
|
||||||
resp = []
|
resp = []
|
||||||
for movie, ratings in movieratings:
|
for movie, ratings in movieratings:
|
||||||
|
|
@ -325,7 +337,8 @@ async def add_movie(request):
|
||||||
@route("/movies/_reload_imdb", methods=["GET"])
|
@route("/movies/_reload_imdb", methods=["GET"])
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def progress_for_load_imdb_movies(request):
|
async def progress_for_load_imdb_movies(request):
|
||||||
progress = await db.get_import_progress()
|
async with db.new_connection() as conn:
|
||||||
|
progress = await db.get_import_progress(conn)
|
||||||
if not progress:
|
if not progress:
|
||||||
return JSONResponse({"status": "No import exists."}, status_code=404)
|
return JSONResponse({"status": "No import exists."}, status_code=404)
|
||||||
|
|
||||||
|
|
@ -364,14 +377,16 @@ async def load_imdb_movies(request):
|
||||||
force = truthy(params.get("force"))
|
force = truthy(params.get("force"))
|
||||||
|
|
||||||
async with _import_lock:
|
async with _import_lock:
|
||||||
progress = await db.get_import_progress()
|
async with db.new_connection() as conn:
|
||||||
|
progress = await db.get_import_progress(conn)
|
||||||
if progress and not progress.stopped:
|
if progress and not progress.stopped:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{"status": "Import is running.", "progress": progress.percent},
|
{"status": "Import is running.", "progress": progress.percent},
|
||||||
status_code=409,
|
status_code=409,
|
||||||
)
|
)
|
||||||
|
|
||||||
await db.set_import_progress(0)
|
async with db.transaction() as conn:
|
||||||
|
await db.set_import_progress(conn, 0)
|
||||||
|
|
||||||
task = BackgroundTask(imdb_import.load_from_web, force=force)
|
task = BackgroundTask(imdb_import.load_from_web, force=force)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
|
@ -382,7 +397,8 @@ async def load_imdb_movies(request):
|
||||||
@route("/users")
|
@route("/users")
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def list_users(request):
|
async def list_users(request):
|
||||||
users = await db.get_all(User)
|
async with db.new_connection() as conn:
|
||||||
|
users = await db.get_all(conn, User)
|
||||||
|
|
||||||
return JSONResponse([asplain(u) for u in users])
|
return JSONResponse([asplain(u) for u in users])
|
||||||
|
|
||||||
|
|
@ -398,7 +414,8 @@ async def add_user(request):
|
||||||
secret = secrets.token_bytes()
|
secret = secrets.token_bytes()
|
||||||
|
|
||||||
user = User(name=name, imdb_id=imdb_id, secret=phc_scrypt(secret))
|
user = User(name=name, imdb_id=imdb_id, secret=phc_scrypt(secret))
|
||||||
await db.add(user)
|
async with db.transaction() as conn:
|
||||||
|
await db.add(conn, user)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{
|
{
|
||||||
|
|
@ -414,7 +431,8 @@ async def show_user(request):
|
||||||
user_id = as_ulid(request.path_params["user_id"])
|
user_id = as_ulid(request.path_params["user_id"])
|
||||||
|
|
||||||
if is_admin(request):
|
if is_admin(request):
|
||||||
user = await db.get(User, id=str(user_id))
|
async with db.new_connection() as conn:
|
||||||
|
user = await db.get(conn, User, id=str(user_id))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
user = await auth_user(request)
|
user = await auth_user(request)
|
||||||
|
|
@ -441,14 +459,15 @@ async def show_user(request):
|
||||||
async def remove_user(request):
|
async def remove_user(request):
|
||||||
user_id = as_ulid(request.path_params["user_id"])
|
user_id = as_ulid(request.path_params["user_id"])
|
||||||
|
|
||||||
user = await db.get(User, id=str(user_id))
|
async with db.new_connection() as conn:
|
||||||
|
user = await db.get(conn, User, id=str(user_id))
|
||||||
if not user:
|
if not user:
|
||||||
return not_found()
|
return not_found()
|
||||||
|
|
||||||
async with db.transaction():
|
async with db.transaction() as conn:
|
||||||
# XXX remove user refs from groups and ratings
|
# XXX remove user refs from groups and ratings
|
||||||
|
|
||||||
await db.remove(user)
|
await db.remove(conn, user)
|
||||||
|
|
||||||
return JSONResponse(asplain(user))
|
return JSONResponse(asplain(user))
|
||||||
|
|
||||||
|
|
@ -459,7 +478,8 @@ async def modify_user(request):
|
||||||
user_id = as_ulid(request.path_params["user_id"])
|
user_id = as_ulid(request.path_params["user_id"])
|
||||||
|
|
||||||
if is_admin(request):
|
if is_admin(request):
|
||||||
user = await db.get(User, id=str(user_id))
|
async with db.new_connection() as conn:
|
||||||
|
user = await db.get(conn, User, id=str(user_id))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
user = await auth_user(request)
|
user = await auth_user(request)
|
||||||
|
|
@ -495,7 +515,8 @@ async def modify_user(request):
|
||||||
|
|
||||||
user.secret = phc_scrypt(secret)
|
user.secret = phc_scrypt(secret)
|
||||||
|
|
||||||
await db.update(user)
|
async with db.transaction() as conn:
|
||||||
|
await db.update(conn, user)
|
||||||
|
|
||||||
return JSONResponse(asplain(user))
|
return JSONResponse(asplain(user))
|
||||||
|
|
||||||
|
|
@ -505,13 +526,15 @@ async def modify_user(request):
|
||||||
async def add_group_to_user(request):
|
async def add_group_to_user(request):
|
||||||
user_id = as_ulid(request.path_params["user_id"])
|
user_id = as_ulid(request.path_params["user_id"])
|
||||||
|
|
||||||
user = await db.get(User, id=str(user_id))
|
async with db.new_connection() as conn:
|
||||||
|
user = await db.get(conn, User, id=str(user_id))
|
||||||
if not user:
|
if not user:
|
||||||
return not_found("User not found")
|
return not_found("User not found")
|
||||||
|
|
||||||
(group_id, access) = await json_from_body(request, ["group", "access"])
|
(group_id, access) = await json_from_body(request, ["group", "access"])
|
||||||
|
|
||||||
group = await db.get(Group, id=str(group_id))
|
async with db.new_connection() as conn:
|
||||||
|
group = await db.get(conn, Group, id=str(group_id))
|
||||||
if not group:
|
if not group:
|
||||||
return not_found("Group not found")
|
return not_found("Group not found")
|
||||||
|
|
||||||
|
|
@ -519,7 +542,8 @@ async def add_group_to_user(request):
|
||||||
raise HTTPException(422, f"Invalid access level.")
|
raise HTTPException(422, f"Invalid access level.")
|
||||||
|
|
||||||
user.set_access(group_id, access)
|
user.set_access(group_id, access)
|
||||||
await db.update(user)
|
async with db.transaction() as conn:
|
||||||
|
await db.update(conn, user)
|
||||||
|
|
||||||
return JSONResponse(asplain(user))
|
return JSONResponse(asplain(user))
|
||||||
|
|
||||||
|
|
@ -547,7 +571,8 @@ async def load_imdb_user_ratings(request):
|
||||||
@route("/groups")
|
@route("/groups")
|
||||||
@requires(["authenticated", "admin"])
|
@requires(["authenticated", "admin"])
|
||||||
async def list_groups(request):
|
async def list_groups(request):
|
||||||
groups = await db.get_all(Group)
|
async with db.new_connection() as conn:
|
||||||
|
groups = await db.get_all(conn, Group)
|
||||||
|
|
||||||
return JSONResponse([asplain(g) for g in groups])
|
return JSONResponse([asplain(g) for g in groups])
|
||||||
|
|
||||||
|
|
@ -560,7 +585,8 @@ async def add_group(request):
|
||||||
# XXX restrict name
|
# XXX restrict name
|
||||||
|
|
||||||
group = Group(name=name)
|
group = Group(name=name)
|
||||||
await db.add(group)
|
async with db.transaction() as conn:
|
||||||
|
await db.add(conn, group)
|
||||||
|
|
||||||
return JSONResponse(asplain(group))
|
return JSONResponse(asplain(group))
|
||||||
|
|
||||||
|
|
@ -569,7 +595,8 @@ async def add_group(request):
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def add_user_to_group(request):
|
async def add_user_to_group(request):
|
||||||
group_id = as_ulid(request.path_params["group_id"])
|
group_id = as_ulid(request.path_params["group_id"])
|
||||||
group = await db.get(Group, id=str(group_id))
|
async with db.new_connection() as conn:
|
||||||
|
group = await db.get(conn, Group, id=str(group_id))
|
||||||
|
|
||||||
if not group:
|
if not group:
|
||||||
return not_found()
|
return not_found()
|
||||||
|
|
@ -596,7 +623,8 @@ async def add_user_to_group(request):
|
||||||
else:
|
else:
|
||||||
group.users.append({"name": name, "id": user_id})
|
group.users.append({"name": name, "id": user_id})
|
||||||
|
|
||||||
await db.update(group)
|
async with db.transaction() as conn:
|
||||||
|
await db.update(conn, group)
|
||||||
|
|
||||||
return JSONResponse(asplain(group))
|
return JSONResponse(asplain(group))
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue