fix & improve asplain func

It would previously encode a JSON encoded string coming from the DB
doubly, because there was no differentiation.  It would also not handle
optional values set to None correctly.
There's still other problems with the function, but those are now fixed.
This commit is contained in:
ducklet 2021-12-19 19:25:31 +01:00
parent e49ea603ee
commit a17b49bc0b
2 changed files with 76 additions and 24 deletions

View file

@ -7,6 +7,7 @@ from typing import (
Annotated,
Any,
ClassVar,
Container,
Literal,
Optional,
Type,
@ -18,6 +19,9 @@ from typing import (
from .types import ULID
JSON = Union[int, float, str, None, list["JSON"], dict[str, "JSON"]]
JSONObject = dict[str, JSON]
T = TypeVar("T")
@ -48,7 +52,8 @@ def fields(class_or_instance):
yield f
def is_optional(tp: Type):
def is_optional(tp: Type) -> bool:
"""Return wether the given type is optional."""
if get_origin(tp) is not Union:
return False
@ -56,15 +61,21 @@ def is_optional(tp: Type):
return len(args) == 2 and type(None) in args
def optional_type(tp: Type):
def optional_type(tp: Type) -> Optional[Type]:
"""Return the wrapped type from an optional type.
For example this will return `int` for `Optional[int]`.
Since they're equivalent this also works for other optioning notations, like
`Union[int, None]` and `int | None`.
"""
if get_origin(tp) is not Union:
return None
args = get_args(tp)
if len(args) != 2 or args[1] is not type(None):
if len(args) != 2 or type(None) not in args:
return None
return args[0]
return args[0] if args[1] is type(None) else args[1]
def optional_fields(o):
@ -76,13 +87,30 @@ def optional_fields(o):
json_dump = partial(json.dumps, separators=(",", ":"))
def asplain(o, *, fields_: set = None) -> dict[str, Any]:
def _id(x: T) -> T:
return x
def asplain(
o: object, *, filter_fields: Container[str] = None, serialize: bool = False
) -> dict[str, Any]:
"""Return the given model instance as `dict` with JSON compatible plain datatypes.
If `filter_fields` is given only matching field names will be included in
the resulting `dict`.
If `serialize` is `True`, collection types (lists, dicts, etc.) will be
serialized as strings; this can be useful to store them in a database. Be
sure to set `serialized=True` when using `fromplain` to successfully restore
the object.
"""
validate(o)
d = {}
dump = json_dump if serialize else _id
d: JSONObject = {}
for f in fields(o):
if fields_ is not None and f.name not in fields_:
if filter_fields is not None and f.name not in filter_fields:
continue
target = f.type
@ -93,15 +121,24 @@ def asplain(o, *, fields_: set = None) -> dict[str, Any]:
target = otype
v = getattr(o, f.name)
if target is ULID:
if is_optional(f.type) and v is None:
d[f.name] = None
elif target is ULID:
assert isinstance(v, ULID)
d[f.name] = str(v)
elif target in {datetime}:
assert isinstance(v, datetime)
d[f.name] = v.isoformat()
elif target in {set}:
d[f.name] = json_dump(list(sorted(v)))
assert isinstance(v, set)
d[f.name] = dump(list(sorted(v)))
elif target in {list}:
d[f.name] = json_dump(list(v))
elif target in {bool, str, int, float, None}:
assert isinstance(v, list)
d[f.name] = dump(list(v))
elif target in {bool, str, int, float}:
assert isinstance(
v, target
), f"Type mismatch: {f.name} ({target} != {type(v)})"
d[f.name] = v
else:
raise ValueError(f"Unsupported value type: {f.name}: {type(v)}")
@ -109,8 +146,16 @@ def asplain(o, *, fields_: set = None) -> dict[str, Any]:
return d
def fromplain(cls: Type[T], d: dict[str, Any]) -> T:
dd = {}
def fromplain(cls: Type[T], d: dict[str, Any], *, serialized: bool = False) -> T:
"""Return an instance of the given model using the given data.
If `serialized` is `True`, collection types (lists, dicts, etc.) will be
deserialized from string. This is the opposite operation of `serialize` for
`asplain`.
"""
load = json.loads if serialized else _id
dd: JSONObject = {}
for f in fields(cls):
target = f.type
@ -127,7 +172,7 @@ def fromplain(cls: Type[T], d: dict[str, Any]) -> T:
elif isinstance(v, target):
dd[f.name] = v
elif target in {set, list}:
dd[f.name] = target(json.loads(v))
dd[f.name] = target(load(v))
elif target in {datetime}:
dd[f.name] = target.fromisoformat(v)
else:
@ -138,7 +183,7 @@ def fromplain(cls: Type[T], d: dict[str, Any]) -> T:
return o
def validate(o):
def validate(o: object) -> None:
for f in fields(o):
vtype = type(getattr(o, f.name))
if vtype is not f.type: