unwind/unwind/models.py
2023-03-18 23:30:40 +01:00

375 lines
10 KiB
Python

import json
from dataclasses import dataclass, field
from dataclasses import fields as _fields
from datetime import datetime, timezone
from functools import partial
from types import UnionType
from typing import (
Annotated,
Any,
ClassVar,
Container,
Literal,
Mapping,
Type,
TypeVar,
TypedDict,
Union,
get_args,
get_origin,
)
from .types import ULID
JSON = int | float | str | None | list["JSON"] | dict[str, "JSON"]
JSONObject = dict[str, JSON]
T = TypeVar("T")
def annotations(tp: Type) -> tuple | None:
return tp.__metadata__ if hasattr(tp, "__metadata__") else None
def fields(class_or_instance):
"""Like dataclass' `fields` but with extra support for our models.
This function is a drop-in replacement for dataclass' `fields` and
SHOULD be used instead of it everywhere.
This function filters out fields marked as `Relation`. `Relation`
fields are meant to allow to store the data referenced by an ID field
directly on the instance.
"""
# XXX this might be a little slow (not sure), if so, memoize
for f in _fields(class_or_instance):
if f.name == "_is_lazy":
continue
if (attn := annotations(f.type)) and _RelationSentinel in attn:
continue # Relations are ignored
yield f
def is_optional(tp: Type) -> bool:
"""Return wether the given type is optional."""
if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
return False
args = get_args(tp)
return len(args) == 2 and type(None) in args
def optional_type(tp: Type) -> Type | None:
"""Return the wrapped type from an optional type.
For example this will return `int` for `Optional[int]`.
Since they're equivalent this also works for other optioning notations, like
`Union[int, None]` and `int | None`.
"""
if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
return None
args = get_args(tp)
if len(args) != 2 or type(None) not in args:
return None
return args[0] if args[1] is type(None) else args[1]
def optional_fields(o):
for f in fields(o):
if is_optional(f.type):
yield f
json_dump = partial(json.dumps, separators=(",", ":"))
def _id(x: T) -> T:
return x
def asplain(
o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False
) -> dict[str, Any]:
"""Return the given model instance as `dict` with JSON compatible plain datatypes.
If `filter_fields` is given only matching field names will be included in
the resulting `dict`.
If `serialize` is `True`, collection types (lists, dicts, etc.) will be
serialized as strings; this can be useful to store them in a database. Be
sure to set `serialized=True` when using `fromplain` to successfully restore
the object.
"""
validate(o)
dump = json_dump if serialize else _id
d: JSONObject = {}
for f in fields(o):
if filter_fields is not None and f.name not in filter_fields:
continue
target = f.type
# XXX this doesn't properly support any kind of nested types
if (otype := optional_type(f.type)) is not None:
target = otype
if (otype := get_origin(target)) is not None:
target = otype
v = getattr(o, f.name)
if is_optional(f.type) and v is None:
d[f.name] = None
elif target is ULID:
assert isinstance(v, ULID)
d[f.name] = str(v)
elif target in {datetime}:
assert isinstance(v, datetime)
d[f.name] = v.isoformat()
elif target in {set}:
assert isinstance(v, set)
d[f.name] = dump(list(sorted(v)))
elif target in {list}:
assert isinstance(v, list)
d[f.name] = dump(list(v))
elif target in {bool, str, int, float}:
assert isinstance(
v, target
), f"Type mismatch: {f.name} ({target} != {type(v)})"
d[f.name] = v
else:
raise ValueError(f"Unsupported value type: {f.name}: {type(v)}")
return d
def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
"""Return an instance of the given model using the given data.
If `serialized` is `True`, collection types (lists, dicts, etc.) will be
deserialized from string. This is the opposite operation of `serialize` for
`asplain`.
"""
load = json.loads if serialized else _id
dd: JSONObject = {}
for f in fields(cls):
target = f.type
otype = optional_type(f.type)
is_opt = otype is not None
if is_opt:
target = otype
if (xtype := get_origin(target)) is not None:
target = xtype
v = d[f.name]
if is_opt and v is None:
dd[f.name] = v
elif isinstance(v, target):
dd[f.name] = v
elif target in {set, list}:
dd[f.name] = target(load(v))
elif target in {datetime}:
dd[f.name] = target.fromisoformat(v)
else:
dd[f.name] = target(v)
o = cls(**dd)
validate(o)
return o
def validate(o: object) -> None:
for f in fields(o):
vtype = type(getattr(o, f.name))
if vtype is not f.type:
if get_origin(f.type) is vtype or (
(isinstance(f.type, UnionType) or get_origin(f.type) is Union)
and vtype in get_args(f.type)
):
continue
raise ValueError(f"Invalid value type: {f.name}: {vtype}")
def utcnow():
return datetime.utcnow().replace(tzinfo=timezone.utc)
@dataclass
class Progress:
_table: ClassVar[str] = "progress"
id: ULID = field(default_factory=ULID)
type: str = None
state: str = None
started: datetime = field(default_factory=utcnow)
stopped: str | None = None
@property
def _state(self) -> dict:
return json.loads(self.state or "{}")
@_state.setter
def _state(self, state: dict):
self.state = json_dump(state)
@property
def percent(self) -> float:
return self._state["percent"]
@percent.setter
def percent(self, percent: float):
state = self._state
state["percent"] = percent
self._state = state
@property
def error(self) -> str:
return self._state.get("error", "")
@error.setter
def error(self, error: str):
state = self._state
state["error"] = error
self._state = state
@dataclass
class Movie:
_table: ClassVar[str] = "movies"
id: ULID = field(default_factory=ULID)
title: str = None # canonical title (usually English)
original_title: str | None = (
None # original title (usually transscribed to latin script)
)
release_year: int = None # canonical release date
media_type: str = None
imdb_id: str = None
imdb_score: int | None = None # range: [0,100]
imdb_votes: int | None = None
runtime: int | None = None # minutes
genres: set[str] = None
created: datetime = field(default_factory=utcnow)
updated: datetime = field(default_factory=utcnow)
_is_lazy: bool = field(default=False, init=False, repr=False, compare=False)
@classmethod
def lazy(cls, **kwds):
"""Return a new instance without running default factories.
This is meant purely for optimization purposes, to postpone possibly
expensive initialization operations.
"""
# XXX optimize using a metaclass & storing field refs on the class
kwds.setdefault("id", None)
kwds.setdefault("created", None)
kwds.setdefault("updated", None)
movie = cls(**kwds)
movie._is_lazy = True
return movie
def _lazy_init(self):
if not self._is_lazy:
return
for field in fields(Movie):
if getattr(self, field.name) is None and callable(field.default_factory):
setattr(self, field.name, field.default_factory())
self._is_lazy = False
_RelationSentinel = object()
"""Mark a model field as containing external data.
For each field marked as a Relation there should be another field on the
dataclass containing the ID of the linked data.
The contents of the Relation are ignored or discarded when using
`asplain`, `fromplain`, and `validate`.
"""
Relation = Annotated[T | None, _RelationSentinel]
@dataclass
class Rating:
_table: ClassVar[str] = "ratings"
id: ULID = field(default_factory=ULID)
movie_id: ULID = None
movie: Relation[Movie] = None
user_id: ULID = None
user: Relation["User"] = None
score: int = None # range: [0,100]
rating_date: datetime = None
favorite: bool | None = None
finished: bool | None = None
def __eq__(self, other):
"""Return wether two Ratings are equal.
This operation compares all fields as expected, except that it
ignores any field marked as Relation.
"""
if type(other) is not type(self):
return False
return all(
getattr(self, f.name) == getattr(other, f.name) for f in fields(self)
)
Access = Literal[
"r", # read
"i", # index
"w", # write
]
class UserGroup(TypedDict):
id: str
access: Access
@dataclass
class User:
_table: ClassVar[str] = "users"
id: ULID = field(default_factory=ULID)
imdb_id: str = None
name: str = None # canonical user name
secret: str = None
groups: list[UserGroup] = field(default_factory=list)
def has_access(self, group_id: ULID | str, access: Access = "r"):
group_id = group_id if isinstance(group_id, str) else str(group_id)
return any(g["id"] == group_id and access == g["access"] for g in self.groups)
def set_access(self, group_id: ULID | str, access: Access):
group_id = group_id if isinstance(group_id, str) else str(group_id)
for g in self.groups:
if g["id"] == group_id:
g["access"] = access
break
else:
self.groups.append({"id": group_id, "access": access})
class GroupUser(TypedDict):
id: str
name: str
@dataclass
class Group:
_table: ClassVar[str] = "groups"
id: ULID = field(default_factory=ULID)
name: str = None
users: list[GroupUser] = field(default_factory=list)