30
30
THE SOFTWARE.
31
31
"""
32
32
33
- from typing import Union , get_args
33
+ from typing import Tuple , Union , get_args
34
34
try :
35
35
# NOTE: only available in python >= 3.8
36
36
from typing import get_origin
37
37
except ImportError :
38
38
from typing_extensions import get_origin
39
39
40
- from dataclasses import fields
40
+ from dataclasses import Field , is_dataclass , fields
41
41
from arraycontext .container import is_array_container_type
42
42
43
43
44
44
# {{{ dataclass containers
45
45
46
+ def is_array_type (tp : type ) -> bool :
47
+ from arraycontext import Array
48
+ return tp is Array or is_array_container_type (tp )
49
+
50
+
46
51
def dataclass_array_container (cls : type ) -> type :
47
52
"""A class decorator that makes the class to which it is applied an
48
53
:class:`ArrayContainer` by registering appropriate implementations of
@@ -51,24 +56,37 @@ def dataclass_array_container(cls: type) -> type:
51
56
52
57
Attributes that are not array containers are allowed. In order to decide
53
58
whether an attribute is an array container, the declared attribute type
54
- is checked by the criteria from :func:`is_array_container_type`.
59
+ is checked by the criteria from :func:`is_array_container_type`. This
60
+ includes some support for type annotations:
61
+
62
+ * a :class:`typing.Union` of array containers is considered an array container.
63
+ * other type annotations, e.g. :class:`typing.Optional`, are not considered
64
+ array containers, even if they wrap one.
55
65
"""
56
- from dataclasses import is_dataclass , Field
66
+
57
67
assert is_dataclass (cls )
58
68
59
69
def is_array_field (f : Field ) -> bool :
60
- from arraycontext import Array
70
+ # NOTE: unions of array containers are treated separately to handle
71
+ # unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
72
+ # they can work seamlessly with arithmetic and traversal.
73
+ #
74
+ # `Optional[ArrayContainer]` is not allowed, since `None` is not
75
+ # handled by `with_container_arithmetic`, which is the common case
76
+ # for current container usage. Other type annotations, e.g.
77
+ # `Tuple[Container, Container]`, are also not allowed, as they do not
78
+ # work with `with_container_arithmetic`.
79
+ #
80
+ # This is not set in stone, but mostly driven by current usage!
61
81
62
82
origin = get_origin (f .type )
63
83
if origin is Union :
64
- if not all (
65
- arg is Array or is_array_container_type ( arg )
66
- for arg in get_args ( f . type )) :
84
+ if all (is_array_type ( arg ) for arg in get_args ( f . type )):
85
+ return True
86
+ else :
67
87
raise TypeError (
68
88
f"Field '{ f .name } ' union contains non-array container "
69
89
"arguments. All arguments must be array containers." )
70
- else :
71
- return True
72
90
73
91
if __debug__ :
74
92
if not f .init :
@@ -79,8 +97,12 @@ def is_array_field(f: Field) -> bool:
79
97
raise TypeError (
80
98
f"string annotation on field '{ f .name } ' not supported" )
81
99
82
- from typing import _SpecialForm
83
- if isinstance (f .type , _SpecialForm ):
100
+ # NOTE:
101
+ # * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
102
+ # * `_SpecialForm` catches `Any`, `Literal`, etc.
103
+ from typing import ( # type: ignore[attr-defined]
104
+ _BaseGenericAlias , _SpecialForm )
105
+ if isinstance (f .type , (_BaseGenericAlias , _SpecialForm )):
84
106
# NOTE: anything except a Union is not allowed
85
107
raise TypeError (
86
108
f"typing annotation not supported on field '{ f .name } ': "
@@ -91,7 +113,7 @@ def is_array_field(f: Field) -> bool:
91
113
f"field '{ f .name } ' not an instance of 'type': "
92
114
f"'{ f .type !r} '" )
93
115
94
- return f . type is Array or is_array_container_type (f .type )
116
+ return is_array_type (f .type )
95
117
96
118
from pytools import partition
97
119
array_fields , non_array_fields = partition (is_array_field , fields (cls ))
@@ -100,6 +122,27 @@ def is_array_field(f: Field) -> bool:
100
122
raise ValueError (f"'{ cls } ' must have fields with array container type "
101
123
"in order to use the 'dataclass_array_container' decorator" )
102
124
125
+ return inject_dataclass_serialization (cls , array_fields , non_array_fields )
126
+
127
+
128
+ def inject_dataclass_serialization (
129
+ cls : type ,
130
+ array_fields : Tuple [Field , ...],
131
+ non_array_fields : Tuple [Field , ...]) -> type :
132
+ """Implements :func:`~arraycontext.serialize_container` and
133
+ :func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
134
+
135
+ This function modifies *cls* in place, so the returned value is the same
136
+ object with additional functionality.
137
+
138
+ :arg array_fields: fields of the given dataclass *cls* which are considered
139
+ array containers and should be serialized.
140
+ :arg non_array_fields: remaining fields of the dataclass *cls* which are
141
+ copied over from the template array in deserialization.
142
+ """
143
+
144
+ assert is_dataclass (cls )
145
+
103
146
serialize_expr = ", " .join (
104
147
f"({ f .name !r} , ary.{ f .name } )" for f in array_fields )
105
148
template_kwargs = ", " .join (
0 commit comments