From e84a6bc865a1f116f800567a24a85ed7b28b6fd2 Mon Sep 17 00:00:00 2001 From: ducklet Date: Sat, 4 Feb 2023 12:46:30 +0100 Subject: [PATCH] fix support for union type expressions --- unwind/models.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/unwind/models.py b/unwind/models.py index 8922e5b..030b87e 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, field from dataclasses import fields as _fields from datetime import datetime, timezone from functools import partial +from types import UnionType from typing import ( Annotated, Any, @@ -52,7 +53,7 @@ def fields(class_or_instance): def is_optional(tp: Type) -> bool: """Return wether the given type is optional.""" - if get_origin(tp) is not Union: + if not isinstance(tp, UnionType) and get_origin(tp) is not Union: return False args = get_args(tp) @@ -66,7 +67,7 @@ def optional_type(tp: Type) -> Type | None: Since they're equivalent this also works for other optioning notations, like `Union[int, None]` and `int | None`. """ - if get_origin(tp) is not Union: + if not isinstance(tp, UnionType) and get_origin(tp) is not Union: return None args = get_args(tp) @@ -184,7 +185,8 @@ def validate(o: object) -> None: vtype = type(getattr(o, f.name)) if vtype is not f.type: if get_origin(f.type) is vtype or ( - get_origin(f.type) is Union and vtype in get_args(f.type) + (isinstance(f.type, UnionType) or get_origin(f.type) is Union) + and vtype in get_args(f.type) ): continue raise ValueError(f"Invalid value type: {f.name}: {vtype}")