Merge branch 'feat/charts'

This commit is contained in:
ducklet 2024-05-20 17:01:10 +02:00
commit 73d5b1fd73
44 changed files with 1197 additions and 842 deletions

View file

@ -18,15 +18,18 @@ RUN pip install --no-cache-dir --upgrade \
USER 10000:10001 USER 10000:10001
COPY run ./ COPY alembic.ini entrypoint.sh pyproject.toml run ./
COPY alembic ./alembic
COPY scripts ./scripts COPY scripts ./scripts
COPY unwind ./unwind COPY unwind ./unwind
RUN pip install --no-cache-dir --editable .
ENV UNWIND_DATA="/data" ENV UNWIND_DATA="/data"
VOLUME $UNWIND_DATA VOLUME $UNWIND_DATA
ENV UNWIND_PORT=8097 ENV UNWIND_PORT=8097
EXPOSE $UNWIND_PORT EXPOSE $UNWIND_PORT
ENTRYPOINT ["/var/app/run"] ENTRYPOINT ["/var/app/entrypoint.sh"]
CMD ["server"] CMD ["server"]

39
alembic.ini Normal file
View file

@ -0,0 +1,39 @@
[alembic]
script_location = alembic
file_template = %%(epoch)s-%%(rev)s_%%(slug)s
timezone = UTC
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

108
alembic/env.py Normal file
View file

@ -0,0 +1,108 @@
import asyncio
from logging.config import fileConfig
import sqlalchemy as sa
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config
from alembic import context
from unwind import db, models
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
def is_different_type(
context,
inspected_column: sa.Column,
metadata_column: sa.Column,
inspected_type: sa.types.TypeEngine,
metadata_type: sa.types.TypeEngine,
) -> bool | None:
# We used "TEXT" in our manual SQL, which in SQLite is the same as VARCHAR, but
# for SQLAlchemy/Alembic looks different.
equiv_types = [(sa.TEXT, sa.String)]
for types in equiv_types:
if isinstance(inspected_type, types) and isinstance(metadata_type, types):
return False
return None # defer to default compare implementation
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
context.configure(
url=db._connection_uri(),
target_metadata=models.metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=is_different_type,
render_as_batch=True,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
context.configure(
connection=connection,
target_metadata=models.metadata,
compare_type=is_different_type,
render_as_batch=True,
)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = async_engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
url=db._connection_uri(),
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
# Support having a (sync) connection passed in from another script.
if (conn := config.attributes.get("connection")) and isinstance(
conn, sa.Connection
):
do_run_migrations(conn)
else:
asyncio.run(run_async_migrations())
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

26
alembic/script.py.mako Normal file
View file

@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: str | None = ${repr(down_revision)}
branch_labels: str | Sequence[str] | None = ${repr(branch_labels)}
depends_on: str | Sequence[str] | None = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View file

@ -0,0 +1,69 @@
"""fix data types
Revision ID: c08ae04dc482
Revises:
Create Date: 2024-05-18 16:24:31.152480+00:00
"""
from typing import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "c08ae04dc482"
down_revision: str | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("ratings", schema=None) as batch_op:
batch_op.alter_column(
"score",
existing_type=sa.NUMERIC(),
type_=sa.Integer(),
existing_nullable=False,
)
batch_op.alter_column(
"favorite",
existing_type=sa.NUMERIC(),
type_=sa.Integer(),
existing_nullable=True,
)
batch_op.alter_column(
"finished",
existing_type=sa.NUMERIC(),
type_=sa.Integer(),
existing_nullable=True,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("ratings", schema=None) as batch_op:
batch_op.alter_column(
"finished",
existing_type=sa.Integer(),
type_=sa.NUMERIC(),
existing_nullable=True,
)
batch_op.alter_column(
"favorite",
existing_type=sa.Integer(),
type_=sa.NUMERIC(),
existing_nullable=True,
)
batch_op.alter_column(
"score",
existing_type=sa.Integer(),
type_=sa.NUMERIC(),
existing_nullable=False,
)
# ### end Alembic commands ###

View file

@ -0,0 +1,44 @@
"""add awards table
Revision ID: 62882ef5e3ff
Revises: c08ae04dc482
Create Date: 2024-05-18 16:35:10.145964+00:00
"""
from typing import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "62882ef5e3ff"
down_revision: str | None = "c08ae04dc482"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"awards",
sa.Column("id", sa.String(), nullable=False),
sa.Column("movie_id", sa.String(), nullable=False),
sa.Column("category", sa.String(), nullable=False),
sa.Column("details", sa.String(), nullable=False),
sa.Column("created", sa.String(), nullable=False),
sa.Column("updated", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["movie_id"],
["movies.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("awards")
# ### end Alembic commands ###

View file

@ -0,0 +1,41 @@
"""use named constraints
See https://alembic.sqlalchemy.org/en/latest/naming.html
Revision ID: f17c7ca9afa4
Revises: 62882ef5e3ff
Create Date: 2024-05-18 17:06:27.696713+00:00
"""
from typing import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "f17c7ca9afa4"
down_revision: str | None = "62882ef5e3ff"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("movies") as batch_op:
batch_op.create_unique_constraint(batch_op.f("uq_movies_imdb_id"), ["imdb_id"])
with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.create_unique_constraint(batch_op.f("uq_users_imdb_id"), ["imdb_id"])
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.drop_constraint(batch_op.f("uq_users_imdb_id"), type_="unique")
with op.batch_alter_table("movies", schema=None) as batch_op:
batch_op.drop_constraint(batch_op.f("uq_movies_imdb_id"), type_="unique")
# ### end Alembic commands ###

View file

@ -0,0 +1,38 @@
"""remove db_patches table
We replace our old patch process with Alembic's.
Revision ID: 8b06e4916840
Revises: f17c7ca9afa4
Create Date: 2024-05-19 00:11:06.730421+00:00
"""
from typing import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "8b06e4916840"
down_revision: str | None = "f17c7ca9afa4"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("db_patches")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"db_patches",
sa.Column("id", sa.INTEGER(), nullable=False),
sa.Column("current", sa.VARCHAR(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###

4
entrypoint.sh Executable file
View file

@ -0,0 +1,4 @@
#!/bin/sh -eu
alembic upgrade head
exec ./run "$@"

111
poetry.lock generated
View file

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]] [[package]]
name = "aiosqlite" name = "aiosqlite"
@ -18,6 +18,25 @@ typing_extensions = ">=4.0"
dev = ["attribution (==1.7.0)", "black (==24.2.0)", "coverage[toml] (==7.4.1)", "flake8 (==7.0.0)", "flake8-bugbear (==24.2.6)", "flit (==3.9.0)", "mypy (==1.8.0)", "ufmt (==2.3.0)", "usort (==1.0.8.post1)"] dev = ["attribution (==1.7.0)", "black (==24.2.0)", "coverage[toml] (==7.4.1)", "flake8 (==7.0.0)", "flake8-bugbear (==24.2.6)", "flit (==3.9.0)", "mypy (==1.8.0)", "ufmt (==2.3.0)", "usort (==1.0.8.post1)"]
docs = ["sphinx (==7.2.6)", "sphinx-mdinclude (==0.5.3)"] docs = ["sphinx (==7.2.6)", "sphinx-mdinclude (==0.5.3)"]
[[package]]
name = "alembic"
version = "1.13.1"
description = "A database migration tool for SQLAlchemy."
optional = false
python-versions = ">=3.8"
files = [
{file = "alembic-1.13.1-py3-none-any.whl", hash = "sha256:2edcc97bed0bd3272611ce3a98d98279e9c209e7186e43e75bbb1b2bdfdbcc43"},
{file = "alembic-1.13.1.tar.gz", hash = "sha256:4932c8558bf68f2ee92b9bbcb8218671c627064d5b08939437af6d77dc05e595"},
]
[package.dependencies]
Mako = "*"
SQLAlchemy = ">=1.3.0"
typing-extensions = ">=4"
[package.extras]
tz = ["backports.zoneinfo"]
[[package]] [[package]]
name = "anyio" name = "anyio"
version = "4.3.0" version = "4.3.0"
@ -346,6 +365,94 @@ files = [
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
] ]
[[package]]
name = "mako"
version = "1.3.5"
description = "A super-fast templating language that borrows the best ideas from the existing templating languages."
optional = false
python-versions = ">=3.8"
files = [
{file = "Mako-1.3.5-py3-none-any.whl", hash = "sha256:260f1dbc3a519453a9c856dedfe4beb4e50bd5a26d96386cb6c80856556bb91a"},
{file = "Mako-1.3.5.tar.gz", hash = "sha256:48dbc20568c1d276a2698b36d968fa76161bf127194907ea6fc594fa81f943bc"},
]
[package.dependencies]
MarkupSafe = ">=0.9.2"
[package.extras]
babel = ["Babel"]
lingua = ["lingua"]
testing = ["pytest"]
[[package]]
name = "markupsafe"
version = "2.1.5"
description = "Safely add untrusted strings to HTML/XML markup."
optional = false
python-versions = ">=3.7"
files = [
{file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"},
{file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5"},
{file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46"},
{file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f"},
{file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900"},
{file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff"},
{file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad"},
{file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd"},
{file = "MarkupSafe-2.1.5-cp310-cp310-win32.whl", hash = "sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4"},
{file = "MarkupSafe-2.1.5-cp310-cp310-win_amd64.whl", hash = "sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5"},
{file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f"},
{file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2"},
{file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced"},
{file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5"},
{file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c"},
{file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f"},
{file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a"},
{file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f"},
{file = "MarkupSafe-2.1.5-cp311-cp311-win32.whl", hash = "sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906"},
{file = "MarkupSafe-2.1.5-cp311-cp311-win_amd64.whl", hash = "sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617"},
{file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1"},
{file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4"},
{file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee"},
{file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5"},
{file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b"},
{file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a"},
{file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f"},
{file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169"},
{file = "MarkupSafe-2.1.5-cp312-cp312-win32.whl", hash = "sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad"},
{file = "MarkupSafe-2.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb"},
{file = "MarkupSafe-2.1.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c8b29db45f8fe46ad280a7294f5c3ec36dbac9491f2d1c17345be8e69cc5928f"},
{file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec6a563cff360b50eed26f13adc43e61bc0c04d94b8be985e6fb24b81f6dcfdf"},
{file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a549b9c31bec33820e885335b451286e2969a2d9e24879f83fe904a5ce59d70a"},
{file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4f11aa001c540f62c6166c7726f71f7573b52c68c31f014c25cc7901deea0b52"},
{file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7b2e5a267c855eea6b4283940daa6e88a285f5f2a67f2220203786dfa59b37e9"},
{file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:2d2d793e36e230fd32babe143b04cec8a8b3eb8a3122d2aceb4a371e6b09b8df"},
{file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ce409136744f6521e39fd8e2a24c53fa18ad67aa5bc7c2cf83645cce5b5c4e50"},
{file = "MarkupSafe-2.1.5-cp37-cp37m-win32.whl", hash = "sha256:4096e9de5c6fdf43fb4f04c26fb114f61ef0bf2e5604b6ee3019d51b69e8c371"},
{file = "MarkupSafe-2.1.5-cp37-cp37m-win_amd64.whl", hash = "sha256:4275d846e41ecefa46e2015117a9f491e57a71ddd59bbead77e904dc02b1bed2"},
{file = "MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a"},
{file = "MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46"},
{file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532"},
{file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab"},
{file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68"},
{file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0"},
{file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4"},
{file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3"},
{file = "MarkupSafe-2.1.5-cp38-cp38-win32.whl", hash = "sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff"},
{file = "MarkupSafe-2.1.5-cp38-cp38-win_amd64.whl", hash = "sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029"},
{file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf"},
{file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2"},
{file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8"},
{file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3"},
{file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465"},
{file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e"},
{file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea"},
{file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6"},
{file = "MarkupSafe-2.1.5-cp39-cp39-win32.whl", hash = "sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf"},
{file = "MarkupSafe-2.1.5-cp39-cp39-win_amd64.whl", hash = "sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5"},
{file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"},
]
[[package]] [[package]]
name = "nodeenv" name = "nodeenv"
version = "1.8.0" version = "1.8.0"
@ -694,4 +801,4 @@ files = [
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.12" python-versions = "^3.12"
content-hash = "038fed338d6b75c17eb8eb88d36c2411ff936dab23887b70594e5ba1da518451" content-hash = "9dbc732b312d6d39fbf4e8b8af22739aad6c25312cee92736f19d3a106f93129"

View file

@ -18,6 +18,7 @@ ulid-py = "^1.1.0"
uvicorn = "^0.29.0" uvicorn = "^0.29.0"
httpx = "^0.27.0" httpx = "^0.27.0"
sqlalchemy = {version = "^2.0", extras = ["aiosqlite"]} sqlalchemy = {version = "^2.0", extras = ["aiosqlite"]}
alembic = "^1.13.1"
[tool.poetry.group.build.dependencies] [tool.poetry.group.build.dependencies]
# When we run poetry export, typing-extensions is a transient dependency via # When we run poetry export, typing-extensions is a transient dependency via

View file

@ -20,14 +20,6 @@ def a_movie(**kwds) -> models.Movie:
return models.Movie(**args) return models.Movie(**args)
@pytest.mark.asyncio
async def test_current_patch_level(conn: db.Connection):
patch_level = "some-patch-level"
assert patch_level != await db.current_patch_level(conn)
await db.set_current_patch_level(conn, patch_level)
assert patch_level == await db.current_patch_level(conn)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get(conn: db.Connection): async def test_get(conn: db.Connection):
m1 = a_movie() m1 = a_movie()

View file

@ -32,6 +32,74 @@ def admin_client() -> TestClient:
return client return client
@pytest.mark.asyncio
async def test_get_ratings_for_group_with_awards(
conn: db.Connection, unauthorized_client: TestClient
):
user = models.User(
imdb_id="ur12345678",
name="user-1",
secret="secret-1", # noqa: S106
groups=[],
)
group = models.Group(
name="group-1",
users=[models.GroupUser(id=str(user.id), name=user.name)],
)
user.groups = [models.UserGroup(id=str(group.id), access="r")]
path = app.url_path_for("get_ratings_for_group", group_id=str(group.id))
await db.add(conn, user)
await db.add(conn, group)
movie1 = models.Movie(
title="test movie",
release_year=2013,
media_type="Movie",
imdb_id="tt12345678",
genres={"genre-1"},
)
await db.add(conn, movie1)
movie2 = models.Movie(
title="test movie 2",
release_year=2014,
media_type="Movie",
imdb_id="tt12345679",
genres={"genre-2"},
)
await db.add(conn, movie2)
award1 = models.Award(
movie_id=movie1.id, category="imdb-top-250", details='{"position":23}'
)
award2 = models.Award(
movie_id=movie2.id, category="imdb-top-250", details='{"position":99}'
)
await db.add(conn, award1)
await db.add(conn, award2)
rating = models.Rating(
movie_id=movie1.id, user_id=user.id, score=66, rating_date=datetime.now(tz=UTC)
)
await db.add(conn, rating)
rating_aggregate = {
"canonical_title": movie1.title,
"imdb_score": movie1.imdb_score,
"imdb_votes": movie1.imdb_votes,
"link": imdb.movie_url(movie1.imdb_id),
"media_type": movie1.media_type,
"original_title": movie1.original_title,
"user_scores": [rating.score],
"year": movie1.release_year,
"awards": ["imdb-top-250:23"],
}
resp = unauthorized_client.get(path)
assert resp.status_code == 200
assert resp.json() == [rating_aggregate]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_ratings_for_group( async def test_get_ratings_for_group(
conn: db.Connection, unauthorized_client: TestClient conn: db.Connection, unauthorized_client: TestClient
@ -82,6 +150,7 @@ async def test_get_ratings_for_group(
"original_title": movie.original_title, "original_title": movie.original_title,
"user_scores": [rating.score], "user_scores": [rating.score],
"year": movie.release_year, "year": movie.release_year,
"awards": [],
} }
resp = unauthorized_client.get(path) resp = unauthorized_client.get(path)
@ -158,6 +227,7 @@ async def test_list_movies(
"original_title": m.original_title, "original_title": m.original_title,
"user_scores": [], "user_scores": [],
"year": m.release_year, "year": m.release_year,
"awards": [],
} }
response = authorized_client.get(path, params={"imdb_id": m.imdb_id}) response = authorized_client.get(path, params={"imdb_id": m.imdb_id})

View file

@ -1,149 +1,26 @@
import argparse import argparse
import asyncio import asyncio
import logging import logging
import secrets import sys
from base64 import b64encode
from pathlib import Path
from . import config, db, models, utils from . import cli, config
from .db import close_connection_pool, open_connection_pool
from .imdb import refresh_user_ratings_from_imdb
from .imdb_import import download_datasets, import_from_file
log = logging.getLogger(__name__) log = logging.getLogger(__package__)
async def run_add_user(user_id: str, name: str, overwrite_existing: bool):
if not user_id.startswith("ur"):
raise ValueError(f"Invalid IMDb user ID: {user_id!a}")
await open_connection_pool()
async with db.new_connection() as conn:
user = await db.get(conn, models.User, imdb_id=user_id)
if user is not None:
if overwrite_existing:
log.warning("⚠️ Overwriting existing user: %a", user)
else:
log.error("❌ User already exists: %a", user)
return
secret = secrets.token_bytes()
user = models.User(name=name, imdb_id=user_id, secret=utils.phc_scrypt(secret))
async with db.transaction() as conn:
await db.add_or_update_user(conn, user)
user_data = {
"secret": b64encode(secret),
"user": models.asplain(user),
}
log.info("✨ User created: %a", user_data)
await close_connection_pool()
async def run_load_user_ratings_from_imdb():
await open_connection_pool()
i = 0
async for _ in refresh_user_ratings_from_imdb():
i += 1
log.info("✨ Imported %s new ratings.", i)
await close_connection_pool()
async def run_import_imdb_dataset(basics_path: Path, ratings_path: Path):
await open_connection_pool()
await import_from_file(basics_path=basics_path, ratings_path=ratings_path)
await close_connection_pool()
async def run_download_imdb_dataset(basics_path: Path, ratings_path: Path):
await download_datasets(basics_path=basics_path, ratings_path=ratings_path)
def getargs(): def getargs():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(prog="unwind", allow_abbrev=False)
commands = parser.add_subparsers(required=True) commands = parser.add_subparsers(title="commands", metavar="COMMAND", dest="mode")
parser_import_imdb_dataset = commands.add_parser( for module in cli.modules:
"import-imdb-dataset", help_, *descr = module.help.splitlines()
help="Import IMDb datasets.", cmd = commands.add_parser(
description=""" module.name,
Import IMDb datasets. help=help_,
New datasets available from https://www.imdb.com/interfaces/. description="\n".join(descr) or help_,
""", allow_abbrev=False,
) )
parser_import_imdb_dataset.add_argument( module.add_args(cmd)
dest="mode",
action="store_const",
const="import-imdb-dataset",
)
parser_import_imdb_dataset.add_argument(
"--basics", metavar="basics_file.tsv.gz", type=Path, required=True
)
parser_import_imdb_dataset.add_argument(
"--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True
)
parser_download_imdb_dataset = commands.add_parser(
"download-imdb-dataset",
help="Download IMDb datasets.",
description="""
Download IMDb datasets.
""",
)
parser_download_imdb_dataset.add_argument(
dest="mode",
action="store_const",
const="download-imdb-dataset",
)
parser_download_imdb_dataset.add_argument(
"--basics", metavar="basics_file.tsv.gz", type=Path, required=True
)
parser_download_imdb_dataset.add_argument(
"--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True
)
parser_load_user_ratings_from_imdb = commands.add_parser(
"load-user-ratings-from-imdb",
help="Load user ratings from imdb.com.",
description="""
Refresh user ratings for all registered users live from IMDb's website.
""",
)
parser_load_user_ratings_from_imdb.add_argument(
dest="mode",
action="store_const",
const="load-user-ratings-from-imdb",
)
parser_add_user = commands.add_parser(
"add-user",
help="Add a new user.",
description="""
Add a new user.
""",
)
parser_add_user.add_argument(
dest="mode",
action="store_const",
const="add-user",
)
parser_add_user.add_argument("--name", required=True)
parser_add_user.add_argument("--imdb-id", required=True)
parser_add_user.add_argument(
"--overwrite-existing",
action="store_true",
help="Allow overwriting an existing user. WARNING: This will reset the user's password!",
)
try: try:
args = parser.parse_args() args = parser.parse_args()
@ -151,6 +28,10 @@ def getargs():
parser.print_usage() parser.print_usage()
raise raise
if args.mode is None:
parser.print_help()
sys.exit(1)
return args return args
@ -158,23 +39,16 @@ def main():
logging.basicConfig( logging.basicConfig(
format="%(asctime)s.%(msecs)03d [%(name)s:%(process)d] %(levelname)s: %(message)s", format="%(asctime)s.%(msecs)03d [%(name)s:%(process)d] %(levelname)s: %(message)s",
datefmt="%H:%M:%S", datefmt="%H:%M:%S",
level=config.loglevel, # level=config.loglevel,
) )
log.setLevel(config.loglevel)
log.debug(f"Log level: {config.loglevel}") log.debug(f"Log level: {config.loglevel}")
try: args = getargs()
args = getargs()
except Exception:
return
if args.mode == "load-user-ratings-from-imdb": modes = {m.name: m.main for m in cli.modules}
asyncio.run(run_load_user_ratings_from_imdb()) if handler := modes.get(args.mode):
elif args.mode == "add-user": asyncio.run(handler(args))
asyncio.run(run_add_user(args.imdb_id, args.name, args.overwrite_existing))
elif args.mode == "import-imdb-dataset":
asyncio.run(run_import_imdb_dataset(args.basics, args.ratings))
elif args.mode == "download-imdb-dataset":
asyncio.run(run_download_imdb_dataset(args.basics, args.ratings))
main() main()

39
unwind/cli/__init__.py Normal file
View file

@ -0,0 +1,39 @@
import argparse
import importlib
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, Coroutine, Iterable, Protocol, TypeGuard
type CommandHandler = Callable[[argparse.Namespace], Coroutine[Any, Any, None]]
class CliModule(Protocol):
name: str
help: str
add_args: Callable[[argparse.ArgumentParser], None]
main: CommandHandler
def _is_cli_module(m: ModuleType) -> TypeGuard[CliModule]:
return (
hasattr(m, "name")
and hasattr(m, "help")
and hasattr(m, "add_args")
and hasattr(m, "main")
)
_clidir = Path(__file__).parent
def _load_cmds() -> Iterable[CliModule]:
"""Return all CLI command modules."""
for f in _clidir.iterdir():
if f.suffix == ".py" and not f.name.startswith("__"):
m = importlib.import_module(f"{__package__}.{f.stem}")
if not _is_cli_module(m):
raise ValueError(f"Invalid CLI module: {m!a}")
yield m
modules = sorted(_load_cmds(), key=lambda m: m.name)

56
unwind/cli/add_user.py Normal file
View file

@ -0,0 +1,56 @@
import argparse
import logging
import secrets
from unwind import db, models, utils
log = logging.getLogger(__name__)
name = "add-user"
help = "Add a new user."
def add_args(cmd: argparse.ArgumentParser) -> None:
cmd.add_argument("--name", required=True)
cmd.add_argument("--imdb-id", required=True)
cmd.add_argument(
"--overwrite-existing",
action="store_true",
help="Allow overwriting an existing user. WARNING: This will reset the user's password!",
)
async def main(args: argparse.Namespace) -> None:
user_id: str = args.imdb_id
name: str = args.name
overwrite_existing: bool = args.overwrite_existing
if not user_id.startswith("ur"):
raise ValueError(f"Invalid IMDb user ID: {user_id!a}")
await db.open_connection_pool()
async with db.new_connection() as conn:
user = await db.get(conn, models.User, imdb_id=user_id)
if user is not None:
if overwrite_existing:
log.warning("⚠️ Overwriting existing user: %a", user)
else:
log.error("❌ User already exists: %a", user)
return
secret = secrets.token_bytes()
user = models.User(name=name, imdb_id=user_id, secret=utils.phc_scrypt(secret))
async with db.transaction() as conn:
await db.add_or_update_user(conn, user)
user_data = {
"secret": utils.b64encode(secret),
"user": models.asplain(user),
}
log.info("✨ User created: %a", user_data)
await db.close_connection_pool()

View file

@ -0,0 +1,24 @@
import argparse
import logging
from pathlib import Path
from unwind.imdb_import import download_datasets
log = logging.getLogger(__name__)
name = "download-imdb-dataset"
help = "Download IMDb datasets."
def add_args(cmd: argparse.ArgumentParser) -> None:
cmd.add_argument("--basics", metavar="basics_file.tsv.gz", type=Path, required=True)
cmd.add_argument(
"--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True
)
async def main(args: argparse.Namespace) -> None:
basics_path: Path = args.basics
ratings_path: Path = args.ratings
await download_datasets(basics_path=basics_path, ratings_path=ratings_path)

View file

@ -0,0 +1,31 @@
import argparse
import logging
from pathlib import Path
from unwind import db
from unwind.imdb_import import import_from_file
log = logging.getLogger(__name__)
name = "import-imdb-dataset"
help = """Import IMDb datasets.
New datasets available from https://www.imdb.com/interfaces/.
"""
def add_args(cmd: argparse.ArgumentParser) -> None:
cmd.add_argument("--basics", metavar="basics_file.tsv.gz", type=Path, required=True)
cmd.add_argument(
"--ratings", metavar="ratings_file.tsv.gz", type=Path, required=True
)
async def main(args: argparse.Namespace) -> None:
basics_path: Path = args.basics
ratings_path: Path = args.ratings
await db.open_connection_pool()
await import_from_file(basics_path=basics_path, ratings_path=ratings_path)
await db.close_connection_pool()

View file

@ -0,0 +1,97 @@
import argparse
import logging
from typing import Callable
import sqlalchemy as sa
from unwind import db, imdb, models, types, utils
log = logging.getLogger(__name__)
name = "load-imdb-charts"
help = "Load and import charts from imdb.com."
def add_args(cmd: argparse.ArgumentParser) -> None:
cmd.add_argument(
"--select",
action="append",
dest="charts",
default=[],
choices={"top250", "bottom100", "pop100"},
help="Select which charts to refresh.",
)
async def get_movie_ids(
conn: db.Connection, imdb_ids: list[imdb.MovieId]
) -> dict[imdb.MovieId, types.ULID]:
c = models.movies.c
query = sa.select(c.imdb_id, c.id).where(c.imdb_id.in_(imdb_ids))
rows = await db.fetch_all(conn, query)
return {row.imdb_id: types.ULID(row.id) for row in rows}
async def remove_all_awards(
conn: db.Connection, category: models.AwardCategory
) -> None:
stmt = models.awards.delete().where(models.awards.c.category == category)
await conn.execute(stmt)
_award_handlers: dict[models.AwardCategory, Callable] = {
"imdb-pop-100": imdb.load_most_popular_100,
"imdb-top-250": imdb.load_top_250,
"imdb-bottom-100": imdb.load_bottom_100,
}
async def update_awards(conn: db.Connection, category: models.AwardCategory) -> None:
load_imdb_ids = _award_handlers[category]
imdb_ids = await load_imdb_ids()
available = await get_movie_ids(conn, imdb_ids)
if missing := set(imdb_ids).difference(available):
log.warning(
"⚠️ Charts for category (%a) contained %i unknown movies: %a",
category,
len(missing),
missing,
)
await remove_all_awards(conn, category=category)
for pos, imdb_id in enumerate(imdb_ids, 1):
if (movie_id := available.get(imdb_id)) is None:
continue
award = models.Award(
movie_id=movie_id,
category=category,
details=utils.json_dump({"position": pos}),
)
await db.add(conn, award)
async def main(args: argparse.Namespace) -> None:
await db.open_connection_pool()
if not args.charts:
args.charts = {"top250", "bottom100", "pop100"}
if "pop100" in args.charts:
async with db.transaction() as conn:
await update_awards(conn, "imdb-pop-100")
log.info("✨ Updated most popular 100 movies.")
if "bottom100" in args.charts:
async with db.transaction() as conn:
await update_awards(conn, "imdb-bottom-100")
log.info("✨ Updated bottom 100 movies.")
if "top250" in args.charts:
async with db.transaction() as conn:
await update_awards(conn, "imdb-top-250")
log.info("✨ Updated top 250 rated movies.")
await db.close_connection_pool()

View file

@ -0,0 +1,28 @@
import argparse
import logging
from unwind import db
from unwind.imdb import refresh_user_ratings_from_imdb
log = logging.getLogger(__name__)
name = "load-user-ratings-from-imdb"
help = """Load user ratings from imdb.com.
Refresh user ratings for all registered users live from IMDb's website.
"""
def add_args(cmd: argparse.ArgumentParser) -> None:
pass
async def main(args: argparse.Namespace) -> None:
await db.open_connection_pool()
i = 0
async for _ in refresh_user_ratings_from_imdb():
i += 1
log.info("✨ Imported %s new ratings.", i)
await db.close_connection_pool()

View file

@ -1,21 +1,25 @@
import contextlib import contextlib
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type, TypeVar from typing import Any, AsyncGenerator, Iterable, Literal, Sequence, Type
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.dialects.sqlite import insert
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
import alembic.command
import alembic.config
import alembic.migration
from . import config from . import config
from .models import ( from .models import (
Award,
Model, Model,
Movie, Movie,
Progress, Progress,
Rating, Rating,
User, User,
asplain, asplain,
db_patches, awards,
fromplain, fromplain,
metadata, metadata,
movies, movies,
@ -24,15 +28,33 @@ from .models import (
ratings, ratings,
utcnow, utcnow,
) )
from .types import ULID from .types import ULID, ImdbMovieId, UserIdStr
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
T = TypeVar("T")
_engine: AsyncEngine | None = None _engine: AsyncEngine | None = None
type Connection = AsyncConnection type Connection = AsyncConnection
_project_dir = Path(__file__).parent.parent
_alembic_ini = _project_dir / "alembic.ini"
def _init(conn: sa.Connection) -> None:
# See https://alembic.sqlalchemy.org/en/latest/cookbook.html#building-an-up-to-date-database-from-scratch
context = alembic.migration.MigrationContext.configure(conn)
heads = context.get_current_heads()
is_empty_db = not heads # We consider a DB empty if Alembic hasn't touched it yet.
if is_empty_db:
log.info("⚡️ Initializing empty database.")
metadata.create_all(conn)
# We pass our existing connection to Alembic's env.py, to avoid running another asyncio loop there.
alembic_cfg = alembic.config.Config(_alembic_ini)
alembic_cfg.attributes["connection"] = conn
alembic.command.stamp(alembic_cfg, "head")
async def open_connection_pool() -> None: async def open_connection_pool() -> None:
"""Open the DB connection pool. """Open the DB connection pool.
@ -41,11 +63,7 @@ async def open_connection_pool() -> None:
""" """
async with transaction() as conn: async with transaction() as conn:
await conn.execute(sa.text("PRAGMA journal_mode=WAL")) await conn.execute(sa.text("PRAGMA journal_mode=WAL"))
await conn.run_sync(_init)
await conn.run_sync(metadata.create_all, tables=[db_patches])
async with new_connection() as conn:
await apply_db_patches(conn)
async def close_connection_pool() -> None: async def close_connection_pool() -> None:
@ -65,65 +83,7 @@ async def close_connection_pool() -> None:
await engine.dispose() await engine.dispose()
async def current_patch_level(conn: Connection, /) -> str: async def vacuum(conn: Connection, /) -> None:
query = sa.select(db_patches.c.current)
current = await conn.scalar(query)
return current or ""
async def set_current_patch_level(conn: Connection, /, current: str) -> None:
stmt = insert(db_patches).values(id=1, current=current)
stmt = stmt.on_conflict_do_update(set_={"current": stmt.excluded.current})
await conn.execute(stmt)
db_patches_dir = Path(__file__).parent / "sql"
async def apply_db_patches(conn: Connection, /) -> None:
"""Apply all remaining patches to the database.
Beware that patches will be applied in lexicographical order,
i.e. "10" comes before "9".
The current patch state is recorded in the DB itself.
Please note that every SQL statement in a patch file MUST be terminated
using two consecutive semi-colons (;).
Failing to do so will result in an error.
"""
applied_lvl = await current_patch_level(conn)
did_patch = False
for patchfile in sorted(db_patches_dir.glob("*.sql"), key=lambda p: p.stem):
patch_lvl = patchfile.stem
if patch_lvl <= applied_lvl:
continue
log.info("Applying patch: %s", patch_lvl)
sql = patchfile.read_text()
queries = sql.split(";;")
if len(queries) < 2:
log.error(
"Patch file is missing statement terminator (`;;'): %s", patchfile
)
raise RuntimeError("No statement found.")
async with transacted(conn):
for query in queries:
await conn.execute(sa.text(query))
await set_current_patch_level(conn, patch_lvl)
did_patch = True
if did_patch:
await _vacuum(conn)
async def _vacuum(conn: Connection, /) -> None:
"""Vacuum the database. """Vacuum the database.
This function cannot be run on a connection with an open transaction. This function cannot be run on a connection with an open transaction.
@ -194,11 +154,13 @@ async def set_import_progress(conn: Connection, /, progress: float) -> Progress:
return current return current
def _new_engine() -> AsyncEngine: def _connection_uri() -> str:
uri = f"sqlite+aiosqlite:///{config.storage_path}" return f"sqlite+aiosqlite:///{config.storage_path}"
def _new_engine() -> AsyncEngine:
return create_async_engine( return create_async_engine(
uri, _connection_uri(),
isolation_level="SERIALIZABLE", isolation_level="SERIALIZABLE",
) )
@ -257,6 +219,13 @@ async def new_connection() -> AsyncGenerator[Connection, None]:
async def transacted( async def transacted(
conn: Connection, /, *, force_rollback: bool = False conn: Connection, /, *, force_rollback: bool = False
) -> AsyncGenerator[None, None]: ) -> AsyncGenerator[None, None]:
"""Start a transaction for the given connection.
If `force_rollback` is `True` any changes will be rolled back at the end of the
transaction, unless they are explicitly committed.
Nesting transactions is allowed, but mixing values for `force_rollback` will likely
yield unexpected results.
"""
transaction = contextlib.nullcontext() if conn.in_transaction() else conn.begin() transaction = contextlib.nullcontext() if conn.in_transaction() else conn.begin()
async with transaction: async with transaction:
@ -272,7 +241,7 @@ async def add(conn: Connection, /, item: Model) -> None:
# Support late initializing - used for optimization. # Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False): if getattr(item, "_is_lazy", False):
assert hasattr(item, "_lazy_init") assert hasattr(item, "_lazy_init")
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues] item._lazy_init() # pyright: ignore[reportAttributeAccessIssue]
table: sa.Table = item.__table__ table: sa.Table = item.__table__
values = asplain(item, serialize=True) values = asplain(item, serialize=True)
@ -294,17 +263,14 @@ async def fetch_one(
return result.first() return result.first()
ModelType = TypeVar("ModelType", bound=Model) async def get[T: Model](
async def get(
conn: Connection, conn: Connection,
/, /,
model: Type[ModelType], model: Type[T],
*, *,
order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None, order_by: tuple[sa.Column, Literal["asc", "desc"]] | None = None,
**field_values, **field_values,
) -> ModelType | None: ) -> T | None:
"""Load a model instance from the database. """Load a model instance from the database.
Passing `field_values` allows to filter the item to load. You have to encode the Passing `field_values` allows to filter the item to load. You have to encode the
@ -327,9 +293,9 @@ async def get(
return fromplain(model, row._mapping, serialized=True) if row else None return fromplain(model, row._mapping, serialized=True) if row else None
async def get_many( async def get_many[T: Model](
conn: Connection, /, model: Type[ModelType], **field_sets: set | list conn: Connection, /, model: Type[T], **field_sets: set | list
) -> Iterable[ModelType]: ) -> Iterable[T]:
"""Return the items with any values matching all given field sets. """Return the items with any values matching all given field sets.
This is similar to `get_all`, but instead of a scalar value a list of values This is similar to `get_all`, but instead of a scalar value a list of values
@ -346,9 +312,9 @@ async def get_many(
return (fromplain(model, row._mapping, serialized=True) for row in rows) return (fromplain(model, row._mapping, serialized=True) for row in rows)
async def get_all( async def get_all[T: Model](
conn: Connection, /, model: Type[ModelType], **field_values conn: Connection, /, model: Type[T], **field_values
) -> Iterable[ModelType]: ) -> Iterable[T]:
"""Filter all items by comparing all given field values. """Filter all items by comparing all given field values.
If no filters are given, all items will be returned. If no filters are given, all items will be returned.
@ -365,7 +331,7 @@ async def update(conn: Connection, /, item: Model) -> None:
# Support late initializing - used for optimization. # Support late initializing - used for optimization.
if getattr(item, "_is_lazy", False): if getattr(item, "_is_lazy", False):
assert hasattr(item, "_lazy_init") assert hasattr(item, "_lazy_init")
item._lazy_init() # pyright: ignore [reportGeneralTypeIssues] item._lazy_init() # pyright: ignore[reportAttributeAccessIssue]
table: sa.Table = item.__table__ table: sa.Table = item.__table__
values = asplain(item, serialize=True) values = asplain(item, serialize=True)
@ -466,6 +432,23 @@ async def add_or_update_rating(conn: Connection, /, rating: Rating) -> bool:
return False return False
async def get_awards(
conn: Connection, /, imdb_ids: list[ImdbMovieId]
) -> dict[ImdbMovieId, list[Award]]:
query = (
sa.select(Award, movies.c.imdb_id)
.join(movies, awards.c.movie_id == movies.c.id)
.where(movies.c.imdb_id.in_(imdb_ids))
)
rows = await fetch_all(conn, query)
awards_dict: dict[ImdbMovieId, list[Award]] = {}
for row in rows:
awards_dict.setdefault(row.imdb_id, []).append(
fromplain(Award, row._mapping, serialized=True)
)
return awards_dict
def sql_escape(s: str, char: str = "#") -> str: def sql_escape(s: str, char: str = "#") -> str:
return s.replace(char, 2 * char).replace("%", f"{char}%").replace("_", f"{char}_") return s.replace(char, 2 * char).replace("%", f"{char}%").replace("_", f"{char}_")
@ -481,7 +464,7 @@ async def find_ratings(
include_unrated: bool = False, include_unrated: bool = False,
yearcomp: tuple[Literal["<", "=", ">"], int] | None = None, yearcomp: tuple[Literal["<", "=", ">"], int] | None = None,
limit_rows: int = 10, limit_rows: int = 10,
user_ids: Iterable[str] = [], user_ids: Iterable[UserIdStr] = [],
) -> Iterable[dict[str, Any]]: ) -> Iterable[dict[str, Any]]:
conditions = [] conditions = []

View file

@ -12,6 +12,7 @@ import bs4
from . import db from . import db
from .models import Movie, Rating, User from .models import Movie, Rating, User
from .request import adownload, asession, asoup_from_url, cache_path from .request import adownload, asession, asoup_from_url, cache_path
from .utils import json_dump
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -355,9 +356,12 @@ async def _load_ratings_page_legacy(url: str, soup: bs4.BeautifulSoup) -> _Ratin
return page return page
async def load_and_store_ratings( async def load_and_store_ratings(user_id: UserId) -> AsyncIterable[tuple[Rating, bool]]:
user_id: UserId, """Load user ratings from imdb.com and store them in our database.
) -> AsyncIterable[tuple[Rating, bool]]:
All loaded ratings are yielded together with the information whether each rating
was already present in our database.
"""
async with db.new_connection() as conn: async with db.new_connection() as conn:
user = await db.get(conn, User, imdb_id=user_id) or User( user = await db.get(conn, User, imdb_id=user_id) or User(
imdb_id=user_id, name="", secret="" imdb_id=user_id, name="", secret=""
@ -385,6 +389,7 @@ async def load_and_store_ratings(
async def load_ratings(user_id: UserId) -> AsyncIterable[Rating]: async def load_ratings(user_id: UserId) -> AsyncIterable[Rating]:
"""Return all ratings for the given user from imdb.com."""
next_url = user_ratings_url(user_id) next_url = user_ratings_url(user_id)
while next_url: while next_url:
@ -443,13 +448,15 @@ async def load_top_250() -> list[MovieId]:
qgl_api_url = "https://caching.graphql.imdb.com/" qgl_api_url = "https://caching.graphql.imdb.com/"
query = { query = {
"operationName": "Top250MoviesPagination", "operationName": "Top250MoviesPagination",
"variables": {"first": 250, "locale": "en-US"}, "variables": json_dump({"first": 250, "locale": "en-US"}),
"extensions": { "extensions": json_dump(
"persistedQuery": { {
"sha256Hash": "26114ee01d97e04f65d6c8c7212ae8b7888fa57ceed105450d1fce09df749b2d", "persistedQuery": {
"version": 1, "sha256Hash": "26114ee01d97e04f65d6c8c7212ae8b7888fa57ceed105450d1fce09df749b2d",
"version": 1,
}
} }
}, ),
} }
headers = { headers = {
"accept": "application/graphql+json, application/json", "accept": "application/graphql+json, application/json",

View file

@ -5,7 +5,7 @@ import logging
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Generator, Literal, Type, TypeVar, overload from typing import Generator, Literal, Type, overload
from . import config, db, request from . import config, db, request
from .db import add_or_update_many_movies from .db import add_or_update_many_movies
@ -14,8 +14,6 @@ from .models import Movie
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
T = TypeVar("T")
# See # See
# - https://developer.imdb.com/non-commercial-datasets/ # - https://developer.imdb.com/non-commercial-datasets/
# - https://datasets.imdbws.com/ # - https://datasets.imdbws.com/
@ -127,7 +125,7 @@ def read_imdb_tsv(
@overload @overload
def read_imdb_tsv( def read_imdb_tsv[T](
path: Path, row_type: Type[T], *, unpack: Literal[True] = True path: Path, row_type: Type[T], *, unpack: Literal[True] = True
) -> Generator[T, None, None]: ... ) -> Generator[T, None, None]: ...

View file

@ -2,7 +2,6 @@ import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from dataclasses import fields as _fields from dataclasses import fields as _fields
from datetime import datetime, timezone from datetime import datetime, timezone
from functools import partial
from types import UnionType from types import UnionType
from typing import ( from typing import (
Annotated, Annotated,
@ -11,24 +10,33 @@ from typing import (
Container, Container,
Literal, Literal,
Mapping, Mapping,
NewType,
Protocol, Protocol,
Type, Type,
TypeAliasType,
TypedDict, TypedDict,
TypeVar,
Union, Union,
get_args, get_args,
get_origin, get_origin,
) )
from sqlalchemy import Column, ForeignKey, Integer, String, Table from sqlalchemy import Column, ForeignKey, Index, Integer, String, Table
from sqlalchemy.orm import registry from sqlalchemy.orm import registry
from .types import ULID from .types import (
ULID,
JSON = int | float | str | None | list["JSON"] | dict[str, "JSON"] AwardId,
JSONObject = dict[str, JSON] GroupId,
ImdbMovieId,
T = TypeVar("T") JSONObject,
JSONScalar,
MovieId,
RatingId,
Score100,
UserId,
UserIdStr,
)
from .utils import json_dump
class Model(Protocol): class Model(Protocol):
@ -38,8 +46,22 @@ class Model(Protocol):
mapper_registry = registry() mapper_registry = registry()
metadata = mapper_registry.metadata metadata = mapper_registry.metadata
# An explicit naming convention helps Alembic do its job,
# see https://alembic.sqlalchemy.org/en/latest/naming.html.
metadata.naming_convention = {
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
def annotations(tp: Type) -> tuple | None: def annotations(tp: Type) -> tuple | None:
# Support type aliases and generic aliases.
if isinstance(tp, TypeAliasType) or hasattr(tp, "__value__"):
tp = tp.__value__
return tp.__metadata__ if hasattr(tp, "__metadata__") else None # type: ignore return tp.__metadata__ if hasattr(tp, "__metadata__") else None # type: ignore
@ -97,13 +119,24 @@ def optional_fields(o):
yield f yield f
json_dump = partial(json.dumps, separators=(",", ":")) def _id[T](x: T) -> T:
"""Return the given argument, aka. the identity function."""
def _id(x: T) -> T:
return x return x
def _unpack(type_: Any) -> Any:
"""Return the wrapped type."""
# Handle type aliases.
if isinstance(type_, TypeAliasType):
return _unpack(type_.__value__)
# Handle newtypes.
if isinstance(type_, NewType):
return _unpack(type_.__supertype__)
return type_
def asplain( def asplain(
o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False
) -> dict[str, Any]: ) -> dict[str, Any]:
@ -125,13 +158,16 @@ def asplain(
if filter_fields is not None and f.name not in filter_fields: if filter_fields is not None and f.name not in filter_fields:
continue continue
target: Any = f.type target: Any = _unpack(f.type)
# XXX this doesn't properly support any kind of nested types # XXX this doesn't properly support any kind of nested types
if (otype := optional_type(f.type)) is not None: if (otype := optional_type(f.type)) is not None:
target = otype target = otype
if (otype := get_origin(target)) is not None: if (otype := get_origin(target)) is not None:
target = otype target = otype
target = _unpack(target)
v = getattr(o, f.name) v = getattr(o, f.name)
if is_optional(f.type) and v is None: if is_optional(f.type) and v is None:
d[f.name] = None d[f.name] = None
@ -150,26 +186,31 @@ def asplain(
elif target in {bool, str, int, float}: elif target in {bool, str, int, float}:
assert isinstance( assert isinstance(
v, target v, target
), f"Type mismatch: {f.name} ({target} != {type(v)})" ), f"Type mismatch: {f.name!a} ({target!a} != {type(v)!a})"
d[f.name] = v
elif target in {Literal}:
assert isinstance(v, JSONScalar.__value__)
d[f.name] = v d[f.name] = v
else: else:
raise ValueError(f"Unsupported value type: {f.name}: {type(v)}") raise ValueError(f"Unsupported value type: {f.name!a}: {type(v)!a}")
return d return d
def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T: def fromplain[T](cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
"""Return an instance of the given model using the given data. """Return an instance of the given model using the given data.
If `serialized` is `True`, collection types (lists, dicts, etc.) will be If `serialized` is `True`, collection types (lists, dicts, etc.) will be
deserialized from string. This is the opposite operation of `serialize` for deserialized from string. This is the opposite operation of `serialize` for
`asplain`. `asplain`.
Fields in the data that cannot be mapped to the given type are simply ignored.
""" """
load = json.loads if serialized else _id load = json.loads if serialized else _id
dd: JSONObject = {} dd: JSONObject = {}
for f in fields(cls): for f in fields(cls):
target: Any = f.type target: Any = _unpack(f.type)
otype = optional_type(f.type) otype = optional_type(f.type)
is_opt = otype is not None is_opt = otype is not None
if is_opt: if is_opt:
@ -177,9 +218,17 @@ def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
if (xtype := get_origin(target)) is not None: if (xtype := get_origin(target)) is not None:
target = xtype target = xtype
target = _unpack(target)
v = d[f.name] v = d[f.name]
if is_opt and v is None: if is_opt and v is None:
dd[f.name] = v dd[f.name] = v
elif target is Literal:
# Support literal types.
vals = get_args(f.type.__value__)
if v not in vals:
raise ValueError(f"Invalid value: {f.name!a}: {v!a}")
dd[f.name] = v
elif isinstance(v, target): elif isinstance(v, target):
dd[f.name] = v dd[f.name] = v
elif target in {set, list}: elif target in {set, list}:
@ -196,27 +245,38 @@ def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
def validate(o: object) -> None: def validate(o: object) -> None:
for f in fields(o): for f in fields(o):
vtype = type(getattr(o, f.name)) ftype = _unpack(f.type)
if vtype is f.type:
v = getattr(o, f.name)
vtype = type(v)
if vtype is ftype:
continue continue
origin = get_origin(f.type) origin = get_origin(ftype)
if origin is vtype: if origin is vtype:
continue continue
is_union = isinstance(f.type, UnionType) or origin is Union is_union = isinstance(ftype, UnionType) or origin is Union
if is_union: if is_union:
# Support unioned types. # Support unioned types.
utypes = get_args(f.type) utypes = get_args(ftype)
utypes = [_unpack(t) for t in utypes]
if vtype in utypes: if vtype in utypes:
continue continue
# Support generic types (set[str], list[int], etc.) # Support generic types (set[str], list[int], etc.)
gtypes = [g for u in utypes if (g := get_origin(u)) is not None] gtypes = [_unpack(g) for u in utypes if (g := get_origin(u)) is not None]
if any(vtype is gtype for gtype in gtypes): if any(vtype is gtype for gtype in gtypes):
continue continue
raise ValueError(f"Invalid value type: {f.name}: {vtype}") if origin is Literal:
# Support literal types.
vals = get_args(ftype)
if v in vals:
continue
raise ValueError(f"Invalid value: {f.name!a}: {v!a}")
raise ValueError(f"Invalid value type: {f.name!a}: {vtype!a}")
def utcnow() -> datetime: def utcnow() -> datetime:
@ -224,23 +284,6 @@ def utcnow() -> datetime:
return datetime.now(timezone.utc) return datetime.now(timezone.utc)
@mapper_registry.mapped
@dataclass
class DbPatch:
__table__: ClassVar[Table] = Table(
"db_patches",
metadata,
Column("id", Integer, primary_key=True),
Column("current", String),
)
id: int
current: str
db_patches = DbPatch.__table__
@mapper_registry.mapped @mapper_registry.mapped
@dataclass @dataclass
class Progress: class Progress:
@ -312,15 +355,15 @@ class Movie:
Column("updated", String, nullable=False), # datetime Column("updated", String, nullable=False), # datetime
) )
id: ULID = field(default_factory=ULID) id: MovieId = field(default_factory=ULID)
title: str = None # canonical title (usually English) title: str = None # canonical title (usually English)
original_title: str | None = ( original_title: str | None = (
None # original title (usually transscribed to latin script) None # original title (usually transscribed to latin script)
) )
release_year: int = None # canonical release date release_year: int = None # canonical release date
media_type: str = None media_type: str = None
imdb_id: str = None imdb_id: ImdbMovieId = None
imdb_score: int | None = None # range: [0,100] imdb_score: Score100 | None = None # range: [0,100]
imdb_votes: int | None = None imdb_votes: int | None = None
runtime: int | None = None # minutes runtime: int | None = None # minutes
genres: set[str] | None = None genres: set[str] | None = None
@ -365,10 +408,10 @@ dataclass containing the ID of the linked data.
The contents of the Relation are ignored or discarded when using The contents of the Relation are ignored or discarded when using
`asplain`, `fromplain`, and `validate`. `asplain`, `fromplain`, and `validate`.
""" """
Relation = Annotated[T | None, _RelationSentinel] type Relation[T] = Annotated[T | None, _RelationSentinel]
Access = Literal[ type Access = Literal[
"r", # read "r", # read
"i", # index "i", # index
"w", # write "w", # write
@ -393,8 +436,8 @@ class User:
Column("groups", String, nullable=False), # JSON array Column("groups", String, nullable=False), # JSON array
) )
id: ULID = field(default_factory=ULID) id: UserId = field(default_factory=ULID)
imdb_id: str = None imdb_id: ImdbMovieId = None
name: str = None # canonical user name name: str = None # canonical user name
secret: str = None secret: str = None
groups: list[UserGroup] = field(default_factory=list) groups: list[UserGroup] = field(default_factory=list)
@ -413,6 +456,9 @@ class User:
self.groups.append({"id": group_id, "access": access}) self.groups.append({"id": group_id, "access": access})
users = User.__table__
@mapper_registry.mapped @mapper_registry.mapped
@dataclass @dataclass
class Rating: class Rating:
@ -428,15 +474,15 @@ class Rating:
Column("finished", Integer), # bool Column("finished", Integer), # bool
) )
id: ULID = field(default_factory=ULID) id: RatingId = field(default_factory=ULID)
movie_id: ULID = None movie_id: MovieId = None
movie: Relation[Movie] = None movie: Relation[Movie] = None
user_id: ULID = None user_id: UserId = None
user: Relation[User] = None user: Relation[User] = None
score: int = None # range: [0,100] score: Score100 = None # range: [0,100]
rating_date: datetime = None rating_date: datetime = None
favorite: bool | None = None favorite: bool | None = None
finished: bool | None = None finished: bool | None = None
@ -455,10 +501,11 @@ class Rating:
ratings = Rating.__table__ ratings = Rating.__table__
Index("ratings_index", ratings.c.movie_id, ratings.c.user_id, unique=True)
class GroupUser(TypedDict): class GroupUser(TypedDict):
id: str id: UserIdStr
name: str name: str
@ -473,6 +520,62 @@ class Group:
Column("users", String, nullable=False), # JSON array Column("users", String, nullable=False), # JSON array
) )
id: ULID = field(default_factory=ULID) id: GroupId = field(default_factory=ULID)
name: str = None name: str = None
users: list[GroupUser] = field(default_factory=list) users: list[GroupUser] = field(default_factory=list)
type AwardCategory = Literal[
"imdb-top-250", "imdb-bottom-100", "imdb-pop-100", "oscars"
]
@mapper_registry.mapped
@dataclass
class Award:
__table__: ClassVar[Table] = Table(
"awards",
metadata,
Column("id", String, primary_key=True), # ULID
Column("movie_id", ForeignKey("movies.id"), nullable=False), # ULID
Column(
"category", String, nullable=False
), # Enum: "imdb-top-250", "imdb-bottom-100", "imdb-pop-100", "oscars", ...
Column(
"details", String, nullable=False
), # e.g. "23" (position in list), "2024, nominee, best director", "1977, winner, best picture", ...
Column("created", String, nullable=False), # datetime
Column("updated", String, nullable=False), # datetime
)
id: AwardId = field(default_factory=ULID)
movie_id: MovieId = None
movie: Relation[Movie] = None
category: AwardCategory = None
details: str = None
created: datetime = field(default_factory=utcnow)
updated: datetime = field(default_factory=utcnow)
@property
def _details(self) -> JSONObject:
return json.loads(self.details or "{}")
@_details.setter
def _details(self, details: JSONObject):
self.details = json_dump(details)
@property
def position(self) -> int:
return self._details["position"]
@position.setter
def position(self, position: int):
details = self._details
details["position"] = position
self._details = details
awards = Award.__table__

View file

@ -11,7 +11,7 @@ from hashlib import md5
from pathlib import Path from pathlib import Path
from random import random from random import random
from time import sleep, time from time import sleep, time
from typing import Any, Callable, ParamSpec, TypeVar, cast, overload from typing import Any, Callable, cast, overload
import bs4 import bs4
import httpx import httpx
@ -24,13 +24,10 @@ if config.debug and config.cachedir:
config.cachedir.mkdir(exist_ok=True) config.cachedir.mkdir(exist_ok=True)
_shared_asession = None
_ASession_T = httpx.AsyncClient _ASession_T = httpx.AsyncClient
_Response_T = httpx.Response type _Response_T = httpx.Response
_T = TypeVar("_T") _shared_asession: _ASession_T | None = None
_P = ParamSpec("_P")
@asynccontextmanager @asynccontextmanager
@ -59,17 +56,17 @@ async def asession():
_shared_asession = None _shared_asession = None
def _throttle( def _throttle[T, **P](
times: int, per_seconds: float, jitter: Callable[[], float] | None = None times: int, per_seconds: float, jitter: Callable[[], float] | None = None
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ) -> Callable[[Callable[P, T]], Callable[P, T]]:
calls: deque[float] = deque(maxlen=times) calls: deque[float] = deque(maxlen=times)
if jitter is None: if jitter is None:
jitter = lambda: 0.0 # noqa: E731 jitter = lambda: 0.0 # noqa: E731
def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]: def decorator(func: Callable[P, T]) -> Callable[P, T]:
@wraps(func) @wraps(func)
def inner(*args: _P.args, **kwds: _P.kwargs): def inner(*args: P.args, **kwds: P.kwargs):
# clean up # clean up
while calls: while calls:
if calls[0] + per_seconds > time(): if calls[0] + per_seconds > time():

View file

@ -1,36 +0,0 @@
PRAGMA foreign_keys = ON;;
CREATE TABLE IF NOT EXISTS users (
id TEXT NOT NULL PRIMARY KEY,
imdb_id TEXT NOT NULL UNIQUE,
name TEXT NOT NULL
);;
CREATE TABLE IF NOT EXISTS movies (
id TEXT NOT NULL PRIMARY KEY,
title TEXT NOT NULL,
release_year NUMBER NOT NULL,
media_type TEXT NOT NULL,
imdb_id TEXT NOT NULL UNIQUE,
score NUMBER NOT NULL,
runtime NUMBER,
genres TEXT NOT NULL,
updated TEXT NOT NULL
);;
CREATE TABLE IF NOT EXISTS ratings (
id TEXT NOT NULL PRIMARY KEY,
movie_id TEXT NOT NULL,
user_id TEXT NOT NULL,
score NUMBER NOT NULL,
rating_date TEXT NOT NULL,
favorite NUMBER,
finished NUMBER,
FOREIGN KEY(movie_id) REFERENCES movies(id),
FOREIGN KEY(user_id) REFERENCES users(id)
);;
CREATE UNIQUE INDEX IF NOT EXISTS ratings_index ON ratings (
movie_id,
user_id
);;

View file

@ -1,40 +0,0 @@
-- add original_title to movies table
-- see https://www.sqlite.org/lang_altertable.html#caution
-- 1. Create new table
-- 2. Copy data
-- 3. Drop old table
-- 4. Rename new into old
CREATE TABLE _migrate_movies (
id TEXT NOT NULL PRIMARY KEY,
title TEXT NOT NULL,
original_title TEXT,
release_year NUMBER NOT NULL,
media_type TEXT NOT NULL,
imdb_id TEXT NOT NULL UNIQUE,
score NUMBER,
runtime NUMBER,
genres TEXT NOT NULL,
updated TEXT NOT NULL
);;
INSERT INTO _migrate_movies
SELECT
id,
title,
NULL,
release_year,
media_type,
imdb_id,
score,
runtime,
genres,
updated
FROM movies
WHERE true;;
DROP TABLE movies;;
ALTER TABLE _migrate_movies
RENAME TO movies;;

View file

@ -1,46 +0,0 @@
-- only set original_title if it differs from title,
-- and normalize media_type with an extra table.
CREATE TABLE mediatypes (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL UNIQUE
);;
INSERT INTO mediatypes (name)
SELECT DISTINCT media_type
FROM movies
WHERE true;;
CREATE TABLE _migrate_movies (
id TEXT PRIMARY KEY NOT NULL,
title TEXT NOT NULL,
original_title TEXT,
release_year INTEGER NOT NULL,
media_type_id INTEGER NOT NULL,
imdb_id TEXT NOT NULL UNIQUE,
score INTEGER,
runtime INTEGER,
genres TEXT NOT NULL,
updated TEXT NOT NULL,
FOREIGN KEY(media_type_id) REFERENCES mediatypes(id)
);;
INSERT INTO _migrate_movies
SELECT
id,
title,
(CASE WHEN original_title=title THEN NULL ELSE original_title END),
release_year,
(SELECT id FROM mediatypes WHERE name=media_type) AS media_type_id,
imdb_id,
score,
runtime,
genres,
updated
FROM movies
WHERE true;;
DROP TABLE movies;;
ALTER TABLE _migrate_movies
RENAME TO movies;;

View file

@ -1,62 +0,0 @@
-- add convenient view for movies
CREATE VIEW IF NOT EXISTS movies_view
AS SELECT
movies.id,
movies.title,
movies.original_title,
movies.release_year,
mediatypes.name AS media_type,
movies.imdb_id,
movies.score,
movies.runtime,
movies.genres,
movies.updated
FROM movies
JOIN mediatypes ON mediatypes.id=movies.media_type_id;;
CREATE TRIGGER IF NOT EXISTS insert_movies_view
INSTEAD OF INSERT
ON movies_view
BEGIN
INSERT INTO movies (
id,
title,
original_title,
release_year,
media_type_id,
imdb_id,
score,
runtime,
genres,
updated
) VALUES (
NEW.id,
NEW.title,
NEW.original_title,
NEW.release_year,
(SELECT id FROM mediatypes WHERE name=NEW.media_type),
NEW.imdb_id,
NEW.score,
NEW.runtime,
NEW.genres,
NEW.updated
);
END;;
CREATE TRIGGER IF NOT EXISTS update_movies_view
INSTEAD OF UPDATE OF media_type
ON movies_view
BEGIN
UPDATE movies
SET media_type_id=(SELECT id FROM mediatypes WHERE name=NEW.media_type)
WHERE id=OLD.id;
END;;
CREATE TRIGGER IF NOT EXISTS delete_movies_view
INSTEAD OF DELETE
ON movies_view
BEGIN
DELETE FROM movies
WHERE movies.id=OLD.id;
END;;

View file

@ -1,37 +0,0 @@
-- denormalize movie media_type
CREATE TABLE _migrate_movies (
id TEXT PRIMARY KEY NOT NULL,
title TEXT NOT NULL,
original_title TEXT,
release_year INTEGER NOT NULL,
media_type TEXT NOT NULL,
imdb_id TEXT NOT NULL UNIQUE,
score INTEGER,
runtime INTEGER,
genres TEXT NOT NULL,
updated TEXT NOT NULL
);;
INSERT INTO _migrate_movies
SELECT
id,
title,
original_title,
release_year,
(SELECT name FROM mediatypes WHERE id=media_type_id) AS media_type,
imdb_id,
score,
runtime,
genres,
updated
FROM movies
WHERE true;;
DROP VIEW movies_view;;
DROP TABLE mediatypes;;
DROP TABLE movies;;
ALTER TABLE _migrate_movies
RENAME TO movies;;

View file

@ -1,2 +0,0 @@
-- see the commit of this file for details.
;;

View file

@ -1,8 +0,0 @@
-- add groups table
CREATE TABLE groups (
id TEXT PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
secret TEXT NOT NULL,
users TEXT NOT NULL -- JSON array
);;

View file

@ -1,7 +0,0 @@
-- add progress table
CREATE TABLE progress (
id TEXT PRIMARY KEY NOT NULL,
state TEXT NOT NULL,
started TEXT NOT NULL
);;

View file

@ -1,36 +0,0 @@
-- add IMDb vote count
CREATE TABLE _migrate_movies (
id TEXT PRIMARY KEY NOT NULL,
title TEXT NOT NULL,
original_title TEXT,
release_year INTEGER NOT NULL,
media_type TEXT NOT NULL,
imdb_id TEXT NOT NULL UNIQUE,
imdb_score INTEGER,
imdb_votes INTEGER,
runtime INTEGER,
genres TEXT NOT NULL,
updated TEXT NOT NULL
);;
INSERT INTO _migrate_movies
SELECT
id,
title,
original_title,
release_year,
media_type,
imdb_id,
score AS imdb_score,
NULL AS imdb_votes,
runtime,
genres,
updated
FROM movies
WHERE true;;
DROP TABLE movies;;
ALTER TABLE _migrate_movies
RENAME TO movies;;

View file

@ -1,24 +0,0 @@
-- add IMDb vote count
CREATE TABLE _migrate_progress (
id TEXT PRIMARY KEY NOT NULL,
type TEXT NOT NULL,
state TEXT NOT NULL,
started TEXT NOT NULL,
stopped TEXT
);;
INSERT INTO _migrate_progress
SELECT
id,
'import-imdb-movies' AS type,
state,
started,
NULL AS stopped
FROM progress
WHERE true;;
DROP TABLE progress;;
ALTER TABLE _migrate_progress
RENAME TO progress;;

View file

@ -1,38 +0,0 @@
-- add creation timestamp to movies
CREATE TABLE _migrate_movies (
id TEXT PRIMARY KEY NOT NULL,
title TEXT NOT NULL,
original_title TEXT,
release_year INTEGER NOT NULL,
media_type TEXT NOT NULL,
imdb_id TEXT NOT NULL UNIQUE,
imdb_score INTEGER,
imdb_votes INTEGER,
runtime INTEGER,
genres TEXT NOT NULL,
created TEXT NOT NULL,
updated TEXT NOT NULL
);;
INSERT INTO _migrate_movies
SELECT
id,
title,
original_title,
release_year,
media_type,
imdb_id,
imdb_score,
imdb_votes,
runtime,
genres,
updated AS created,
updated
FROM movies
WHERE true;;
DROP TABLE movies;;
ALTER TABLE _migrate_movies
RENAME TO movies;;

View file

@ -1,24 +0,0 @@
-- add IMDb vote count
CREATE TABLE _migrate_progress (
id TEXT PRIMARY KEY NOT NULL,
type TEXT NOT NULL,
state TEXT NOT NULL,
started TEXT NOT NULL,
stopped TEXT
);;
INSERT INTO _migrate_progress
SELECT
id,
type,
'{"percent":' || state || '}' AS state,
started,
stopped
FROM progress
WHERE true;;
DROP TABLE progress;;
ALTER TABLE _migrate_progress
RENAME TO progress;;

View file

@ -1,22 +0,0 @@
-- add secret to users
CREATE TABLE _migrate_users (
id TEXT PRIMARY KEY NOT NULL,
imdb_id TEXT NOT NULL UNIQUE,
name TEXT NOT NULL,
secret TEXT NOT NULL
);;
INSERT INTO _migrate_users
SELECT
id,
imdb_id,
name,
'' AS secret
FROM users
WHERE true;;
DROP TABLE users;;
ALTER TABLE _migrate_users
RENAME TO users;;

View file

@ -1,45 +0,0 @@
-- add group admins
--- remove secrets from groups
CREATE TABLE _migrate_groups (
id TEXT PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
users TEXT NOT NULL -- JSON array
);;
INSERT INTO _migrate_groups
SELECT
id,
name,
users
FROM groups
WHERE true;;
DROP TABLE groups;;
ALTER TABLE _migrate_groups
RENAME TO groups;;
--- add group access to users
CREATE TABLE _migrate_users (
id TEXT PRIMARY KEY NOT NULL,
imdb_id TEXT NOT NULL UNIQUE,
name TEXT NOT NULL,
secret TEXT NOT NULL,
groups TEXT NOT NULL -- JSON array
);;
INSERT INTO _migrate_users
SELECT
id,
imdb_id,
name,
secret,
'[]' AS groups
FROM users
WHERE true;;
DROP TABLE users;;
ALTER TABLE _migrate_users
RENAME TO users;;

View file

@ -1,38 +0,0 @@
-- remove NOTNULL constraint from movies.genres
CREATE TABLE _migrate_movies (
id TEXT PRIMARY KEY NOT NULL,
title TEXT NOT NULL,
original_title TEXT,
release_year INTEGER NOT NULL,
media_type TEXT NOT NULL,
imdb_id TEXT NOT NULL UNIQUE,
imdb_score INTEGER,
imdb_votes INTEGER,
runtime INTEGER,
genres TEXT,
created TEXT NOT NULL,
updated TEXT NOT NULL
);;
INSERT INTO _migrate_movies
SELECT
id,
title,
original_title,
release_year,
media_type,
imdb_id,
imdb_score,
imdb_votes,
runtime,
genres,
created,
updated
FROM movies
WHERE true;;
DROP TABLE movies;;
ALTER TABLE _migrate_movies
RENAME TO movies;;

View file

@ -1,9 +1,13 @@
import re import re
from typing import cast from typing import NewType, cast
import ulid import ulid
from ulid.hints import Buffer from ulid.hints import Buffer
type JSONScalar = int | float | str | None
type JSON = JSONScalar | list["JSON"] | dict[str, "JSON"]
type JSONObject = dict[str, JSON]
class ULID(ulid.ULID): class ULID(ulid.ULID):
"""Extended ULID type. """Extended ULID type.
@ -29,3 +33,14 @@ class ULID(ulid.ULID):
buffer = cast(memoryview, ulid.new().memory) buffer = cast(memoryview, ulid.new().memory)
super().__init__(buffer) super().__init__(buffer)
AwardId = NewType("AwardId", ULID)
GroupId = NewType("GroupId", ULID)
ImdbMovieId = NewType("ImdbMovieId", str)
MovieId = NewType("MovieId", ULID)
MovieIdStr = NewType("MovieIdStr", str)
RatingId = NewType("RatingId", ULID)
Score100 = NewType("Score100", int) # [0, 100]
UserId = NewType("UserId", ULID)
UserIdStr = NewType("UserIdStr", str)

View file

@ -1,8 +1,12 @@
import base64 import base64
import hashlib import hashlib
import json
import secrets import secrets
from functools import partial
from typing import Any, TypedDict from typing import Any, TypedDict
json_dump = partial(json.dumps, separators=(",", ":"))
def b64encode(b: bytes) -> str: def b64encode(b: bytes) -> str:
return base64.b64encode(b).decode().rstrip("=") return base64.b64encode(b).decode().rstrip("=")

View file

@ -3,7 +3,7 @@ import contextlib
import logging import logging
import secrets import secrets
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from typing import Literal, overload from typing import Any, Literal, Never, TypeGuard, overload
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.authentication import ( from starlette.authentication import (
@ -20,15 +20,15 @@ from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.gzip import GZipMiddleware from starlette.middleware.gzip import GZipMiddleware
from starlette.requests import HTTPConnection from starlette.requests import HTTPConnection, Request
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from starlette.routing import Mount, Route from starlette.routing import Mount, Route
from . import config, db, imdb, imdb_import, web_models from . import config, db, imdb, imdb_import, web_models
from .db import close_connection_pool, find_movies, find_ratings, open_connection_pool from .db import close_connection_pool, find_movies, find_ratings, open_connection_pool
from .middleware.responsetime import ResponseTimeMiddleware from .middleware.responsetime import ResponseTimeMiddleware
from .models import Group, Movie, User, asplain from .models import Access, Group, Movie, User, asplain
from .types import ULID from .types import JSON, ULID
from .utils import b64decode, b64encode, phc_compare, phc_scrypt from .utils import b64decode, b64encode, phc_compare, phc_scrypt
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -83,11 +83,11 @@ class BearerAuthBackend(AuthenticationBackend):
return AuthCredentials(["authenticated", *roles]), user return AuthCredentials(["authenticated", *roles]), user
def truthy(s: str): def truthy(s: str | None) -> bool:
return bool(s) and s.lower() in {"1", "yes", "true"} return bool(s) and s.lower() in {"1", "yes", "true"}
_Yearcomp = Literal["<", "=", ">"] type _Yearcomp = Literal["<", "=", ">"]
def yearcomp(s: str) -> tuple[_Yearcomp, int] | None: def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
@ -103,7 +103,7 @@ def yearcomp(s: str) -> tuple[_Yearcomp, int] | None:
def as_int( def as_int(
x, *, max: int | None = None, min: int | None = 1, default: int | None = None x: Any, *, max: int | None = None, min: int | None = 1, default: int | None = None
) -> int: ) -> int:
try: try:
if not isinstance(x, int): if not isinstance(x, int):
@ -121,9 +121,9 @@ def as_int(
return default return default
def as_ulid(s: str) -> ULID: def as_ulid(s: Any) -> ULID:
try: try:
if not s: if not isinstance(s, str) or not s:
raise ValueError("Invalid ULID.") raise ValueError("Invalid ULID.")
return ULID(s) return ULID(s)
@ -133,14 +133,17 @@ def as_ulid(s: str) -> ULID:
@overload @overload
async def json_from_body(request) -> dict: ... async def json_from_body(request: Request) -> dict[str, JSON]: ...
@overload @overload
async def json_from_body(request, keys: list[str]) -> list: ... async def json_from_body(request: Request, keys: list[str]) -> list[JSON]: ...
async def json_from_body(request, keys: list[str] | None = None): async def json_from_body(
request: Request, keys: list[str] | None = None
) -> dict[str, JSON] | list[JSON]:
data: dict[str, JSON]
if not await request.body(): if not await request.body():
data = {} data = {}
@ -150,6 +153,9 @@ async def json_from_body(request, keys: list[str] | None = None):
except JSONDecodeError as err: except JSONDecodeError as err:
raise HTTPException(422, "Invalid JSON content.") from err raise HTTPException(422, "Invalid JSON content.") from err
if not isinstance(data, dict):
raise HTTPException(422, f"Invalid JSON type: {type(data)!a}")
if not keys: if not keys:
return data return data
@ -159,11 +165,11 @@ async def json_from_body(request, keys: list[str] | None = None):
raise HTTPException(422, f"Missing data for key: {err.args[0]}") from err raise HTTPException(422, f"Missing data for key: {err.args[0]}") from err
def is_admin(request): def is_admin(request: Request) -> bool:
return "admin" in request.auth.scopes return "admin" in request.auth.scopes
async def auth_user(request) -> User | None: async def auth_user(request: Request) -> User | None:
if not isinstance(request.user, AuthedUser): if not isinstance(request.user, AuthedUser):
return return
@ -192,7 +198,7 @@ def route(path: str, *, methods: list[str] | None = None, **kwds):
@route("/groups/{group_id}/ratings") @route("/groups/{group_id}/ratings")
async def get_ratings_for_group(request): async def get_ratings_for_group(request: Request) -> JSONResponse:
group_id = as_ulid(request.path_params["group_id"]) group_id = as_ulid(request.path_params["group_id"])
async with db.new_connection() as conn: async with db.new_connection() as conn:
@ -229,11 +235,13 @@ async def get_ratings_for_group(request):
user_ids=user_ids, user_ids=user_ids,
) )
ratings = (web_models.Rating(**r) for r in rows) ratings = [web_models.Rating(**r) for r in rows]
aggr = web_models.aggregate_ratings(ratings, user_ids) awards = await db.get_awards(conn, imdb_ids=[r.movie_imdb_id for r in ratings])
resp = tuple(asplain(r) for r in aggr) aggrs = web_models.aggregate_ratings(ratings, user_ids, awards_dict=awards)
resp = tuple(asplain(r) for r in aggrs)
return JSONResponse(resp) return JSONResponse(resp)
@ -250,13 +258,13 @@ def not_found(reason: str = "Not Found"):
return JSONResponse({"error": reason}, status_code=404) return JSONResponse({"error": reason}, status_code=404)
def not_implemented(): def not_implemented() -> Never:
raise HTTPException(404, "Not yet implemented.") raise HTTPException(404, "Not yet implemented.")
@route("/movies") @route("/movies")
@requires(["authenticated"]) @requires(["authenticated"])
async def list_movies(request): async def list_movies(request: Request) -> JSONResponse:
params = request.query_params params = request.query_params
user = await auth_user(request) user = await auth_user(request)
@ -329,13 +337,13 @@ async def list_movies(request):
@route("/movies", methods=["POST"]) @route("/movies", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def add_movie(request): async def add_movie(request: Request) -> JSONResponse:
not_implemented() not_implemented()
@route("/movies/_reload_imdb", methods=["GET"]) @route("/movies/_reload_imdb", methods=["GET"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def progress_for_load_imdb_movies(request): async def progress_for_load_imdb_movies(request: Request) -> JSONResponse:
async with db.new_connection() as conn: async with db.new_connection() as conn:
progress = await db.get_import_progress(conn) progress = await db.get_import_progress(conn)
if not progress: if not progress:
@ -371,7 +379,7 @@ _import_lock = asyncio.Lock()
@route("/movies/_reload_imdb", methods=["POST"]) @route("/movies/_reload_imdb", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def load_imdb_movies(request): async def load_imdb_movies(request: Request) -> JSONResponse:
params = request.query_params params = request.query_params
force = truthy(params.get("force")) force = truthy(params.get("force"))
@ -395,7 +403,7 @@ async def load_imdb_movies(request):
@route("/users") @route("/users")
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def list_users(request): async def list_users(request: Request) -> JSONResponse:
async with db.new_connection() as conn: async with db.new_connection() as conn:
users = await db.get_all(conn, User) users = await db.get_all(conn, User)
@ -404,7 +412,7 @@ async def list_users(request):
@route("/users", methods=["POST"]) @route("/users", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def add_user(request): async def add_user(request: Request) -> JSONResponse:
name, imdb_id = await json_from_body(request, ["name", "imdb_id"]) name, imdb_id = await json_from_body(request, ["name", "imdb_id"])
# XXX restrict name # XXX restrict name
@ -426,7 +434,7 @@ async def add_user(request):
@route("/users/{user_id}") @route("/users/{user_id}")
@requires(["authenticated"]) @requires(["authenticated"])
async def show_user(request): async def show_user(request: Request) -> JSONResponse:
user_id = as_ulid(request.path_params["user_id"]) user_id = as_ulid(request.path_params["user_id"])
if is_admin(request): if is_admin(request):
@ -455,7 +463,7 @@ async def show_user(request):
@route("/users/{user_id}", methods=["DELETE"]) @route("/users/{user_id}", methods=["DELETE"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def remove_user(request): async def remove_user(request: Request) -> JSONResponse:
user_id = as_ulid(request.path_params["user_id"]) user_id = as_ulid(request.path_params["user_id"])
async with db.new_connection() as conn: async with db.new_connection() as conn:
@ -473,7 +481,7 @@ async def remove_user(request):
@route("/users/{user_id}", methods=["PATCH"]) @route("/users/{user_id}", methods=["PATCH"])
@requires(["authenticated"]) @requires(["authenticated"])
async def modify_user(request): async def modify_user(request: Request) -> JSONResponse:
user_id = as_ulid(request.path_params["user_id"]) user_id = as_ulid(request.path_params["user_id"])
if is_admin(request): if is_admin(request):
@ -520,9 +528,13 @@ async def modify_user(request):
return JSONResponse(asplain(user)) return JSONResponse(asplain(user))
def is_valid_access(x: Any) -> TypeGuard[Access]:
return isinstance(x, str) and x in set("riw")
@route("/users/{user_id}/groups", methods=["POST"]) @route("/users/{user_id}/groups", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def add_group_to_user(request): async def add_group_to_user(request: Request) -> JSONResponse:
user_id = as_ulid(request.path_params["user_id"]) user_id = as_ulid(request.path_params["user_id"])
async with db.new_connection() as conn: async with db.new_connection() as conn:
@ -537,7 +549,7 @@ async def add_group_to_user(request):
if not group: if not group:
return not_found("Group not found") return not_found("Group not found")
if access not in set("riw"): if not is_valid_access(access):
raise HTTPException(422, "Invalid access level.") raise HTTPException(422, "Invalid access level.")
user.set_access(group_id, access) user.set_access(group_id, access)
@ -549,19 +561,19 @@ async def add_group_to_user(request):
@route("/users/{user_id}/ratings") @route("/users/{user_id}/ratings")
@requires(["private"]) @requires(["private"])
async def ratings_for_user(request): async def ratings_for_user(request: Request) -> JSONResponse:
not_implemented() not_implemented()
@route("/users/{user_id}/ratings", methods=["PUT"]) @route("/users/{user_id}/ratings", methods=["PUT"])
@requires("authenticated") @requires("authenticated")
async def set_rating_for_user(request): async def set_rating_for_user(request: Request) -> JSONResponse:
not_implemented() not_implemented()
@route("/users/_reload_ratings", methods=["POST"]) @route("/users/_reload_ratings", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def load_imdb_user_ratings(request): async def load_imdb_user_ratings(request: Request) -> JSONResponse:
ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()] ratings = [rating async for rating in imdb.refresh_user_ratings_from_imdb()]
return JSONResponse({"new_ratings": [asplain(r) for r in ratings]}) return JSONResponse({"new_ratings": [asplain(r) for r in ratings]})
@ -569,7 +581,7 @@ async def load_imdb_user_ratings(request):
@route("/groups") @route("/groups")
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def list_groups(request): async def list_groups(request: Request) -> JSONResponse:
async with db.new_connection() as conn: async with db.new_connection() as conn:
groups = await db.get_all(conn, Group) groups = await db.get_all(conn, Group)
@ -578,7 +590,7 @@ async def list_groups(request):
@route("/groups", methods=["POST"]) @route("/groups", methods=["POST"])
@requires(["authenticated", "admin"]) @requires(["authenticated", "admin"])
async def add_group(request): async def add_group(request: Request) -> JSONResponse:
(name,) = await json_from_body(request, ["name"]) (name,) = await json_from_body(request, ["name"])
# XXX restrict name # XXX restrict name
@ -592,7 +604,7 @@ async def add_group(request):
@route("/groups/{group_id}/users", methods=["POST"]) @route("/groups/{group_id}/users", methods=["POST"])
@requires(["authenticated"]) @requires(["authenticated"])
async def add_user_to_group(request): async def add_user_to_group(request: Request) -> JSONResponse:
group_id = as_ulid(request.path_params["group_id"]) group_id = as_ulid(request.path_params["group_id"])
async with db.new_connection() as conn: async with db.new_connection() as conn:
group = await db.get(conn, Group, id=str(group_id)) group = await db.get(conn, Group, id=str(group_id))
@ -628,11 +640,12 @@ async def add_user_to_group(request):
return JSONResponse(asplain(group)) return JSONResponse(asplain(group))
async def http_exception(request, exc): async def http_exception(request: Request, exc: Exception) -> JSONResponse:
assert isinstance(exc, HTTPException)
return JSONResponse({"error": exc.detail}, status_code=exc.status_code) return JSONResponse({"error": exc.detail}, status_code=exc.status_code)
def auth_error(request, err): def auth_error(conn: HTTPConnection, err: Exception) -> JSONResponse:
return unauthorized(str(err)) return unauthorized(str(err))

View file

@ -1,23 +1,22 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Container, Iterable from typing import Container, Iterable
from . import imdb, models from . import imdb, models, types
URL = str type URL = str
Score100 = int # [0, 100]
@dataclass @dataclass
class Rating: class Rating:
canonical_title: str canonical_title: str
imdb_score: Score100 | None imdb_score: types.Score100 | None
imdb_votes: int | None imdb_votes: int | None
media_type: str media_type: str
movie_imdb_id: str movie_imdb_id: types.ImdbMovieId
original_title: str | None original_title: str | None
release_year: int release_year: int
user_id: str | None user_id: types.UserIdStr | None
user_score: Score100 | None user_score: types.Score100 | None
@classmethod @classmethod
def from_movie(cls, movie: models.Movie, *, rating: models.Rating | None = None): def from_movie(cls, movie: models.Movie, *, rating: models.Rating | None = None):
@ -37,13 +36,14 @@ class Rating:
@dataclass @dataclass
class RatingAggregate: class RatingAggregate:
canonical_title: str canonical_title: str
imdb_score: Score100 | None imdb_score: types.Score100 | None
imdb_votes: int | None imdb_votes: int | None
link: URL link: URL
media_type: str media_type: str
original_title: str | None original_title: str | None
user_scores: list[Score100] user_scores: list[types.Score100]
year: int year: int
awards: list[str]
@classmethod @classmethod
def from_movie(cls, movie: models.Movie, *, ratings: Iterable[models.Rating] = []): def from_movie(cls, movie: models.Movie, *, ratings: Iterable[models.Rating] = []):
@ -56,15 +56,23 @@ class RatingAggregate:
original_title=movie.original_title, original_title=movie.original_title,
user_scores=[r.score for r in ratings], user_scores=[r.score for r in ratings],
year=movie.release_year, year=movie.release_year,
awards=[],
) )
def aggregate_ratings( def aggregate_ratings(
ratings: Iterable[Rating], user_ids: Container[str] ratings: Iterable[Rating],
user_ids: Container[types.UserIdStr],
*,
awards_dict: dict[types.ImdbMovieId, list[models.Award]] | None = None,
) -> Iterable[RatingAggregate]: ) -> Iterable[RatingAggregate]:
aggr: dict[str, RatingAggregate] = {} if awards_dict is None:
awards_dict = {}
aggr: dict[types.ImdbMovieId, RatingAggregate] = {}
for r in ratings: for r in ratings:
awards = awards_dict.get(r.movie_imdb_id, [])
mov = aggr.setdefault( mov = aggr.setdefault(
r.movie_imdb_id, r.movie_imdb_id,
RatingAggregate( RatingAggregate(
@ -76,6 +84,7 @@ def aggregate_ratings(
original_title=r.original_title, original_title=r.original_title,
user_scores=[], user_scores=[],
year=r.release_year, year=r.release_year,
awards=[f"{a.category}:{a.position}" for a in awards],
), ),
) )
# XXX do we need this? why don't we just get the ratings we're supposed to aggregate? # XXX do we need this? why don't we just get the ratings we're supposed to aggregate?