fix support for union type expressions
This commit is contained in:
parent
7da3a094f1
commit
e84a6bc865
1 changed files with 5 additions and 3 deletions
|
|
@ -3,6 +3,7 @@ from dataclasses import dataclass, field
|
||||||
from dataclasses import fields as _fields
|
from dataclasses import fields as _fields
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from types import UnionType
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
|
|
@ -52,7 +53,7 @@ def fields(class_or_instance):
|
||||||
|
|
||||||
def is_optional(tp: Type) -> bool:
|
def is_optional(tp: Type) -> bool:
|
||||||
"""Return wether the given type is optional."""
|
"""Return wether the given type is optional."""
|
||||||
if get_origin(tp) is not Union:
|
if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
args = get_args(tp)
|
args = get_args(tp)
|
||||||
|
|
@ -66,7 +67,7 @@ def optional_type(tp: Type) -> Type | None:
|
||||||
Since they're equivalent this also works for other optioning notations, like
|
Since they're equivalent this also works for other optioning notations, like
|
||||||
`Union[int, None]` and `int | None`.
|
`Union[int, None]` and `int | None`.
|
||||||
"""
|
"""
|
||||||
if get_origin(tp) is not Union:
|
if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
args = get_args(tp)
|
args = get_args(tp)
|
||||||
|
|
@ -184,7 +185,8 @@ def validate(o: object) -> None:
|
||||||
vtype = type(getattr(o, f.name))
|
vtype = type(getattr(o, f.name))
|
||||||
if vtype is not f.type:
|
if vtype is not f.type:
|
||||||
if get_origin(f.type) is vtype or (
|
if get_origin(f.type) is vtype or (
|
||||||
get_origin(f.type) is Union and vtype in get_args(f.type)
|
(isinstance(f.type, UnionType) or get_origin(f.type) is Union)
|
||||||
|
and vtype in get_args(f.type)
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
raise ValueError(f"Invalid value type: {f.name}: {vtype}")
|
raise ValueError(f"Invalid value type: {f.name}: {vtype}")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue