fix support for union type expressions

This commit is contained in:
ducklet 2023-02-04 12:46:30 +01:00
parent 7da3a094f1
commit e84a6bc865

View file

@ -3,6 +3,7 @@ from dataclasses import dataclass, field
from dataclasses import fields as _fields from dataclasses import fields as _fields
from datetime import datetime, timezone from datetime import datetime, timezone
from functools import partial from functools import partial
from types import UnionType
from typing import ( from typing import (
Annotated, Annotated,
Any, Any,
@ -52,7 +53,7 @@ def fields(class_or_instance):
def is_optional(tp: Type) -> bool: def is_optional(tp: Type) -> bool:
"""Return wether the given type is optional.""" """Return wether the given type is optional."""
if get_origin(tp) is not Union: if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
return False return False
args = get_args(tp) args = get_args(tp)
@ -66,7 +67,7 @@ def optional_type(tp: Type) -> Type | None:
Since they're equivalent this also works for other optioning notations, like Since they're equivalent this also works for other optioning notations, like
`Union[int, None]` and `int | None`. `Union[int, None]` and `int | None`.
""" """
if get_origin(tp) is not Union: if not isinstance(tp, UnionType) and get_origin(tp) is not Union:
return None return None
args = get_args(tp) args = get_args(tp)
@ -184,7 +185,8 @@ def validate(o: object) -> None:
vtype = type(getattr(o, f.name)) vtype = type(getattr(o, f.name))
if vtype is not f.type: if vtype is not f.type:
if get_origin(f.type) is vtype or ( if get_origin(f.type) is vtype or (
get_origin(f.type) is Union and vtype in get_args(f.type) (isinstance(f.type, UnionType) or get_origin(f.type) is Union)
and vtype in get_args(f.type)
): ):
continue continue
raise ValueError(f"Invalid value type: {f.name}: {vtype}") raise ValueError(f"Invalid value type: {f.name}: {vtype}")