diff --git a/unwind/db.py b/unwind/db.py index a34b2ba..4b71aff 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -231,7 +231,7 @@ async def add(item): if getattr(item, "_is_lazy", False): item._lazy_init() - values = asplain(item) + values = asplain(item, serialize=True) keys = ", ".join(f"{k}" for k in values) placeholders = ", ".join(f":{k}" for k in values) query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})" @@ -245,6 +245,12 @@ ModelType = TypeVar("ModelType") async def get( model: Type[ModelType], *, order_by: str = None, **kwds ) -> Optional[ModelType]: + """Load a model instance from the database. + + Passing `kwds` allows to filter the instance to load. You have to encode the + values as the appropriate data type for the database prior to passing them + to this function. + """ values = {k: v for k, v in kwds.items() if v is not None} if not values: return @@ -256,7 +262,7 @@ async def get( query += f" ORDER BY {order_by}" async with locked_connection() as conn: row = await conn.fetch_one(query=query, values=values) - return fromplain(model, row) if row else None + return fromplain(model, row, serialized=True) if row else None async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]: @@ -276,7 +282,7 @@ async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]: query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" async with locked_connection() as conn: rows = await conn.fetch_all(query=query, values=values) - return (fromplain(model, row) for row in rows) + return (fromplain(model, row, serialized=True) for row in rows) async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]: @@ -287,7 +293,7 @@ async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]: query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" async with locked_connection() as conn: rows = await conn.fetch_all(query=query, values=values) - return (fromplain(model, row) for row in rows) + return (fromplain(model, row, serialized=True) for row in rows) async def update(item): @@ -295,7 +301,7 @@ async def update(item): if getattr(item, "_is_lazy", False): item._lazy_init() - values = asplain(item) + values = asplain(item, serialize=True) keys = ", ".join(f"{k}=:{k}" for k in values if k != "id") query = f"UPDATE {item._table} SET {keys} WHERE id=:id" async with locked_connection() as conn: @@ -303,7 +309,7 @@ async def update(item): async def remove(item): - values = asplain(item, fields_={"id"}) + values = asplain(item, filter_fields={"id"}, serialize=True) query = f"DELETE FROM {item._table} WHERE id=:id" async with locked_connection() as conn: await conn.execute(query=query, values=values) @@ -523,7 +529,8 @@ def mux(*tps: Type): def demux(tp: Type[ModelType], row) -> ModelType: - return fromplain(tp, {n: row[k] for k, (_, n) in sql_fieldmap(tp).items()}) + d = {n: row[k] for k, (_, n) in sql_fieldmap(tp).items()} + return fromplain(tp, d, serialized=True) def sql_in(column: str, values: list[T]) -> tuple[str, dict[str, T]]: @@ -557,7 +564,7 @@ async def ratings_for_movies( async with locked_connection() as conn: rows = await conn.fetch_all(query, values) - return (fromplain(Rating, row) for row in rows) + return (fromplain(Rating, row, serialized=True) for row in rows) async def find_movies( @@ -624,7 +631,7 @@ async def find_movies( async with locked_connection() as conn: rows = await conn.fetch_all(bindparams(query, values)) - movies = [fromplain(Movie, row) for row in rows] + movies = [fromplain(Movie, row, serialized=True) for row in rows] if not user_ids: return ((m, []) for m in movies) diff --git a/unwind/models.py b/unwind/models.py index d85547d..37cd48d 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -7,6 +7,7 @@ from typing import ( Annotated, Any, ClassVar, + Container, Literal, Optional, Type, @@ -18,6 +19,9 @@ from typing import ( from .types import ULID +JSON = Union[int, float, str, None, list["JSON"], dict[str, "JSON"]] +JSONObject = dict[str, JSON] + T = TypeVar("T") @@ -48,7 +52,8 @@ def fields(class_or_instance): yield f -def is_optional(tp: Type): +def is_optional(tp: Type) -> bool: + """Return wether the given type is optional.""" if get_origin(tp) is not Union: return False @@ -56,15 +61,21 @@ def is_optional(tp: Type): return len(args) == 2 and type(None) in args -def optional_type(tp: Type): +def optional_type(tp: Type) -> Optional[Type]: + """Return the wrapped type from an optional type. + + For example this will return `int` for `Optional[int]`. + Since they're equivalent this also works for other optioning notations, like + `Union[int, None]` and `int | None`. + """ if get_origin(tp) is not Union: return None args = get_args(tp) - if len(args) != 2 or args[1] is not type(None): + if len(args) != 2 or type(None) not in args: return None - return args[0] + return args[0] if args[1] is type(None) else args[1] def optional_fields(o): @@ -76,13 +87,30 @@ def optional_fields(o): json_dump = partial(json.dumps, separators=(",", ":")) -def asplain(o, *, fields_: set = None) -> dict[str, Any]: +def _id(x: T) -> T: + return x + + +def asplain( + o: object, *, filter_fields: Container[str] = None, serialize: bool = False +) -> dict[str, Any]: + """Return the given model instance as `dict` with JSON compatible plain datatypes. + + If `filter_fields` is given only matching field names will be included in + the resulting `dict`. + If `serialize` is `True`, collection types (lists, dicts, etc.) will be + serialized as strings; this can be useful to store them in a database. Be + sure to set `serialized=True` when using `fromplain` to successfully restore + the object. + """ validate(o) - d = {} + dump = json_dump if serialize else _id + + d: JSONObject = {} for f in fields(o): - if fields_ is not None and f.name not in fields_: + if filter_fields is not None and f.name not in filter_fields: continue target = f.type @@ -93,15 +121,24 @@ def asplain(o, *, fields_: set = None) -> dict[str, Any]: target = otype v = getattr(o, f.name) - if target is ULID: + if is_optional(f.type) and v is None: + d[f.name] = None + elif target is ULID: + assert isinstance(v, ULID) d[f.name] = str(v) elif target in {datetime}: + assert isinstance(v, datetime) d[f.name] = v.isoformat() elif target in {set}: - d[f.name] = json_dump(list(sorted(v))) + assert isinstance(v, set) + d[f.name] = dump(list(sorted(v))) elif target in {list}: - d[f.name] = json_dump(list(v)) - elif target in {bool, str, int, float, None}: + assert isinstance(v, list) + d[f.name] = dump(list(v)) + elif target in {bool, str, int, float}: + assert isinstance( + v, target + ), f"Type mismatch: {f.name} ({target} != {type(v)})" d[f.name] = v else: raise ValueError(f"Unsupported value type: {f.name}: {type(v)}") @@ -109,8 +146,16 @@ def asplain(o, *, fields_: set = None) -> dict[str, Any]: return d -def fromplain(cls: Type[T], d: dict[str, Any]) -> T: - dd = {} +def fromplain(cls: Type[T], d: dict[str, Any], *, serialized: bool = False) -> T: + """Return an instance of the given model using the given data. + + If `serialized` is `True`, collection types (lists, dicts, etc.) will be + deserialized from string. This is the opposite operation of `serialize` for + `asplain`. + """ + load = json.loads if serialized else _id + + dd: JSONObject = {} for f in fields(cls): target = f.type @@ -127,7 +172,7 @@ def fromplain(cls: Type[T], d: dict[str, Any]) -> T: elif isinstance(v, target): dd[f.name] = v elif target in {set, list}: - dd[f.name] = target(json.loads(v)) + dd[f.name] = target(load(v)) elif target in {datetime}: dd[f.name] = target.fromisoformat(v) else: @@ -138,7 +183,7 @@ def fromplain(cls: Type[T], d: dict[str, Any]) -> T: return o -def validate(o): +def validate(o: object) -> None: for f in fields(o): vtype = type(getattr(o, f.name)) if vtype is not f.type: