diff --git a/poetry.lock b/poetry.lock index 9dd4032..df69612 100644 --- a/poetry.lock +++ b/poetry.lock @@ -153,31 +153,6 @@ files = [ [package.extras] toml = ["tomli"] -[[package]] -name = "databases" -version = "0.7.0" -description = "Async database support for Python." -optional = false -python-versions = ">=3.7" -files = [ - {file = "databases-0.7.0-py3-none-any.whl", hash = "sha256:cf5da4b8a3e3cd038c459529725ebb64931cbbb7a091102664f20ef8f6cefd0d"}, - {file = "databases-0.7.0.tar.gz", hash = "sha256:ea2d419d3d2eb80595b7ceb8f282056f080af62efe2fb9bcd83562f93ec4b674"}, -] - -[package.dependencies] -aiosqlite = {version = "*", optional = true, markers = "extra == \"sqlite\""} -sqlalchemy = ">=1.4.42,<1.5" - -[package.extras] -aiomysql = ["aiomysql"] -aiopg = ["aiopg"] -aiosqlite = ["aiosqlite"] -asyncmy = ["asyncmy"] -asyncpg = ["asyncpg"] -mysql = ["aiomysql"] -postgresql = ["asyncpg"] -sqlite = ["aiosqlite"] - [[package]] name = "greenlet" version = "3.0.1" @@ -554,62 +529,90 @@ files = [ [[package]] name = "sqlalchemy" -version = "1.4.50" +version = "2.0.23" description = "Database Abstraction Library" optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +python-versions = ">=3.7" files = [ - {file = "SQLAlchemy-1.4.50-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d00665725063692c42badfd521d0c4392e83c6c826795d38eb88fb108e5660e5"}, - {file = "SQLAlchemy-1.4.50-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85292ff52ddf85a39367057c3d7968a12ee1fb84565331a36a8fead346f08796"}, - {file = "SQLAlchemy-1.4.50-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d0fed0f791d78e7767c2db28d34068649dfeea027b83ed18c45a423f741425cb"}, - {file = "SQLAlchemy-1.4.50-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db4db3c08ffbb18582f856545f058a7a5e4ab6f17f75795ca90b3c38ee0a8ba4"}, - {file = "SQLAlchemy-1.4.50-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:14b0cacdc8a4759a1e1bd47dc3ee3f5db997129eb091330beda1da5a0e9e5bd7"}, - {file = "SQLAlchemy-1.4.50-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1fb9cb60e0f33040e4f4681e6658a7eb03b5cb4643284172f91410d8c493dace"}, - {file = "SQLAlchemy-1.4.50-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c4cb501d585aa74a0f86d0ea6263b9c5e1d1463f8f9071392477fd401bd3c7cc"}, - {file = "SQLAlchemy-1.4.50-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a7a66297e46f85a04d68981917c75723e377d2e0599d15fbe7a56abed5e2d75"}, - {file = "SQLAlchemy-1.4.50-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1db0221cb26d66294f4ca18c533e427211673ab86c1fbaca8d6d9ff78654293"}, - {file = "SQLAlchemy-1.4.50-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b7dbe6369677a2bea68fe9812c6e4bbca06ebfa4b5cde257b2b0bf208709131"}, - {file = "SQLAlchemy-1.4.50-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a9bddb60566dc45c57fd0a5e14dd2d9e5f106d2241e0a2dc0c1da144f9444516"}, - {file = "SQLAlchemy-1.4.50-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82dd4131d88395df7c318eeeef367ec768c2a6fe5bd69423f7720c4edb79473c"}, - {file = "SQLAlchemy-1.4.50-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:273505fcad22e58cc67329cefab2e436006fc68e3c5423056ee0513e6523268a"}, - {file = "SQLAlchemy-1.4.50-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3257a6e09626d32b28a0c5b4f1a97bced585e319cfa90b417f9ab0f6145c33c"}, - {file = "SQLAlchemy-1.4.50-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d69738d582e3a24125f0c246ed8d712b03bd21e148268421e4a4d09c34f521a5"}, - {file = "SQLAlchemy-1.4.50-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:34e1c5d9cd3e6bf3d1ce56971c62a40c06bfc02861728f368dcfec8aeedb2814"}, - {file = "SQLAlchemy-1.4.50-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1fcee5a2c859eecb4ed179edac5ffbc7c84ab09a5420219078ccc6edda45436"}, - {file = "SQLAlchemy-1.4.50-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbaf6643a604aa17e7a7afd74f665f9db882df5c297bdd86c38368f2c471f37d"}, - {file = "SQLAlchemy-1.4.50-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2e70e0673d7d12fa6cd363453a0d22dac0d9978500aa6b46aa96e22690a55eab"}, - {file = "SQLAlchemy-1.4.50-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b881ac07d15fb3e4f68c5a67aa5cdaf9eb8f09eb5545aaf4b0a5f5f4659be18"}, - {file = "SQLAlchemy-1.4.50-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f6997da81114daef9203d30aabfa6b218a577fc2bd797c795c9c88c9eb78d49"}, - {file = "SQLAlchemy-1.4.50-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdb77e1789e7596b77fd48d99ec1d2108c3349abd20227eea0d48d3f8cf398d9"}, - {file = "SQLAlchemy-1.4.50-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:128a948bd40780667114b0297e2cc6d657b71effa942e0a368d8cc24293febb3"}, - {file = "SQLAlchemy-1.4.50-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2d526aeea1bd6a442abc7c9b4b00386fd70253b80d54a0930c0a216230a35be"}, - {file = "SQLAlchemy-1.4.50.tar.gz", hash = "sha256:3b97ddf509fc21e10b09403b5219b06c5b558b27fc2453150274fa4e70707dbf"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:638c2c0b6b4661a4fd264f6fb804eccd392745c5887f9317feb64bb7cb03b3ea"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e3b5036aa326dc2df50cba3c958e29b291a80f604b1afa4c8ce73e78e1c9f01d"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:787af80107fb691934a01889ca8f82a44adedbf5ef3d6ad7d0f0b9ac557e0c34"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c14eba45983d2f48f7546bb32b47937ee2cafae353646295f0e99f35b14286ab"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0666031df46b9badba9bed00092a1ffa3aa063a5e68fa244acd9f08070e936d3"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:89a01238fcb9a8af118eaad3ffcc5dedaacbd429dc6fdc43fe430d3a941ff965"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-win32.whl", hash = "sha256:cabafc7837b6cec61c0e1e5c6d14ef250b675fa9c3060ed8a7e38653bd732ff8"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-win_amd64.whl", hash = "sha256:87a3d6b53c39cd173990de2f5f4b83431d534a74f0e2f88bd16eabb5667e65c6"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d5578e6863eeb998980c212a39106ea139bdc0b3f73291b96e27c929c90cd8e1"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:62d9e964870ea5ade4bc870ac4004c456efe75fb50404c03c5fd61f8bc669a72"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c80c38bd2ea35b97cbf7c21aeb129dcbebbf344ee01a7141016ab7b851464f8e"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75eefe09e98043cff2fb8af9796e20747ae870c903dc61d41b0c2e55128f958d"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bd45a5b6c68357578263d74daab6ff9439517f87da63442d244f9f23df56138d"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a86cb7063e2c9fb8e774f77fbf8475516d270a3e989da55fa05d08089d77f8c4"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-win32.whl", hash = "sha256:b41f5d65b54cdf4934ecede2f41b9c60c9f785620416e8e6c48349ab18643855"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-win_amd64.whl", hash = "sha256:9ca922f305d67605668e93991aaf2c12239c78207bca3b891cd51a4515c72e22"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0f7fb0c7527c41fa6fcae2be537ac137f636a41b4c5a4c58914541e2f436b45"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7c424983ab447dab126c39d3ce3be5bee95700783204a72549c3dceffe0fc8f4"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f508ba8f89e0a5ecdfd3761f82dda2a3d7b678a626967608f4273e0dba8f07ac"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6463aa765cf02b9247e38b35853923edbf2f6fd1963df88706bc1d02410a5577"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e599a51acf3cc4d31d1a0cf248d8f8d863b6386d2b6782c5074427ebb7803bda"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:fd54601ef9cc455a0c61e5245f690c8a3ad67ddb03d3b91c361d076def0b4c60"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-win32.whl", hash = "sha256:42d0b0290a8fb0165ea2c2781ae66e95cca6e27a2fbe1016ff8db3112ac1e846"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-win_amd64.whl", hash = "sha256:227135ef1e48165f37590b8bfc44ed7ff4c074bf04dc8d6f8e7f1c14a94aa6ca"}, + {file = "SQLAlchemy-2.0.23-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:14aebfe28b99f24f8a4c1346c48bc3d63705b1f919a24c27471136d2f219f02d"}, + {file = "SQLAlchemy-2.0.23-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e983fa42164577d073778d06d2cc5d020322425a509a08119bdcee70ad856bf"}, + {file = "SQLAlchemy-2.0.23-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e0dc9031baa46ad0dd5a269cb7a92a73284d1309228be1d5935dac8fb3cae24"}, + {file = "SQLAlchemy-2.0.23-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5f94aeb99f43729960638e7468d4688f6efccb837a858b34574e01143cf11f89"}, + {file = "SQLAlchemy-2.0.23-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:63bfc3acc970776036f6d1d0e65faa7473be9f3135d37a463c5eba5efcdb24c8"}, + {file = "SQLAlchemy-2.0.23-cp37-cp37m-win32.whl", hash = "sha256:f48ed89dd11c3c586f45e9eec1e437b355b3b6f6884ea4a4c3111a3358fd0c18"}, + {file = "SQLAlchemy-2.0.23-cp37-cp37m-win_amd64.whl", hash = "sha256:1e018aba8363adb0599e745af245306cb8c46b9ad0a6fc0a86745b6ff7d940fc"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:64ac935a90bc479fee77f9463f298943b0e60005fe5de2aa654d9cdef46c54df"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c4722f3bc3c1c2fcc3702dbe0016ba31148dd6efcd2a2fd33c1b4897c6a19693"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4af79c06825e2836de21439cb2a6ce22b2ca129bad74f359bddd173f39582bf5"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:683ef58ca8eea4747737a1c35c11372ffeb84578d3aab8f3e10b1d13d66f2bc4"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:d4041ad05b35f1f4da481f6b811b4af2f29e83af253bf37c3c4582b2c68934ab"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aeb397de65a0a62f14c257f36a726945a7f7bb60253462e8602d9b97b5cbe204"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-win32.whl", hash = "sha256:42ede90148b73fe4ab4a089f3126b2cfae8cfefc955c8174d697bb46210c8306"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-win_amd64.whl", hash = "sha256:964971b52daab357d2c0875825e36584d58f536e920f2968df8d581054eada4b"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:616fe7bcff0a05098f64b4478b78ec2dfa03225c23734d83d6c169eb41a93e55"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0e680527245895aba86afbd5bef6c316831c02aa988d1aad83c47ffe92655e74"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9585b646ffb048c0250acc7dad92536591ffe35dba624bb8fd9b471e25212a35"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4895a63e2c271ffc7a81ea424b94060f7b3b03b4ea0cd58ab5bb676ed02f4221"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:cc1d21576f958c42d9aec68eba5c1a7d715e5fc07825a629015fe8e3b0657fb0"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:967c0b71156f793e6662dd839da54f884631755275ed71f1539c95bbada9aaab"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-win32.whl", hash = "sha256:0a8c6aa506893e25a04233bc721c6b6cf844bafd7250535abb56cb6cc1368884"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-win_amd64.whl", hash = "sha256:f3420d00d2cb42432c1d0e44540ae83185ccbbc67a6054dcc8ab5387add6620b"}, + {file = "SQLAlchemy-2.0.23-py3-none-any.whl", hash = "sha256:31952bbc527d633b9479f5f81e8b9dfada00b91d6baba021a869095f1a97006d"}, + {file = "SQLAlchemy-2.0.23.tar.gz", hash = "sha256:c1bda93cbbe4aa2aa0aa8655c5aeda505cd219ff3e8da91d1d329e143e4aff69"}, ] [package.dependencies] -aiosqlite = {version = "*", optional = true, markers = "python_version >= \"3\" and extra == \"aiosqlite\""} -greenlet = {version = "!=0.4.17", optional = true, markers = "python_version >= \"3\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"aiosqlite\")"} -typing-extensions = {version = "!=3.10.0.1", optional = true, markers = "extra == \"aiosqlite\""} +aiosqlite = {version = "*", optional = true, markers = "extra == \"aiosqlite\""} +greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"aiosqlite\""} +typing-extensions = {version = ">=4.2.0", optional = true, markers = "extra == \"aiosqlite\""} [package.extras] aiomysql = ["aiomysql (>=0.2.0)", "greenlet (!=0.4.17)"] +aioodbc = ["aioodbc", "greenlet (!=0.4.17)"] aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing-extensions (!=3.10.0.1)"] asyncio = ["greenlet (!=0.4.17)"] -asyncmy = ["asyncmy (>=0.2.3,!=0.2.4)", "greenlet (!=0.4.17)"] -mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2)"] +asyncmy = ["asyncmy (>=0.2.3,!=0.2.4,!=0.2.6)", "greenlet (!=0.4.17)"] +mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5)"] mssql = ["pyodbc"] mssql-pymssql = ["pymssql"] mssql-pyodbc = ["pyodbc"] -mypy = ["mypy (>=0.910)", "sqlalchemy2-stubs"] -mysql = ["mysqlclient (>=1.4.0)", "mysqlclient (>=1.4.0,<2)"] +mypy = ["mypy (>=0.910)"] +mysql = ["mysqlclient (>=1.4.0)"] mysql-connector = ["mysql-connector-python"] -oracle = ["cx-oracle (>=7)", "cx-oracle (>=7,<8)"] +oracle = ["cx-oracle (>=8)"] +oracle-oracledb = ["oracledb (>=1.0.1)"] postgresql = ["psycopg2 (>=2.7)"] postgresql-asyncpg = ["asyncpg", "greenlet (!=0.4.17)"] -postgresql-pg8000 = ["pg8000 (>=1.16.6,!=1.29.0)"] +postgresql-pg8000 = ["pg8000 (>=1.29.1)"] +postgresql-psycopg = ["psycopg (>=3.0.7)"] postgresql-psycopg2binary = ["psycopg2-binary"] postgresql-psycopg2cffi = ["psycopg2cffi"] -pymysql = ["pymysql", "pymysql (<1)"] +postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] +pymysql = ["pymysql"] sqlcipher = ["sqlcipher3-binary"] [[package]] @@ -683,4 +686,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "8d0ddcdcd96f4736bb3608df11678d78776f5cf7c6883474b61b158c99ac4732" +content-hash = "fc07028820963701634eb55b42ea12962fd7c6fc25ef76ddadf30f2c74544b5f" diff --git a/pyproject.toml b/pyproject.toml index 92b2b9d..f7c1b6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,10 +11,9 @@ beautifulsoup4 = "^4.9.3" html5lib = "^1.1" starlette = "^0.30" ulid-py = "^1.1.0" -databases = {extras = ["sqlite"], version = "^0.7.0"} uvicorn = "^0.23" httpx = "^0.24" -sqlalchemy = {version = "^1.4", extras = ["aiosqlite"]} +sqlalchemy = {version = "^2.0", extras = ["aiosqlite"]} [tool.poetry.group.dev] optional = true diff --git a/scripts/tests b/scripts/tests index 8261f1e..5eefc57 100755 --- a/scripts/tests +++ b/scripts/tests @@ -6,7 +6,7 @@ dbfile="${UNWIND_DATA:-./data}/tests.sqlite" # Rollback in Databases is currently broken, so we have to rebuild the database # each time; see https://github.com/encode/databases/issues/403 -trap 'rm "$dbfile"' EXIT TERM INT QUIT +trap 'rm "$dbfile" "${dbfile}-shm" "${dbfile}-wal"' EXIT TERM INT QUIT [ -z "${DEBUG:-}" ] || set -x diff --git a/tests/conftest.py b/tests/conftest.py index 470bc4d..17ce01a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,16 +17,19 @@ def event_loop(): @pytest_asyncio.fixture(scope="session") async def shared_conn(): - c = db._shared_connection() - await c.connect() + """A database connection, ready to use.""" + await db.open_connection_pool() - await db.apply_db_patches(c) - yield c + async with db.new_connection() as c: + db._test_connection = c + yield c + db._test_connection = None - await c.disconnect() + await db.close_connection_pool() @pytest_asyncio.fixture -async def conn(shared_conn): - async with shared_conn.transaction(force_rollback=True): +async def conn(shared_conn: db.Connection): + """A transacted database connection, will be rolled back after use.""" + async with db.transacted(shared_conn, force_rollback=True): yield shared_conn diff --git a/tests/test_db.py b/tests/test_db.py index cd5f295..3619497 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -4,7 +4,7 @@ import pytest from unwind import db, models, web_models -_movie_imdb_id = 1234567 +_movie_imdb_id = 1230000 def a_movie(**kwds) -> models.Movie: @@ -21,394 +21,399 @@ def a_movie(**kwds) -> models.Movie: @pytest.mark.asyncio -async def test_current_patch_level(shared_conn: db.Database): - async with shared_conn.transaction(force_rollback=True): - patch_level = "some-patch-level" - assert patch_level != await db.current_patch_level(shared_conn) - await db.set_current_patch_level(shared_conn, patch_level) - assert patch_level == await db.current_patch_level(shared_conn) +async def test_current_patch_level(conn: db.Connection): + patch_level = "some-patch-level" + assert patch_level != await db.current_patch_level(conn) + await db.set_current_patch_level(conn, patch_level) + assert patch_level == await db.current_patch_level(conn) @pytest.mark.asyncio -async def test_get(shared_conn: db.Database): - async with shared_conn.transaction(force_rollback=True): - m1 = a_movie() - await db.add(m1) +async def test_get(conn: db.Connection): + m1 = a_movie() + await db.add(conn, m1) - m2 = a_movie(release_year=m1.release_year + 1) - await db.add(m2) + m2 = a_movie(release_year=m1.release_year + 1) + await db.add(conn, m2) - assert None is await db.get(models.Movie) - assert None is await db.get(models.Movie, id="blerp") - assert m1 == await db.get(models.Movie, id=str(m1.id)) - assert m2 == await db.get(models.Movie, release_year=m2.release_year) - assert None is await db.get( - models.Movie, id=str(m1.id), release_year=m2.release_year - ) - assert m2 == await db.get( - models.Movie, id=str(m2.id), release_year=m2.release_year - ) - assert m1 == await db.get( - models.Movie, - media_type=m1.media_type, - order_by=(models.movies.c.release_year, "asc"), - ) - assert m2 == await db.get( - models.Movie, - media_type=m1.media_type, - order_by=(models.movies.c.release_year, "desc"), - ) + assert None is await db.get(conn, models.Movie) + assert None is await db.get(conn, models.Movie, id="blerp") + assert m1 == await db.get(conn, models.Movie, id=str(m1.id)) + assert m2 == await db.get(conn, models.Movie, release_year=m2.release_year) + assert None is await db.get( + conn, models.Movie, id=str(m1.id), release_year=m2.release_year + ) + assert m2 == await db.get( + conn, models.Movie, id=str(m2.id), release_year=m2.release_year + ) + assert m1 == await db.get( + conn, + models.Movie, + media_type=m1.media_type, + order_by=(models.movies.c.release_year, "asc"), + ) + assert m2 == await db.get( + conn, + models.Movie, + media_type=m1.media_type, + order_by=(models.movies.c.release_year, "desc"), + ) @pytest.mark.asyncio -async def test_get_all(shared_conn: db.Database): - async with shared_conn.transaction(force_rollback=True): - m1 = a_movie() - await db.add(m1) +async def test_get_all(conn: db.Connection): + m1 = a_movie() + await db.add(conn, m1) - m2 = a_movie(release_year=m1.release_year) - await db.add(m2) + m2 = a_movie(release_year=m1.release_year) + await db.add(conn, m2) - m3 = a_movie(release_year=m1.release_year + 1) - await db.add(m3) + m3 = a_movie(release_year=m1.release_year + 1) + await db.add(conn, m3) - assert [] == list(await db.get_all(models.Movie, id="blerp")) - assert [m1] == list(await db.get_all(models.Movie, id=str(m1.id))) - assert [m1, m2] == list( - await db.get_all(models.Movie, release_year=m1.release_year) - ) - assert [m1, m2, m3] == list(await db.get_all(models.Movie)) + assert [] == list(await db.get_all(conn, models.Movie, id="blerp")) + assert [m1] == list(await db.get_all(conn, models.Movie, id=str(m1.id))) + assert [m1, m2] == list( + await db.get_all(conn, models.Movie, release_year=m1.release_year) + ) + assert [m1, m2, m3] == list(await db.get_all(conn, models.Movie)) @pytest.mark.asyncio -async def test_get_many(shared_conn: db.Database): - async with shared_conn.transaction(force_rollback=True): - m1 = a_movie() - await db.add(m1) +async def test_get_many(conn: db.Connection): + m1 = a_movie() + await db.add(conn, m1) - m2 = a_movie(release_year=m1.release_year) - await db.add(m2) + m2 = a_movie(release_year=m1.release_year) + await db.add(conn, m2) - m3 = a_movie(release_year=m1.release_year + 1) - await db.add(m3) + m3 = a_movie(release_year=m1.release_year + 1) + await db.add(conn, m3) - assert [] == list(await db.get_many(models.Movie)), "selected nothing" - assert [m1] == list(await db.get_many(models.Movie, id=[str(m1.id)])) - assert [m1] == list(await db.get_many(models.Movie, id={str(m1.id)})) - assert [m1, m2] == list( - await db.get_many(models.Movie, release_year=[m1.release_year]) - ) - assert [m1, m2, m3] == list( - await db.get_many( - models.Movie, release_year=[m1.release_year, m3.release_year] - ) + assert [] == list(await db.get_many(conn, models.Movie)), "selected nothing" + assert [m1] == list(await db.get_many(conn, models.Movie, id=[str(m1.id)])) + assert [m1] == list(await db.get_many(conn, models.Movie, id={str(m1.id)})) + assert [m1, m2] == list( + await db.get_many(conn, models.Movie, release_year=[m1.release_year]) + ) + assert [m1, m2, m3] == list( + await db.get_many( + conn, models.Movie, release_year=[m1.release_year, m3.release_year] ) + ) @pytest.mark.asyncio -async def test_add_and_get(shared_conn: db.Database): - async with shared_conn.transaction(force_rollback=True): - m1 = a_movie() - await db.add(m1) +async def test_add_and_get(conn: db.Connection): + m1 = a_movie() + await db.add(conn, m1) - m2 = a_movie() - await db.add(m2) + m2 = a_movie() + await db.add(conn, m2) - assert m1 == await db.get(models.Movie, id=str(m1.id)) - assert m2 == await db.get(models.Movie, id=str(m2.id)) + assert m1 == await db.get(conn, models.Movie, id=str(m1.id)) + assert m2 == await db.get(conn, models.Movie, id=str(m2.id)) @pytest.mark.asyncio -async def test_update(shared_conn: db.Database): - async with shared_conn.transaction(force_rollback=True): - m = a_movie() - await db.add(m) +async def test_update(conn: db.Connection): + m = a_movie() + await db.add(conn, m) - assert m == await db.get(models.Movie, id=str(m.id)) - m.title += "something else" - assert m != await db.get(models.Movie, id=str(m.id)) + assert m == await db.get(conn, models.Movie, id=str(m.id)) + m.title += "something else" + assert m != await db.get(conn, models.Movie, id=str(m.id)) - await db.update(m) - assert m == await db.get(models.Movie, id=str(m.id)) + await db.update(conn, m) + assert m == await db.get(conn, models.Movie, id=str(m.id)) @pytest.mark.asyncio -async def test_remove(shared_conn: db.Database): - async with shared_conn.transaction(force_rollback=True): - m1 = a_movie() - await db.add(m1) - assert m1 == await db.get(models.Movie, id=str(m1.id)) +async def test_remove(conn: db.Connection): + m1 = a_movie() + await db.add(conn, m1) + assert m1 == await db.get(conn, models.Movie, id=str(m1.id)) - await db.remove(m1) - assert None is await db.get(models.Movie, id=str(m1.id)) + await db.remove(conn, m1) + assert None is await db.get(conn, models.Movie, id=str(m1.id)) @pytest.mark.asyncio -async def test_find_ratings(shared_conn: db.Database): - async with shared_conn.transaction(force_rollback=True): - m1 = a_movie( - title="test movie", - release_year=2013, - genres={"genre-1"}, - ) - await db.add(m1) +async def test_find_ratings(conn: db.Connection): + m1 = a_movie( + title="test movie", + release_year=2013, + genres={"genre-1"}, + ) + await db.add(conn, m1) - m2 = a_movie( - title="it's anöther Movie, Part 2", - release_year=2015, - genres={"genre-2"}, - ) - await db.add(m2) + m2 = a_movie( + title="it's anöther Movie, Part 2", + release_year=2015, + genres={"genre-2"}, + ) + await db.add(conn, m2) - m3 = a_movie( - title="movie it's, Part 3", - release_year=m2.release_year, - genres=m2.genres, - ) - await db.add(m3) + m3 = a_movie( + title="movie it's, Part 3", + release_year=m2.release_year, + genres=m2.genres, + ) + await db.add(conn, m3) - u1 = models.User( - imdb_id="u00001", - name="User1", - secret="secret1", - ) - await db.add(u1) + u1 = models.User( + imdb_id="u00001", + name="User1", + secret="secret1", + ) + await db.add(conn, u1) - u2 = models.User( - imdb_id="u00002", - name="User2", - secret="secret2", - ) - await db.add(u2) + u2 = models.User( + imdb_id="u00002", + name="User2", + secret="secret2", + ) + await db.add(conn, u2) - r1 = models.Rating( - movie_id=m2.id, - movie=m2, - user_id=u1.id, - user=u1, - score=66, - rating_date=datetime.now(), - ) - await db.add(r1) + r1 = models.Rating( + movie_id=m2.id, + movie=m2, + user_id=u1.id, + user=u1, + score=66, + rating_date=datetime.now(), + ) + await db.add(conn, r1) - r2 = models.Rating( - movie_id=m2.id, - movie=m2, - user_id=u2.id, - user=u2, - score=77, - rating_date=datetime.now(), - ) - await db.add(r2) + r2 = models.Rating( + movie_id=m2.id, + movie=m2, + user_id=u2.id, + user=u2, + score=77, + rating_date=datetime.now(), + ) + await db.add(conn, r2) - # --- + # --- - rows = await db.find_ratings( - title=m1.title, - media_type=m1.media_type, - exact=True, - ignore_tv_episodes=True, - include_unrated=True, - yearcomp=("=", m1.release_year), - limit_rows=3, - user_ids=[], - ) - ratings = (web_models.Rating(**r) for r in rows) - assert (web_models.RatingAggregate.from_movie(m1),) == tuple( - web_models.aggregate_ratings(ratings, user_ids=[]) - ) + rows = await db.find_ratings( + conn, + title=m1.title, + media_type=m1.media_type, + exact=True, + ignore_tv_episodes=True, + include_unrated=True, + yearcomp=("=", m1.release_year), + limit_rows=3, + user_ids=[], + ) + ratings = (web_models.Rating(**r) for r in rows) + assert (web_models.RatingAggregate.from_movie(m1),) == tuple( + web_models.aggregate_ratings(ratings, user_ids=[]) + ) - rows = await db.find_ratings(title="movie", include_unrated=False) - ratings = tuple(web_models.Rating(**r) for r in rows) - assert ( - web_models.Rating.from_movie(m2, rating=r1), - web_models.Rating.from_movie(m2, rating=r2), - ) == ratings + rows = await db.find_ratings(conn, title="movie", include_unrated=False) + ratings = tuple(web_models.Rating(**r) for r in rows) + assert ( + web_models.Rating.from_movie(m2, rating=r1), + web_models.Rating.from_movie(m2, rating=r2), + ) == ratings - rows = await db.find_ratings(title="movie", include_unrated=True) - ratings = tuple(web_models.Rating(**r) for r in rows) - assert ( - web_models.Rating.from_movie(m1), - web_models.Rating.from_movie(m2, rating=r1), - web_models.Rating.from_movie(m2, rating=r2), - web_models.Rating.from_movie(m3), - ) == ratings + rows = await db.find_ratings(conn, title="movie", include_unrated=True) + ratings = tuple(web_models.Rating(**r) for r in rows) + assert ( + web_models.Rating.from_movie(m1), + web_models.Rating.from_movie(m2, rating=r1), + web_models.Rating.from_movie(m2, rating=r2), + web_models.Rating.from_movie(m3), + ) == ratings - aggr = web_models.aggregate_ratings(ratings, user_ids=[]) - assert tuple( - web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3] - ) == tuple(aggr) + aggr = web_models.aggregate_ratings(ratings, user_ids=[]) + assert tuple( + web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3] + ) == tuple(aggr) - aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id)]) - assert ( - web_models.RatingAggregate.from_movie(m1), - web_models.RatingAggregate.from_movie(m2, ratings=[r1]), - web_models.RatingAggregate.from_movie(m3), - ) == tuple(aggr) + aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id)]) + assert ( + web_models.RatingAggregate.from_movie(m1), + web_models.RatingAggregate.from_movie(m2, ratings=[r1]), + web_models.RatingAggregate.from_movie(m3), + ) == tuple(aggr) - aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id), str(u2.id)]) - assert ( - web_models.RatingAggregate.from_movie(m1), - web_models.RatingAggregate.from_movie(m2, ratings=[r1, r2]), - web_models.RatingAggregate.from_movie(m3), - ) == tuple(aggr) + aggr = web_models.aggregate_ratings(ratings, user_ids=[str(u1.id), str(u2.id)]) + assert ( + web_models.RatingAggregate.from_movie(m1), + web_models.RatingAggregate.from_movie(m2, ratings=[r1, r2]), + web_models.RatingAggregate.from_movie(m3), + ) == tuple(aggr) - rows = await db.find_ratings(title="movie", include_unrated=True) - ratings = (web_models.Rating(**r) for r in rows) - aggr = web_models.aggregate_ratings(ratings, user_ids=[]) - assert tuple( - web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3] - ) == tuple(aggr) + rows = await db.find_ratings(conn, title="movie", include_unrated=True) + ratings = (web_models.Rating(**r) for r in rows) + aggr = web_models.aggregate_ratings(ratings, user_ids=[]) + assert tuple( + web_models.RatingAggregate.from_movie(m) for m in [m1, m2, m3] + ) == tuple(aggr) - rows = await db.find_ratings(title="test", include_unrated=True) - ratings = tuple(web_models.Rating(**r) for r in rows) - assert (web_models.Rating.from_movie(m1),) == ratings + rows = await db.find_ratings(conn, title="test", include_unrated=True) + ratings = tuple(web_models.Rating(**r) for r in rows) + assert (web_models.Rating.from_movie(m1),) == ratings @pytest.mark.asyncio -async def test_ratings_for_movies(shared_conn: db.Database): - async with shared_conn.transaction(force_rollback=True): - m1 = a_movie() - await db.add(m1) +async def test_ratings_for_movies(conn: db.Connection): + m1 = a_movie() + await db.add(conn, m1) - m2 = a_movie() - await db.add(m2) + m2 = a_movie() + await db.add(conn, m2) - u1 = models.User( - imdb_id="u00001", - name="User1", - secret="secret1", - ) - await db.add(u1) + u1 = models.User( + imdb_id="u00001", + name="User1", + secret="secret1", + ) + await db.add(conn, u1) - u2 = models.User( - imdb_id="u00002", - name="User2", - secret="secret2", - ) - await db.add(u2) + u2 = models.User( + imdb_id="u00002", + name="User2", + secret="secret2", + ) + await db.add(conn, u2) - r1 = models.Rating( - movie_id=m2.id, - movie=m2, - user_id=u1.id, - user=u1, - score=66, - rating_date=datetime.now(), - ) - await db.add(r1) + r1 = models.Rating( + movie_id=m2.id, + movie=m2, + user_id=u1.id, + user=u1, + score=66, + rating_date=datetime.now(), + ) + await db.add(conn, r1) - # --- + # --- - movie_ids = [m1.id] - user_ids = [] - assert tuple() == tuple( - await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids) - ) + movie_ids = [m1.id] + user_ids = [] + assert tuple() == tuple( + await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids) + ) - movie_ids = [m2.id] - user_ids = [] - assert (r1,) == tuple( - await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids) - ) + movie_ids = [m2.id] + user_ids = [] + assert (r1,) == tuple( + await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids) + ) - movie_ids = [m2.id] - user_ids = [u2.id] - assert tuple() == tuple( - await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids) - ) + movie_ids = [m2.id] + user_ids = [u2.id] + assert tuple() == tuple( + await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids) + ) - movie_ids = [m2.id] - user_ids = [u1.id] - assert (r1,) == tuple( - await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids) - ) + movie_ids = [m2.id] + user_ids = [u1.id] + assert (r1,) == tuple( + await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids) + ) - movie_ids = [m1.id, m2.id] - user_ids = [u1.id, u2.id] - assert (r1,) == tuple( - await db.ratings_for_movies(movie_ids=movie_ids, user_ids=user_ids) - ) + movie_ids = [m1.id, m2.id] + user_ids = [u1.id, u2.id] + assert (r1,) == tuple( + await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids) + ) @pytest.mark.asyncio -async def test_find_movies(shared_conn: db.Database): - async with shared_conn.transaction(force_rollback=True): - m1 = a_movie(title="movie one") - await db.add(m1) +async def test_find_movies(conn: db.Connection): + m1 = a_movie(title="movie one") + await db.add(conn, m1) - m2 = a_movie(title="movie two", imdb_score=33, release_year=m1.release_year + 1) - await db.add(m2) + m2 = a_movie(title="movie two", imdb_score=33, release_year=m1.release_year + 1) + await db.add(conn, m2) - u1 = models.User( - imdb_id="u00001", - name="User1", - secret="secret1", - ) - await db.add(u1) + u1 = models.User( + imdb_id="u00001", + name="User1", + secret="secret1", + ) + await db.add(conn, u1) - u2 = models.User( - imdb_id="u00002", - name="User2", - secret="secret2", - ) - await db.add(u2) + u2 = models.User( + imdb_id="u00002", + name="User2", + secret="secret2", + ) + await db.add(conn, u2) - r1 = models.Rating( - movie_id=m2.id, - movie=m2, - user_id=u1.id, - user=u1, - score=66, - rating_date=datetime.now(), - ) - await db.add(r1) + r1 = models.Rating( + movie_id=m2.id, + movie=m2, + user_id=u1.id, + user=u1, + score=66, + rating_date=datetime.now(), + ) + await db.add(conn, r1) - # --- + # --- - assert () == tuple(await db.find_movies(title=m1.title, include_unrated=False)) - assert ((m1, []),) == tuple( - await db.find_movies(title=m1.title, include_unrated=True) - ) + assert () == tuple( + await db.find_movies(conn, title=m1.title, include_unrated=False) + ) + assert ((m1, []),) == tuple( + await db.find_movies(conn, title=m1.title, include_unrated=True) + ) - assert ((m1, []),) == tuple( - await db.find_movies(title="mo on", exact=False, include_unrated=True) - ) - assert ((m1, []),) == tuple( - await db.find_movies(title="movie one", exact=True, include_unrated=True) - ) - assert () == tuple( - await db.find_movies(title="mo on", exact=True, include_unrated=True) - ) + assert ((m1, []),) == tuple( + await db.find_movies(conn, title="mo on", exact=False, include_unrated=True) + ) + assert ((m1, []),) == tuple( + await db.find_movies(conn, title="movie one", exact=True, include_unrated=True) + ) + assert () == tuple( + await db.find_movies(conn, title="mo on", exact=True, include_unrated=True) + ) - assert ((m2, []),) == tuple( - await db.find_movies(title="movie", exact=False, include_unrated=False) - ) - assert ((m2, []), (m1, [])) == tuple( - await db.find_movies(title="movie", exact=False, include_unrated=True) - ) + assert ((m2, []),) == tuple( + await db.find_movies(conn, title="movie", exact=False, include_unrated=False) + ) + assert ((m2, []), (m1, [])) == tuple( + await db.find_movies(conn, title="movie", exact=False, include_unrated=True) + ) - assert ((m1, []),) == tuple( - await db.find_movies(include_unrated=True, yearcomp=("=", m1.release_year)) + assert ((m1, []),) == tuple( + await db.find_movies( + conn, include_unrated=True, yearcomp=("=", m1.release_year) ) - assert ((m2, []),) == tuple( - await db.find_movies(include_unrated=True, yearcomp=("=", m2.release_year)) + ) + assert ((m2, []),) == tuple( + await db.find_movies( + conn, include_unrated=True, yearcomp=("=", m2.release_year) ) - assert ((m1, []),) == tuple( - await db.find_movies(include_unrated=True, yearcomp=("<", m2.release_year)) + ) + assert ((m1, []),) == tuple( + await db.find_movies( + conn, include_unrated=True, yearcomp=("<", m2.release_year) ) - assert ((m2, []),) == tuple( - await db.find_movies(include_unrated=True, yearcomp=(">", m1.release_year)) + ) + assert ((m2, []),) == tuple( + await db.find_movies( + conn, include_unrated=True, yearcomp=(">", m1.release_year) ) + ) - assert ((m2, []), (m1, [])) == tuple(await db.find_movies(include_unrated=True)) - assert ((m2, []),) == tuple( - await db.find_movies(include_unrated=True, limit_rows=1) - ) - assert ((m1, []),) == tuple( - await db.find_movies(include_unrated=True, skip_rows=1) - ) + assert ((m2, []), (m1, [])) == tuple( + await db.find_movies(conn, include_unrated=True) + ) + assert ((m2, []),) == tuple( + await db.find_movies(conn, include_unrated=True, limit_rows=1) + ) + assert ((m1, []),) == tuple( + await db.find_movies(conn, include_unrated=True, skip_rows=1) + ) - assert ((m2, [r1]), (m1, [])) == tuple( - await db.find_movies(include_unrated=True, user_ids=[u1.id, u2.id]) - ) + assert ((m2, [r1]), (m1, [])) == tuple( + await db.find_movies(conn, include_unrated=True, user_ids=[u1.id, u2.id]) + ) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..f4bd7b6 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,11 @@ +import pytest + +from unwind import models + + +@pytest.mark.parametrize("mapper", models.mapper_registry.mappers) +def test_fields(mapper): + """Test that models.fields() matches exactly all table columns.""" + dcfields = {f.name for f in models.fields(mapper.class_)} + mfields = {c.name for c in mapper.columns} + assert dcfields == mfields diff --git a/tests/test_web.py b/tests/test_web.py index 5a4c3c5..0444406 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -34,7 +34,7 @@ def admin_client() -> TestClient: @pytest.mark.asyncio async def test_get_ratings_for_group( - shared_conn: db.Database, unauthorized_client: TestClient + conn: db.Connection, unauthorized_client: TestClient ): user = models.User( imdb_id="ur12345678", @@ -48,201 +48,196 @@ async def test_get_ratings_for_group( ) user.groups = [models.UserGroup(id=str(group.id), access="r")] path = app.url_path_for("get_ratings_for_group", group_id=str(group.id)) - async with shared_conn.transaction(force_rollback=True): - resp = unauthorized_client.get(path) - assert resp.status_code == 404, "Group does not exist (yet)" - await db.add(user) - await db.add(group) + resp = unauthorized_client.get(path) + assert resp.status_code == 404, "Group does not exist (yet)" - resp = unauthorized_client.get(path) - assert resp.status_code == 200 - assert resp.json() == [] + await db.add(conn, user) + await db.add(conn, group) - movie = models.Movie( - title="test movie", - release_year=2013, - media_type="Movie", - imdb_id="tt12345678", - genres={"genre-1"}, - ) - await db.add(movie) + resp = unauthorized_client.get(path) + assert resp.status_code == 200 + assert resp.json() == [] - rating = models.Rating( - movie_id=movie.id, user_id=user.id, score=66, rating_date=datetime.now() - ) - await db.add(rating) + movie = models.Movie( + title="test movie", + release_year=2013, + media_type="Movie", + imdb_id="tt12345678", + genres={"genre-1"}, + ) + await db.add(conn, movie) - rating_aggregate = { - "canonical_title": movie.title, - "imdb_score": movie.imdb_score, - "imdb_votes": movie.imdb_votes, - "link": imdb.movie_url(movie.imdb_id), - "media_type": movie.media_type, - "original_title": movie.original_title, - "user_scores": [rating.score], - "year": movie.release_year, - } + rating = models.Rating( + movie_id=movie.id, user_id=user.id, score=66, rating_date=datetime.now() + ) + await db.add(conn, rating) - resp = unauthorized_client.get(path) + rating_aggregate = { + "canonical_title": movie.title, + "imdb_score": movie.imdb_score, + "imdb_votes": movie.imdb_votes, + "link": imdb.movie_url(movie.imdb_id), + "media_type": movie.media_type, + "original_title": movie.original_title, + "user_scores": [rating.score], + "year": movie.release_year, + } + + resp = unauthorized_client.get(path) + assert resp.status_code == 200 + assert resp.json() == [rating_aggregate] + + filters = { + "imdb_id": movie.imdb_id, + "unwind_id": str(movie.id), + "title": movie.title, + "media_type": movie.media_type, + "year": movie.release_year, + } + for k, v in filters.items(): + resp = unauthorized_client.get(path, params={k: v}) assert resp.status_code == 200 assert resp.json() == [rating_aggregate] - filters = { - "imdb_id": movie.imdb_id, - "unwind_id": str(movie.id), - "title": movie.title, - "media_type": movie.media_type, - "year": movie.release_year, - } - for k, v in filters.items(): - resp = unauthorized_client.get(path, params={k: v}) - assert resp.status_code == 200 - assert resp.json() == [rating_aggregate] + resp = unauthorized_client.get(path, params={"title": "no such thing"}) + assert resp.status_code == 200 + assert resp.json() == [] - resp = unauthorized_client.get(path, params={"title": "no such thing"}) - assert resp.status_code == 200 - assert resp.json() == [] + # Test "exact" query param. + resp = unauthorized_client.get( + path, params={"title": "test movie", "exact": "true"} + ) + assert resp.status_code == 200 + assert resp.json() == [rating_aggregate] + resp = unauthorized_client.get(path, params={"title": "te mo", "exact": "false"}) + assert resp.status_code == 200 + assert resp.json() == [rating_aggregate] + resp = unauthorized_client.get(path, params={"title": "te mo", "exact": "true"}) + assert resp.status_code == 200 + assert resp.json() == [] - # Test "exact" query param. - resp = unauthorized_client.get( - path, params={"title": "test movie", "exact": "true"} - ) - assert resp.status_code == 200 - assert resp.json() == [rating_aggregate] - resp = unauthorized_client.get( - path, params={"title": "te mo", "exact": "false"} - ) - assert resp.status_code == 200 - assert resp.json() == [rating_aggregate] - resp = unauthorized_client.get(path, params={"title": "te mo", "exact": "true"}) - assert resp.status_code == 200 - assert resp.json() == [] - - # XXX Test "ignore_tv_episodes" query param. - # XXX Test "include_unrated" query param. - # XXX Test "per_page" query param. + # XXX Test "ignore_tv_episodes" query param. + # XXX Test "include_unrated" query param. + # XXX Test "per_page" query param. @pytest.mark.asyncio async def test_list_movies( - shared_conn: db.Database, + conn: db.Connection, unauthorized_client: TestClient, authorized_client: TestClient, ): path = app.url_path_for("list_movies") - async with shared_conn.transaction(force_rollback=True): - response = unauthorized_client.get(path) - assert response.status_code == 403 + response = unauthorized_client.get(path) + assert response.status_code == 403 - response = authorized_client.get(path) - assert response.status_code == 200 - assert response.json() == [] + response = authorized_client.get(path) + assert response.status_code == 200 + assert response.json() == [] - m = models.Movie( - title="test movie", - release_year=2013, - media_type="Movie", - imdb_id="tt12345678", - genres={"genre-1"}, - ) - await db.add(m) + m = models.Movie( + title="test movie", + release_year=2013, + media_type="Movie", + imdb_id="tt12345678", + genres={"genre-1"}, + ) + await db.add(conn, m) - response = authorized_client.get(path, params={"include_unrated": 1}) - assert response.status_code == 200 - assert response.json() == [{**models.asplain(m), "user_scores": []}] + response = authorized_client.get(path, params={"include_unrated": 1}) + assert response.status_code == 200 + assert response.json() == [{**models.asplain(m), "user_scores": []}] - m_plain = { - "canonical_title": m.title, - "imdb_score": m.imdb_score, - "imdb_votes": m.imdb_votes, - "link": imdb.movie_url(m.imdb_id), - "media_type": m.media_type, - "original_title": m.original_title, - "user_scores": [], - "year": m.release_year, - } + m_plain = { + "canonical_title": m.title, + "imdb_score": m.imdb_score, + "imdb_votes": m.imdb_votes, + "link": imdb.movie_url(m.imdb_id), + "media_type": m.media_type, + "original_title": m.original_title, + "user_scores": [], + "year": m.release_year, + } - response = authorized_client.get(path, params={"imdb_id": m.imdb_id}) - assert response.status_code == 200 - assert response.json() == [m_plain] + response = authorized_client.get(path, params={"imdb_id": m.imdb_id}) + assert response.status_code == 200 + assert response.json() == [m_plain] - response = authorized_client.get(path, params={"unwind_id": str(m.id)}) - assert response.status_code == 200 - assert response.json() == [m_plain] + response = authorized_client.get(path, params={"unwind_id": str(m.id)}) + assert response.status_code == 200 + assert response.json() == [m_plain] @pytest.mark.asyncio async def test_list_users( - shared_conn: db.Database, + conn: db.Connection, unauthorized_client: TestClient, authorized_client: TestClient, admin_client: TestClient, ): path = app.url_path_for("list_users") - async with shared_conn.transaction(force_rollback=True): - response = unauthorized_client.get(path) - assert response.status_code == 403 + response = unauthorized_client.get(path) + assert response.status_code == 403 - response = authorized_client.get(path) - assert response.status_code == 403 + response = authorized_client.get(path) + assert response.status_code == 403 - response = admin_client.get(path) - assert response.status_code == 200 - assert response.json() == [] + response = admin_client.get(path) + assert response.status_code == 200 + assert response.json() == [] - m = models.User( - imdb_id="ur12345678", - name="user-1", - secret="secret-1", - groups=[], - ) - await db.add(m) + m = models.User( + imdb_id="ur12345678", + name="user-1", + secret="secret-1", + groups=[], + ) + await db.add(conn, m) - m_plain = { - "groups": m.groups, - "id": m.id, - "imdb_id": m.imdb_id, - "name": m.name, - "secret": m.secret, - } + m_plain = { + "groups": m.groups, + "id": m.id, + "imdb_id": m.imdb_id, + "name": m.name, + "secret": m.secret, + } - response = admin_client.get(path) - assert response.status_code == 200 - assert response.json() == [m_plain] + response = admin_client.get(path) + assert response.status_code == 200 + assert response.json() == [m_plain] @pytest.mark.asyncio async def test_list_groups( - shared_conn: db.Database, + conn: db.Connection, unauthorized_client: TestClient, authorized_client: TestClient, admin_client: TestClient, ): path = app.url_path_for("list_groups") - async with shared_conn.transaction(force_rollback=True): - response = unauthorized_client.get(path) - assert response.status_code == 403 + response = unauthorized_client.get(path) + assert response.status_code == 403 - response = authorized_client.get(path) - assert response.status_code == 403 + response = authorized_client.get(path) + assert response.status_code == 403 - response = admin_client.get(path) - assert response.status_code == 200 - assert response.json() == [] + response = admin_client.get(path) + assert response.status_code == 200 + assert response.json() == [] - m = models.Group( - name="group-1", - users=[models.GroupUser(id="123", name="itsa-me")], - ) - await db.add(m) + m = models.Group( + name="group-1", + users=[models.GroupUser(id="123", name="itsa-me")], + ) + await db.add(conn, m) - m_plain = { - "users": m.users, - "id": m.id, - "name": m.name, - } + m_plain = { + "users": m.users, + "id": m.id, + "name": m.name, + } - response = admin_client.get(path) - assert response.status_code == 200 - assert response.json() == [m_plain] + response = admin_client.get(path) + assert response.status_code == 200 + assert response.json() == [m_plain] diff --git a/unwind/db.py b/unwind/db.py index 51e66d5..3ebca66 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -1,13 +1,11 @@ -import asyncio import contextlib import logging -import threading from pathlib import Path -from typing import Any, AsyncGenerator, Iterable, Literal, Type, TypeVar +from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type, TypeVar import sqlalchemy as sa -from databases import Database from sqlalchemy.dialects.sqlite import insert +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine from . import config from .models import ( @@ -31,7 +29,9 @@ from .types import ULID log = logging.getLogger(__name__) T = TypeVar("T") -_database: Database | None = None +_engine: AsyncEngine | None = None + +type Connection = AsyncConnection async def open_connection_pool() -> None: @@ -39,12 +39,13 @@ async def open_connection_pool() -> None: This function needs to be called before any access to the database can happen. """ - db = _shared_connection() - await db.connect() + async with transaction() as conn: + await conn.execute(sa.text("PRAGMA journal_mode=WAL")) - await db.execute(sa.text("PRAGMA journal_mode=WAL")) + await conn.run_sync(metadata.create_all, tables=[db_patches]) - await apply_db_patches(db) + async with new_connection() as conn: + await apply_db_patches(conn) async def close_connection_pool() -> None: @@ -53,32 +54,33 @@ async def close_connection_pool() -> None: This function should be called before the app shuts down to ensure all data has been flushed to the database. """ - db = _shared_connection() + engine = _shared_engine() - # Run automatic ANALYZE prior to closing the db, - # see https://sqlite.com/lang_analyze.html. - await db.execute(sa.text("PRAGMA analysis_limit=400")) - await db.execute(sa.text("PRAGMA optimize")) + async with engine.begin() as conn: + # Run automatic ANALYZE prior to closing the db, + # see https://sqlite.com/lang_analyze.html. + await conn.execute(sa.text("PRAGMA analysis_limit=400")) + await conn.execute(sa.text("PRAGMA optimize")) - await db.disconnect() + await engine.dispose() -async def current_patch_level(db: Database) -> str: +async def current_patch_level(conn: Connection, /) -> str: query = sa.select(db_patches.c.current) - current = await db.fetch_val(query) + current = await conn.scalar(query) return current or "" -async def set_current_patch_level(db: Database, current: str) -> None: +async def set_current_patch_level(conn: Connection, /, current: str) -> None: stmt = insert(db_patches).values(id=1, current=current) stmt = stmt.on_conflict_do_update(set_={"current": stmt.excluded.current}) - await db.execute(stmt) + await conn.execute(stmt) db_patches_dir = Path(__file__).parent / "sql" -async def apply_db_patches(db: Database) -> None: +async def apply_db_patches(conn: Connection, /) -> None: """Apply all remaining patches to the database. Beware that patches will be applied in lexicographical order, @@ -90,7 +92,7 @@ async def apply_db_patches(db: Database) -> None: using two consecutive semi-colons (;). Failing to do so will result in an error. """ - applied_lvl = await current_patch_level(db) + applied_lvl = await current_patch_level(conn) did_patch = False @@ -109,31 +111,52 @@ async def apply_db_patches(db: Database) -> None: ) raise RuntimeError("No statement found.") - async with db.transaction(): + async with transacted(conn): for query in queries: - await db.execute(sa.text(query)) + await conn.execute(sa.text(query)) - await set_current_patch_level(db, patch_lvl) + await set_current_patch_level(conn, patch_lvl) did_patch = True if did_patch: - await db.execute(sa.text("vacuum")) + await _vacuum(conn) -async def get_import_progress() -> Progress | None: +async def _vacuum(conn: Connection, /) -> None: + """Vacuum the database. + + This function cannot be run on a connection with an open transaction. + """ + # With SQLAlchemy's "autobegin" behavior we need to switch the connection + # to "autocommit" first to keep it from automatically starting a transaction, + # as VACUUM cannot be run inside a transaction for most databases. + await conn.commit() + isolation_level = await conn.get_isolation_level() + log.debug("Previous isolation_level: %a", isolation_level) + await conn.execution_options(isolation_level="AUTOCOMMIT") + try: + await conn.execute(sa.text("vacuum")) + await conn.commit() + finally: + await conn.execution_options(isolation_level=isolation_level) + + +async def get_import_progress(conn: Connection, /) -> Progress | None: """Return the latest import progress.""" return await get( - Progress, type="import-imdb-movies", order_by=(progress.c.started, "desc") + conn, Progress, type="import-imdb-movies", order_by=(progress.c.started, "desc") ) -async def stop_import_progress(*, error: BaseException | None = None) -> None: +async def stop_import_progress( + conn: Connection, /, *, error: BaseException | None = None +) -> None: """Stop the current import. If an error is given, it will be logged to the progress state. """ - current = await get_import_progress() + current = await get_import_progress(conn) is_running = current and current.stopped is None if not is_running: @@ -144,17 +167,17 @@ async def stop_import_progress(*, error: BaseException | None = None) -> None: current.error = repr(error) current.stopped = utcnow().isoformat() - await update(current) + await update(conn, current) -async def set_import_progress(progress: float) -> Progress: +async def set_import_progress(conn: Connection, /, progress: float) -> Progress: """Set the current import progress percentage. If no import is currently running, this will create a new one. """ progress = min(max(0.0, progress), 100.0) # clamp to 0 <= progress <= 100 - current = await get_import_progress() + current = await get_import_progress(conn) is_running = current and current.stopped is None if not is_running: @@ -164,71 +187,88 @@ async def set_import_progress(progress: float) -> Progress: current.percent = progress if is_running: - await update(current) + await update(conn, current) else: - await add(current) + await add(conn, current) return current -_lock = threading.Lock() -_prelock = threading.Lock() +def _new_engine() -> AsyncEngine: + uri = f"sqlite+aiosqlite:///{config.storage_path}" + + return create_async_engine( + uri, + isolation_level="SERIALIZABLE", + ) + + +def _shared_engine() -> AsyncEngine: + global _engine + + if _engine is None: + _engine = _new_engine() + + return _engine + + +def _new_connection() -> Connection: + return _shared_engine().connect() @contextlib.asynccontextmanager -async def single_threaded(): - """Ensure the nested code is run only by a single thread at a time.""" - wait = 1e-5 # XXX not sure if there's a better magic value here +async def transaction( + *, force_rollback: bool = False +) -> AsyncGenerator[Connection, None]: + async with new_connection() as conn: + yield conn - # The pre-lock (a lock for the lock) allows for multiple threads to hand of - # the main lock. - # With only a single lock the contending thread will spend most of its time - # in the asyncio.sleep and the reigning thread will have time to finish - # whatever it's doing and simply acquire the lock again before the other - # thread has had a change to try. - # By having another lock (and the same sleep time!) the contending thread - # will always have a chance to acquire the main lock. - while not _prelock.acquire(blocking=False): - await asyncio.sleep(wait) + if not force_rollback: + await conn.commit() - try: - while not _lock.acquire(blocking=False): - await asyncio.sleep(wait) - finally: - _prelock.release() - try: - yield - - finally: - _lock.release() +# The _test_connection allows pinning a connection that will be shared across the app. +# This can (and should only) be used when running tests, NOT IN PRODUCTION! +_test_connection: Connection | None = None @contextlib.asynccontextmanager -async def _locked_connection(): - async with single_threaded(): - yield _shared_connection() +async def new_connection() -> AsyncGenerator[Connection, None]: + """Return a new connection. + + Any changes will be rolled back, unless `.commit()` is called on the + connection. + + If you want to commit changes, consider using `transaction()` instead. + """ + conn = _test_connection or _new_connection() + + # Support reusing the same connection for _test_connection. + is_started = conn.sync_connection is not None + if is_started: + yield conn + return + + async with conn: + yield conn -def _shared_connection() -> Database: - global _database +@contextlib.asynccontextmanager +async def transacted( + conn: Connection, /, *, force_rollback: bool = False +) -> AsyncGenerator[None, None]: + transaction = contextlib.nullcontext() if conn.in_transaction() else conn.begin() - if _database is None: - uri = f"sqlite:///{config.storage_path}" - # uri = f"sqlite+aiosqlite:///{config.storage_path}" - _database = Database(uri) + async with transaction: + try: + yield - engine = sa.create_engine(uri, future=True) - metadata.create_all(engine, tables=[db_patches]) - - return _database + finally: + if force_rollback: + await conn.rollback() -def transaction(): - return _shared_connection().transaction() - - -async def add(item: Model) -> None: +async def add(conn: Connection, /, item: Model) -> None: # Support late initializing - used for optimization. if getattr(item, "_is_lazy", False): assert hasattr(item, "_lazy_init") @@ -237,14 +277,29 @@ async def add(item: Model) -> None: table: sa.Table = item.__table__ values = asplain(item, serialize=True) stmt = table.insert().values(values) - async with _locked_connection() as conn: - await conn.execute(stmt) + await conn.execute(stmt) + + +async def fetch_all( + conn: Connection, /, query: sa.Executable, values: "dict | None" = None +) -> Sequence[sa.Row]: + result = await conn.execute(query, values) + return result.all() + + +async def fetch_one( + conn: Connection, /, query: sa.Executable, values: "dict | None" = None +) -> sa.Row | None: + result = await conn.execute(query, values) + return result.first() ModelType = TypeVar("ModelType", bound=Model) async def get( + conn: Connection, + /, model: Type[ModelType], *, order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None, @@ -268,13 +323,12 @@ async def get( query = query.order_by( order_col.asc() if order_dir == "asc" else order_col.desc() ) - async with _locked_connection() as conn: - row = await conn.fetch_one(query) + row = await fetch_one(conn, query) return fromplain(model, row._mapping, serialized=True) if row else None async def get_many( - model: Type[ModelType], **field_sets: set | list + conn: Connection, /, model: Type[ModelType], **field_sets: set | list ) -> Iterable[ModelType]: """Return the items with any values matching all given field sets. @@ -288,12 +342,13 @@ async def get_many( table: sa.Table = model.__table__ query = sa.select(model).where(*(table.c[k].in_(v) for k, v in field_sets.items())) - async with _locked_connection() as conn: - rows = await conn.fetch_all(query) + rows = await fetch_all(conn, query) return (fromplain(model, row._mapping, serialized=True) for row in rows) -async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType]: +async def get_all( + conn: Connection, /, model: Type[ModelType], **field_values +) -> Iterable[ModelType]: """Filter all items by comparing all given field values. If no filters are given, all items will be returned. @@ -302,12 +357,11 @@ async def get_all(model: Type[ModelType], **field_values) -> Iterable[ModelType] query = sa.select(model).where( *(table.c[k] == v for k, v in field_values.items() if v is not None) ) - async with _locked_connection() as conn: - rows = await conn.fetch_all(query) + rows = await fetch_all(conn, query) return (fromplain(model, row._mapping, serialized=True) for row in rows) -async def update(item: Model) -> None: +async def update(conn: Connection, /, item: Model) -> None: # Support late initializing - used for optimization. if getattr(item, "_is_lazy", False): assert hasattr(item, "_lazy_init") @@ -316,30 +370,28 @@ async def update(item: Model) -> None: table: sa.Table = item.__table__ values = asplain(item, serialize=True) stmt = table.update().where(table.c.id == values["id"]).values(values) - async with _locked_connection() as conn: - await conn.execute(stmt) + await conn.execute(stmt) -async def remove(item: Model) -> None: +async def remove(conn: Connection, /, item: Model) -> None: table: sa.Table = item.__table__ values = asplain(item, filter_fields={"id"}, serialize=True) stmt = table.delete().where(table.c.id == values["id"]) - async with _locked_connection() as conn: - await conn.execute(stmt) + await conn.execute(stmt) -async def add_or_update_user(user: User) -> None: - db_user = await get(User, imdb_id=user.imdb_id) +async def add_or_update_user(conn: Connection, /, user: User) -> None: + db_user = await get(conn, User, imdb_id=user.imdb_id) if not db_user: - await add(user) + await add(conn, user) else: user.id = db_user.id if user != db_user: - await update(user) + await update(conn, user) -async def add_or_update_many_movies(movies: list[Movie]) -> None: +async def add_or_update_many_movies(conn: Connection, /, movies: list[Movie]) -> None: """Add or update Movies in the database. This is an optimized version of `add_or_update_movie` for the purpose @@ -348,12 +400,13 @@ async def add_or_update_many_movies(movies: list[Movie]) -> None: # for movie in movies: # await add_or_update_movie(movie) db_movies = { - m.imdb_id: m for m in await get_many(Movie, imdb_id=[m.imdb_id for m in movies]) + m.imdb_id: m + for m in await get_many(conn, Movie, imdb_id=[m.imdb_id for m in movies]) } for movie in movies: # XXX optimize bulk add & update as well if movie.imdb_id not in db_movies: - await add(movie) + await add(conn, movie) else: db_movie = db_movies[movie.imdb_id] movie.id = db_movie.id @@ -366,10 +419,10 @@ async def add_or_update_many_movies(movies: list[Movie]) -> None: if movie.updated <= db_movie.updated: return - await update(movie) + await update(conn, movie) -async def add_or_update_movie(movie: Movie) -> None: +async def add_or_update_movie(conn: Connection, /, movie: Movie) -> None: """Add or update a Movie in the database. This is an upsert operation, but it will also update the Movie you pass @@ -377,9 +430,9 @@ async def add_or_update_movie(movie: Movie) -> None: set all optional values on your Movie that might be unset but exist in the database. It's a bidirectional sync. """ - db_movie = await get(Movie, imdb_id=movie.imdb_id) + db_movie = await get(conn, Movie, imdb_id=movie.imdb_id) if not db_movie: - await add(movie) + await add(conn, movie) else: movie.id = db_movie.id @@ -391,23 +444,23 @@ async def add_or_update_movie(movie: Movie) -> None: if movie.updated <= db_movie.updated: return - await update(movie) + await update(conn, movie) -async def add_or_update_rating(rating: Rating) -> bool: +async def add_or_update_rating(conn: Connection, /, rating: Rating) -> bool: db_rating = await get( - Rating, movie_id=str(rating.movie_id), user_id=str(rating.user_id) + conn, Rating, movie_id=str(rating.movie_id), user_id=str(rating.user_id) ) if not db_rating: - await add(rating) + await add(conn, rating) return True else: rating.id = db_rating.id if rating != db_rating: - await update(rating) + await update(conn, rating) return True return False @@ -418,6 +471,8 @@ def sql_escape(s: str, char: str = "#") -> str: async def find_ratings( + conn: Connection, + /, *, title: str | None = None, media_type: str | None = None, @@ -475,9 +530,8 @@ async def find_ratings( ) .limit(limit_rows) ) - async with _locked_connection() as conn: - rating_rows: AsyncGenerator[Rating, None] = conn.iterate(query) # type: ignore - movie_ids = [r.movie_id async for r in rating_rows] + rating_rows: sa.CursorResult[Rating] = await conn.execute(query) + movie_ids = [r.movie_id for r in rating_rows] if include_unrated and len(movie_ids) < limit_rows: query = ( @@ -491,15 +545,17 @@ async def find_ratings( ) .limit(limit_rows - len(movie_ids)) ) - async with _locked_connection() as conn: - movie_rows: AsyncGenerator[Movie, None] = conn.iterate(query) # type: ignore - movie_ids += [r.id async for r in movie_rows] + movie_rows: sa.CursorResult[Movie] = await conn.execute(query) + movie_ids += [r.id for r in movie_rows] - return await ratings_for_movie_ids(ids=movie_ids) + return await ratings_for_movie_ids(conn, ids=movie_ids) async def ratings_for_movie_ids( - ids: Iterable[ULID | str] = [], imdb_ids: Iterable[str] = [] + conn: Connection, + /, + ids: Iterable[ULID | str] = [], + imdb_ids: Iterable[str] = [], ) -> Iterable[dict[str, Any]]: conds = [] @@ -527,13 +583,12 @@ async def ratings_for_movie_ids( .outerjoin_from(movies, ratings, movies.c.id == ratings.c.movie_id) .where(sa.or_(*conds)) ) - async with _locked_connection() as conn: - rows = await conn.fetch_all(query) + rows = await fetch_all(conn, query) return tuple(dict(r._mapping) for r in rows) async def ratings_for_movies( - movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = [] + conn: Connection, /, movie_ids: Iterable[ULID], user_ids: Iterable[ULID] = [] ) -> Iterable[Rating]: conditions = [ratings.c.movie_id.in_(str(x) for x in movie_ids)] @@ -542,13 +597,14 @@ async def ratings_for_movies( query = sa.select(ratings).where(*conditions) - async with _locked_connection() as conn: - rows = await conn.fetch_all(query) + rows = await fetch_all(conn, query) return (fromplain(Rating, row._mapping, serialized=True) for row in rows) async def find_movies( + conn: Connection, + /, *, title: str | None = None, media_type: str | None = None, @@ -606,15 +662,14 @@ async def find_movies( .offset(skip_rows) ) - async with _locked_connection() as conn: - rows = await conn.fetch_all(query) + rows = await fetch_all(conn, query) movies_ = [fromplain(Movie, row._mapping, serialized=True) for row in rows] if not user_ids: return ((m, []) for m in movies_) - ratings = await ratings_for_movies((m.id for m in movies_), user_ids) + ratings = await ratings_for_movies(conn, (m.id for m in movies_), user_ids) aggreg: dict[ULID, tuple[Movie, list[Rating]]] = {m.id: (m, []) for m in movies_} for rating in ratings: diff --git a/unwind/imdb.py b/unwind/imdb.py index 6858fc7..631a088 100644 --- a/unwind/imdb.py +++ b/unwind/imdb.py @@ -40,7 +40,9 @@ async def refresh_user_ratings_from_imdb(stop_on_dupe: bool = True): async with asession() as s: s.headers["Accept-Language"] = "en-US, en;q=0.5" - for user in await db.get_all(User): + async with db.new_connection() as conn: + users = list(await db.get_all(conn, User)) + for user in users: log.info("⚡️ Loading data for %s ...", user.name) try: @@ -96,7 +98,7 @@ find_year = re.compile( find_movie_id = re.compile(r"/title/(?Ptt\d+)/").search -def movie_and_rating_from_item(item) -> tuple[Movie, Rating]: +def movie_and_rating_from_item(item: bs4.Tag) -> tuple[Movie, Rating]: genres = (genre := item.find("span", "genre")) and genre.string or "" movie = Movie( title=item.h3.a.string.strip(), @@ -161,9 +163,10 @@ async def parse_page(url: str) -> tuple[list[Rating], str | None]: assert isinstance(meta, bs4.Tag) imdb_id = meta["content"] assert isinstance(imdb_id, str) - user = await db.get(User, imdb_id=imdb_id) or User( - imdb_id=imdb_id, name="", secret="" - ) + async with db.new_connection() as conn: + user = await db.get(conn, User, imdb_id=imdb_id) or User( + imdb_id=imdb_id, name="", secret="" + ) if (headline := soup.h1) is None: raise RuntimeError("No headline found.") @@ -213,14 +216,15 @@ async def load_ratings(user_id: str): for i, rating in enumerate(ratings): assert rating.user and rating.movie - if i == 0: - # All rating objects share the same user. - await db.add_or_update_user(rating.user) - rating.user_id = rating.user.id + async with db.transaction() as conn: + if i == 0: + # All rating objects share the same user. + await db.add_or_update_user(conn, rating.user) + rating.user_id = rating.user.id - await db.add_or_update_movie(rating.movie) - rating.movie_id = rating.movie.id + await db.add_or_update_movie(conn, rating.movie) + rating.movie_id = rating.movie.id - is_updated = await db.add_or_update_rating(rating) + is_updated = await db.add_or_update_rating(conn, rating) yield rating, is_updated diff --git a/unwind/imdb_import.py b/unwind/imdb_import.py index 705db2f..dad419e 100644 --- a/unwind/imdb_import.py +++ b/unwind/imdb_import.py @@ -209,7 +209,8 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path): for i, m in enumerate(read_basics(basics_path)): perc = 100 * i / total if perc >= perc_next_report: - await db.set_import_progress(perc) + async with db.transaction() as conn: + await db.set_import_progress(conn, perc) log.info("⏳ Imported %s%%", round(perc, 1)) perc_next_report += perc_step @@ -233,15 +234,18 @@ async def import_from_file(*, basics_path: Path, ratings_path: Path): chunk.append(m) if len(chunk) > 1000: - await add_or_update_many_movies(chunk) + async with db.transaction() as conn: + await add_or_update_many_movies(conn, chunk) chunk = [] if chunk: - await add_or_update_many_movies(chunk) + async with db.transaction() as conn: + await add_or_update_many_movies(conn, chunk) chunk = [] log.info("👍 Imported 100%") - await db.set_import_progress(100) + async with db.transaction() as conn: + await db.set_import_progress(conn, 100) async def download_datasets(*, basics_path: Path, ratings_path: Path) -> None: @@ -270,7 +274,8 @@ async def load_from_web(*, force: bool = False) -> None: See https://www.imdb.com/interfaces/ and https://datasets.imdbws.com/ for more information on the IMDb database dumps. """ - await db.set_import_progress(0) + async with db.transaction() as conn: + await db.set_import_progress(conn, 0) try: ratings_file = config.datadir / "imdb/title.ratings.tsv.gz" @@ -290,8 +295,10 @@ async def load_from_web(*, force: bool = False) -> None: await import_from_file(basics_path=basics_file, ratings_path=ratings_file) except BaseException as err: - await db.stop_import_progress(error=err) + async with db.transaction() as conn: + await db.stop_import_progress(conn, error=err) raise else: - await db.stop_import_progress() + async with db.transaction() as conn: + await db.stop_import_progress(conn) diff --git a/unwind/models.py b/unwind/models.py index ff961fc..3e52225 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -354,49 +354,6 @@ The contents of the Relation are ignored or discarded when using Relation = Annotated[T | None, _RelationSentinel] -@mapper_registry.mapped -@dataclass -class Rating: - __table__: ClassVar[Table] = Table( - "ratings", - metadata, - Column("id", String, primary_key=True), # ULID - Column("movie_id", ForeignKey("movies.id"), nullable=False), # ULID - Column("user_id", ForeignKey("users.id"), nullable=False), # ULID - Column("score", Integer, nullable=False), - Column("rating_date", String, nullable=False), # datetime - Column("favorite", Integer), # bool - Column("finished", Integer), # bool - ) - - id: ULID = field(default_factory=ULID) - - movie_id: ULID = None - movie: Relation[Movie] = None - - user_id: ULID = None - user: Relation["User"] = None - - score: int = None # range: [0,100] - rating_date: datetime = None - favorite: bool | None = None - finished: bool | None = None - - def __eq__(self, other): - """Return wether two Ratings are equal. - - This operation compares all fields as expected, except that it - ignores any field marked as Relation. - """ - if type(other) is not type(self): - return False - return all( - getattr(self, f.name) == getattr(other, f.name) for f in fields(self) - ) - - -ratings = Rating.__table__ - Access = Literal[ "r", # read "i", # index @@ -442,6 +399,50 @@ class User: self.groups.append({"id": group_id, "access": access}) +@mapper_registry.mapped +@dataclass +class Rating: + __table__: ClassVar[Table] = Table( + "ratings", + metadata, + Column("id", String, primary_key=True), # ULID + Column("movie_id", ForeignKey("movies.id"), nullable=False), # ULID + Column("user_id", ForeignKey("users.id"), nullable=False), # ULID + Column("score", Integer, nullable=False), + Column("rating_date", String, nullable=False), # datetime + Column("favorite", Integer), # bool + Column("finished", Integer), # bool + ) + + id: ULID = field(default_factory=ULID) + + movie_id: ULID = None + movie: Relation[Movie] = None + + user_id: ULID = None + user: Relation[User] = None + + score: int = None # range: [0,100] + rating_date: datetime = None + favorite: bool | None = None + finished: bool | None = None + + def __eq__(self, other): + """Return wether two Ratings are equal. + + This operation compares all fields as expected, except that it + ignores any field marked as Relation. + """ + if type(other) is not type(self): + return False + return all( + getattr(self, f.name) == getattr(other, f.name) for f in fields(self) + ) + + +ratings = Rating.__table__ + + class GroupUser(TypedDict): id: str name: str diff --git a/unwind/web.py b/unwind/web.py index bddd54d..8b02863 100644 --- a/unwind/web.py +++ b/unwind/web.py @@ -168,7 +168,8 @@ async def auth_user(request) -> User | None: if not isinstance(request.user, AuthedUser): return - user = await db.get(User, id=request.user.user_id) + async with db.new_connection() as conn: + user = await db.get(conn, User, id=request.user.user_id) if not user: return @@ -195,8 +196,9 @@ def route(path: str, *, methods: list[str] | None = None, **kwds): async def get_ratings_for_group(request): group_id = as_ulid(request.path_params["group_id"]) - if (group := await db.get(Group, id=str(group_id))) is None: - return not_found() + async with db.new_connection() as conn: + if (group := await db.get(conn, Group, id=str(group_id))) is None: + return not_found() user_ids = {u["id"] for u in group.users} @@ -207,22 +209,26 @@ async def get_ratings_for_group(request): # if (imdb_id or unwind_id) and (movie := await db.get(Movie, id=unwind_id, imdb_id=imdb_id)): if unwind_id: - rows = await db.ratings_for_movie_ids(ids=[unwind_id]) + async with db.new_connection() as conn: + rows = await db.ratings_for_movie_ids(conn, ids=[unwind_id]) elif imdb_id: - rows = await db.ratings_for_movie_ids(imdb_ids=[imdb_id]) + async with db.new_connection() as conn: + rows = await db.ratings_for_movie_ids(conn, imdb_ids=[imdb_id]) else: - rows = await find_ratings( - title=params.get("title"), - media_type=params.get("media_type"), - exact=truthy(params.get("exact")), - ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")), - include_unrated=truthy(params.get("include_unrated")), - yearcomp=yearcomp(params["year"]) if "year" in params else None, - limit_rows=as_int(params.get("per_page"), max=10, default=5), - user_ids=user_ids, - ) + async with db.new_connection() as conn: + rows = await find_ratings( + conn, + title=params.get("title"), + media_type=params.get("media_type"), + exact=truthy(params.get("exact")), + ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")), + include_unrated=truthy(params.get("include_unrated")), + yearcomp=yearcomp(params["year"]) if "year" in params else None, + limit_rows=as_int(params.get("per_page"), max=10, default=5), + user_ids=user_ids, + ) ratings = (web_models.Rating(**r) for r in rows) @@ -261,7 +267,8 @@ async def list_movies(request): if group_id := params.get("group_id"): group_id = as_ulid(group_id) - group = await db.get(Group, id=str(group_id)) + async with db.new_connection() as conn: + group = await db.get(conn, Group, id=str(group_id)) if not group: return not_found("Group not found.") @@ -286,26 +293,31 @@ async def list_movies(request): if imdb_id or unwind_id: # XXX missing support for user_ids and user_scores - movies = ( - [m] if (m := await db.get(Movie, id=unwind_id, imdb_id=imdb_id)) else [] - ) + async with db.new_connection() as conn: + movies = ( + [m] + if (m := await db.get(conn, Movie, id=unwind_id, imdb_id=imdb_id)) + else [] + ) resp = [asplain(web_models.RatingAggregate.from_movie(m)) for m in movies] else: per_page = as_int(params.get("per_page"), max=1000, default=5) page = as_int(params.get("page"), min=1, default=1) - movieratings = await find_movies( - title=params.get("title"), - media_type=params.get("media_type"), - exact=truthy(params.get("exact")), - ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")), - include_unrated=truthy(params.get("include_unrated")), - yearcomp=yearcomp(params["year"]) if "year" in params else None, - limit_rows=per_page, - skip_rows=(page - 1) * per_page, - user_ids=list(user_ids), - ) + async with db.new_connection() as conn: + movieratings = await find_movies( + conn, + title=params.get("title"), + media_type=params.get("media_type"), + exact=truthy(params.get("exact")), + ignore_tv_episodes=truthy(params.get("ignore_tv_episodes")), + include_unrated=truthy(params.get("include_unrated")), + yearcomp=yearcomp(params["year"]) if "year" in params else None, + limit_rows=per_page, + skip_rows=(page - 1) * per_page, + user_ids=list(user_ids), + ) resp = [] for movie, ratings in movieratings: @@ -325,7 +337,8 @@ async def add_movie(request): @route("/movies/_reload_imdb", methods=["GET"]) @requires(["authenticated", "admin"]) async def progress_for_load_imdb_movies(request): - progress = await db.get_import_progress() + async with db.new_connection() as conn: + progress = await db.get_import_progress(conn) if not progress: return JSONResponse({"status": "No import exists."}, status_code=404) @@ -364,14 +377,16 @@ async def load_imdb_movies(request): force = truthy(params.get("force")) async with _import_lock: - progress = await db.get_import_progress() + async with db.new_connection() as conn: + progress = await db.get_import_progress(conn) if progress and not progress.stopped: return JSONResponse( {"status": "Import is running.", "progress": progress.percent}, status_code=409, ) - await db.set_import_progress(0) + async with db.transaction() as conn: + await db.set_import_progress(conn, 0) task = BackgroundTask(imdb_import.load_from_web, force=force) return JSONResponse( @@ -382,7 +397,8 @@ async def load_imdb_movies(request): @route("/users") @requires(["authenticated", "admin"]) async def list_users(request): - users = await db.get_all(User) + async with db.new_connection() as conn: + users = await db.get_all(conn, User) return JSONResponse([asplain(u) for u in users]) @@ -398,7 +414,8 @@ async def add_user(request): secret = secrets.token_bytes() user = User(name=name, imdb_id=imdb_id, secret=phc_scrypt(secret)) - await db.add(user) + async with db.transaction() as conn: + await db.add(conn, user) return JSONResponse( { @@ -414,7 +431,8 @@ async def show_user(request): user_id = as_ulid(request.path_params["user_id"]) if is_admin(request): - user = await db.get(User, id=str(user_id)) + async with db.new_connection() as conn: + user = await db.get(conn, User, id=str(user_id)) else: user = await auth_user(request) @@ -441,14 +459,15 @@ async def show_user(request): async def remove_user(request): user_id = as_ulid(request.path_params["user_id"]) - user = await db.get(User, id=str(user_id)) + async with db.new_connection() as conn: + user = await db.get(conn, User, id=str(user_id)) if not user: return not_found() - async with db.transaction(): + async with db.transaction() as conn: # XXX remove user refs from groups and ratings - await db.remove(user) + await db.remove(conn, user) return JSONResponse(asplain(user)) @@ -459,7 +478,8 @@ async def modify_user(request): user_id = as_ulid(request.path_params["user_id"]) if is_admin(request): - user = await db.get(User, id=str(user_id)) + async with db.new_connection() as conn: + user = await db.get(conn, User, id=str(user_id)) else: user = await auth_user(request) @@ -495,7 +515,8 @@ async def modify_user(request): user.secret = phc_scrypt(secret) - await db.update(user) + async with db.transaction() as conn: + await db.update(conn, user) return JSONResponse(asplain(user)) @@ -505,13 +526,15 @@ async def modify_user(request): async def add_group_to_user(request): user_id = as_ulid(request.path_params["user_id"]) - user = await db.get(User, id=str(user_id)) + async with db.new_connection() as conn: + user = await db.get(conn, User, id=str(user_id)) if not user: return not_found("User not found") (group_id, access) = await json_from_body(request, ["group", "access"]) - group = await db.get(Group, id=str(group_id)) + async with db.new_connection() as conn: + group = await db.get(conn, Group, id=str(group_id)) if not group: return not_found("Group not found") @@ -519,7 +542,8 @@ async def add_group_to_user(request): raise HTTPException(422, f"Invalid access level.") user.set_access(group_id, access) - await db.update(user) + async with db.transaction() as conn: + await db.update(conn, user) return JSONResponse(asplain(user)) @@ -547,7 +571,8 @@ async def load_imdb_user_ratings(request): @route("/groups") @requires(["authenticated", "admin"]) async def list_groups(request): - groups = await db.get_all(Group) + async with db.new_connection() as conn: + groups = await db.get_all(conn, Group) return JSONResponse([asplain(g) for g in groups]) @@ -560,7 +585,8 @@ async def add_group(request): # XXX restrict name group = Group(name=name) - await db.add(group) + async with db.transaction() as conn: + await db.add(conn, group) return JSONResponse(asplain(group)) @@ -569,7 +595,8 @@ async def add_group(request): @requires(["authenticated"]) async def add_user_to_group(request): group_id = as_ulid(request.path_params["group_id"]) - group = await db.get(Group, id=str(group_id)) + async with db.new_connection() as conn: + group = await db.get(conn, Group, id=str(group_id)) if not group: return not_found() @@ -596,7 +623,8 @@ async def add_user_to_group(request): else: group.users.append({"name": name, "id": user_id}) - await db.update(group) + async with db.transaction() as conn: + await db.update(conn, group) return JSONResponse(asplain(group))