fix support for union type expressions
This commit is contained in:
parent
7da3a094f1
commit
e84a6bc865
1 changed files with 5 additions and 3 deletions
|
|
@ -3,6 +3,7 @@ from dataclasses import dataclass, field
|
|||
from dataclasses import fields as _fields
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from types import UnionType
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
|
|
@ -52,7 +53,7 @@ def fields(class_or_instance):
|
|||
|
||||
def is_optional(tp: Type) -> bool:
|
||||
"""Return wether the given type is optional."""
|
||||
if get_origin(tp) is not Union:
|
||||
if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
|
||||
return False
|
||||
|
||||
args = get_args(tp)
|
||||
|
|
@ -66,7 +67,7 @@ def optional_type(tp: Type) -> Type | None:
|
|||
Since they're equivalent this also works for other optioning notations, like
|
||||
`Union[int, None]` and `int | None`.
|
||||
"""
|
||||
if get_origin(tp) is not Union:
|
||||
if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
|
||||
return None
|
||||
|
||||
args = get_args(tp)
|
||||
|
|
@ -184,7 +185,8 @@ def validate(o: object) -> None:
|
|||
vtype = type(getattr(o, f.name))
|
||||
if vtype is not f.type:
|
||||
if get_origin(f.type) is vtype or (
|
||||
get_origin(f.type) is Union and vtype in get_args(f.type)
|
||||
(isinstance(f.type, UnionType) or get_origin(f.type) is Union)
|
||||
and vtype in get_args(f.type)
|
||||
):
|
||||
continue
|
||||
raise ValueError(f"Invalid value type: {f.name}: {vtype}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue