@@ -51,47 +51,39 @@ def dataclass_array_container(cls: type) -> type:
51
51
52
52
Attributes that are not array containers are allowed. In order to decide
53
53
whether an attribute is an array container, the declared attribute type
54
- is checked by the criteria from :func:`is_array_container_type`.
54
+ is checked by the criteria from :func:`is_array_container_type`. This
55
+ includes some support for type annotations:
56
+
57
+ * a :class:`typing.Union` of array containers is considered an array container.
58
+ * other type annotations, e.g. :class:`typing.Optional`, are not considered
59
+ array containers, even if they wrap one.
55
60
"""
56
61
from dataclasses import is_dataclass , Field
57
62
assert is_dataclass (cls )
58
63
59
- def is_array_field ( f : Field ) -> bool :
64
+ def is_array_type ( tp : type ) -> bool :
60
65
from arraycontext import Array
66
+ return tp is Array or is_array_container_type (tp )
61
67
62
- origin = get_origin (f .type )
63
- if origin is Union :
64
- if not all (
65
- arg is Array or is_array_container_type (arg )
66
- for arg in get_args (f .type )):
67
- raise TypeError (
68
- f"Field '{ f .name } ' union contains non-array container "
69
- "arguments. All arguments must be array containers." )
70
- else :
71
- return True
72
-
68
+ def is_array_field (f : Field ) -> bool :
73
69
if __debug__ :
74
70
if not f .init :
75
71
raise ValueError (
76
- f"'init=False' field not allowed: '{ f .name } '" )
72
+ f"Fields with 'init=False' not allowed: '{ f .name } '" )
77
73
78
74
if isinstance (f .type , str ):
79
75
raise TypeError (
80
- f"string annotation on field '{ f .name } ' not supported" )
76
+ f"String annotation on field '{ f .name } ' not supported" )
81
77
82
- from typing import _SpecialForm
83
- if isinstance (f .type , _SpecialForm ):
84
- # NOTE: anything except a Union is not allowed
85
- raise TypeError (
86
- f"typing annotation not supported on field '{ f .name } ': "
87
- f"'{ f .type !r} '" )
78
+ origin = get_origin (f .type )
79
+ if origin is Union :
80
+ return all (is_array_type (arg ) for arg in get_args (f .type ))
88
81
89
- if not isinstance (f .type , type ):
90
- raise TypeError (
91
- f"field '{ f .name } ' not an instance of 'type': "
92
- f"'{ f .type !r} '" )
82
+ from typing import _GenericAlias , _SpecialForm
83
+ if isinstance (f .type , (_GenericAlias , _SpecialForm )):
84
+ return False
93
85
94
- return f . type is Array or is_array_container_type (f .type )
86
+ return is_array_type (f .type )
95
87
96
88
from pytools import partition
97
89
array_fields , non_array_fields = partition (is_array_field , fields (cls ))
0 commit comments