Skip to content

Commit 4ba3383

Browse files
committed
allow more type annotations in dataclass_array_container
1 parent ff1cd0c commit 4ba3383

File tree

2 files changed

+97
-74
lines changed

2 files changed

+97
-74
lines changed

arraycontext/container/dataclass.py

Lines changed: 81 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -30,75 +30,44 @@
3030
THE SOFTWARE.
3131
"""
3232

33-
from typing import Union, get_args
33+
from typing import Tuple, Union, get_args
3434
try:
3535
# NOTE: only available in python >= 3.8
3636
from typing import get_origin
3737
except ImportError:
3838
from typing_extensions import get_origin
3939

40-
from dataclasses import fields
40+
from dataclasses import is_dataclass, fields, Field
4141
from arraycontext.container import is_array_container_type
4242

4343

4444
# {{{ dataclass containers
4545

46-
def dataclass_array_container(cls: type) -> type:
47-
"""A class decorator that makes the class to which it is applied an
48-
:class:`ArrayContainer` by registering appropriate implementations of
49-
:func:`serialize_container` and :func:`deserialize_container`.
50-
*cls* must be a :func:`~dataclasses.dataclass`.
46+
def is_array_type(tp: type) -> bool:
47+
from arraycontext import Array
48+
return tp is Array or is_array_container_type(tp)
5149

52-
Attributes that are not array containers are allowed. In order to decide
53-
whether an attribute is an array container, the declared attribute type
54-
is checked by the criteria from :func:`is_array_container_type`.
55-
"""
56-
from dataclasses import is_dataclass, Field
57-
assert is_dataclass(cls)
5850

59-
def is_array_field(f: Field) -> bool:
60-
from arraycontext import Array
51+
def inject_container_serialization(
52+
cls: type,
53+
array_fields: Tuple[Field, ...],
54+
non_array_fields: Tuple[Field, ...],
55+
) -> type:
56+
"""Implements :func:`~arraycontext.serialize_container` and
57+
:func:`~arraycontext.deserialize_container` for the given class *cls*.
6158
62-
origin = get_origin(f.type)
63-
if origin is Union:
64-
if not all(
65-
arg is Array or is_array_container_type(arg)
66-
for arg in get_args(f.type)):
67-
raise TypeError(
68-
f"Field '{f.name}' union contains non-array container "
69-
"arguments. All arguments must be array containers.")
70-
else:
71-
return True
59+
This function modifies *cls* in place, so the returned value is the same
60+
object with additional functionality.
7261
73-
if __debug__:
74-
if not f.init:
75-
raise ValueError(
76-
f"'init=False' field not allowed: '{f.name}'")
62+
:arg array_fields: fields of the given dataclass *cls* which are considered
63+
array containers and should be serialized.
64+
:arg non_array_fields: remaining fields of the dataclass *cls* which are
65+
copied over from the template array in deserialization.
7766
78-
if isinstance(f.type, str):
79-
raise TypeError(
80-
f"string annotation on field '{f.name}' not supported")
81-
82-
from typing import _SpecialForm
83-
if isinstance(f.type, _SpecialForm):
84-
# NOTE: anything except a Union is not allowed
85-
raise TypeError(
86-
f"typing annotation not supported on field '{f.name}': "
87-
f"'{f.type!r}'")
88-
89-
if not isinstance(f.type, type):
90-
raise TypeError(
91-
f"field '{f.name}' not an instance of 'type': "
92-
f"'{f.type!r}'")
93-
94-
return f.type is Array or is_array_container_type(f.type)
95-
96-
from pytools import partition
97-
array_fields, non_array_fields = partition(is_array_field, fields(cls))
67+
:returns: the input class *cls*.
68+
"""
9869

99-
if not array_fields:
100-
raise ValueError(f"'{cls}' must have fields with array container type "
101-
"in order to use the 'dataclass_array_container' decorator")
70+
assert is_dataclass(cls)
10271

10372
serialize_expr = ", ".join(
10473
f"({f.name!r}, ary.{f.name})" for f in array_fields)
@@ -153,6 +122,66 @@ def _deserialize_init_arrays_code_{lower_cls_name}(
153122

154123
return cls
155124

125+
126+
def dataclass_array_container(cls: type) -> type:
127+
"""A class decorator that makes the class to which it is applied an
128+
:class:`ArrayContainer` by registering appropriate implementations of
129+
:func:`serialize_container` and :func:`deserialize_container`.
130+
*cls* must be a :func:`~dataclasses.dataclass`.
131+
132+
Attributes that are not array containers are allowed. In order to decide
133+
whether an attribute is an array container, the declared attribute type
134+
is checked by the criteria from :func:`is_array_container_type`. This
135+
includes some support for type annotations:
136+
137+
* a :class:`typing.Union` of array containers is considered an array container.
138+
* other type annotations, e.g. :class:`typing.Optional`, are not considered
139+
array containers, even if they wrap one.
140+
"""
141+
assert is_dataclass(cls)
142+
143+
def is_array_field(f: Field) -> bool:
144+
if __debug__:
145+
if not f.init:
146+
raise ValueError(
147+
f"Fields with 'init=False' not allowed: '{f.name}'")
148+
149+
if isinstance(f.type, str):
150+
raise TypeError(
151+
f"String annotation on field '{f.name}' not supported")
152+
153+
# NOTE: unions of array containers are treated seprately to allow
154+
# * unions of only array containers, e.g. Union[np.ndarray, Array], as
155+
# they can work seamlessly with arithmetic and traversal.
156+
# * `Optional[ArrayContainer]` is not allowed, since `None` is not
157+
# handled by `with_container_arithmetic`, which is the common case
158+
# for current container usage.
159+
#
160+
# Other type annotations, e.g. `Tuple[Container, Container]`, are also
161+
# not allowed, as they do not work with `with_container_arithmetic`.
162+
#
163+
# This is not set in stone, but mostly driven by current usage!
164+
165+
origin = get_origin(f.type)
166+
if origin is Union:
167+
# NOTE: `Optional` is caught in here as an alias for `Union[Anon, type]`
168+
return all(is_array_type(arg) for arg in get_args(f.type))
169+
170+
from typing import _GenericAlias, _SpecialForm # type: ignore[attr-defined]
171+
if isinstance(f.type, (_GenericAlias, _SpecialForm)):
172+
return False
173+
174+
return is_array_type(f.type)
175+
176+
from pytools import partition
177+
array_fields, non_array_fields = partition(is_array_field, fields(cls))
178+
179+
if not array_fields:
180+
raise ValueError(f"'{cls}' must have fields with array container type "
181+
"in order to use the 'dataclass_array_container' decorator")
182+
183+
return inject_container_serialization(cls, array_fields, non_array_fields)
184+
156185
# }}}
157186

158187
# vim: foldmethod=marker

test/test_utils.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def test_pt_actx_key_stringification_uniqueness():
4747
# {{{ test_dataclass_array_container
4848

4949
def test_dataclass_array_container():
50-
from typing import Optional
5150
from dataclasses import dataclass, field
5251
from arraycontext import dataclass_array_container
5352

@@ -64,19 +63,6 @@ class ArrayContainerWithStringTypes:
6463

6564
# }}}
6665

67-
# {{{ optional fields
68-
69-
@dataclass
70-
class ArrayContainerWithOptional:
71-
x: np.ndarray
72-
y: Optional[np.ndarray]
73-
74-
with pytest.raises(TypeError):
75-
# NOTE: cannot have wrapped annotations (here by `Optional`)
76-
dataclass_array_container(ArrayContainerWithOptional)
77-
78-
# }}}
79-
8066
# {{{ field(init=False)
8167

8268
@dataclass
@@ -106,36 +92,44 @@ class ArrayContainerWithArray:
10692
# }}}
10793

10894

109-
# {{{ test_dataclass_container_unions
95+
# {{{ test_dataclass_container_type_annotations
11096

111-
def test_dataclass_container_unions():
97+
def test_dataclass_container_type_annotations():
11298
from dataclasses import dataclass
11399
from arraycontext import dataclass_array_container
114100

115-
from typing import Union
101+
from typing import Optional, Tuple, Union
116102
from arraycontext import Array
117103

118104
# {{{ union fields
119105

106+
@dataclass_array_container
120107
@dataclass
121108
class ArrayContainerWithUnion:
122109
x: np.ndarray
123110
y: Union[np.ndarray, Array]
124111

125-
dataclass_array_container(ArrayContainerWithUnion)
126-
127112
# }}}
128113

129114
# {{{ non-container union
130115

116+
@dataclass_array_container
131117
@dataclass
132118
class ArrayContainerWithWrongUnion:
133119
x: np.ndarray
134120
y: Union[np.ndarray, float]
135121

136-
with pytest.raises(TypeError):
137-
# NOTE: float is not an ArrayContainer, so y should fail
138-
dataclass_array_container(ArrayContainerWithWrongUnion)
122+
# }}}
123+
124+
# {{{ optional and other fields
125+
126+
@dataclass_array_container
127+
@dataclass
128+
class ArrayContainerWithAnnotations:
129+
x: np.ndarray
130+
y: Tuple[float, float]
131+
z: Optional[np.ndarray]
132+
w: str
139133

140134
# }}}
141135

0 commit comments

Comments
 (0)