import json from dataclasses import dataclass, field from dataclasses import fields as _fields from datetime import datetime, timezone from functools import partial from types import UnionType from typing import ( Annotated, Any, ClassVar, Container, Literal, Mapping, Type, TypeVar, Union, get_args, get_origin, ) from .types import ULID JSON = int | float | str | None | list["JSON"] | dict[str, "JSON"] JSONObject = dict[str, JSON] T = TypeVar("T") def annotations(tp: Type) -> tuple | None: return tp.__metadata__ if hasattr(tp, "__metadata__") else None def fields(class_or_instance): """Like dataclass' `fields` but with extra support for our models. This function is a drop-in replacement for dataclass' `fields` and SHOULD be used instead of it everywhere. This function filters out fields marked as `Relation`. `Relation` fields are meant to allow to store the data referenced by an ID field directly on the instance. """ # XXX this might be a little slow (not sure), if so, memoize for f in _fields(class_or_instance): if f.name == "_is_lazy": continue if (attn := annotations(f.type)) and _RelationSentinel in attn: continue # Relations are ignored yield f def is_optional(tp: Type) -> bool: """Return wether the given type is optional.""" if not isinstance(tp, UnionType) and get_origin(tp) is not Union: return False args = get_args(tp) return len(args) == 2 and type(None) in args def optional_type(tp: Type) -> Type | None: """Return the wrapped type from an optional type. For example this will return `int` for `Optional[int]`. Since they're equivalent this also works for other optioning notations, like `Union[int, None]` and `int | None`. """ if not isinstance(tp, UnionType) and get_origin(tp) is not Union: return None args = get_args(tp) if len(args) != 2 or type(None) not in args: return None return args[0] if args[1] is type(None) else args[1] def optional_fields(o): for f in fields(o): if is_optional(f.type): yield f json_dump = partial(json.dumps, separators=(",", ":")) def _id(x: T) -> T: return x def asplain( o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False ) -> dict[str, Any]: """Return the given model instance as `dict` with JSON compatible plain datatypes. If `filter_fields` is given only matching field names will be included in the resulting `dict`. If `serialize` is `True`, collection types (lists, dicts, etc.) will be serialized as strings; this can be useful to store them in a database. Be sure to set `serialized=True` when using `fromplain` to successfully restore the object. """ validate(o) dump = json_dump if serialize else _id d: JSONObject = {} for f in fields(o): if filter_fields is not None and f.name not in filter_fields: continue target = f.type # XXX this doesn't properly support any kind of nested types if (otype := optional_type(f.type)) is not None: target = otype if (otype := get_origin(target)) is not None: target = otype v = getattr(o, f.name) if is_optional(f.type) and v is None: d[f.name] = None elif target is ULID: assert isinstance(v, ULID) d[f.name] = str(v) elif target in {datetime}: assert isinstance(v, datetime) d[f.name] = v.isoformat() elif target in {set}: assert isinstance(v, set) d[f.name] = dump(list(sorted(v))) elif target in {list}: assert isinstance(v, list) d[f.name] = dump(list(v)) elif target in {bool, str, int, float}: assert isinstance( v, target ), f"Type mismatch: {f.name} ({target} != {type(v)})" d[f.name] = v else: raise ValueError(f"Unsupported value type: {f.name}: {type(v)}") return d def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T: """Return an instance of the given model using the given data. If `serialized` is `True`, collection types (lists, dicts, etc.) will be deserialized from string. This is the opposite operation of `serialize` for `asplain`. """ load = json.loads if serialized else _id dd: JSONObject = {} for f in fields(cls): target = f.type otype = optional_type(f.type) is_opt = otype is not None if is_opt: target = otype if (xtype := get_origin(target)) is not None: target = xtype v = d[f.name] if is_opt and v is None: dd[f.name] = v elif isinstance(v, target): dd[f.name] = v elif target in {set, list}: dd[f.name] = target(load(v)) elif target in {datetime}: dd[f.name] = target.fromisoformat(v) else: dd[f.name] = target(v) o = cls(**dd) validate(o) return o def validate(o: object) -> None: for f in fields(o): vtype = type(getattr(o, f.name)) if vtype is not f.type: if get_origin(f.type) is vtype or ( (isinstance(f.type, UnionType) or get_origin(f.type) is Union) and vtype in get_args(f.type) ): continue raise ValueError(f"Invalid value type: {f.name}: {vtype}") def utcnow(): return datetime.utcnow().replace(tzinfo=timezone.utc) @dataclass class Progress: _table: ClassVar[str] = "progress" id: ULID = field(default_factory=ULID) type: str = None state: str = None started: datetime = field(default_factory=utcnow) stopped: str | None = None @property def _state(self) -> dict: return json.loads(self.state or "{}") @_state.setter def _state(self, state: dict): self.state = json_dump(state) @property def percent(self) -> float: return self._state["percent"] @percent.setter def percent(self, percent: float): state = self._state state["percent"] = percent self._state = state @property def error(self) -> str: return self._state.get("error", "") @error.setter def error(self, error: str): state = self._state state["error"] = error self._state = state @dataclass class Movie: _table: ClassVar[str] = "movies" id: ULID = field(default_factory=ULID) title: str = None # canonical title (usually English) original_title: str | None = ( None # original title (usually transscribed to latin script) ) release_year: int = None # canonical release date media_type: str = None imdb_id: str = None imdb_score: int | None = None # range: [0,100] imdb_votes: int | None = None runtime: int | None = None # minutes genres: set[str] = None created: datetime = field(default_factory=utcnow) updated: datetime = field(default_factory=utcnow) _is_lazy: bool = field(default=False, init=False, repr=False, compare=False) @classmethod def lazy(cls, **kwds): """Return a new instance without running default factories. This is meant purely for optimization purposes, to postpone possibly expensive initialization operations. """ # XXX optimize using a metaclass & storing field refs on the class kwds.setdefault("id", None) kwds.setdefault("created", None) kwds.setdefault("updated", None) movie = cls(**kwds) movie._is_lazy = True return movie def _lazy_init(self): if not self._is_lazy: return for field in fields(Movie): if getattr(self, field.name) is None and callable(field.default_factory): setattr(self, field.name, field.default_factory()) self._is_lazy = False _RelationSentinel = object() """Mark a model field as containing external data. For each field marked as a Relation there should be another field on the dataclass containing the ID of the linked data. The contents of the Relation are ignored or discarded when using `asplain`, `fromplain`, and `validate`. """ Relation = Annotated[T | None, _RelationSentinel] @dataclass class Rating: _table: ClassVar[str] = "ratings" id: ULID = field(default_factory=ULID) movie_id: ULID = None movie: Relation[Movie] = None user_id: ULID = None user: Relation["User"] = None score: int = None # range: [0,100] rating_date: datetime = None favorite: bool | None = None finished: bool | None = None def __eq__(self, other): """Return wether two Ratings are equal. This operation compares all fields as expected, except that it ignores any field marked as Relation. """ if type(other) is not type(self): return False return all( getattr(self, f.name) == getattr(other, f.name) for f in fields(self) ) Access = Literal[ "r", # read "i", # index "w", # write ] @dataclass class User: _table: ClassVar[str] = "users" id: ULID = field(default_factory=ULID) imdb_id: str = None name: str = None # canonical user name secret: str = None groups: list[dict[str, str]] = field(default_factory=list) def has_access(self, group_id: ULID | str, access: Access = "r"): group_id = group_id if isinstance(group_id, str) else str(group_id) return any(g["id"] == group_id and access == g["access"] for g in self.groups) def set_access(self, group_id: ULID | str, access: Access): group_id = group_id if isinstance(group_id, str) else str(group_id) for g in self.groups: if g["id"] == group_id: g["access"] = access break else: self.groups.append({"id": group_id, "access": access}) @dataclass class Group: _table: ClassVar[str] = "groups" id: ULID = field(default_factory=ULID) name: str = None users: list[dict[str, str]] = field(default_factory=list)