fix & improve asplain func
It would previously encode a JSON encoded string coming from the DB doubly, because there was no differentiation. It would also not handle optional values set to None correctly. There's still other problems with the function, but those are now fixed.
This commit is contained in:
parent
e49ea603ee
commit
a17b49bc0b
2 changed files with 76 additions and 24 deletions
25
unwind/db.py
25
unwind/db.py
|
|
@ -231,7 +231,7 @@ async def add(item):
|
||||||
if getattr(item, "_is_lazy", False):
|
if getattr(item, "_is_lazy", False):
|
||||||
item._lazy_init()
|
item._lazy_init()
|
||||||
|
|
||||||
values = asplain(item)
|
values = asplain(item, serialize=True)
|
||||||
keys = ", ".join(f"{k}" for k in values)
|
keys = ", ".join(f"{k}" for k in values)
|
||||||
placeholders = ", ".join(f":{k}" for k in values)
|
placeholders = ", ".join(f":{k}" for k in values)
|
||||||
query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})"
|
query = f"INSERT INTO {item._table} ({keys}) VALUES ({placeholders})"
|
||||||
|
|
@ -245,6 +245,12 @@ ModelType = TypeVar("ModelType")
|
||||||
async def get(
|
async def get(
|
||||||
model: Type[ModelType], *, order_by: str = None, **kwds
|
model: Type[ModelType], *, order_by: str = None, **kwds
|
||||||
) -> Optional[ModelType]:
|
) -> Optional[ModelType]:
|
||||||
|
"""Load a model instance from the database.
|
||||||
|
|
||||||
|
Passing `kwds` allows to filter the instance to load. You have to encode the
|
||||||
|
values as the appropriate data type for the database prior to passing them
|
||||||
|
to this function.
|
||||||
|
"""
|
||||||
values = {k: v for k, v in kwds.items() if v is not None}
|
values = {k: v for k, v in kwds.items() if v is not None}
|
||||||
if not values:
|
if not values:
|
||||||
return
|
return
|
||||||
|
|
@ -256,7 +262,7 @@ async def get(
|
||||||
query += f" ORDER BY {order_by}"
|
query += f" ORDER BY {order_by}"
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
row = await conn.fetch_one(query=query, values=values)
|
row = await conn.fetch_one(query=query, values=values)
|
||||||
return fromplain(model, row) if row else None
|
return fromplain(model, row, serialized=True) if row else None
|
||||||
|
|
||||||
|
|
||||||
async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
||||||
|
|
@ -276,7 +282,7 @@ async def get_many(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
||||||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(query=query, values=values)
|
rows = await conn.fetch_all(query=query, values=values)
|
||||||
return (fromplain(model, row) for row in rows)
|
return (fromplain(model, row, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
||||||
async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
||||||
|
|
@ -287,7 +293,7 @@ async def get_all(model: Type[ModelType], **kwds) -> Iterable[ModelType]:
|
||||||
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
query = f"SELECT {fields_} FROM {model._table} WHERE {cond}"
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(query=query, values=values)
|
rows = await conn.fetch_all(query=query, values=values)
|
||||||
return (fromplain(model, row) for row in rows)
|
return (fromplain(model, row, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
||||||
async def update(item):
|
async def update(item):
|
||||||
|
|
@ -295,7 +301,7 @@ async def update(item):
|
||||||
if getattr(item, "_is_lazy", False):
|
if getattr(item, "_is_lazy", False):
|
||||||
item._lazy_init()
|
item._lazy_init()
|
||||||
|
|
||||||
values = asplain(item)
|
values = asplain(item, serialize=True)
|
||||||
keys = ", ".join(f"{k}=:{k}" for k in values if k != "id")
|
keys = ", ".join(f"{k}=:{k}" for k in values if k != "id")
|
||||||
query = f"UPDATE {item._table} SET {keys} WHERE id=:id"
|
query = f"UPDATE {item._table} SET {keys} WHERE id=:id"
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
|
|
@ -303,7 +309,7 @@ async def update(item):
|
||||||
|
|
||||||
|
|
||||||
async def remove(item):
|
async def remove(item):
|
||||||
values = asplain(item, fields_={"id"})
|
values = asplain(item, filter_fields={"id"}, serialize=True)
|
||||||
query = f"DELETE FROM {item._table} WHERE id=:id"
|
query = f"DELETE FROM {item._table} WHERE id=:id"
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
await conn.execute(query=query, values=values)
|
await conn.execute(query=query, values=values)
|
||||||
|
|
@ -523,7 +529,8 @@ def mux(*tps: Type):
|
||||||
|
|
||||||
|
|
||||||
def demux(tp: Type[ModelType], row) -> ModelType:
|
def demux(tp: Type[ModelType], row) -> ModelType:
|
||||||
return fromplain(tp, {n: row[k] for k, (_, n) in sql_fieldmap(tp).items()})
|
d = {n: row[k] for k, (_, n) in sql_fieldmap(tp).items()}
|
||||||
|
return fromplain(tp, d, serialized=True)
|
||||||
|
|
||||||
|
|
||||||
def sql_in(column: str, values: list[T]) -> tuple[str, dict[str, T]]:
|
def sql_in(column: str, values: list[T]) -> tuple[str, dict[str, T]]:
|
||||||
|
|
@ -557,7 +564,7 @@ async def ratings_for_movies(
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(query, values)
|
rows = await conn.fetch_all(query, values)
|
||||||
|
|
||||||
return (fromplain(Rating, row) for row in rows)
|
return (fromplain(Rating, row, serialized=True) for row in rows)
|
||||||
|
|
||||||
|
|
||||||
async def find_movies(
|
async def find_movies(
|
||||||
|
|
@ -624,7 +631,7 @@ async def find_movies(
|
||||||
async with locked_connection() as conn:
|
async with locked_connection() as conn:
|
||||||
rows = await conn.fetch_all(bindparams(query, values))
|
rows = await conn.fetch_all(bindparams(query, values))
|
||||||
|
|
||||||
movies = [fromplain(Movie, row) for row in rows]
|
movies = [fromplain(Movie, row, serialized=True) for row in rows]
|
||||||
|
|
||||||
if not user_ids:
|
if not user_ids:
|
||||||
return ((m, []) for m in movies)
|
return ((m, []) for m in movies)
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
|
Container,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Type,
|
Type,
|
||||||
|
|
@ -18,6 +19,9 @@ from typing import (
|
||||||
|
|
||||||
from .types import ULID
|
from .types import ULID
|
||||||
|
|
||||||
|
JSON = Union[int, float, str, None, list["JSON"], dict[str, "JSON"]]
|
||||||
|
JSONObject = dict[str, JSON]
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -48,7 +52,8 @@ def fields(class_or_instance):
|
||||||
yield f
|
yield f
|
||||||
|
|
||||||
|
|
||||||
def is_optional(tp: Type):
|
def is_optional(tp: Type) -> bool:
|
||||||
|
"""Return wether the given type is optional."""
|
||||||
if get_origin(tp) is not Union:
|
if get_origin(tp) is not Union:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -56,15 +61,21 @@ def is_optional(tp: Type):
|
||||||
return len(args) == 2 and type(None) in args
|
return len(args) == 2 and type(None) in args
|
||||||
|
|
||||||
|
|
||||||
def optional_type(tp: Type):
|
def optional_type(tp: Type) -> Optional[Type]:
|
||||||
|
"""Return the wrapped type from an optional type.
|
||||||
|
|
||||||
|
For example this will return `int` for `Optional[int]`.
|
||||||
|
Since they're equivalent this also works for other optioning notations, like
|
||||||
|
`Union[int, None]` and `int | None`.
|
||||||
|
"""
|
||||||
if get_origin(tp) is not Union:
|
if get_origin(tp) is not Union:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
args = get_args(tp)
|
args = get_args(tp)
|
||||||
if len(args) != 2 or args[1] is not type(None):
|
if len(args) != 2 or type(None) not in args:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return args[0]
|
return args[0] if args[1] is type(None) else args[1]
|
||||||
|
|
||||||
|
|
||||||
def optional_fields(o):
|
def optional_fields(o):
|
||||||
|
|
@ -76,13 +87,30 @@ def optional_fields(o):
|
||||||
json_dump = partial(json.dumps, separators=(",", ":"))
|
json_dump = partial(json.dumps, separators=(",", ":"))
|
||||||
|
|
||||||
|
|
||||||
def asplain(o, *, fields_: set = None) -> dict[str, Any]:
|
def _id(x: T) -> T:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def asplain(
|
||||||
|
o: object, *, filter_fields: Container[str] = None, serialize: bool = False
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return the given model instance as `dict` with JSON compatible plain datatypes.
|
||||||
|
|
||||||
|
If `filter_fields` is given only matching field names will be included in
|
||||||
|
the resulting `dict`.
|
||||||
|
If `serialize` is `True`, collection types (lists, dicts, etc.) will be
|
||||||
|
serialized as strings; this can be useful to store them in a database. Be
|
||||||
|
sure to set `serialized=True` when using `fromplain` to successfully restore
|
||||||
|
the object.
|
||||||
|
"""
|
||||||
validate(o)
|
validate(o)
|
||||||
|
|
||||||
d = {}
|
dump = json_dump if serialize else _id
|
||||||
|
|
||||||
|
d: JSONObject = {}
|
||||||
for f in fields(o):
|
for f in fields(o):
|
||||||
|
|
||||||
if fields_ is not None and f.name not in fields_:
|
if filter_fields is not None and f.name not in filter_fields:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
target = f.type
|
target = f.type
|
||||||
|
|
@ -93,15 +121,24 @@ def asplain(o, *, fields_: set = None) -> dict[str, Any]:
|
||||||
target = otype
|
target = otype
|
||||||
|
|
||||||
v = getattr(o, f.name)
|
v = getattr(o, f.name)
|
||||||
if target is ULID:
|
if is_optional(f.type) and v is None:
|
||||||
|
d[f.name] = None
|
||||||
|
elif target is ULID:
|
||||||
|
assert isinstance(v, ULID)
|
||||||
d[f.name] = str(v)
|
d[f.name] = str(v)
|
||||||
elif target in {datetime}:
|
elif target in {datetime}:
|
||||||
|
assert isinstance(v, datetime)
|
||||||
d[f.name] = v.isoformat()
|
d[f.name] = v.isoformat()
|
||||||
elif target in {set}:
|
elif target in {set}:
|
||||||
d[f.name] = json_dump(list(sorted(v)))
|
assert isinstance(v, set)
|
||||||
|
d[f.name] = dump(list(sorted(v)))
|
||||||
elif target in {list}:
|
elif target in {list}:
|
||||||
d[f.name] = json_dump(list(v))
|
assert isinstance(v, list)
|
||||||
elif target in {bool, str, int, float, None}:
|
d[f.name] = dump(list(v))
|
||||||
|
elif target in {bool, str, int, float}:
|
||||||
|
assert isinstance(
|
||||||
|
v, target
|
||||||
|
), f"Type mismatch: {f.name} ({target} != {type(v)})"
|
||||||
d[f.name] = v
|
d[f.name] = v
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported value type: {f.name}: {type(v)}")
|
raise ValueError(f"Unsupported value type: {f.name}: {type(v)}")
|
||||||
|
|
@ -109,8 +146,16 @@ def asplain(o, *, fields_: set = None) -> dict[str, Any]:
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def fromplain(cls: Type[T], d: dict[str, Any]) -> T:
|
def fromplain(cls: Type[T], d: dict[str, Any], *, serialized: bool = False) -> T:
|
||||||
dd = {}
|
"""Return an instance of the given model using the given data.
|
||||||
|
|
||||||
|
If `serialized` is `True`, collection types (lists, dicts, etc.) will be
|
||||||
|
deserialized from string. This is the opposite operation of `serialize` for
|
||||||
|
`asplain`.
|
||||||
|
"""
|
||||||
|
load = json.loads if serialized else _id
|
||||||
|
|
||||||
|
dd: JSONObject = {}
|
||||||
for f in fields(cls):
|
for f in fields(cls):
|
||||||
|
|
||||||
target = f.type
|
target = f.type
|
||||||
|
|
@ -127,7 +172,7 @@ def fromplain(cls: Type[T], d: dict[str, Any]) -> T:
|
||||||
elif isinstance(v, target):
|
elif isinstance(v, target):
|
||||||
dd[f.name] = v
|
dd[f.name] = v
|
||||||
elif target in {set, list}:
|
elif target in {set, list}:
|
||||||
dd[f.name] = target(json.loads(v))
|
dd[f.name] = target(load(v))
|
||||||
elif target in {datetime}:
|
elif target in {datetime}:
|
||||||
dd[f.name] = target.fromisoformat(v)
|
dd[f.name] = target.fromisoformat(v)
|
||||||
else:
|
else:
|
||||||
|
|
@ -138,7 +183,7 @@ def fromplain(cls: Type[T], d: dict[str, Any]) -> T:
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
|
||||||
def validate(o):
|
def validate(o: object) -> None:
|
||||||
for f in fields(o):
|
for f in fields(o):
|
||||||
vtype = type(getattr(o, f.name))
|
vtype = type(getattr(o, f.name))
|
||||||
if vtype is not f.type:
|
if vtype is not f.type:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue