diff --git a/unwind/db.py b/unwind/db.py index fa035e7..ddbb007 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -1,6 +1,5 @@ import logging import re -from dataclasses import fields from pathlib import Path from typing import Iterable, Literal, Optional, Type, TypeVar, Union @@ -8,7 +7,7 @@ import sqlalchemy from databases import Database from . import config -from .models import Movie, Rating, User, asplain, fromplain, optional_fields +from .models import Movie, Rating, User, asplain, fields, fromplain, optional_fields log = logging.getLogger(__name__) diff --git a/unwind/models.py b/unwind/models.py index 2a60071..2789583 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -1,11 +1,46 @@ import json -from dataclasses import asdict, dataclass, field, fields +from dataclasses import dataclass, field +from dataclasses import fields as _fields from datetime import datetime, timezone -from typing import Any, ClassVar, Optional, Type, Union, get_args, get_origin +from typing import ( + Annotated, + Any, + ClassVar, + Optional, + Type, + TypeVar, + Union, + get_args, + get_origin, +) from .types import ULID +def annotations(tp: Type) -> Optional[tuple]: + return tp.__metadata__ if hasattr(tp, "__metadata__") else None + + +def fields(class_or_instance): + """Like dataclass' `fields` but with extra support for our models. + + This function is a drop-in replacement for dataclass' `fields` and + SHOULD be used instead of it everywhere. + This function filters out fields marked as `Relation`. `Relation` + fields are meant to allow to store the data referenced by an ID field + directly on the instance. + """ + + # XXX this might be a little slow (not sure), if so, memoize + + for f in _fields(class_or_instance): + + if (attn := annotations(f.type)) and _RelationSentinel in attn: + continue # Relations are ignored + + yield f + + def is_optional(tp: Type): if get_origin(tp) is not Union: return False @@ -34,7 +69,7 @@ def optional_fields(o): def asplain(o) -> dict[str, Any]: validate(o) - d = asdict(o) + d = {} for f in fields(o): target = f.type @@ -44,7 +79,7 @@ def asplain(o) -> dict[str, Any]: if (otype := get_origin(target)) is not None: target = otype - v = d[f.name] + v = getattr(o, f.name) if target is ULID: d[f.name] = str(v) elif target in {datetime}: @@ -54,7 +89,7 @@ def asplain(o) -> dict[str, Any]: elif target in {list}: d[f.name] = json.dumps(list(v)) elif target in {bool, str, int, float, None}: - pass + d[f.name] = v else: raise ValueError(f"Unsupported value type: {f.name}: {type(v)}") @@ -123,18 +158,47 @@ class Movie: updated: datetime = field(default_factory=utcnow) +T = TypeVar("T") +_RelationSentinel = object() +"""Mark a model field as containing external data. + +For each field marked as a Relation there should be another field on the +dataclass containing the ID of the linked data. +The contents of the Relation are ignored or discarded when using +`asplain`, `fromplain`, and `validate`. +""" +Relation = Annotated[Optional[T], _RelationSentinel] + + @dataclass class Rating: _table: ClassVar[str] = "ratings" id: ULID = field(default_factory=ULID) + movie_id: ULID = None + movie: Relation[Movie] = None + user_id: ULID = None + user: Relation["User"] = None + score: int = None # range: [0,100] rating_date: datetime = None favorite: Optional[bool] = None finished: Optional[bool] = None + def __eq__(self, other): + """Return wether two Ratings are equal. + + This operation compares all fields as expected, except that it + ignores any field marked as Relation. + """ + if type(other) is not type(self): + return False + return all( + getattr(self, f.name) == getattr(other, f.name) for f in fields(self) + ) + @dataclass class User: