unwind/unwind/models.py
ducklet 4981de4a04 remove databases, use SQLAlechemy 2.0 instead
Among the many changes we switch to using SQLAlchemy's connection pool,
which means we are no longer required to guard against multiple threads
working on the database.
All db funcs now receive a connection to use as their first argument,
this allows the caller to control transaction & rollback behavior.
2023-11-27 23:24:35 +01:00

464 lines
13 KiB
Python

import json
from dataclasses import dataclass, field
from dataclasses import fields as _fields
from datetime import datetime, timezone
from functools import partial
from types import UnionType
from typing import (
Annotated,
Any,
ClassVar,
Container,
Literal,
Mapping,
Protocol,
Type,
TypedDict,
TypeVar,
Union,
get_args,
get_origin,
)
from sqlalchemy import Column, ForeignKey, Integer, String, Table
from sqlalchemy.orm import registry
from .types import ULID
JSON = int | float | str | None | list["JSON"] | dict[str, "JSON"]
JSONObject = dict[str, JSON]
T = TypeVar("T")
class Model(Protocol):
__table__: ClassVar[Table]
mapper_registry = registry()
metadata = mapper_registry.metadata
def annotations(tp: Type) -> tuple | None:
return tp.__metadata__ if hasattr(tp, "__metadata__") else None # type: ignore
def fields(class_or_instance):
"""Like dataclass' `fields` but with extra support for our models.
This function is a drop-in replacement for dataclass' `fields` and
SHOULD be used instead of it everywhere.
This function filters out fields marked as `Relation`. `Relation`
fields are meant to allow to store the data referenced by an ID field
directly on the instance.
"""
# XXX this might be a little slow (not sure), if so, memoize
for f in _fields(class_or_instance):
if f.name == "_is_lazy":
continue
if (attn := annotations(f.type)) and _RelationSentinel in attn:
continue # Relations are ignored
yield f
def is_optional(tp: Type) -> bool:
"""Return wether the given type is optional."""
if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
return False
args = get_args(tp)
return len(args) == 2 and type(None) in args
def optional_type(tp: Type) -> Type | None:
"""Return the wrapped type from an optional type.
For example this will return `int` for `Optional[int]`.
Since they're equivalent this also works for other optioning notations, like
`Union[int, None]` and `int | None`.
"""
if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
return None
args = get_args(tp)
if len(args) != 2 or type(None) not in args:
return None
return args[0] if args[1] is type(None) else args[1]
def optional_fields(o):
for f in fields(o):
if is_optional(f.type):
yield f
json_dump = partial(json.dumps, separators=(",", ":"))
def _id(x: T) -> T:
return x
def asplain(
o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False
) -> dict[str, Any]:
"""Return the given model instance as `dict` with JSON compatible plain datatypes.
If `filter_fields` is given only matching field names will be included in
the resulting `dict`.
If `serialize` is `True`, collection types (lists, dicts, etc.) will be
serialized as strings; this can be useful to store them in a database. Be
sure to set `serialized=True` when using `fromplain` to successfully restore
the object.
"""
validate(o)
dump = json_dump if serialize else _id
d: JSONObject = {}
for f in fields(o):
if filter_fields is not None and f.name not in filter_fields:
continue
target: Any = f.type
# XXX this doesn't properly support any kind of nested types
if (otype := optional_type(f.type)) is not None:
target = otype
if (otype := get_origin(target)) is not None:
target = otype
v = getattr(o, f.name)
if is_optional(f.type) and v is None:
d[f.name] = None
elif target is ULID:
assert isinstance(v, ULID)
d[f.name] = str(v)
elif target in {datetime}:
assert isinstance(v, datetime)
d[f.name] = v.isoformat()
elif target in {set}:
assert isinstance(v, set)
d[f.name] = dump(list(sorted(v)))
elif target in {list}:
assert isinstance(v, list)
d[f.name] = dump(list(v))
elif target in {bool, str, int, float}:
assert isinstance(
v, target
), f"Type mismatch: {f.name} ({target} != {type(v)})"
d[f.name] = v
else:
raise ValueError(f"Unsupported value type: {f.name}: {type(v)}")
return d
def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
"""Return an instance of the given model using the given data.
If `serialized` is `True`, collection types (lists, dicts, etc.) will be
deserialized from string. This is the opposite operation of `serialize` for
`asplain`.
"""
load = json.loads if serialized else _id
dd: JSONObject = {}
for f in fields(cls):
target: Any = f.type
otype = optional_type(f.type)
is_opt = otype is not None
if is_opt:
target = otype
if (xtype := get_origin(target)) is not None:
target = xtype
v = d[f.name]
if is_opt and v is None:
dd[f.name] = v
elif isinstance(v, target):
dd[f.name] = v
elif target in {set, list}:
dd[f.name] = target(load(v))
elif target in {datetime}:
dd[f.name] = target.fromisoformat(v)
else:
dd[f.name] = target(v)
o = cls(**dd)
validate(o)
return o
def validate(o: object) -> None:
for f in fields(o):
vtype = type(getattr(o, f.name))
if vtype is not f.type:
if get_origin(f.type) is vtype or (
(isinstance(f.type, UnionType) or get_origin(f.type) is Union)
and vtype in get_args(f.type)
):
continue
raise ValueError(f"Invalid value type: {f.name}: {vtype}")
def utcnow():
return datetime.now(timezone.utc)
@mapper_registry.mapped
@dataclass
class DbPatch:
__table__: ClassVar[Table] = Table(
"db_patches",
metadata,
Column("id", Integer, primary_key=True),
Column("current", String),
)
id: int
current: str
db_patches = DbPatch.__table__
@mapper_registry.mapped
@dataclass
class Progress:
__table__: ClassVar[Table] = Table(
"progress",
metadata,
Column("id", String, primary_key=True), # ULID
Column("type", String, nullable=False),
Column("state", String, nullable=False), # JSON {"percent": ..., "error": ...}
Column("started", String, nullable=False), # datetime
Column("stopped", String),
)
id: ULID = field(default_factory=ULID)
type: str = None
state: str = None
started: datetime = field(default_factory=utcnow)
stopped: str | None = None
@property
def _state(self) -> dict:
return json.loads(self.state or "{}")
@_state.setter
def _state(self, state: dict):
self.state = json_dump(state)
@property
def percent(self) -> float:
return self._state["percent"]
@percent.setter
def percent(self, percent: float):
state = self._state
state["percent"] = percent
self._state = state
@property
def error(self) -> str:
return self._state.get("error", "")
@error.setter
def error(self, error: str):
state = self._state
state["error"] = error
self._state = state
progress = Progress.__table__
@mapper_registry.mapped
@dataclass
class Movie:
__table__: ClassVar[Table] = Table(
"movies",
metadata,
Column("id", String, primary_key=True), # ULID
Column("title", String, nullable=False),
Column("original_title", String),
Column("release_year", Integer, nullable=False),
Column("media_type", String, nullable=False),
Column("imdb_id", String, nullable=False, unique=True),
Column("imdb_score", Integer),
Column("imdb_votes", Integer),
Column("runtime", Integer),
Column("genres", String, nullable=False),
Column("created", String, nullable=False), # datetime
Column("updated", String, nullable=False), # datetime
)
id: ULID = field(default_factory=ULID)
title: str = None # canonical title (usually English)
original_title: str | None = (
None # original title (usually transscribed to latin script)
)
release_year: int = None # canonical release date
media_type: str = None
imdb_id: str = None
imdb_score: int | None = None # range: [0,100]
imdb_votes: int | None = None
runtime: int | None = None # minutes
genres: set[str] = None
created: datetime = field(default_factory=utcnow)
updated: datetime = field(default_factory=utcnow)
_is_lazy: bool = field(default=False, init=False, repr=False, compare=False)
@classmethod
def lazy(cls, **kwds):
"""Return a new instance without running default factories.
This is meant purely for optimization purposes, to postpone possibly
expensive initialization operations.
"""
# XXX optimize using a metaclass & storing field refs on the class
kwds.setdefault("id", None)
kwds.setdefault("created", None)
kwds.setdefault("updated", None)
movie = cls(**kwds)
movie._is_lazy = True
return movie
def _lazy_init(self):
if not self._is_lazy:
return
for field in fields(Movie):
if getattr(self, field.name) is None and callable(field.default_factory):
setattr(self, field.name, field.default_factory())
self._is_lazy = False
movies = Movie.__table__
_RelationSentinel = object()
"""Mark a model field as containing external data.
For each field marked as a Relation there should be another field on the
dataclass containing the ID of the linked data.
The contents of the Relation are ignored or discarded when using
`asplain`, `fromplain`, and `validate`.
"""
Relation = Annotated[T | None, _RelationSentinel]
Access = Literal[
"r", # read
"i", # index
"w", # write
]
class UserGroup(TypedDict):
id: str
access: Access
@mapper_registry.mapped
@dataclass
class User:
__table__: ClassVar[Table] = Table(
"users",
metadata,
Column("id", String, primary_key=True), # ULID
Column("imdb_id", String, nullable=False, unique=True),
Column("name", String, nullable=False),
Column("secret", String, nullable=False),
Column("groups", String, nullable=False), # JSON array
)
id: ULID = field(default_factory=ULID)
imdb_id: str = None
name: str = None # canonical user name
secret: str = None
groups: list[UserGroup] = field(default_factory=list)
def has_access(self, group_id: ULID | str, access: Access = "r"):
group_id = group_id if isinstance(group_id, str) else str(group_id)
return any(g["id"] == group_id and access == g["access"] for g in self.groups)
def set_access(self, group_id: ULID | str, access: Access):
group_id = group_id if isinstance(group_id, str) else str(group_id)
for g in self.groups:
if g["id"] == group_id:
g["access"] = access
break
else:
self.groups.append({"id": group_id, "access": access})
@mapper_registry.mapped
@dataclass
class Rating:
__table__: ClassVar[Table] = Table(
"ratings",
metadata,
Column("id", String, primary_key=True), # ULID
Column("movie_id", ForeignKey("movies.id"), nullable=False), # ULID
Column("user_id", ForeignKey("users.id"), nullable=False), # ULID
Column("score", Integer, nullable=False),
Column("rating_date", String, nullable=False), # datetime
Column("favorite", Integer), # bool
Column("finished", Integer), # bool
)
id: ULID = field(default_factory=ULID)
movie_id: ULID = None
movie: Relation[Movie] = None
user_id: ULID = None
user: Relation[User] = None
score: int = None # range: [0,100]
rating_date: datetime = None
favorite: bool | None = None
finished: bool | None = None
def __eq__(self, other):
"""Return wether two Ratings are equal.
This operation compares all fields as expected, except that it
ignores any field marked as Relation.
"""
if type(other) is not type(self):
return False
return all(
getattr(self, f.name) == getattr(other, f.name) for f in fields(self)
)
ratings = Rating.__table__
class GroupUser(TypedDict):
id: str
name: str
@mapper_registry.mapped
@dataclass
class Group:
__table__: ClassVar[Table] = Table(
"groups",
metadata,
Column("id", String, primary_key=True), # ULID
Column("name", String, nullable=False),
Column("users", String, nullable=False), # JSON array
)
id: ULID = field(default_factory=ULID)
name: str = None
users: list[GroupUser] = field(default_factory=list)