diff --git a/Dockerfile b/Dockerfile index 66014c6..1571d75 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,15 +18,18 @@ RUN pip install --no-cache-dir --upgrade \ USER 10000:10001 -COPY run ./ +COPY alembic.ini entrypoint.sh pyproject.toml run ./ +COPY alembic ./alembic COPY scripts ./scripts COPY unwind ./unwind +RUN pip install --no-cache-dir --editable . + ENV UNWIND_DATA="/data" VOLUME $UNWIND_DATA ENV UNWIND_PORT=8097 EXPOSE $UNWIND_PORT -ENTRYPOINT ["/var/app/run"] +ENTRYPOINT ["/var/app/entrypoint.sh"] CMD ["server"] diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..d8a741e --- /dev/null +++ b/alembic.ini @@ -0,0 +1,39 @@ +[alembic] +script_location = alembic +file_template = %%(epoch)s-%%(rev)s_%%(slug)s +timezone = UTC + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..b3ea427 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,108 @@ +import asyncio +from logging.config import fileConfig + +import sqlalchemy as sa +from sqlalchemy import pool +from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import async_engine_from_config + +from alembic import context +from unwind import db, models + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + + +def is_different_type( + context, + inspected_column: sa.Column, + metadata_column: sa.Column, + inspected_type: sa.types.TypeEngine, + metadata_type: sa.types.TypeEngine, +) -> bool | None: + # We used "TEXT" in our manual SQL, which in SQLite is the same as VARCHAR, but + # for SQLAlchemy/Alembic looks different. + equiv_types = [(sa.TEXT, sa.String)] + for types in equiv_types: + if isinstance(inspected_type, types) and isinstance(metadata_type, types): + return False + return None # defer to default compare implementation + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + context.configure( + url=db._connection_uri(), + target_metadata=models.metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + compare_type=is_different_type, + render_as_batch=True, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection: Connection) -> None: + context.configure( + connection=connection, + target_metadata=models.metadata, + compare_type=is_different_type, + render_as_batch=True, + ) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + connectable = async_engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + url=db._connection_uri(), + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + # Support having a (sync) connection passed in from another script. + if (conn := config.attributes.get("connection")) and isinstance( + conn, sa.Connection + ): + do_run_migrations(conn) + else: + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..f31592a --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: str | None = ${repr(down_revision)} +branch_labels: str | Sequence[str] | None = ${repr(branch_labels)} +depends_on: str | Sequence[str] | None = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/1716049471-c08ae04dc482_fix_data_types.py b/alembic/versions/1716049471-c08ae04dc482_fix_data_types.py new file mode 100644 index 0000000..5a72f3a --- /dev/null +++ b/alembic/versions/1716049471-c08ae04dc482_fix_data_types.py @@ -0,0 +1,69 @@ +"""fix data types + +Revision ID: c08ae04dc482 +Revises: +Create Date: 2024-05-18 16:24:31.152480+00:00 + +""" + +from typing import Sequence + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "c08ae04dc482" +down_revision: str | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("ratings", schema=None) as batch_op: + batch_op.alter_column( + "score", + existing_type=sa.NUMERIC(), + type_=sa.Integer(), + existing_nullable=False, + ) + batch_op.alter_column( + "favorite", + existing_type=sa.NUMERIC(), + type_=sa.Integer(), + existing_nullable=True, + ) + batch_op.alter_column( + "finished", + existing_type=sa.NUMERIC(), + type_=sa.Integer(), + existing_nullable=True, + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("ratings", schema=None) as batch_op: + batch_op.alter_column( + "finished", + existing_type=sa.Integer(), + type_=sa.NUMERIC(), + existing_nullable=True, + ) + batch_op.alter_column( + "favorite", + existing_type=sa.Integer(), + type_=sa.NUMERIC(), + existing_nullable=True, + ) + batch_op.alter_column( + "score", + existing_type=sa.Integer(), + type_=sa.NUMERIC(), + existing_nullable=False, + ) + + # ### end Alembic commands ### diff --git a/alembic/versions/1716050110-62882ef5e3ff_add_awards_table.py b/alembic/versions/1716050110-62882ef5e3ff_add_awards_table.py new file mode 100644 index 0000000..b66cee0 --- /dev/null +++ b/alembic/versions/1716050110-62882ef5e3ff_add_awards_table.py @@ -0,0 +1,44 @@ +"""add awards table + +Revision ID: 62882ef5e3ff +Revises: c08ae04dc482 +Create Date: 2024-05-18 16:35:10.145964+00:00 + +""" + +from typing import Sequence + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "62882ef5e3ff" +down_revision: str | None = "c08ae04dc482" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "awards", + sa.Column("id", sa.String(), nullable=False), + sa.Column("movie_id", sa.String(), nullable=False), + sa.Column("category", sa.String(), nullable=False), + sa.Column("details", sa.String(), nullable=False), + sa.Column("created", sa.String(), nullable=False), + sa.Column("updated", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["movie_id"], + ["movies.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("awards") + # ### end Alembic commands ### diff --git a/alembic/versions/1716051987-f17c7ca9afa4_use_named_constraints.py b/alembic/versions/1716051987-f17c7ca9afa4_use_named_constraints.py new file mode 100644 index 0000000..be21664 --- /dev/null +++ b/alembic/versions/1716051987-f17c7ca9afa4_use_named_constraints.py @@ -0,0 +1,41 @@ +"""use named constraints + +See https://alembic.sqlalchemy.org/en/latest/naming.html + +Revision ID: f17c7ca9afa4 +Revises: 62882ef5e3ff +Create Date: 2024-05-18 17:06:27.696713+00:00 + +""" + +from typing import Sequence + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "f17c7ca9afa4" +down_revision: str | None = "62882ef5e3ff" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("movies") as batch_op: + batch_op.create_unique_constraint(batch_op.f("uq_movies_imdb_id"), ["imdb_id"]) + + with op.batch_alter_table("users", schema=None) as batch_op: + batch_op.create_unique_constraint(batch_op.f("uq_users_imdb_id"), ["imdb_id"]) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("users", schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f("uq_users_imdb_id"), type_="unique") + + with op.batch_alter_table("movies", schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f("uq_movies_imdb_id"), type_="unique") + + # ### end Alembic commands ### diff --git a/alembic/versions/1716077466-8b06e4916840_remove_db_patches_table.py b/alembic/versions/1716077466-8b06e4916840_remove_db_patches_table.py new file mode 100644 index 0000000..840cb33 --- /dev/null +++ b/alembic/versions/1716077466-8b06e4916840_remove_db_patches_table.py @@ -0,0 +1,38 @@ +"""remove db_patches table + +We replace our old patch process with Alembic's. + +Revision ID: 8b06e4916840 +Revises: f17c7ca9afa4 +Create Date: 2024-05-19 00:11:06.730421+00:00 + +""" + +from typing import Sequence + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "8b06e4916840" +down_revision: str | None = "f17c7ca9afa4" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("db_patches") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "db_patches", + sa.Column("id", sa.INTEGER(), nullable=False), + sa.Column("current", sa.VARCHAR(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### diff --git a/entrypoint.sh b/entrypoint.sh new file mode 100755 index 0000000..7df7daa --- /dev/null +++ b/entrypoint.sh @@ -0,0 +1,4 @@ +#!/bin/sh -eu + +alembic upgrade head +exec ./run "$@" diff --git a/poetry.lock b/poetry.lock index dc71990..d01be12 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiosqlite" @@ -18,6 +18,25 @@ typing_extensions = ">=4.0" dev = ["attribution (==1.7.0)", "black (==24.2.0)", "coverage[toml] (==7.4.1)", "flake8 (==7.0.0)", "flake8-bugbear (==24.2.6)", "flit (==3.9.0)", "mypy (==1.8.0)", "ufmt (==2.3.0)", "usort (==1.0.8.post1)"] docs = ["sphinx (==7.2.6)", "sphinx-mdinclude (==0.5.3)"] +[[package]] +name = "alembic" +version = "1.13.1" +description = "A database migration tool for SQLAlchemy." +optional = false +python-versions = ">=3.8" +files = [ + {file = "alembic-1.13.1-py3-none-any.whl", hash = "sha256:2edcc97bed0bd3272611ce3a98d98279e9c209e7186e43e75bbb1b2bdfdbcc43"}, + {file = "alembic-1.13.1.tar.gz", hash = "sha256:4932c8558bf68f2ee92b9bbcb8218671c627064d5b08939437af6d77dc05e595"}, +] + +[package.dependencies] +Mako = "*" +SQLAlchemy = ">=1.3.0" +typing-extensions = ">=4" + +[package.extras] +tz = ["backports.zoneinfo"] + [[package]] name = "anyio" version = "4.3.0" @@ -346,6 +365,94 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "mako" +version = "1.3.5" +description = "A super-fast templating language that borrows the best ideas from the existing templating languages." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Mako-1.3.5-py3-none-any.whl", hash = "sha256:260f1dbc3a519453a9c856dedfe4beb4e50bd5a26d96386cb6c80856556bb91a"}, + {file = "Mako-1.3.5.tar.gz", hash = "sha256:48dbc20568c1d276a2698b36d968fa76161bf127194907ea6fc594fa81f943bc"}, +] + +[package.dependencies] +MarkupSafe = ">=0.9.2" + +[package.extras] +babel = ["Babel"] +lingua = ["lingua"] +testing = ["pytest"] + +[[package]] +name = "markupsafe" +version = "2.1.5" +description = "Safely add untrusted strings to HTML/XML markup." +optional = false +python-versions = ">=3.7" +files = [ + {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-win32.whl", hash = "sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-win_amd64.whl", hash = "sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-win32.whl", hash = "sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-win_amd64.whl", hash = "sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-win32.whl", hash = "sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c8b29db45f8fe46ad280a7294f5c3ec36dbac9491f2d1c17345be8e69cc5928f"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec6a563cff360b50eed26f13adc43e61bc0c04d94b8be985e6fb24b81f6dcfdf"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a549b9c31bec33820e885335b451286e2969a2d9e24879f83fe904a5ce59d70a"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4f11aa001c540f62c6166c7726f71f7573b52c68c31f014c25cc7901deea0b52"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7b2e5a267c855eea6b4283940daa6e88a285f5f2a67f2220203786dfa59b37e9"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:2d2d793e36e230fd32babe143b04cec8a8b3eb8a3122d2aceb4a371e6b09b8df"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ce409136744f6521e39fd8e2a24c53fa18ad67aa5bc7c2cf83645cce5b5c4e50"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-win32.whl", hash = "sha256:4096e9de5c6fdf43fb4f04c26fb114f61ef0bf2e5604b6ee3019d51b69e8c371"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-win_amd64.whl", hash = "sha256:4275d846e41ecefa46e2015117a9f491e57a71ddd59bbead77e904dc02b1bed2"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-win32.whl", hash = "sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-win_amd64.whl", hash = "sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-win32.whl", hash = "sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-win_amd64.whl", hash = "sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5"}, + {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"}, +] + [[package]] name = "nodeenv" version = "1.8.0" @@ -694,4 +801,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "038fed338d6b75c17eb8eb88d36c2411ff936dab23887b70594e5ba1da518451" +content-hash = "9dbc732b312d6d39fbf4e8b8af22739aad6c25312cee92736f19d3a106f93129" diff --git a/pyproject.toml b/pyproject.toml index b230473..134c089 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ ulid-py = "^1.1.0" uvicorn = "^0.29.0" httpx = "^0.27.0" sqlalchemy = {version = "^2.0", extras = ["aiosqlite"]} +alembic = "^1.13.1" [tool.poetry.group.build.dependencies] # When we run poetry export, typing-extensions is a transient dependency via diff --git a/tests/test_db.py b/tests/test_db.py index 981e65b..c22359d 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -20,14 +20,6 @@ def a_movie(**kwds) -> models.Movie: return models.Movie(**args) -@pytest.mark.asyncio -async def test_current_patch_level(conn: db.Connection): - patch_level = "some-patch-level" - assert patch_level != await db.current_patch_level(conn) - await db.set_current_patch_level(conn, patch_level) - assert patch_level == await db.current_patch_level(conn) - - @pytest.mark.asyncio async def test_get(conn: db.Connection): m1 = a_movie() diff --git a/tests/test_web.py b/tests/test_web.py index b1e7e4b..46cc28a 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -32,6 +32,74 @@ def admin_client() -> TestClient: return client +@pytest.mark.asyncio +async def test_get_ratings_for_group_with_awards( + conn: db.Connection, unauthorized_client: TestClient +): + user = models.User( + imdb_id="ur12345678", + name="user-1", + secret="secret-1", # noqa: S106 + groups=[], + ) + group = models.Group( + name="group-1", + users=[models.GroupUser(id=str(user.id), name=user.name)], + ) + user.groups = [models.UserGroup(id=str(group.id), access="r")] + path = app.url_path_for("get_ratings_for_group", group_id=str(group.id)) + + await db.add(conn, user) + await db.add(conn, group) + + movie1 = models.Movie( + title="test movie", + release_year=2013, + media_type="Movie", + imdb_id="tt12345678", + genres={"genre-1"}, + ) + await db.add(conn, movie1) + movie2 = models.Movie( + title="test movie 2", + release_year=2014, + media_type="Movie", + imdb_id="tt12345679", + genres={"genre-2"}, + ) + await db.add(conn, movie2) + + award1 = models.Award( + movie_id=movie1.id, category="imdb-top-250", details='{"position":23}' + ) + award2 = models.Award( + movie_id=movie2.id, category="imdb-top-250", details='{"position":99}' + ) + await db.add(conn, award1) + await db.add(conn, award2) + + rating = models.Rating( + movie_id=movie1.id, user_id=user.id, score=66, rating_date=datetime.now(tz=UTC) + ) + await db.add(conn, rating) + + rating_aggregate = { + "canonical_title": movie1.title, + "imdb_score": movie1.imdb_score, + "imdb_votes": movie1.imdb_votes, + "link": imdb.movie_url(movie1.imdb_id), + "media_type": movie1.media_type, + "original_title": movie1.original_title, + "user_scores": [rating.score], + "year": movie1.release_year, + "awards": ["imdb-top-250:23"], + } + + resp = unauthorized_client.get(path) + assert resp.status_code == 200 + assert resp.json() == [rating_aggregate] + + @pytest.mark.asyncio async def test_get_ratings_for_group( conn: db.Connection, unauthorized_client: TestClient @@ -82,6 +150,7 @@ async def test_get_ratings_for_group( "original_title": movie.original_title, "user_scores": [rating.score], "year": movie.release_year, + "awards": [], } resp = unauthorized_client.get(path) @@ -158,6 +227,7 @@ async def test_list_movies( "original_title": m.original_title, "user_scores": [], "year": m.release_year, + "awards": [], } response = authorized_client.get(path, params={"imdb_id": m.imdb_id}) diff --git a/unwind/__main__.py b/unwind/__main__.py index 82abfee..a193d8e 100644 --- a/unwind/__main__.py +++ b/unwind/__main__.py @@ -1,149 +1,26 @@ import argparse import asyncio import logging -import secrets -from base64 import b64encode -from pathlib import Path +import sys -from . import config, db, models, utils -from .db import close_connection_pool, open_connection_pool -from .imdb import refresh_user_ratings_from_imdb -from .imdb_import import download_datasets, import_from_file +from . import cli, config -log = logging.getLogger(__name__) - - -async def run_add_user(user_id: str, name: str, overwrite_existing: bool): - if not user_id.startswith("ur"): - raise ValueError(f"Invalid IMDb user ID: {user_id!a}") - - await open_connection_pool() - - async with db.new_connection() as conn: - user = await db.get(conn, models.User, imdb_id=user_id) - - if user is not None: - if overwrite_existing: - log.warning("⚠️ Overwriting existing user: %a", user) - else: - log.error("❌ User already exists: %a", user) - return - - secret = secrets.token_bytes() - - user = models.User(name=name, imdb_id=user_id, secret=utils.phc_scrypt(secret)) - async with db.transaction() as conn: - await db.add_or_update_user(conn, user) - - user_data = { - "secret": b64encode(secret), - "user": models.asplain(user), - } - - log.info("✨ User created: %a", user_data) - - await close_connection_pool() - - -async def run_load_user_ratings_from_imdb(): - await open_connection_pool() - - i = 0 - async for _ in refresh_user_ratings_from_imdb(): - i += 1 - - log.info("✨ Imported %s new ratings.", i) - - await close_connection_pool() - - -async def run_import_imdb_dataset(basics_path: Path, ratings_path: Path): - await open_connection_pool() - - await import_from_file(basics_path=basics_path, ratings_path=ratings_path) - - await close_connection_pool() - - -async def run_download_imdb_dataset(basics_path: Path, ratings_path: Path): - await download_datasets(basics_path=basics_path, ratings_path=ratings_path) +log = logging.getLogger(__package__) def getargs(): - parser = argparse.ArgumentParser() - commands = parser.add_subparsers(required=True) + parser = argparse.ArgumentParser(prog="unwind", allow_abbrev=False) + commands = parser.add_subparsers(title="commands", metavar="COMMAND", dest="mode") - parser_import_imdb_dataset = commands.add_parser( - "import-imdb-dataset", - help="Import IMDb datasets.", - description=""" - Import IMDb datasets. - New datasets available from https://www.imdb.com/interfaces/. - """, - ) - parser_import_imdb_dataset.add_argument( - dest="mode", - action="store_const", - const="import-imdb-dataset", - ) - parser_import_imdb_dataset.add_argument( - "--basics", metavar="basics_file.tsv.gz", type=Path, required=True - ) - parser_import_imdb_dataset.add_argument( - "--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True - ) - - parser_download_imdb_dataset = commands.add_parser( - "download-imdb-dataset", - help="Download IMDb datasets.", - description=""" - Download IMDb datasets. - """, - ) - parser_download_imdb_dataset.add_argument( - dest="mode", - action="store_const", - const="download-imdb-dataset", - ) - parser_download_imdb_dataset.add_argument( - "--basics", metavar="basics_file.tsv.gz", type=Path, required=True - ) - parser_download_imdb_dataset.add_argument( - "--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True - ) - - parser_load_user_ratings_from_imdb = commands.add_parser( - "load-user-ratings-from-imdb", - help="Load user ratings from imdb.com.", - description=""" - Refresh user ratings for all registered users live from IMDb's website. - """, - ) - parser_load_user_ratings_from_imdb.add_argument( - dest="mode", - action="store_const", - const="load-user-ratings-from-imdb", - ) - - parser_add_user = commands.add_parser( - "add-user", - help="Add a new user.", - description=""" - Add a new user. - """, - ) - parser_add_user.add_argument( - dest="mode", - action="store_const", - const="add-user", - ) - parser_add_user.add_argument("--name", required=True) - parser_add_user.add_argument("--imdb-id", required=True) - parser_add_user.add_argument( - "--overwrite-existing", - action="store_true", - help="Allow overwriting an existing user. WARNING: This will reset the user's password!", - ) + for module in cli.modules: + help_, *descr = module.help.splitlines() + cmd = commands.add_parser( + module.name, + help=help_, + description="\n".join(descr) or help_, + allow_abbrev=False, + ) + module.add_args(cmd) try: args = parser.parse_args() @@ -151,6 +28,10 @@ def getargs(): parser.print_usage() raise + if args.mode is None: + parser.print_help() + sys.exit(1) + return args @@ -158,23 +39,16 @@ def main(): logging.basicConfig( format="%(asctime)s.%(msecs)03d [%(name)s:%(process)d] %(levelname)s: %(message)s", datefmt="%H:%M:%S", - level=config.loglevel, + # level=config.loglevel, ) + log.setLevel(config.loglevel) log.debug(f"Log level: {config.loglevel}") - try: - args = getargs() - except Exception: - return + args = getargs() - if args.mode == "load-user-ratings-from-imdb": - asyncio.run(run_load_user_ratings_from_imdb()) - elif args.mode == "add-user": - asyncio.run(run_add_user(args.imdb_id, args.name, args.overwrite_existing)) - elif args.mode == "import-imdb-dataset": - asyncio.run(run_import_imdb_dataset(args.basics, args.ratings)) - elif args.mode == "download-imdb-dataset": - asyncio.run(run_download_imdb_dataset(args.basics, args.ratings)) + modes = {m.name: m.main for m in cli.modules} + if handler := modes.get(args.mode): + asyncio.run(handler(args)) main() diff --git a/unwind/cli/__init__.py b/unwind/cli/__init__.py new file mode 100644 index 0000000..dd6f8cc --- /dev/null +++ b/unwind/cli/__init__.py @@ -0,0 +1,39 @@ +import argparse +import importlib +from pathlib import Path +from types import ModuleType +from typing import Any, Callable, Coroutine, Iterable, Protocol, TypeGuard + +type CommandHandler = Callable[[argparse.Namespace], Coroutine[Any, Any, None]] + + +class CliModule(Protocol): + name: str + help: str + add_args: Callable[[argparse.ArgumentParser], None] + main: CommandHandler + + +def _is_cli_module(m: ModuleType) -> TypeGuard[CliModule]: + return ( + hasattr(m, "name") + and hasattr(m, "help") + and hasattr(m, "add_args") + and hasattr(m, "main") + ) + + +_clidir = Path(__file__).parent + + +def _load_cmds() -> Iterable[CliModule]: + """Return all CLI command modules.""" + for f in _clidir.iterdir(): + if f.suffix == ".py" and not f.name.startswith("__"): + m = importlib.import_module(f"{__package__}.{f.stem}") + if not _is_cli_module(m): + raise ValueError(f"Invalid CLI module: {m!a}") + yield m + + +modules = sorted(_load_cmds(), key=lambda m: m.name) diff --git a/unwind/cli/add_user.py b/unwind/cli/add_user.py new file mode 100644 index 0000000..cc3d305 --- /dev/null +++ b/unwind/cli/add_user.py @@ -0,0 +1,56 @@ +import argparse +import logging +import secrets + +from unwind import db, models, utils + +log = logging.getLogger(__name__) + +name = "add-user" +help = "Add a new user." + + +def add_args(cmd: argparse.ArgumentParser) -> None: + cmd.add_argument("--name", required=True) + cmd.add_argument("--imdb-id", required=True) + cmd.add_argument( + "--overwrite-existing", + action="store_true", + help="Allow overwriting an existing user. WARNING: This will reset the user's password!", + ) + + +async def main(args: argparse.Namespace) -> None: + user_id: str = args.imdb_id + name: str = args.name + overwrite_existing: bool = args.overwrite_existing + + if not user_id.startswith("ur"): + raise ValueError(f"Invalid IMDb user ID: {user_id!a}") + + await db.open_connection_pool() + + async with db.new_connection() as conn: + user = await db.get(conn, models.User, imdb_id=user_id) + + if user is not None: + if overwrite_existing: + log.warning("⚠️ Overwriting existing user: %a", user) + else: + log.error("❌ User already exists: %a", user) + return + + secret = secrets.token_bytes() + + user = models.User(name=name, imdb_id=user_id, secret=utils.phc_scrypt(secret)) + async with db.transaction() as conn: + await db.add_or_update_user(conn, user) + + user_data = { + "secret": utils.b64encode(secret), + "user": models.asplain(user), + } + + log.info("✨ User created: %a", user_data) + + await db.close_connection_pool() diff --git a/unwind/cli/download_imdb_dataset.py b/unwind/cli/download_imdb_dataset.py new file mode 100644 index 0000000..5b81045 --- /dev/null +++ b/unwind/cli/download_imdb_dataset.py @@ -0,0 +1,24 @@ +import argparse +import logging +from pathlib import Path + +from unwind.imdb_import import download_datasets + +log = logging.getLogger(__name__) + +name = "download-imdb-dataset" +help = "Download IMDb datasets." + + +def add_args(cmd: argparse.ArgumentParser) -> None: + cmd.add_argument("--basics", metavar="basics_file.tsv.gz", type=Path, required=True) + cmd.add_argument( + "--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True + ) + + +async def main(args: argparse.Namespace) -> None: + basics_path: Path = args.basics + ratings_path: Path = args.ratings + + await download_datasets(basics_path=basics_path, ratings_path=ratings_path) diff --git a/unwind/cli/import_imdb_dataset.py b/unwind/cli/import_imdb_dataset.py new file mode 100644 index 0000000..3adb5da --- /dev/null +++ b/unwind/cli/import_imdb_dataset.py @@ -0,0 +1,31 @@ +import argparse +import logging +from pathlib import Path + +from unwind import db +from unwind.imdb_import import import_from_file + +log = logging.getLogger(__name__) + +name = "import-imdb-dataset" +help = """Import IMDb datasets. +New datasets available from https://www.imdb.com/interfaces/. +""" + + +def add_args(cmd: argparse.ArgumentParser) -> None: + cmd.add_argument("--basics", metavar="basics_file.tsv.gz", type=Path, required=True) + cmd.add_argument( + "--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True + ) + + +async def main(args: argparse.Namespace) -> None: + basics_path: Path = args.basics + ratings_path: Path = args.ratings + + await db.open_connection_pool() + + await import_from_file(basics_path=basics_path, ratings_path=ratings_path) + + await db.close_connection_pool() diff --git a/unwind/cli/load_imdb_charts.py b/unwind/cli/load_imdb_charts.py new file mode 100644 index 0000000..2b82774 --- /dev/null +++ b/unwind/cli/load_imdb_charts.py @@ -0,0 +1,97 @@ +import argparse +import logging +from typing import Callable + +import sqlalchemy as sa + +from unwind import db, imdb, models, types, utils + +log = logging.getLogger(__name__) + +name = "load-imdb-charts" +help = "Load and import charts from imdb.com." + + +def add_args(cmd: argparse.ArgumentParser) -> None: + cmd.add_argument( + "--select", + action="append", + dest="charts", + default=[], + choices={"top250", "bottom100", "pop100"}, + help="Select which charts to refresh.", + ) + + +async def get_movie_ids( + conn: db.Connection, imdb_ids: list[imdb.MovieId] +) -> dict[imdb.MovieId, types.ULID]: + c = models.movies.c + query = sa.select(c.imdb_id, c.id).where(c.imdb_id.in_(imdb_ids)) + rows = await db.fetch_all(conn, query) + return {row.imdb_id: types.ULID(row.id) for row in rows} + + +async def remove_all_awards( + conn: db.Connection, category: models.AwardCategory +) -> None: + stmt = models.awards.delete().where(models.awards.c.category == category) + await conn.execute(stmt) + + +_award_handlers: dict[models.AwardCategory, Callable] = { + "imdb-pop-100": imdb.load_most_popular_100, + "imdb-top-250": imdb.load_top_250, + "imdb-bottom-100": imdb.load_bottom_100, +} + + +async def update_awards(conn: db.Connection, category: models.AwardCategory) -> None: + load_imdb_ids = _award_handlers[category] + imdb_ids = await load_imdb_ids() + + available = await get_movie_ids(conn, imdb_ids) + if missing := set(imdb_ids).difference(available): + log.warning( + "⚠️ Charts for category (%a) contained %i unknown movies: %a", + category, + len(missing), + missing, + ) + + await remove_all_awards(conn, category=category) + + for pos, imdb_id in enumerate(imdb_ids, 1): + if (movie_id := available.get(imdb_id)) is None: + continue + + award = models.Award( + movie_id=movie_id, + category=category, + details=utils.json_dump({"position": pos}), + ) + await db.add(conn, award) + + +async def main(args: argparse.Namespace) -> None: + await db.open_connection_pool() + + if not args.charts: + args.charts = {"top250", "bottom100", "pop100"} + + if "pop100" in args.charts: + async with db.transaction() as conn: + await update_awards(conn, "imdb-pop-100") + log.info("✨ Updated most popular 100 movies.") + + if "bottom100" in args.charts: + async with db.transaction() as conn: + await update_awards(conn, "imdb-bottom-100") + log.info("✨ Updated bottom 100 movies.") + + if "top250" in args.charts: + async with db.transaction() as conn: + await update_awards(conn, "imdb-top-250") + log.info("✨ Updated top 250 rated movies.") + + await db.close_connection_pool() diff --git a/unwind/cli/load_user_ratings_from_imdb.py b/unwind/cli/load_user_ratings_from_imdb.py new file mode 100644 index 0000000..b4a8e0f --- /dev/null +++ b/unwind/cli/load_user_ratings_from_imdb.py @@ -0,0 +1,28 @@ +import argparse +import logging + +from unwind import db +from unwind.imdb import refresh_user_ratings_from_imdb + +log = logging.getLogger(__name__) + +name = "load-user-ratings-from-imdb" +help = """Load user ratings from imdb.com. +Refresh user ratings for all registered users live from IMDb's website. +""" + + +def add_args(cmd: argparse.ArgumentParser) -> None: + pass + + +async def main(args: argparse.Namespace) -> None: + await db.open_connection_pool() + + i = 0 + async for _ in refresh_user_ratings_from_imdb(): + i += 1 + + log.info("✨ Imported %s new ratings.", i) + + await db.close_connection_pool() diff --git a/unwind/db.py b/unwind/db.py index 7759f1b..a8e23d8 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -1,21 +1,25 @@ import contextlib import logging from pathlib import Path -from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type, TypeVar +from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type import sqlalchemy as sa -from sqlalchemy.dialects.sqlite import insert from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine +import alembic.command +import alembic.config +import alembic.migration + from . import config from .models import ( + Award, Model, Movie, Progress, Rating, User, asplain, - db_patches, + awards, fromplain, metadata, movies, @@ -24,15 +28,33 @@ from .models import ( ratings, utcnow, ) -from .types import ULID +from .types import ULID, ImdbMovieId, UserIdStr log = logging.getLogger(__name__) -T = TypeVar("T") _engine: AsyncEngine | None = None type Connection = AsyncConnection +_project_dir = Path(__file__).parent.parent +_alembic_ini = _project_dir / "alembic.ini" + + +def _init(conn: sa.Connection) -> None: + # See https://alembic.sqlalchemy.org/en/latest/cookbook.html#building-an-up-to-date-database-from-scratch + context = alembic.migration.MigrationContext.configure(conn) + heads = context.get_current_heads() + + is_empty_db = not heads # We consider a DB empty if Alembic hasn't touched it yet. + if is_empty_db: + log.info("⚡️ Initializing empty database.") + metadata.create_all(conn) + + # We pass our existing connection to Alembic's env.py, to avoid running another asyncio loop there. + alembic_cfg = alembic.config.Config(_alembic_ini) + alembic_cfg.attributes["connection"] = conn + alembic.command.stamp(alembic_cfg, "head") + async def open_connection_pool() -> None: """Open the DB connection pool. @@ -41,11 +63,7 @@ async def open_connection_pool() -> None: """ async with transaction() as conn: await conn.execute(sa.text("PRAGMA journal_mode=WAL")) - - await conn.run_sync(metadata.create_all, tables=[db_patches]) - - async with new_connection() as conn: - await apply_db_patches(conn) + await conn.run_sync(_init) async def close_connection_pool() -> None: @@ -65,65 +83,7 @@ async def close_connection_pool() -> None: await engine.dispose() -async def current_patch_level(conn: Connection, /) -> str: - query = sa.select(db_patches.c.current) - current = await conn.scalar(query) - return current or "" - - -async def set_current_patch_level(conn: Connection, /, current: str) -> None: - stmt = insert(db_patches).values(id=1, current=current) - stmt = stmt.on_conflict_do_update(set_={"current": stmt.excluded.current}) - await conn.execute(stmt) - - -db_patches_dir = Path(__file__).parent / "sql" - - -async def apply_db_patches(conn: Connection, /) -> None: - """Apply all remaining patches to the database. - - Beware that patches will be applied in lexicographical order, - i.e. "10" comes before "9". - - The current patch state is recorded in the DB itself. - - Please note that every SQL statement in a patch file MUST be terminated - using two consecutive semi-colons (;). - Failing to do so will result in an error. - """ - applied_lvl = await current_patch_level(conn) - - did_patch = False - - for patchfile in sorted(db_patches_dir.glob("*.sql"), key=lambda p: p.stem): - patch_lvl = patchfile.stem - if patch_lvl <= applied_lvl: - continue - - log.info("Applying patch: %s", patch_lvl) - - sql = patchfile.read_text() - queries = sql.split(";;") - if len(queries) < 2: - log.error( - "Patch file is missing statement terminator (`;;'): %s", patchfile - ) - raise RuntimeError("No statement found.") - - async with transacted(conn): - for query in queries: - await conn.execute(sa.text(query)) - - await set_current_patch_level(conn, patch_lvl) - - did_patch = True - - if did_patch: - await _vacuum(conn) - - -async def _vacuum(conn: Connection, /) -> None: +async def vacuum(conn: Connection, /) -> None: """Vacuum the database. This function cannot be run on a connection with an open transaction. @@ -194,11 +154,13 @@ async def set_import_progress(conn: Connection, /, progress: float) -> Progress: return current -def _new_engine() -> AsyncEngine: - uri = f"sqlite+aiosqlite:///{config.storage_path}" +def _connection_uri() -> str: + return f"sqlite+aiosqlite:///{config.storage_path}" + +def _new_engine() -> AsyncEngine: return create_async_engine( - uri, + _connection_uri(), isolation_level="SERIALIZABLE", ) @@ -257,6 +219,13 @@ async def new_connection() -> AsyncGenerator[Connection, None]: async def transacted( conn: Connection, /, *, force_rollback: bool = False ) -> AsyncGenerator[None, None]: + """Start a transaction for the given connection. + + If `force_rollback` is `True` any changes will be rolled back at the end of the + transaction, unless they are explicitly committed. + Nesting transactions is allowed, but mixing values for `force_rollback` will likely + yield unexpected results. + """ transaction = contextlib.nullcontext() if conn.in_transaction() else conn.begin() async with transaction: @@ -272,7 +241,7 @@ async def add(conn: Connection, /, item: Model) -> None: # Support late initializing - used for optimization. if getattr(item, "_is_lazy", False): assert hasattr(item, "_lazy_init") - item._lazy_init() # pyright: ignore [reportGeneralTypeIssues] + item._lazy_init() # pyright: ignore[reportAttributeAccessIssue] table: sa.Table = item.__table__ values = asplain(item, serialize=True) @@ -294,17 +263,14 @@ async def fetch_one( return result.first() -ModelType = TypeVar("ModelType", bound=Model) - - -async def get( +async def get[T: Model]( conn: Connection, /, - model: Type[ModelType], + model: Type[T], *, order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None, **field_values, -) -> ModelType | None: +) -> T | None: """Load a model instance from the database. Passing `field_values` allows to filter the item to load. You have to encode the @@ -327,9 +293,9 @@ async def get( return fromplain(model, row._mapping, serialized=True) if row else None -async def get_many( - conn: Connection, /, model: Type[ModelType], **field_sets: set | list -) -> Iterable[ModelType]: +async def get_many[T: Model]( + conn: Connection, /, model: Type[T], **field_sets: set | list +) -> Iterable[T]: """Return the items with any values matching all given field sets. This is similar to `get_all`, but instead of a scalar value a list of values @@ -346,9 +312,9 @@ async def get_many( return (fromplain(model, row._mapping, serialized=True) for row in rows) -async def get_all( - conn: Connection, /, model: Type[ModelType], **field_values -) -> Iterable[ModelType]: +async def get_all[T: Model]( + conn: Connection, /, model: Type[T], **field_values +) -> Iterable[T]: """Filter all items by comparing all given field values. If no filters are given, all items will be returned. @@ -365,7 +331,7 @@ async def update(conn: Connection, /, item: Model) -> None: # Support late initializing - used for optimization. if getattr(item, "_is_lazy", False): assert hasattr(item, "_lazy_init") - item._lazy_init() # pyright: ignore [reportGeneralTypeIssues] + item._lazy_init() # pyright: ignore[reportAttributeAccessIssue] table: sa.Table = item.__table__ values = asplain(item, serialize=True) @@ -466,6 +432,23 @@ async def add_or_update_rating(conn: Connection, /, rating: Rating) -> bool: return False +async def get_awards( + conn: Connection, /, imdb_ids: list[ImdbMovieId] +) -> dict[ImdbMovieId, list[Award]]: + query = ( + sa.select(Award, movies.c.imdb_id) + .join(movies, awards.c.movie_id == movies.c.id) + .where(movies.c.imdb_id.in_(imdb_ids)) + ) + rows = await fetch_all(conn, query) + awards_dict: dict[ImdbMovieId, list[Award]] = {} + for row in rows: + awards_dict.setdefault(row.imdb_id, []).append( + fromplain(Award, row._mapping, serialized=True) + ) + return awards_dict + + def sql_escape(s: str, char: str = "#") -> str: return s.replace(char, 2 * char).replace("%", f"{char}%").replace("_", f"{char}_") @@ -481,7 +464,7 @@ async def find_ratings( include_unrated: bool = False, yearcomp: tuple[Literal["<", "=", ">"], int] | None = None, limit_rows: int = 10, - user_ids: Iterable[str] = [], + user_ids: Iterable[UserIdStr] = [], ) -> Iterable[dict[str, Any]]: conditions = [] diff --git a/unwind/imdb.py b/unwind/imdb.py index cff1d68..0f46c24 100644 --- a/unwind/imdb.py +++ b/unwind/imdb.py @@ -12,6 +12,7 @@ import bs4 from . import db from .models import Movie, Rating, User from .request import adownload, asession, asoup_from_url, cache_path +from .utils import json_dump log = logging.getLogger(__name__) @@ -355,9 +356,12 @@ async def _load_ratings_page_legacy(url: str, soup: bs4.BeautifulSoup) -> _Ratin return page -async def load_and_store_ratings( - user_id: UserId, -) -> AsyncIterable[tuple[Rating, bool]]: +async def load_and_store_ratings(user_id: UserId) -> AsyncIterable[tuple[Rating, bool]]: + """Load user ratings from imdb.com and store them in our database. + + All loaded ratings are yielded together with the information whether each rating + was already present in our database. + """ async with db.new_connection() as conn: user = await db.get(conn, User, imdb_id=user_id) or User( imdb_id=user_id, name="", secret="" @@ -385,6 +389,7 @@ async def load_and_store_ratings( async def load_ratings(user_id: UserId) -> AsyncIterable[Rating]: + """Return all ratings for the given user from imdb.com.""" next_url = user_ratings_url(user_id) while next_url: @@ -443,13 +448,15 @@ async def load_top_250() -> list[MovieId]: qgl_api_url = "https://caching.graphql.imdb.com/" query = { "operationName": "Top250MoviesPagination", - "variables": {"first": 250, "locale": "en-US"}, - "extensions": { - "persistedQuery": { - "sha256Hash": "26114ee01d97e04f65d6c8c7212ae8b7888fa57ceed105450d1fce09df749b2d", - "version": 1, + "variables": json_dump({"first": 250, "locale": "en-US"}), + "extensions": json_dump( + { + "persistedQuery": { + "sha256Hash": "26114ee01d97e04f65d6c8c7212ae8b7888fa57ceed105450d1fce09df749b2d", + "version": 1, + } } - }, + ), } headers = { "accept": "application/graphql+json, application/json", diff --git a/unwind/imdb_import.py b/unwind/imdb_import.py index 5464df0..28792e2 100644 --- a/unwind/imdb_import.py +++ b/unwind/imdb_import.py @@ -5,7 +5,7 @@ import logging from dataclasses import dataclass, fields from datetime import datetime, timezone from pathlib import Path -from typing import Generator, Literal, Type, TypeVar, overload +from typing import Generator, Literal, Type, overload from . import config, db, request from .db import add_or_update_many_movies @@ -14,8 +14,6 @@ from .models import Movie log = logging.getLogger(__name__) -T = TypeVar("T") - # See # - https://developer.imdb.com/non-commercial-datasets/ # - https://datasets.imdbws.com/ @@ -127,7 +125,7 @@ def read_imdb_tsv( @overload -def read_imdb_tsv( +def read_imdb_tsv[T]( path: Path, row_type: Type[T], *, unpack: Literal[True] = True ) -> Generator[T, None, None]: ... diff --git a/unwind/models.py b/unwind/models.py index 6ea13e9..f952686 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -2,7 +2,6 @@ import json from dataclasses import dataclass, field from dataclasses import fields as _fields from datetime import datetime, timezone -from functools import partial from types import UnionType from typing import ( Annotated, @@ -11,24 +10,33 @@ from typing import ( Container, Literal, Mapping, + NewType, Protocol, Type, + TypeAliasType, TypedDict, - TypeVar, Union, get_args, get_origin, ) -from sqlalchemy import Column, ForeignKey, Integer, String, Table +from sqlalchemy import Column, ForeignKey, Index, Integer, String, Table from sqlalchemy.orm import registry -from .types import ULID - -JSON = int | float | str | None | list["JSON"] | dict[str, "JSON"] -JSONObject = dict[str, JSON] - -T = TypeVar("T") +from .types import ( + ULID, + AwardId, + GroupId, + ImdbMovieId, + JSONObject, + JSONScalar, + MovieId, + RatingId, + Score100, + UserId, + UserIdStr, +) +from .utils import json_dump class Model(Protocol): @@ -38,8 +46,22 @@ class Model(Protocol): mapper_registry = registry() metadata = mapper_registry.metadata +# An explicit naming convention helps Alembic do its job, +# see https://alembic.sqlalchemy.org/en/latest/naming.html. +metadata.naming_convention = { + "ix": "ix_%(column_0_label)s", + "uq": "uq_%(table_name)s_%(column_0_name)s", + "ck": "ck_%(table_name)s_%(constraint_name)s", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(table_name)s", +} + def annotations(tp: Type) -> tuple | None: + # Support type aliases and generic aliases. + if isinstance(tp, TypeAliasType) or hasattr(tp, "__value__"): + tp = tp.__value__ + return tp.__metadata__ if hasattr(tp, "__metadata__") else None # type: ignore @@ -97,13 +119,24 @@ def optional_fields(o): yield f -json_dump = partial(json.dumps, separators=(",", ":")) - - -def _id(x: T) -> T: +def _id[T](x: T) -> T: + """Return the given argument, aka. the identity function.""" return x +def _unpack(type_: Any) -> Any: + """Return the wrapped type.""" + # Handle type aliases. + if isinstance(type_, TypeAliasType): + return _unpack(type_.__value__) + + # Handle newtypes. + if isinstance(type_, NewType): + return _unpack(type_.__supertype__) + + return type_ + + def asplain( o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False ) -> dict[str, Any]: @@ -125,13 +158,16 @@ def asplain( if filter_fields is not None and f.name not in filter_fields: continue - target: Any = f.type + target: Any = _unpack(f.type) + # XXX this doesn't properly support any kind of nested types if (otype := optional_type(f.type)) is not None: target = otype if (otype := get_origin(target)) is not None: target = otype + target = _unpack(target) + v = getattr(o, f.name) if is_optional(f.type) and v is None: d[f.name] = None @@ -150,26 +186,31 @@ def asplain( elif target in {bool, str, int, float}: assert isinstance( v, target - ), f"Type mismatch: {f.name} ({target} != {type(v)})" + ), f"Type mismatch: {f.name!a} ({target!a} != {type(v)!a})" + d[f.name] = v + elif target in {Literal}: + assert isinstance(v, JSONScalar.__value__) d[f.name] = v else: - raise ValueError(f"Unsupported value type: {f.name}: {type(v)}") + raise ValueError(f"Unsupported value type: {f.name!a}: {type(v)!a}") return d -def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T: +def fromplain[T](cls: Type[T], d: Mapping, *, serialized: bool = False) -> T: """Return an instance of the given model using the given data. If `serialized` is `True`, collection types (lists, dicts, etc.) will be deserialized from string. This is the opposite operation of `serialize` for `asplain`. + Fields in the data that cannot be mapped to the given type are simply ignored. """ load = json.loads if serialized else _id dd: JSONObject = {} for f in fields(cls): - target: Any = f.type + target: Any = _unpack(f.type) + otype = optional_type(f.type) is_opt = otype is not None if is_opt: @@ -177,9 +218,17 @@ def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T: if (xtype := get_origin(target)) is not None: target = xtype + target = _unpack(target) + v = d[f.name] if is_opt and v is None: dd[f.name] = v + elif target is Literal: + # Support literal types. + vals = get_args(f.type.__value__) + if v not in vals: + raise ValueError(f"Invalid value: {f.name!a}: {v!a}") + dd[f.name] = v elif isinstance(v, target): dd[f.name] = v elif target in {set, list}: @@ -196,27 +245,38 @@ def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T: def validate(o: object) -> None: for f in fields(o): - vtype = type(getattr(o, f.name)) - if vtype is f.type: + ftype = _unpack(f.type) + + v = getattr(o, f.name) + vtype = type(v) + if vtype is ftype: continue - origin = get_origin(f.type) + origin = get_origin(ftype) if origin is vtype: continue - is_union = isinstance(f.type, UnionType) or origin is Union + is_union = isinstance(ftype, UnionType) or origin is Union if is_union: # Support unioned types. - utypes = get_args(f.type) + utypes = get_args(ftype) + utypes = [_unpack(t) for t in utypes] if vtype in utypes: continue # Support generic types (set[str], list[int], etc.) - gtypes = [g for u in utypes if (g := get_origin(u)) is not None] + gtypes = [_unpack(g) for u in utypes if (g := get_origin(u)) is not None] if any(vtype is gtype for gtype in gtypes): continue - raise ValueError(f"Invalid value type: {f.name}: {vtype}") + if origin is Literal: + # Support literal types. + vals = get_args(ftype) + if v in vals: + continue + raise ValueError(f"Invalid value: {f.name!a}: {v!a}") + + raise ValueError(f"Invalid value type: {f.name!a}: {vtype!a}") def utcnow() -> datetime: @@ -224,23 +284,6 @@ def utcnow() -> datetime: return datetime.now(timezone.utc) -@mapper_registry.mapped -@dataclass -class DbPatch: - __table__: ClassVar[Table] = Table( - "db_patches", - metadata, - Column("id", Integer, primary_key=True), - Column("current", String), - ) - - id: int - current: str - - -db_patches = DbPatch.__table__ - - @mapper_registry.mapped @dataclass class Progress: @@ -312,15 +355,15 @@ class Movie: Column("updated", String, nullable=False), # datetime ) - id: ULID = field(default_factory=ULID) + id: MovieId = field(default_factory=ULID) title: str = None # canonical title (usually English) original_title: str | None = ( None # original title (usually transscribed to latin script) ) release_year: int = None # canonical release date media_type: str = None - imdb_id: str = None - imdb_score: int | None = None # range: [0,100] + imdb_id: ImdbMovieId = None + imdb_score: Score100 | None = None # range: [0,100] imdb_votes: int | None = None runtime: int | None = None # minutes genres: set[str] | None = None @@ -365,10 +408,10 @@ dataclass containing the ID of the linked data. The contents of the Relation are ignored or discarded when using `asplain`, `fromplain`, and `validate`. """ -Relation = Annotated[T | None, _RelationSentinel] +type Relation[T] = Annotated[T | None, _RelationSentinel] -Access = Literal[ +type Access = Literal[ "r", # read "i", # index "w", # write @@ -393,8 +436,8 @@ class User: Column("groups", String, nullable=False), # JSON array ) - id: ULID = field(default_factory=ULID) - imdb_id: str = None + id: UserId = field(default_factory=ULID) + imdb_id: ImdbMovieId = None name: str = None # canonical user name secret: str = None groups: list[UserGroup] = field(default_factory=list) @@ -413,6 +456,9 @@ class User: self.groups.append({"id": group_id, "access": access}) +users = User.__table__ + + @mapper_registry.mapped @dataclass class Rating: @@ -428,15 +474,15 @@ class Rating: Column("finished", Integer), # bool ) - id: ULID = field(default_factory=ULID) + id: RatingId = field(default_factory=ULID) - movie_id: ULID = None + movie_id: MovieId = None movie: Relation[Movie] = None - user_id: ULID = None + user_id: UserId = None user: Relation[User] = None - score: int = None # range: [0,100] + score: Score100 = None # range: [0,100] rating_date: datetime = None favorite: bool | None = None finished: bool | None = None @@ -455,10 +501,11 @@ class Rating: ratings = Rating.__table__ +Index("ratings_index", ratings.c.movie_id, ratings.c.user_id, unique=True) class GroupUser(TypedDict): - id: str + id: UserIdStr name: str @@ -473,6 +520,62 @@ class Group: Column("users", String, nullable=False), # JSON array ) - id: ULID = field(default_factory=ULID) + id: GroupId = field(default_factory=ULID) name: str = None users: list[GroupUser] = field(default_factory=list) + + +type AwardCategory = Literal[ + "imdb-top-250", "imdb-bottom-100", "imdb-pop-100", "oscars" +] + + +@mapper_registry.mapped +@dataclass +class Award: + __table__: ClassVar[Table] = Table( + "awards", + metadata, + Column("id", String, primary_key=True), # ULID + Column("movie_id", ForeignKey("movies.id"), nullable=False), # ULID + Column( + "category", String, nullable=False + ), # Enum: "imdb-top-250", "imdb-bottom-100", "imdb-pop-100", "oscars", ... + Column( + "details", String, nullable=False + ), # e.g. "23" (position in list), "2024, nominee, best director", "1977, winner, best picture", ... + Column("created", String, nullable=False), # datetime + Column("updated", String, nullable=False), # datetime + ) + + id: AwardId = field(default_factory=ULID) + + movie_id: MovieId = None + movie: Relation[Movie] = None + + category: AwardCategory = None + details: str = None + + created: datetime = field(default_factory=utcnow) + updated: datetime = field(default_factory=utcnow) + + @property + def _details(self) -> JSONObject: + return json.loads(self.details or "{}") + + @_details.setter + def _details(self, details: JSONObject): + self.details = json_dump(details) + + @property + def position(self) -> int: + return self._details["position"] + + @position.setter + def position(self, position: int): + details = self._details + details["position"] = position + self._details = details + + +awards = Award.__table__ diff --git a/unwind/request.py b/unwind/request.py index f12936b..46d1e9b 100644 --- a/unwind/request.py +++ b/unwind/request.py @@ -11,7 +11,7 @@ from hashlib import md5 from pathlib import Path from random import random from time import sleep, time -from typing import Any, Callable, ParamSpec, TypeVar, cast, overload +from typing import Any, Callable, cast, overload import bs4 import httpx @@ -24,13 +24,10 @@ if config.debug and config.cachedir: config.cachedir.mkdir(exist_ok=True) -_shared_asession = None - _ASession_T = httpx.AsyncClient -_Response_T = httpx.Response +type _Response_T = httpx.Response -_T = TypeVar("_T") -_P = ParamSpec("_P") +_shared_asession: _ASession_T | None = None @asynccontextmanager @@ -59,17 +56,17 @@ async def asession(): _shared_asession = None -def _throttle( +def _throttle[T, **P]( times: int, per_seconds: float, jitter: Callable[[], float] | None = None -) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: +) -> Callable[[Callable[P, T]], Callable[P, T]]: calls: deque[float] = deque(maxlen=times) if jitter is None: jitter = lambda: 0.0 # noqa: E731 - def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]: + def decorator(func: Callable[P, T]) -> Callable[P, T]: @wraps(func) - def inner(*args: _P.args, **kwds: _P.kwargs): + def inner(*args: P.args, **kwds: P.kwargs): # clean up while calls: if calls[0] + per_seconds > time(): diff --git a/unwind/sql/00000000-init-0.sql b/unwind/sql/00000000-init-0.sql deleted file mode 100644 index d0bd446..0000000 --- a/unwind/sql/00000000-init-0.sql +++ /dev/null @@ -1,36 +0,0 @@ -PRAGMA foreign_keys = ON;; - -CREATE TABLE IF NOT EXISTS users ( - id TEXT NOT NULL PRIMARY KEY, - imdb_id TEXT NOT NULL UNIQUE, - name TEXT NOT NULL -);; - -CREATE TABLE IF NOT EXISTS movies ( - id TEXT NOT NULL PRIMARY KEY, - title TEXT NOT NULL, - release_year NUMBER NOT NULL, - media_type TEXT NOT NULL, - imdb_id TEXT NOT NULL UNIQUE, - score NUMBER NOT NULL, - runtime NUMBER, - genres TEXT NOT NULL, - updated TEXT NOT NULL -);; - -CREATE TABLE IF NOT EXISTS ratings ( - id TEXT NOT NULL PRIMARY KEY, - movie_id TEXT NOT NULL, - user_id TEXT NOT NULL, - score NUMBER NOT NULL, - rating_date TEXT NOT NULL, - favorite NUMBER, - finished NUMBER, - FOREIGN KEY(movie_id) REFERENCES movies(id), - FOREIGN KEY(user_id) REFERENCES users(id) -);; - -CREATE UNIQUE INDEX IF NOT EXISTS ratings_index ON ratings ( - movie_id, - user_id -);; diff --git a/unwind/sql/00000000-init-1.sql b/unwind/sql/00000000-init-1.sql deleted file mode 100644 index 85d40a6..0000000 --- a/unwind/sql/00000000-init-1.sql +++ /dev/null @@ -1,40 +0,0 @@ --- add original_title to movies table - --- see https://www.sqlite.org/lang_altertable.html#caution --- 1. Create new table --- 2. Copy data --- 3. Drop old table --- 4. Rename new into old - -CREATE TABLE _migrate_movies ( - id TEXT NOT NULL PRIMARY KEY, - title TEXT NOT NULL, - original_title TEXT, - release_year NUMBER NOT NULL, - media_type TEXT NOT NULL, - imdb_id TEXT NOT NULL UNIQUE, - score NUMBER, - runtime NUMBER, - genres TEXT NOT NULL, - updated TEXT NOT NULL -);; - -INSERT INTO _migrate_movies -SELECT - id, - title, - NULL, - release_year, - media_type, - imdb_id, - score, - runtime, - genres, - updated -FROM movies -WHERE true;; - -DROP TABLE movies;; - -ALTER TABLE _migrate_movies -RENAME TO movies;; diff --git a/unwind/sql/00000000-init-2.sql b/unwind/sql/00000000-init-2.sql deleted file mode 100644 index 68fad70..0000000 --- a/unwind/sql/00000000-init-2.sql +++ /dev/null @@ -1,46 +0,0 @@ --- only set original_title if it differs from title, --- and normalize media_type with an extra table. - -CREATE TABLE mediatypes ( - id INTEGER PRIMARY KEY NOT NULL, - name TEXT NOT NULL UNIQUE -);; - -INSERT INTO mediatypes (name) -SELECT DISTINCT media_type -FROM movies -WHERE true;; - -CREATE TABLE _migrate_movies ( - id TEXT PRIMARY KEY NOT NULL, - title TEXT NOT NULL, - original_title TEXT, - release_year INTEGER NOT NULL, - media_type_id INTEGER NOT NULL, - imdb_id TEXT NOT NULL UNIQUE, - score INTEGER, - runtime INTEGER, - genres TEXT NOT NULL, - updated TEXT NOT NULL, - FOREIGN KEY(media_type_id) REFERENCES mediatypes(id) -);; - -INSERT INTO _migrate_movies -SELECT - id, - title, - (CASE WHEN original_title=title THEN NULL ELSE original_title END), - release_year, - (SELECT id FROM mediatypes WHERE name=media_type) AS media_type_id, - imdb_id, - score, - runtime, - genres, - updated -FROM movies -WHERE true;; - -DROP TABLE movies;; - -ALTER TABLE _migrate_movies -RENAME TO movies;; diff --git a/unwind/sql/00000000-init-3.sql b/unwind/sql/00000000-init-3.sql deleted file mode 100644 index 98380c7..0000000 --- a/unwind/sql/00000000-init-3.sql +++ /dev/null @@ -1,62 +0,0 @@ --- add convenient view for movies - -CREATE VIEW IF NOT EXISTS movies_view -AS SELECT - movies.id, - movies.title, - movies.original_title, - movies.release_year, - mediatypes.name AS media_type, - movies.imdb_id, - movies.score, - movies.runtime, - movies.genres, - movies.updated -FROM movies -JOIN mediatypes ON mediatypes.id=movies.media_type_id;; - -CREATE TRIGGER IF NOT EXISTS insert_movies_view - INSTEAD OF INSERT - ON movies_view -BEGIN - INSERT INTO movies ( - id, - title, - original_title, - release_year, - media_type_id, - imdb_id, - score, - runtime, - genres, - updated - ) VALUES ( - NEW.id, - NEW.title, - NEW.original_title, - NEW.release_year, - (SELECT id FROM mediatypes WHERE name=NEW.media_type), - NEW.imdb_id, - NEW.score, - NEW.runtime, - NEW.genres, - NEW.updated - ); -END;; - -CREATE TRIGGER IF NOT EXISTS update_movies_view - INSTEAD OF UPDATE OF media_type - ON movies_view -BEGIN - UPDATE movies - SET media_type_id=(SELECT id FROM mediatypes WHERE name=NEW.media_type) - WHERE id=OLD.id; -END;; - -CREATE TRIGGER IF NOT EXISTS delete_movies_view - INSTEAD OF DELETE - ON movies_view -BEGIN - DELETE FROM movies - WHERE movies.id=OLD.id; -END;; diff --git a/unwind/sql/00000000-init-4.sql b/unwind/sql/00000000-init-4.sql deleted file mode 100644 index 984ef37..0000000 --- a/unwind/sql/00000000-init-4.sql +++ /dev/null @@ -1,37 +0,0 @@ --- denormalize movie media_type - -CREATE TABLE _migrate_movies ( - id TEXT PRIMARY KEY NOT NULL, - title TEXT NOT NULL, - original_title TEXT, - release_year INTEGER NOT NULL, - media_type TEXT NOT NULL, - imdb_id TEXT NOT NULL UNIQUE, - score INTEGER, - runtime INTEGER, - genres TEXT NOT NULL, - updated TEXT NOT NULL -);; - -INSERT INTO _migrate_movies -SELECT - id, - title, - original_title, - release_year, - (SELECT name FROM mediatypes WHERE id=media_type_id) AS media_type, - imdb_id, - score, - runtime, - genres, - updated -FROM movies -WHERE true;; - -DROP VIEW movies_view;; -DROP TABLE mediatypes;; - -DROP TABLE movies;; - -ALTER TABLE _migrate_movies -RENAME TO movies;; diff --git a/unwind/sql/00000001-fix-db.sql.disabled b/unwind/sql/00000001-fix-db.sql.disabled deleted file mode 100644 index e6376a8..0000000 --- a/unwind/sql/00000001-fix-db.sql.disabled +++ /dev/null @@ -1,2 +0,0 @@ --- see the commit of this file for details. -;; diff --git a/unwind/sql/20210705-224139.sql b/unwind/sql/20210705-224139.sql deleted file mode 100644 index e714b4e..0000000 --- a/unwind/sql/20210705-224139.sql +++ /dev/null @@ -1,8 +0,0 @@ --- add groups table - -CREATE TABLE groups ( - id TEXT PRIMARY KEY NOT NULL, - name TEXT NOT NULL, - secret TEXT NOT NULL, - users TEXT NOT NULL -- JSON array -);; diff --git a/unwind/sql/20210711-172808--progress-table.sql b/unwind/sql/20210711-172808--progress-table.sql deleted file mode 100644 index 1ee6a5f..0000000 --- a/unwind/sql/20210711-172808--progress-table.sql +++ /dev/null @@ -1,7 +0,0 @@ --- add progress table - -CREATE TABLE progress ( - id TEXT PRIMARY KEY NOT NULL, - state TEXT NOT NULL, - started TEXT NOT NULL -);; diff --git a/unwind/sql/20210720-213416.sql b/unwind/sql/20210720-213416.sql deleted file mode 100644 index 286e094..0000000 --- a/unwind/sql/20210720-213416.sql +++ /dev/null @@ -1,36 +0,0 @@ --- add IMDb vote count - -CREATE TABLE _migrate_movies ( - id TEXT PRIMARY KEY NOT NULL, - title TEXT NOT NULL, - original_title TEXT, - release_year INTEGER NOT NULL, - media_type TEXT NOT NULL, - imdb_id TEXT NOT NULL UNIQUE, - imdb_score INTEGER, - imdb_votes INTEGER, - runtime INTEGER, - genres TEXT NOT NULL, - updated TEXT NOT NULL -);; - -INSERT INTO _migrate_movies -SELECT - id, - title, - original_title, - release_year, - media_type, - imdb_id, - score AS imdb_score, - NULL AS imdb_votes, - runtime, - genres, - updated -FROM movies -WHERE true;; - -DROP TABLE movies;; - -ALTER TABLE _migrate_movies -RENAME TO movies;; diff --git a/unwind/sql/20210720-223416.sql b/unwind/sql/20210720-223416.sql deleted file mode 100644 index 95e1b78..0000000 --- a/unwind/sql/20210720-223416.sql +++ /dev/null @@ -1,24 +0,0 @@ --- add IMDb vote count - -CREATE TABLE _migrate_progress ( - id TEXT PRIMARY KEY NOT NULL, - type TEXT NOT NULL, - state TEXT NOT NULL, - started TEXT NOT NULL, - stopped TEXT -);; - -INSERT INTO _migrate_progress -SELECT - id, - 'import-imdb-movies' AS type, - state, - started, - NULL AS stopped -FROM progress -WHERE true;; - -DROP TABLE progress;; - -ALTER TABLE _migrate_progress -RENAME TO progress;; diff --git a/unwind/sql/20210721-213417.sql b/unwind/sql/20210721-213417.sql deleted file mode 100644 index 33e891a..0000000 --- a/unwind/sql/20210721-213417.sql +++ /dev/null @@ -1,38 +0,0 @@ --- add creation timestamp to movies - -CREATE TABLE _migrate_movies ( - id TEXT PRIMARY KEY NOT NULL, - title TEXT NOT NULL, - original_title TEXT, - release_year INTEGER NOT NULL, - media_type TEXT NOT NULL, - imdb_id TEXT NOT NULL UNIQUE, - imdb_score INTEGER, - imdb_votes INTEGER, - runtime INTEGER, - genres TEXT NOT NULL, - created TEXT NOT NULL, - updated TEXT NOT NULL -);; - -INSERT INTO _migrate_movies -SELECT - id, - title, - original_title, - release_year, - media_type, - imdb_id, - imdb_score, - imdb_votes, - runtime, - genres, - updated AS created, - updated -FROM movies -WHERE true;; - -DROP TABLE movies;; - -ALTER TABLE _migrate_movies -RENAME TO movies;; diff --git a/unwind/sql/20210728-223416.sql b/unwind/sql/20210728-223416.sql deleted file mode 100644 index 1581060..0000000 --- a/unwind/sql/20210728-223416.sql +++ /dev/null @@ -1,24 +0,0 @@ --- add IMDb vote count - -CREATE TABLE _migrate_progress ( - id TEXT PRIMARY KEY NOT NULL, - type TEXT NOT NULL, - state TEXT NOT NULL, - started TEXT NOT NULL, - stopped TEXT -);; - -INSERT INTO _migrate_progress -SELECT - id, - type, - '{"percent":' || state || '}' AS state, - started, - stopped -FROM progress -WHERE true;; - -DROP TABLE progress;; - -ALTER TABLE _migrate_progress -RENAME TO progress;; diff --git a/unwind/sql/20210801-201151--add-user-secret.sql b/unwind/sql/20210801-201151--add-user-secret.sql deleted file mode 100644 index 3294a56..0000000 --- a/unwind/sql/20210801-201151--add-user-secret.sql +++ /dev/null @@ -1,22 +0,0 @@ --- add secret to users - -CREATE TABLE _migrate_users ( - id TEXT PRIMARY KEY NOT NULL, - imdb_id TEXT NOT NULL UNIQUE, - name TEXT NOT NULL, - secret TEXT NOT NULL -);; - -INSERT INTO _migrate_users -SELECT - id, - imdb_id, - name, - '' AS secret -FROM users -WHERE true;; - -DROP TABLE users;; - -ALTER TABLE _migrate_users -RENAME TO users;; diff --git a/unwind/sql/20210802-212312--add-group-admins.sql b/unwind/sql/20210802-212312--add-group-admins.sql deleted file mode 100644 index 13f3105..0000000 --- a/unwind/sql/20210802-212312--add-group-admins.sql +++ /dev/null @@ -1,45 +0,0 @@ --- add group admins - ---- remove secrets from groups -CREATE TABLE _migrate_groups ( - id TEXT PRIMARY KEY NOT NULL, - name TEXT NOT NULL, - users TEXT NOT NULL -- JSON array -);; - -INSERT INTO _migrate_groups -SELECT - id, - name, - users -FROM groups -WHERE true;; - -DROP TABLE groups;; - -ALTER TABLE _migrate_groups -RENAME TO groups;; - ---- add group access to users -CREATE TABLE _migrate_users ( - id TEXT PRIMARY KEY NOT NULL, - imdb_id TEXT NOT NULL UNIQUE, - name TEXT NOT NULL, - secret TEXT NOT NULL, - groups TEXT NOT NULL -- JSON array -);; - -INSERT INTO _migrate_users -SELECT - id, - imdb_id, - name, - secret, - '[]' AS groups -FROM users -WHERE true;; - -DROP TABLE users;; - -ALTER TABLE _migrate_users -RENAME TO users;; diff --git a/unwind/sql/20240511-001949--remove-genres-notnull.sql b/unwind/sql/20240511-001949--remove-genres-notnull.sql deleted file mode 100644 index 98a7c16..0000000 --- a/unwind/sql/20240511-001949--remove-genres-notnull.sql +++ /dev/null @@ -1,38 +0,0 @@ --- remove NOTNULL constraint from movies.genres - -CREATE TABLE _migrate_movies ( - id TEXT PRIMARY KEY NOT NULL, - title TEXT NOT NULL, - original_title TEXT, - release_year INTEGER NOT NULL, - media_type TEXT NOT NULL, - imdb_id TEXT NOT NULL UNIQUE, - imdb_score INTEGER, - imdb_votes INTEGER, - runtime INTEGER, - genres TEXT, - created TEXT NOT NULL, - updated TEXT NOT NULL -);; - -INSERT INTO _migrate_movies -SELECT - id, - title, - original_title, - release_year, - media_type, - imdb_id, - imdb_score, - imdb_votes, - runtime, - genres, - created, - updated -FROM movies -WHERE true;; - -DROP TABLE movies;; - -ALTER TABLE _migrate_movies -RENAME TO movies;; diff --git a/unwind/types.py b/unwind/types.py index 94c0e00..76ce3e8 100644 --- a/unwind/types.py +++ b/unwind/types.py @@ -1,9 +1,13 @@ import re -from typing import cast +from typing import NewType, cast import ulid from ulid.hints import Buffer +type JSONScalar = int | float | str | None +type JSON = JSONScalar | list["JSON"] | dict[str, "JSON"] +type JSONObject = dict[str, JSON] + class ULID(ulid.ULID): """Extended ULID type. @@ -29,3 +33,14 @@ class ULID(ulid.ULID): buffer = cast(memoryview, ulid.new().memory) super().__init__(buffer) + + +AwardId = NewType("AwardId", ULID) +GroupId = NewType("GroupId", ULID) +ImdbMovieId = NewType("ImdbMovieId", str) +MovieId = NewType("MovieId", ULID) +MovieIdStr = NewType("MovieIdStr", str) +RatingId = NewType("RatingId", ULID) +Score100 = NewType("Score100", int) # [0, 100] +UserId = NewType("UserId", ULID) +UserIdStr = NewType("UserIdStr", str) diff --git a/unwind/utils.py b/unwind/utils.py index 6ad4d32..d1733a8 100644 --- a/unwind/utils.py +++ b/unwind/utils.py @@ -1,8 +1,12 @@ import base64 import hashlib +import json import secrets +from functools import partial from typing import Any, TypedDict +json_dump = partial(json.dumps, separators=(",", ":")) + def b64encode(b: bytes) -> str: return base64.b64encode(b).decode().rstrip("=") diff --git a/unwind/web.py b/unwind/web.py index b4ba575..3f62a53 100644 --- a/unwind/web.py +++ b/unwind/web.py @@ -3,7 +3,7 @@ import contextlib import logging import secrets from json.decoder import JSONDecodeError -from typing import Literal, overload +from typing import Any, Literal, Never, TypeGuard, overload from starlette.applications import Starlette from starlette.authentication import ( @@ -20,15 +20,15 @@ from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.cors import CORSMiddleware from starlette.middleware.gzip import GZipMiddleware -from starlette.requests import HTTPConnection +from starlette.requests import HTTPConnection, Request from starlette.responses import JSONResponse from starlette.routing import Mount, Route from . import config, db, imdb, imdb_import, web_models from .db import close_connection_pool, find_movies, find_ratings, open_connection_pool from .middleware.responsetime import ResponseTimeMiddleware -from .models import Group, Movie, User, asplain -from .types import ULID +from .models import Access, Group, Movie, User, asplain +from .types import JSON, ULID from .utils import b64decode, b64encode, phc_compare, phc_scrypt log = logging.getLogger(__name__) @@ -83,11 +83,11 @@ class BearerAuthBackend(AuthenticationBackend): return AuthCredentials(["authenticated", *roles]), user -def truthy(s: str): +def truthy(s: str | None) -> bool: return bool(s) and s.lower() in {"1", "yes", "true"} -_Yearcomp = Literal["<", "=", ">"] +type _Yearcomp = Literal["<", "=", ">"] def yearcomp(s: str) -> tuple[_Yearcomp, int] | None: @@ -103,7 +103,7 @@ def yearcomp(s: str) -> tuple[_Yearcomp, int] | None: def as_int( - x, *, max: int | None = None, min: int | None = 1, default: int | None = None + x: Any, *, max: int | None = None, min: int | None = 1, default: int | None = None ) -> int: try: if not isinstance(x, int): @@ -121,9 +121,9 @@ def as_int( return default -def as_ulid(s: str) -> ULID: +def as_ulid(s: Any) -> ULID: try: - if not s: + if not isinstance(s, str) or not s: raise ValueError("Invalid ULID.") return ULID(s) @@ -133,14 +133,17 @@ def as_ulid(s: str) -> ULID: @overload -async def json_from_body(request) -> dict: ... +async def json_from_body(request: Request) -> dict[str, JSON]: ... @overload -async def json_from_body(request, keys: list[str]) -> list: ... +async def json_from_body(request: Request, keys: list[str]) -> list[JSON]: ... -async def json_from_body(request, keys: list[str] | None = None): +async def json_from_body( + request: Request, keys: list[str] | None = None +) -> dict[str, JSON] | list[JSON]: + data: dict[str, JSON] if not await request.body(): data = {} @@ -150,6 +153,9 @@ async def json_from_body(request, keys: list[str] | None = None): except JSONDecodeError as err: raise HTTPException(422, "Invalid JSON content.") from err + if not isinstance(data, dict): + raise HTTPException(422, f"Invalid JSON type: {type(data)!a}") + if not keys: return data @@ -159,11 +165,11 @@ async def json_from_body(request, keys: list[str] | None = None): raise HTTPException(422, f"Missing data for key: {err.args[0]}") from err -def is_admin(request): +def is_admin(request: Request) -> bool: return "admin" in request.auth.scopes -async def auth_user(request) -> User | None: +async def auth_user(request: Request) -> User | None: if not isinstance(request.user, AuthedUser): return @@ -192,7 +198,7 @@ def route(path: str, *, methods: list[str] | None = None, **kwds): @route("/groups/{group_id}/ratings") -async def get_ratings_for_group(request): +async def get_ratings_for_group(request: Request) -> JSONResponse: group_id = as_ulid(request.path_params["group_id"]) async with db.new_connection() as conn: @@ -229,11 +235,13 @@ async def get_ratings_for_group(request): user_ids=user_ids, ) - ratings = (web_models.Rating(**r) for r in rows) + ratings = [web_models.Rating(**r) for r in rows] - aggr = web_models.aggregate_ratings(ratings, user_ids) + awards = await db.get_awards(conn, imdb_ids=[r.movie_imdb_id for r in ratings]) - resp = tuple(asplain(r) for r in aggr) + aggrs = web_models.aggregate_ratings(ratings, user_ids, awards_dict=awards) + + resp = tuple(asplain(r) for r in aggrs) return JSONResponse(resp) @@ -250,13 +258,13 @@ def not_found(reason: str = "Not Found"): return JSONResponse({"error": reason}, status_code=404) -def not_implemented(): +def not_implemented() -> Never: raise HTTPException(404, "Not yet implemented.") @route("/movies") @requires(["authenticated"]) -async def list_movies(request): +async def list_movies(request: Request) -> JSONResponse: params = request.query_params user = await auth_user(request) @@ -329,13 +337,13 @@ async def list_movies(request): @route("/movies", methods=["POST"]) @requires(["authenticated", "admin"]) -async def add_movie(request): +async def add_movie(request: Request) -> JSONResponse: not_implemented() @route("/movies/_reload_imdb", methods=["GET"]) @requires(["authenticated", "admin"]) -async def progress_for_load_imdb_movies(request): +async def progress_for_load_imdb_movies(request: Request) -> JSONResponse: async with db.new_connection() as conn: progress = await db.get_import_progress(conn) if not progress: @@ -371,7 +379,7 @@ _import_lock = asyncio.Lock() @route("/movies/_reload_imdb", methods=["POST"]) @requires(["authenticated", "admin"]) -async def load_imdb_movies(request): +async def load_imdb_movies(request: Request) -> JSONResponse: params = request.query_params force = truthy(params.get("force")) @@ -395,7 +403,7 @@ async def load_imdb_movies(request): @route("/users") @requires(["authenticated", "admin"]) -async def list_users(request): +async def list_users(request: Request) -> JSONResponse: async with db.new_connection() as conn: users = await db.get_all(conn, User) @@ -404,7 +412,7 @@ async def list_users(request): @route("/users", methods=["POST"]) @requires(["authenticated", "admin"]) -async def add_user(request): +async def add_user(request: Request) -> JSONResponse: name, imdb_id = await json_from_body(request, ["name", "imdb_id"]) # XXX restrict name @@ -426,7 +434,7 @@ async def add_user(request): @route("/users/{user_id}") @requires(["authenticated"]) -async def show_user(request): +async def show_user(request: Request) -> JSONResponse: user_id = as_ulid(request.path_params["user_id"]) if is_admin(request): @@ -455,7 +463,7 @@ async def show_user(request): @route("/users/{user_id}", methods=["DELETE"]) @requires(["authenticated", "admin"]) -async def remove_user(request): +async def remove_user(request: Request) -> JSONResponse: user_id = as_ulid(request.path_params["user_id"]) async with db.new_connection() as conn: @@ -473,7 +481,7 @@ async def remove_user(request): @route("/users/{user_id}", methods=["PATCH"]) @requires(["authenticated"]) -async def modify_user(request): +async def modify_user(request: Request) -> JSONResponse: user_id = as_ulid(request.path_params["user_id"]) if is_admin(request): @@ -520,9 +528,13 @@ async def modify_user(request): return JSONResponse(asplain(user)) +def is_valid_access(x: Any) -> TypeGuard[Access]: + return isinstance(x, str) and x in set("riw") + + @route("/users/{user_id}/groups", methods=["POST"]) @requires(["authenticated", "admin"]) -async def add_group_to_user(request): +async def add_group_to_user(request: Request) -> JSONResponse: user_id = as_ulid(request.path_params["user_id"]) async with db.new_connection() as conn: @@ -537,7 +549,7 @@ async def add_group_to_user(request): if not group: return not_found("Group not found") - if access not in set("riw"): + if not is_valid_access(access): raise HTTPException(422, "Invalid access level.") user.set_access(group_id, access) @@ -549,19 +561,19 @@ async def add_group_to_user(request): @route("/users/{user_id}/ratings") @requires(["private"]) -async def ratings_for_user(request): +async def ratings_for_user(request: Request) -> JSONResponse: not_implemented() @route("/users/{user_id}/ratings", methods=["PUT"]) @requires("authenticated") -async def set_rating_for_user(request): +async def set_rating_for_user(request: Request) -> JSONResponse: not_implemented() @route("/users/_reload_ratings", methods=["POST"]) @requires(["authenticated", "admin"]) -async def load_imdb_user_ratings(request): +async def load_imdb_user_ratings(request: Request) -> JSONResponse: ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()] return JSONResponse({"new_ratings": [asplain(r) for r in ratings]}) @@ -569,7 +581,7 @@ async def load_imdb_user_ratings(request): @route("/groups") @requires(["authenticated", "admin"]) -async def list_groups(request): +async def list_groups(request: Request) -> JSONResponse: async with db.new_connection() as conn: groups = await db.get_all(conn, Group) @@ -578,7 +590,7 @@ async def list_groups(request): @route("/groups", methods=["POST"]) @requires(["authenticated", "admin"]) -async def add_group(request): +async def add_group(request: Request) -> JSONResponse: (name,) = await json_from_body(request, ["name"]) # XXX restrict name @@ -592,7 +604,7 @@ async def add_group(request): @route("/groups/{group_id}/users", methods=["POST"]) @requires(["authenticated"]) -async def add_user_to_group(request): +async def add_user_to_group(request: Request) -> JSONResponse: group_id = as_ulid(request.path_params["group_id"]) async with db.new_connection() as conn: group = await db.get(conn, Group, id=str(group_id)) @@ -628,11 +640,12 @@ async def add_user_to_group(request): return JSONResponse(asplain(group)) -async def http_exception(request, exc): +async def http_exception(request: Request, exc: Exception) -> JSONResponse: + assert isinstance(exc, HTTPException) return JSONResponse({"error": exc.detail}, status_code=exc.status_code) -def auth_error(request, err): +def auth_error(conn: HTTPConnection, err: Exception) -> JSONResponse: return unauthorized(str(err)) diff --git a/unwind/web_models.py b/unwind/web_models.py index 6e83e1d..42cb4dc 100644 --- a/unwind/web_models.py +++ b/unwind/web_models.py @@ -1,23 +1,22 @@ from dataclasses import dataclass from typing import Container, Iterable -from . import imdb, models +from . import imdb, models, types -URL = str -Score100 = int # [0, 100] +type URL = str @dataclass class Rating: canonical_title: str - imdb_score: Score100 | None + imdb_score: types.Score100 | None imdb_votes: int | None media_type: str - movie_imdb_id: str + movie_imdb_id: types.ImdbMovieId original_title: str | None release_year: int - user_id: str | None - user_score: Score100 | None + user_id: types.UserIdStr | None + user_score: types.Score100 | None @classmethod def from_movie(cls, movie: models.Movie, *, rating: models.Rating | None = None): @@ -37,13 +36,14 @@ class Rating: @dataclass class RatingAggregate: canonical_title: str - imdb_score: Score100 | None + imdb_score: types.Score100 | None imdb_votes: int | None link: URL media_type: str original_title: str | None - user_scores: list[Score100] + user_scores: list[types.Score100] year: int + awards: list[str] @classmethod def from_movie(cls, movie: models.Movie, *, ratings: Iterable[models.Rating] = []): @@ -56,15 +56,23 @@ class RatingAggregate: original_title=movie.original_title, user_scores=[r.score for r in ratings], year=movie.release_year, + awards=[], ) def aggregate_ratings( - ratings: Iterable[Rating], user_ids: Container[str] + ratings: Iterable[Rating], + user_ids: Container[types.UserIdStr], + *, + awards_dict: dict[types.ImdbMovieId, list[models.Award]] | None = None, ) -> Iterable[RatingAggregate]: - aggr: dict[str, RatingAggregate] = {} + if awards_dict is None: + awards_dict = {} + + aggr: dict[types.ImdbMovieId, RatingAggregate] = {} for r in ratings: + awards = awards_dict.get(r.movie_imdb_id, []) mov = aggr.setdefault( r.movie_imdb_id, RatingAggregate( @@ -76,6 +84,7 @@ def aggregate_ratings( original_title=r.original_title, user_scores=[], year=r.release_year, + awards=[f"{a.category}:{a.position}" for a in awards], ), ) # XXX do we need this? why don't we just get the ratings we're supposed to aggregate?