diff --git a/unwind/models.py b/unwind/models.py index 8922e5b..030b87e 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, field from dataclasses import fields as _fields from datetime import datetime, timezone from functools import partial +from types import UnionType from typing import ( Annotated, Any, @@ -52,7 +53,7 @@ def fields(class_or_instance): def is_optional(tp: Type) -> bool: """Return wether the given type is optional.""" - if get_origin(tp) is not Union: + if not isinstance(tp, UnionType) and get_origin(tp) is not Union: return False args = get_args(tp) @@ -66,7 +67,7 @@ def optional_type(tp: Type) -> Type | None: Since they're equivalent this also works for other optioning notations, like `Union[int, None]` and `int | None`. """ - if get_origin(tp) is not Union: + if not isinstance(tp, UnionType) and get_origin(tp) is not Union: return None args = get_args(tp) @@ -184,7 +185,8 @@ def validate(o: object) -> None: vtype = type(getattr(o, f.name)) if vtype is not f.type: if get_origin(f.type) is vtype or ( - get_origin(f.type) is Union and vtype in get_args(f.type) + (isinstance(f.type, UnionType) or get_origin(f.type) is Union) + and vtype in get_args(f.type) ): continue raise ValueError(f"Invalid value type: {f.name}: {vtype}")