Among the many changes we switch to using SQLAlchemy's connection pool, which means we are no longer required to guard against multiple threads working on the database. All db funcs now receive a connection to use as their first argument, this allows the caller to control transaction & rollback behavior.
464 lines
13 KiB
Python
464 lines
13 KiB
Python
import json
|
|
from dataclasses import dataclass, field
|
|
from dataclasses import fields as _fields
|
|
from datetime import datetime, timezone
|
|
from functools import partial
|
|
from types import UnionType
|
|
from typing import (
|
|
Annotated,
|
|
Any,
|
|
ClassVar,
|
|
Container,
|
|
Literal,
|
|
Mapping,
|
|
Protocol,
|
|
Type,
|
|
TypedDict,
|
|
TypeVar,
|
|
Union,
|
|
get_args,
|
|
get_origin,
|
|
)
|
|
|
|
from sqlalchemy import Column, ForeignKey, Integer, String, Table
|
|
from sqlalchemy.orm import registry
|
|
|
|
from .types import ULID
|
|
|
|
JSON = int | float | str | None | list["JSON"] | dict[str, "JSON"]
|
|
JSONObject = dict[str, JSON]
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class Model(Protocol):
|
|
__table__: ClassVar[Table]
|
|
|
|
|
|
mapper_registry = registry()
|
|
metadata = mapper_registry.metadata
|
|
|
|
|
|
def annotations(tp: Type) -> tuple | None:
|
|
return tp.__metadata__ if hasattr(tp, "__metadata__") else None # type: ignore
|
|
|
|
|
|
def fields(class_or_instance):
|
|
"""Like dataclass' `fields` but with extra support for our models.
|
|
|
|
This function is a drop-in replacement for dataclass' `fields` and
|
|
SHOULD be used instead of it everywhere.
|
|
This function filters out fields marked as `Relation`. `Relation`
|
|
fields are meant to allow to store the data referenced by an ID field
|
|
directly on the instance.
|
|
"""
|
|
|
|
# XXX this might be a little slow (not sure), if so, memoize
|
|
|
|
for f in _fields(class_or_instance):
|
|
if f.name == "_is_lazy":
|
|
continue
|
|
|
|
if (attn := annotations(f.type)) and _RelationSentinel in attn:
|
|
continue # Relations are ignored
|
|
|
|
yield f
|
|
|
|
|
|
def is_optional(tp: Type) -> bool:
|
|
"""Return wether the given type is optional."""
|
|
if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
|
|
return False
|
|
|
|
args = get_args(tp)
|
|
return len(args) == 2 and type(None) in args
|
|
|
|
|
|
def optional_type(tp: Type) -> Type | None:
|
|
"""Return the wrapped type from an optional type.
|
|
|
|
For example this will return `int` for `Optional[int]`.
|
|
Since they're equivalent this also works for other optioning notations, like
|
|
`Union[int, None]` and `int | None`.
|
|
"""
|
|
if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
|
|
return None
|
|
|
|
args = get_args(tp)
|
|
if len(args) != 2 or type(None) not in args:
|
|
return None
|
|
|
|
return args[0] if args[1] is type(None) else args[1]
|
|
|
|
|
|
def optional_fields(o):
|
|
for f in fields(o):
|
|
if is_optional(f.type):
|
|
yield f
|
|
|
|
|
|
json_dump = partial(json.dumps, separators=(",", ":"))
|
|
|
|
|
|
def _id(x: T) -> T:
|
|
return x
|
|
|
|
|
|
def asplain(
|
|
o: object, *, filter_fields: Container[str] | None = None, serialize: bool = False
|
|
) -> dict[str, Any]:
|
|
"""Return the given model instance as `dict` with JSON compatible plain datatypes.
|
|
|
|
If `filter_fields` is given only matching field names will be included in
|
|
the resulting `dict`.
|
|
If `serialize` is `True`, collection types (lists, dicts, etc.) will be
|
|
serialized as strings; this can be useful to store them in a database. Be
|
|
sure to set `serialized=True` when using `fromplain` to successfully restore
|
|
the object.
|
|
"""
|
|
validate(o)
|
|
|
|
dump = json_dump if serialize else _id
|
|
|
|
d: JSONObject = {}
|
|
for f in fields(o):
|
|
if filter_fields is not None and f.name not in filter_fields:
|
|
continue
|
|
|
|
target: Any = f.type
|
|
# XXX this doesn't properly support any kind of nested types
|
|
if (otype := optional_type(f.type)) is not None:
|
|
target = otype
|
|
if (otype := get_origin(target)) is not None:
|
|
target = otype
|
|
|
|
v = getattr(o, f.name)
|
|
if is_optional(f.type) and v is None:
|
|
d[f.name] = None
|
|
elif target is ULID:
|
|
assert isinstance(v, ULID)
|
|
d[f.name] = str(v)
|
|
elif target in {datetime}:
|
|
assert isinstance(v, datetime)
|
|
d[f.name] = v.isoformat()
|
|
elif target in {set}:
|
|
assert isinstance(v, set)
|
|
d[f.name] = dump(list(sorted(v)))
|
|
elif target in {list}:
|
|
assert isinstance(v, list)
|
|
d[f.name] = dump(list(v))
|
|
elif target in {bool, str, int, float}:
|
|
assert isinstance(
|
|
v, target
|
|
), f"Type mismatch: {f.name} ({target} != {type(v)})"
|
|
d[f.name] = v
|
|
else:
|
|
raise ValueError(f"Unsupported value type: {f.name}: {type(v)}")
|
|
|
|
return d
|
|
|
|
|
|
def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
|
|
"""Return an instance of the given model using the given data.
|
|
|
|
If `serialized` is `True`, collection types (lists, dicts, etc.) will be
|
|
deserialized from string. This is the opposite operation of `serialize` for
|
|
`asplain`.
|
|
"""
|
|
load = json.loads if serialized else _id
|
|
|
|
dd: JSONObject = {}
|
|
for f in fields(cls):
|
|
target: Any = f.type
|
|
otype = optional_type(f.type)
|
|
is_opt = otype is not None
|
|
if is_opt:
|
|
target = otype
|
|
if (xtype := get_origin(target)) is not None:
|
|
target = xtype
|
|
|
|
v = d[f.name]
|
|
if is_opt and v is None:
|
|
dd[f.name] = v
|
|
elif isinstance(v, target):
|
|
dd[f.name] = v
|
|
elif target in {set, list}:
|
|
dd[f.name] = target(load(v))
|
|
elif target in {datetime}:
|
|
dd[f.name] = target.fromisoformat(v)
|
|
else:
|
|
dd[f.name] = target(v)
|
|
|
|
o = cls(**dd)
|
|
validate(o)
|
|
return o
|
|
|
|
|
|
def validate(o: object) -> None:
|
|
for f in fields(o):
|
|
vtype = type(getattr(o, f.name))
|
|
if vtype is not f.type:
|
|
if get_origin(f.type) is vtype or (
|
|
(isinstance(f.type, UnionType) or get_origin(f.type) is Union)
|
|
and vtype in get_args(f.type)
|
|
):
|
|
continue
|
|
raise ValueError(f"Invalid value type: {f.name}: {vtype}")
|
|
|
|
|
|
def utcnow():
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
@mapper_registry.mapped
|
|
@dataclass
|
|
class DbPatch:
|
|
__table__: ClassVar[Table] = Table(
|
|
"db_patches",
|
|
metadata,
|
|
Column("id", Integer, primary_key=True),
|
|
Column("current", String),
|
|
)
|
|
|
|
id: int
|
|
current: str
|
|
|
|
|
|
db_patches = DbPatch.__table__
|
|
|
|
|
|
@mapper_registry.mapped
|
|
@dataclass
|
|
class Progress:
|
|
__table__: ClassVar[Table] = Table(
|
|
"progress",
|
|
metadata,
|
|
Column("id", String, primary_key=True), # ULID
|
|
Column("type", String, nullable=False),
|
|
Column("state", String, nullable=False), # JSON {"percent": ..., "error": ...}
|
|
Column("started", String, nullable=False), # datetime
|
|
Column("stopped", String),
|
|
)
|
|
|
|
id: ULID = field(default_factory=ULID)
|
|
type: str = None
|
|
state: str = None
|
|
started: datetime = field(default_factory=utcnow)
|
|
stopped: str | None = None
|
|
|
|
@property
|
|
def _state(self) -> dict:
|
|
return json.loads(self.state or "{}")
|
|
|
|
@_state.setter
|
|
def _state(self, state: dict):
|
|
self.state = json_dump(state)
|
|
|
|
@property
|
|
def percent(self) -> float:
|
|
return self._state["percent"]
|
|
|
|
@percent.setter
|
|
def percent(self, percent: float):
|
|
state = self._state
|
|
state["percent"] = percent
|
|
self._state = state
|
|
|
|
@property
|
|
def error(self) -> str:
|
|
return self._state.get("error", "")
|
|
|
|
@error.setter
|
|
def error(self, error: str):
|
|
state = self._state
|
|
state["error"] = error
|
|
self._state = state
|
|
|
|
|
|
progress = Progress.__table__
|
|
|
|
|
|
@mapper_registry.mapped
|
|
@dataclass
|
|
class Movie:
|
|
__table__: ClassVar[Table] = Table(
|
|
"movies",
|
|
metadata,
|
|
Column("id", String, primary_key=True), # ULID
|
|
Column("title", String, nullable=False),
|
|
Column("original_title", String),
|
|
Column("release_year", Integer, nullable=False),
|
|
Column("media_type", String, nullable=False),
|
|
Column("imdb_id", String, nullable=False, unique=True),
|
|
Column("imdb_score", Integer),
|
|
Column("imdb_votes", Integer),
|
|
Column("runtime", Integer),
|
|
Column("genres", String, nullable=False),
|
|
Column("created", String, nullable=False), # datetime
|
|
Column("updated", String, nullable=False), # datetime
|
|
)
|
|
|
|
id: ULID = field(default_factory=ULID)
|
|
title: str = None # canonical title (usually English)
|
|
original_title: str | None = (
|
|
None # original title (usually transscribed to latin script)
|
|
)
|
|
release_year: int = None # canonical release date
|
|
media_type: str = None
|
|
imdb_id: str = None
|
|
imdb_score: int | None = None # range: [0,100]
|
|
imdb_votes: int | None = None
|
|
runtime: int | None = None # minutes
|
|
genres: set[str] = None
|
|
created: datetime = field(default_factory=utcnow)
|
|
updated: datetime = field(default_factory=utcnow)
|
|
|
|
_is_lazy: bool = field(default=False, init=False, repr=False, compare=False)
|
|
|
|
@classmethod
|
|
def lazy(cls, **kwds):
|
|
"""Return a new instance without running default factories.
|
|
|
|
This is meant purely for optimization purposes, to postpone possibly
|
|
expensive initialization operations.
|
|
"""
|
|
# XXX optimize using a metaclass & storing field refs on the class
|
|
kwds.setdefault("id", None)
|
|
kwds.setdefault("created", None)
|
|
kwds.setdefault("updated", None)
|
|
movie = cls(**kwds)
|
|
movie._is_lazy = True
|
|
return movie
|
|
|
|
def _lazy_init(self):
|
|
if not self._is_lazy:
|
|
return
|
|
|
|
for field in fields(Movie):
|
|
if getattr(self, field.name) is None and callable(field.default_factory):
|
|
setattr(self, field.name, field.default_factory())
|
|
|
|
self._is_lazy = False
|
|
|
|
|
|
movies = Movie.__table__
|
|
|
|
_RelationSentinel = object()
|
|
"""Mark a model field as containing external data.
|
|
|
|
For each field marked as a Relation there should be another field on the
|
|
dataclass containing the ID of the linked data.
|
|
The contents of the Relation are ignored or discarded when using
|
|
`asplain`, `fromplain`, and `validate`.
|
|
"""
|
|
Relation = Annotated[T | None, _RelationSentinel]
|
|
|
|
|
|
Access = Literal[
|
|
"r", # read
|
|
"i", # index
|
|
"w", # write
|
|
]
|
|
|
|
|
|
class UserGroup(TypedDict):
|
|
id: str
|
|
access: Access
|
|
|
|
|
|
@mapper_registry.mapped
|
|
@dataclass
|
|
class User:
|
|
__table__: ClassVar[Table] = Table(
|
|
"users",
|
|
metadata,
|
|
Column("id", String, primary_key=True), # ULID
|
|
Column("imdb_id", String, nullable=False, unique=True),
|
|
Column("name", String, nullable=False),
|
|
Column("secret", String, nullable=False),
|
|
Column("groups", String, nullable=False), # JSON array
|
|
)
|
|
|
|
id: ULID = field(default_factory=ULID)
|
|
imdb_id: str = None
|
|
name: str = None # canonical user name
|
|
secret: str = None
|
|
groups: list[UserGroup] = field(default_factory=list)
|
|
|
|
def has_access(self, group_id: ULID | str, access: Access = "r"):
|
|
group_id = group_id if isinstance(group_id, str) else str(group_id)
|
|
return any(g["id"] == group_id and access == g["access"] for g in self.groups)
|
|
|
|
def set_access(self, group_id: ULID | str, access: Access):
|
|
group_id = group_id if isinstance(group_id, str) else str(group_id)
|
|
for g in self.groups:
|
|
if g["id"] == group_id:
|
|
g["access"] = access
|
|
break
|
|
else:
|
|
self.groups.append({"id": group_id, "access": access})
|
|
|
|
|
|
@mapper_registry.mapped
|
|
@dataclass
|
|
class Rating:
|
|
__table__: ClassVar[Table] = Table(
|
|
"ratings",
|
|
metadata,
|
|
Column("id", String, primary_key=True), # ULID
|
|
Column("movie_id", ForeignKey("movies.id"), nullable=False), # ULID
|
|
Column("user_id", ForeignKey("users.id"), nullable=False), # ULID
|
|
Column("score", Integer, nullable=False),
|
|
Column("rating_date", String, nullable=False), # datetime
|
|
Column("favorite", Integer), # bool
|
|
Column("finished", Integer), # bool
|
|
)
|
|
|
|
id: ULID = field(default_factory=ULID)
|
|
|
|
movie_id: ULID = None
|
|
movie: Relation[Movie] = None
|
|
|
|
user_id: ULID = None
|
|
user: Relation[User] = None
|
|
|
|
score: int = None # range: [0,100]
|
|
rating_date: datetime = None
|
|
favorite: bool | None = None
|
|
finished: bool | None = None
|
|
|
|
def __eq__(self, other):
|
|
"""Return wether two Ratings are equal.
|
|
|
|
This operation compares all fields as expected, except that it
|
|
ignores any field marked as Relation.
|
|
"""
|
|
if type(other) is not type(self):
|
|
return False
|
|
return all(
|
|
getattr(self, f.name) == getattr(other, f.name) for f in fields(self)
|
|
)
|
|
|
|
|
|
ratings = Rating.__table__
|
|
|
|
|
|
class GroupUser(TypedDict):
|
|
id: str
|
|
name: str
|
|
|
|
|
|
@mapper_registry.mapped
|
|
@dataclass
|
|
class Group:
|
|
__table__: ClassVar[Table] = Table(
|
|
"groups",
|
|
metadata,
|
|
Column("id", String, primary_key=True), # ULID
|
|
Column("name", String, nullable=False),
|
|
Column("users", String, nullable=False), # JSON array
|
|
)
|
|
|
|
id: ULID = field(default_factory=ULID)
|
|
name: str = None
|
|
users: list[GroupUser] = field(default_factory=list)
|