unwind/unwind/models.py

291 lines
7.4 KiB
Python
Raw Normal View History

import json
from dataclasses import dataclass, field
from dataclasses import fields as _fields
from datetime import datetime, timezone
from typing import (
Annotated,
Any,
ClassVar,
Optional,
Type,
TypeVar,
Union,
get_args,
get_origin,
)
from .types import ULID
def annotations(tp: Type) -> Optional[tuple]:
return tp.__metadata__ if hasattr(tp, "__metadata__") else None
def fields(class_or_instance):
"""Like dataclass' `fields` but with extra support for our models.
This function is a drop-in replacement for dataclass' `fields` and
SHOULD be used instead of it everywhere.
This function filters out fields marked as `Relation`. `Relation`
fields are meant to allow to store the data referenced by an ID field
directly on the instance.
"""
# XXX this might be a little slow (not sure), if so, memoize
for f in _fields(class_or_instance):
2021-07-22 20:30:23 +02:00
if f.name == "_is_lazy":
continue
if (attn := annotations(f.type)) and _RelationSentinel in attn:
continue # Relations are ignored
yield f
def is_optional(tp: Type):
if get_origin(tp) is not Union:
return False
args = get_args(tp)
return len(args) == 2 and type(None) in args
def optional_type(tp: Type):
if get_origin(tp) is not Union:
return None
args = get_args(tp)
if len(args) != 2 or args[1] is not type(None):
return None
return args[0]
2021-06-21 18:54:03 +02:00
def optional_fields(o):
for f in fields(o):
if is_optional(f.type):
yield f
def asplain(o) -> dict[str, Any]:
validate(o)
d = {}
for f in fields(o):
target = f.type
# XXX this doesn't properly support any kind of nested types
if (otype := optional_type(f.type)) is not None:
target = otype
if (otype := get_origin(target)) is not None:
target = otype
v = getattr(o, f.name)
if target is ULID:
d[f.name] = str(v)
elif target in {datetime}:
d[f.name] = v.isoformat()
elif target in {set}:
d[f.name] = json.dumps(list(sorted(v)))
elif target in {list}:
d[f.name] = json.dumps(list(v))
elif target in {bool, str, int, float, None}:
d[f.name] = v
else:
raise ValueError(f"Unsupported value type: {f.name}: {type(v)}")
return d
def fromplain(cls, d: dict[str, Any]):
dd = {}
for f in fields(cls):
target = f.type
otype = optional_type(f.type)
is_opt = otype is not None
if is_opt:
target = otype
if (xtype := get_origin(target)) is not None:
target = xtype
v = d[f.name]
if is_opt and v is None:
dd[f.name] = v
elif isinstance(v, target):
dd[f.name] = v
elif target in {set, list}:
dd[f.name] = target(json.loads(v))
elif target in {datetime}:
dd[f.name] = target.fromisoformat(v)
else:
dd[f.name] = target(v)
o = cls(**dd)
validate(o)
return o
def validate(o):
for f in fields(o):
vtype = type(getattr(o, f.name))
if vtype is not f.type:
if get_origin(f.type) is vtype or (
get_origin(f.type) is Union and vtype in get_args(f.type)
):
continue
raise ValueError(f"Invalid value type: {f.name}: {vtype}")
def utcnow():
2021-07-22 00:05:38 +02:00
return datetime.utcnow().replace(tzinfo=timezone.utc)
@dataclass
class Progress:
_table: ClassVar[str] = "progress"
id: ULID = field(default_factory=ULID)
type: str = None
state: str = None
started: datetime = field(default_factory=utcnow)
stopped: Optional[str] = None
2021-07-28 23:07:04 +02:00
@property
def _state(self) -> dict:
return json.loads(self.state or "{}")
@_state.setter
def _state(self, state: dict):
self.state = json.dumps(state, separators=(",", ":"))
@property
def percent(self) -> float:
return self._state["percent"]
@percent.setter
def percent(self, percent: float):
state = self._state
state["percent"] = percent
self._state = state
@property
def error(self) -> str:
return self._state.get("error", "")
@error.setter
def error(self, error: str):
state = self._state
state["error"] = error
self._state = state
@dataclass
class Movie:
_table: ClassVar[str] = "movies"
id: ULID = field(default_factory=ULID)
2021-06-21 18:54:03 +02:00
title: str = None # canonical title (usually English)
original_title: Optional[
str
] = None # original title (usually transscribed to latin script)
release_year: int = None # canonical release date
2021-06-21 18:54:03 +02:00
media_type: str = None
imdb_id: str = None
2021-07-21 20:04:57 +02:00
imdb_score: Optional[int] = None # range: [0,100]
imdb_votes: Optional[int] = None
runtime: Optional[int] = None # minutes
genres: set[str] = None
2021-07-22 20:30:23 +02:00
created: datetime = field(default_factory=utcnow)
updated: datetime = field(default_factory=utcnow)
2021-07-22 20:30:23 +02:00
_is_lazy: bool = field(default=False, init=False, repr=False, compare=False)
@classmethod
def lazy(cls, **kwds):
"""Return a new instance without running default factories.
This is meant purely for optimization purposes, to postpone possibly
expensive initialization operations.
"""
# XXX optimize using a metaclass & storing field refs on the class
kwds.setdefault("id", None)
kwds.setdefault("created", None)
kwds.setdefault("updated", None)
movie = cls(**kwds)
movie._is_lazy = True
return movie
def _lazy_init(self):
if not self._is_lazy:
return
for field in fields(Movie):
if getattr(self, field.name) is None and callable(field.default_factory):
2021-07-22 20:30:23 +02:00
setattr(self, field.name, field.default_factory())
self._is_lazy = False
T = TypeVar("T")
_RelationSentinel = object()
"""Mark a model field as containing external data.
For each field marked as a Relation there should be another field on the
dataclass containing the ID of the linked data.
The contents of the Relation are ignored or discarded when using
`asplain`, `fromplain`, and `validate`.
"""
Relation = Annotated[Optional[T], _RelationSentinel]
@dataclass
class Rating:
_table: ClassVar[str] = "ratings"
id: ULID = field(default_factory=ULID)
movie_id: ULID = None
movie: Relation[Movie] = None
user_id: ULID = None
user: Relation["User"] = None
score: int = None # range: [0,100]
rating_date: datetime = None
favorite: Optional[bool] = None
finished: Optional[bool] = None
def __eq__(self, other):
"""Return wether two Ratings are equal.
This operation compares all fields as expected, except that it
ignores any field marked as Relation.
"""
if type(other) is not type(self):
return False
return all(
getattr(self, f.name) == getattr(other, f.name) for f in fields(self)
)
@dataclass
class User:
_table: ClassVar[str] = "users"
id: ULID = field(default_factory=ULID)
imdb_id: str = None
name: str = None # canonical user name
@dataclass
class Group:
_table: ClassVar[str] = "groups"
id: ULID = field(default_factory=ULID)
name: str = None
secret: str = None
users: list[dict[str, str]] = field(default_factory=list)