From 1ad7a79d332807a94ecfea0439f905e2fc29faa8 Mon Sep 17 00:00:00 2001 From: ducklet Date: Sat, 10 Jul 2021 01:34:38 +0200 Subject: [PATCH] add a Relation type to models to store related model data Any field marked as Relation is ignored by all model operations (like converting to and from plain form). Fields marked as Relation are meant to store the actual model data for foreign keys stored on the model. --- unwind/db.py | 3 +- unwind/models.py | 74 ++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 70 insertions(+), 7 deletions(-) diff --git a/unwind/db.py b/unwind/db.py index fa035e7..ddbb007 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -1,6 +1,5 @@ import logging import re -from dataclasses import fields from pathlib import Path from typing import Iterable, Literal, Optional, Type, TypeVar, Union @@ -8,7 +7,7 @@ import sqlalchemy from databases import Database from . import config -from .models import Movie, Rating, User, asplain, fromplain, optional_fields +from .models import Movie, Rating, User, asplain, fields, fromplain, optional_fields log = logging.getLogger(__name__) diff --git a/unwind/models.py b/unwind/models.py index 2a60071..2789583 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -1,11 +1,46 @@ import json -from dataclasses import asdict, dataclass, field, fields +from dataclasses import dataclass, field +from dataclasses import fields as _fields from datetime import datetime, timezone -from typing import Any, ClassVar, Optional, Type, Union, get_args, get_origin +from typing import ( + Annotated, + Any, + ClassVar, + Optional, + Type, + TypeVar, + Union, + get_args, + get_origin, +) from .types import ULID +def annotations(tp: Type) -> Optional[tuple]: + return tp.__metadata__ if hasattr(tp, "__metadata__") else None + + +def fields(class_or_instance): + """Like dataclass' `fields` but with extra support for our models. + + This function is a drop-in replacement for dataclass' `fields` and + SHOULD be used instead of it everywhere. + This function filters out fields marked as `Relation`. `Relation` + fields are meant to allow to store the data referenced by an ID field + directly on the instance. + """ + + # XXX this might be a little slow (not sure), if so, memoize + + for f in _fields(class_or_instance): + + if (attn := annotations(f.type)) and _RelationSentinel in attn: + continue # Relations are ignored + + yield f + + def is_optional(tp: Type): if get_origin(tp) is not Union: return False @@ -34,7 +69,7 @@ def optional_fields(o): def asplain(o) -> dict[str, Any]: validate(o) - d = asdict(o) + d = {} for f in fields(o): target = f.type @@ -44,7 +79,7 @@ def asplain(o) -> dict[str, Any]: if (otype := get_origin(target)) is not None: target = otype - v = d[f.name] + v = getattr(o, f.name) if target is ULID: d[f.name] = str(v) elif target in {datetime}: @@ -54,7 +89,7 @@ def asplain(o) -> dict[str, Any]: elif target in {list}: d[f.name] = json.dumps(list(v)) elif target in {bool, str, int, float, None}: - pass + d[f.name] = v else: raise ValueError(f"Unsupported value type: {f.name}: {type(v)}") @@ -123,18 +158,47 @@ class Movie: updated: datetime = field(default_factory=utcnow) +T = TypeVar("T") +_RelationSentinel = object() +"""Mark a model field as containing external data. + +For each field marked as a Relation there should be another field on the +dataclass containing the ID of the linked data. +The contents of the Relation are ignored or discarded when using +`asplain`, `fromplain`, and `validate`. +""" +Relation = Annotated[Optional[T], _RelationSentinel] + + @dataclass class Rating: _table: ClassVar[str] = "ratings" id: ULID = field(default_factory=ULID) + movie_id: ULID = None + movie: Relation[Movie] = None + user_id: ULID = None + user: Relation["User"] = None + score: int = None # range: [0,100] rating_date: datetime = None favorite: Optional[bool] = None finished: Optional[bool] = None + def __eq__(self, other): + """Return wether two Ratings are equal. + + This operation compares all fields as expected, except that it + ignores any field marked as Relation. + """ + if type(other) is not type(self): + return False + return all( + getattr(self, f.name) == getattr(other, f.name) for f in fields(self) + ) + @dataclass class User: