fix & improve asplain func

It would previously encode a JSON encoded string coming from the DB
doubly, because there was no differentiation.  It would also not handle
optional values set to None correctly.
There's still other problems with the function, but those are now fixed.
This commit is contained in:
ducklet 2021-12-19 19:25:31 +01:00
parent e49ea603ee
commit a17b49bc0b
2 changed files with 76 additions and 24 deletions

View file

@ -231,7 +231,7 @@ async def add(item):
if getattr(item, "_is_lazy", False):
item._lazy_init()
values = asplain(item)
values = asplain(item, serialize=True)
keys = ", ".join(f"{k}" for k in values)
placeholders = ", ".join(f":{k}" for k in values)
query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})"
@ -245,6 +245,12 @@ ModelType = TypeVar("ModelType")
async def get(
model: Type[ModelType], *, order_by: str = None, **kwds
) -> Optional[ModelType]:
"""Load a model instance from the database.
Passing `kwds` allows to filter the instance to load. You have to encode the
values as the appropriate data type for the database prior to passing them
to this function.
"""
values = {k: v for k, v in kwds.items() if v is not None}
if not values:
return
@ -256,7 +262,7 @@ async def get(
query += f" ORDER BY {order_by}"
async with locked_connection() as conn:
row = await conn.fetch_one(query=query, values=values)
return fromplain(model, row) if row else None
return fromplain(model, row, serialized=True) if row else None
async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
@ -276,7 +282,7 @@ async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
async with locked_connection() as conn:
rows = await conn.fetch_all(query=query, values=values)
return (fromplain(model, row) for row in rows)
return (fromplain(model, row, serialized=True) for row in rows)
async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
@ -287,7 +293,7 @@ async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
async with locked_connection() as conn:
rows = await conn.fetch_all(query=query, values=values)
return (fromplain(model, row) for row in rows)
return (fromplain(model, row, serialized=True) for row in rows)
async def update(item):
@ -295,7 +301,7 @@ async def update(item):
if getattr(item, "_is_lazy", False):
item._lazy_init()
values = asplain(item)
values = asplain(item, serialize=True)
keys = ", ".join(f"{k}=:{k}" for k in values if k != "id")
query = f"UPDATE {item._table} SET {keys} WHERE id=:id"
async with locked_connection() as conn:
@ -303,7 +309,7 @@ async def update(item):
async def remove(item):
values = asplain(item, fields_={"id"})
values = asplain(item, filter_fields={"id"}, serialize=True)
query = f"DELETE FROM {item._table} WHERE id=:id"
async with locked_connection() as conn:
await conn.execute(query=query, values=values)
@ -523,7 +529,8 @@ def mux(*tps: Type):
def demux(tp: Type[ModelType], row) -> ModelType:
return fromplain(tp, {n: row[k] for k, (_, n) in sql_fieldmap(tp).items()})
d = {n: row[k] for k, (_, n) in sql_fieldmap(tp).items()}
return fromplain(tp, d, serialized=True)
def sql_in(column: str, values: list[T]) -> tuple[str, dict[str, T]]:
@ -557,7 +564,7 @@ async def ratings_for_movies(
async with locked_connection() as conn:
rows = await conn.fetch_all(query, values)
return (fromplain(Rating, row) for row in rows)
return (fromplain(Rating, row, serialized=True) for row in rows)
async def find_movies(
@ -624,7 +631,7 @@ async def find_movies(
async with locked_connection() as conn:
rows = await conn.fetch_all(bindparams(query, values))
movies = [fromplain(Movie, row) for row in rows]
movies = [fromplain(Movie, row, serialized=True) for row in rows]
if not user_ids:
return ((m, []) for m in movies)

View file

@ -7,6 +7,7 @@ from typing import (
Annotated,
Any,
ClassVar,
Container,
Literal,
Optional,
Type,
@ -18,6 +19,9 @@ from typing import (
from .types import ULID
JSON = Union[int, float, str, None, list["JSON"], dict[str, "JSON"]]
JSONObject = dict[str, JSON]
T = TypeVar("T")
@ -48,7 +52,8 @@ def fields(class_or_instance):
yield f
def is_optional(tp: Type):
def is_optional(tp: Type) -> bool:
"""Return wether the given type is optional."""
if get_origin(tp) is not Union:
return False
@ -56,15 +61,21 @@ def is_optional(tp: Type):
return len(args) == 2 and type(None) in args
def optional_type(tp: Type):
def optional_type(tp: Type) -> Optional[Type]:
"""Return the wrapped type from an optional type.
For example this will return `int` for `Optional[int]`.
Since they're equivalent this also works for other optioning notations, like
`Union[int, None]` and `int | None`.
"""
if get_origin(tp) is not Union:
return None
args = get_args(tp)
if len(args) != 2 or args[1] is not type(None):
if len(args) != 2 or type(None) not in args:
return None
return args[0]
return args[0] if args[1] is type(None) else args[1]
def optional_fields(o):
@ -76,13 +87,30 @@ def optional_fields(o):
json_dump = partial(json.dumps, separators=(",", ":"))
def asplain(o, *, fields_: set = None) -> dict[str, Any]:
def _id(x: T) -> T:
return x
def asplain(
o: object, *, filter_fields: Container[str] = None, serialize: bool = False
) -> dict[str, Any]:
"""Return the given model instance as `dict` with JSON compatible plain datatypes.
If `filter_fields` is given only matching field names will be included in
the resulting `dict`.
If `serialize` is `True`, collection types (lists, dicts, etc.) will be
serialized as strings; this can be useful to store them in a database. Be
sure to set `serialized=True` when using `fromplain` to successfully restore
the object.
"""
validate(o)
d = {}
dump = json_dump if serialize else _id
d: JSONObject = {}
for f in fields(o):
if fields_ is not None and f.name not in fields_:
if filter_fields is not None and f.name not in filter_fields:
continue
target = f.type
@ -93,15 +121,24 @@ def asplain(o, *, fields_: set = None) -> dict[str, Any]:
target = otype
v = getattr(o, f.name)
if target is ULID:
if is_optional(f.type) and v is None:
d[f.name] = None
elif target is ULID:
assert isinstance(v, ULID)
d[f.name] = str(v)
elif target in {datetime}:
assert isinstance(v, datetime)
d[f.name] = v.isoformat()
elif target in {set}:
d[f.name] = json_dump(list(sorted(v)))
assert isinstance(v, set)
d[f.name] = dump(list(sorted(v)))
elif target in {list}:
d[f.name] = json_dump(list(v))
elif target in {bool, str, int, float, None}:
assert isinstance(v, list)
d[f.name] = dump(list(v))
elif target in {bool, str, int, float}:
assert isinstance(
v, target
), f"Type mismatch: {f.name} ({target} != {type(v)})"
d[f.name] = v
else:
raise ValueError(f"Unsupported value type: {f.name}: {type(v)}")
@ -109,8 +146,16 @@ def asplain(o, *, fields_: set = None) -> dict[str, Any]:
return d
def fromplain(cls: Type[T], d: dict[str, Any]) -> T:
dd = {}
def fromplain(cls: Type[T], d: dict[str, Any], *, serialized: bool = False) -> T:
"""Return an instance of the given model using the given data.
If `serialized` is `True`, collection types (lists, dicts, etc.) will be
deserialized from string. This is the opposite operation of `serialize` for
`asplain`.
"""
load = json.loads if serialized else _id
dd: JSONObject = {}
for f in fields(cls):
target = f.type
@ -127,7 +172,7 @@ def fromplain(cls: Type[T], d: dict[str, Any]) -> T:
elif isinstance(v, target):
dd[f.name] = v
elif target in {set, list}:
dd[f.name] = target(json.loads(v))
dd[f.name] = target(load(v))
elif target in {datetime}:
dd[f.name] = target.fromisoformat(v)
else:
@ -138,7 +183,7 @@ def fromplain(cls: Type[T], d: dict[str, Any]) -> T:
return o
def validate(o):
def validate(o: object) -> None:
for f in fields(o):
vtype = type(getattr(o, f.name))
if vtype is not f.type: