From a17b49bc0b90e93e1551e84c95f20c45af3b4cb3 Mon Sep 17 00:00:00 2001 From: ducklet Date: Sun, 19 Dec 2021 19:25:31 +0100 Subject: [PATCH] fix & improve `asplain` func It would previously encode a JSON encoded string coming from the DB doubly, because there was no differentiation. It would also not handle optional values set to None correctly. There's still other problems with the function, but those are now fixed. --- unwind/db.py | 25 ++++++++++------ unwind/models.py | 75 ++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 76 insertions(+), 24 deletions(-) diff --git a/unwind/db.py b/unwind/db.py index a34b2ba..4b71aff 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -231,7 +231,7 @@ async def add(item): if getattr(item, "_is_lazy", False): item._lazy_init() - values = asplain(item) + values = asplain(item, serialize=True) keys = ", ".join(f"{k}" for k in values) placeholders = ", ".join(f":{k}" for k in values) query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})" @@ -245,6 +245,12 @@ ModelType = TypeVar("ModelType") async def get( model: Type[ModelType], *, order_by: str = None, **kwds ) -> Optional[ModelType]: + """Load a model instance from the database. + + Passing `kwds` allows to filter the instance to load. You have to encode the + values as the appropriate data type for the database prior to passing them + to this function. + """ values = {k: v for k, v in kwds.items() if v is not None} if not values: return @@ -256,7 +262,7 @@ async def get( query += f" ORDER BY {order_by}" async with locked_connection() as conn: row = await conn.fetch_one(query=query, values=values) - return fromplain(model, row) if row else None + return fromplain(model, row, serialized=True) if row else None async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]: @@ -276,7 +282,7 @@ async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]: query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" async with locked_connection() as conn: rows = await conn.fetch_all(query=query, values=values) - return (fromplain(model, row) for row in rows) + return (fromplain(model, row, serialized=True) for row in rows) async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]: @@ -287,7 +293,7 @@ async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]: query = f"SELECT {fields_} FROM {model._table} WHERE {cond}" async with locked_connection() as conn: rows = await conn.fetch_all(query=query, values=values) - return (fromplain(model, row) for row in rows) + return (fromplain(model, row, serialized=True) for row in rows) async def update(item): @@ -295,7 +301,7 @@ async def update(item): if getattr(item, "_is_lazy", False): item._lazy_init() - values = asplain(item) + values = asplain(item, serialize=True) keys = ", ".join(f"{k}=:{k}" for k in values if k != "id") query = f"UPDATE {item._table} SET {keys} WHERE id=:id" async with locked_connection() as conn: @@ -303,7 +309,7 @@ async def update(item): async def remove(item): - values = asplain(item, fields_={"id"}) + values = asplain(item, filter_fields={"id"}, serialize=True) query = f"DELETE FROM {item._table} WHERE id=:id" async with locked_connection() as conn: await conn.execute(query=query, values=values) @@ -523,7 +529,8 @@ def mux(*tps: Type): def demux(tp: Type[ModelType], row) -> ModelType: - return fromplain(tp, {n: row[k] for k, (_, n) in sql_fieldmap(tp).items()}) + d = {n: row[k] for k, (_, n) in sql_fieldmap(tp).items()} + return fromplain(tp, d, serialized=True) def sql_in(column: str, values: list[T]) -> tuple[str, dict[str, T]]: @@ -557,7 +564,7 @@ async def ratings_for_movies( async with locked_connection() as conn: rows = await conn.fetch_all(query, values) - return (fromplain(Rating, row) for row in rows) + return (fromplain(Rating, row, serialized=True) for row in rows) async def find_movies( @@ -624,7 +631,7 @@ async def find_movies( async with locked_connection() as conn: rows = await conn.fetch_all(bindparams(query, values)) - movies = [fromplain(Movie, row) for row in rows] + movies = [fromplain(Movie, row, serialized=True) for row in rows] if not user_ids: return ((m, []) for m in movies) diff --git a/unwind/models.py b/unwind/models.py index d85547d..37cd48d 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -7,6 +7,7 @@ from typing import ( Annotated, Any, ClassVar, + Container, Literal, Optional, Type, @@ -18,6 +19,9 @@ from typing import ( from .types import ULID +JSON = Union[int, float, str, None, list["JSON"], dict[str, "JSON"]] +JSONObject = dict[str, JSON] + T = TypeVar("T") @@ -48,7 +52,8 @@ def fields(class_or_instance): yield f -def is_optional(tp: Type): +def is_optional(tp: Type) -> bool: + """Return wether the given type is optional.""" if get_origin(tp) is not Union: return False @@ -56,15 +61,21 @@ def is_optional(tp: Type): return len(args) == 2 and type(None) in args -def optional_type(tp: Type): +def optional_type(tp: Type) -> Optional[Type]: + """Return the wrapped type from an optional type. + + For example this will return `int` for `Optional[int]`. + Since they're equivalent this also works for other optioning notations, like + `Union[int, None]` and `int | None`. + """ if get_origin(tp) is not Union: return None args = get_args(tp) - if len(args) != 2 or args[1] is not type(None): + if len(args) != 2 or type(None) not in args: return None - return args[0] + return args[0] if args[1] is type(None) else args[1] def optional_fields(o): @@ -76,13 +87,30 @@ def optional_fields(o): json_dump = partial(json.dumps, separators=(",", ":")) -def asplain(o, *, fields_: set = None) -> dict[str, Any]: +def _id(x: T) -> T: + return x + + +def asplain( + o: object, *, filter_fields: Container[str] = None, serialize: bool = False +) -> dict[str, Any]: + """Return the given model instance as `dict` with JSON compatible plain datatypes. + + If `filter_fields` is given only matching field names will be included in + the resulting `dict`. + If `serialize` is `True`, collection types (lists, dicts, etc.) will be + serialized as strings; this can be useful to store them in a database. Be + sure to set `serialized=True` when using `fromplain` to successfully restore + the object. + """ validate(o) - d = {} + dump = json_dump if serialize else _id + + d: JSONObject = {} for f in fields(o): - if fields_ is not None and f.name not in fields_: + if filter_fields is not None and f.name not in filter_fields: continue target = f.type @@ -93,15 +121,24 @@ def asplain(o, *, fields_: set = None) -> dict[str, Any]: target = otype v = getattr(o, f.name) - if target is ULID: + if is_optional(f.type) and v is None: + d[f.name] = None + elif target is ULID: + assert isinstance(v, ULID) d[f.name] = str(v) elif target in {datetime}: + assert isinstance(v, datetime) d[f.name] = v.isoformat() elif target in {set}: - d[f.name] = json_dump(list(sorted(v))) + assert isinstance(v, set) + d[f.name] = dump(list(sorted(v))) elif target in {list}: - d[f.name] = json_dump(list(v)) - elif target in {bool, str, int, float, None}: + assert isinstance(v, list) + d[f.name] = dump(list(v)) + elif target in {bool, str, int, float}: + assert isinstance( + v, target + ), f"Type mismatch: {f.name} ({target} != {type(v)})" d[f.name] = v else: raise ValueError(f"Unsupported value type: {f.name}: {type(v)}") @@ -109,8 +146,16 @@ def asplain(o, *, fields_: set = None) -> dict[str, Any]: return d -def fromplain(cls: Type[T], d: dict[str, Any]) -> T: - dd = {} +def fromplain(cls: Type[T], d: dict[str, Any], *, serialized: bool = False) -> T: + """Return an instance of the given model using the given data. + + If `serialized` is `True`, collection types (lists, dicts, etc.) will be + deserialized from string. This is the opposite operation of `serialize` for + `asplain`. + """ + load = json.loads if serialized else _id + + dd: JSONObject = {} for f in fields(cls): target = f.type @@ -127,7 +172,7 @@ def fromplain(cls: Type[T], d: dict[str, Any]) -> T: elif isinstance(v, target): dd[f.name] = v elif target in {set, list}: - dd[f.name] = target(json.loads(v)) + dd[f.name] = target(load(v)) elif target in {datetime}: dd[f.name] = target.fromisoformat(v) else: @@ -138,7 +183,7 @@ def fromplain(cls: Type[T], d: dict[str, Any]) -> T: return o -def validate(o): +def validate(o: object) -> None: for f in fields(o): vtype = type(getattr(o, f.name)) if vtype is not f.type: