Skip to content

Commit 201f973

Browse files
committed
allow more type annotations in dataclass_array_container
1 parent 80813d7 commit 201f973

File tree

2 files changed

+34
-48
lines changed

2 files changed

+34
-48
lines changed

arraycontext/container/dataclass.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -51,47 +51,39 @@ def dataclass_array_container(cls: type) -> type:
5151
5252
Attributes that are not array containers are allowed. In order to decide
5353
whether an attribute is an array container, the declared attribute type
54-
is checked by the criteria from :func:`is_array_container_type`.
54+
is checked by the criteria from :func:`is_array_container_type`. This
55+
includes some support for type annotations:
56+
57+
* a :class:`typing.Union` of array containers is considered an array container.
58+
* other type annotations, e.g. :class:`typing.Optional`, are not considered
59+
array containers, even if they wrap one.
5560
"""
5661
from dataclasses import is_dataclass, Field
5762
assert is_dataclass(cls)
5863

59-
def is_array_field(f: Field) -> bool:
64+
def is_array_type(tp: type) -> bool:
6065
from arraycontext import Array
66+
return tp is Array or is_array_container_type(tp)
6167

62-
origin = get_origin(f.type)
63-
if origin is Union:
64-
if not all(
65-
arg is Array or is_array_container_type(arg)
66-
for arg in get_args(f.type)):
67-
raise TypeError(
68-
f"Field '{f.name}' union contains non-array container "
69-
"arguments. All arguments must be array containers.")
70-
else:
71-
return True
72-
68+
def is_array_field(f: Field) -> bool:
7369
if __debug__:
7470
if not f.init:
7571
raise ValueError(
76-
f"'init=False' field not allowed: '{f.name}'")
72+
f"Fields with 'init=False' not allowed: '{f.name}'")
7773

7874
if isinstance(f.type, str):
7975
raise TypeError(
80-
f"string annotation on field '{f.name}' not supported")
76+
f"String annotation on field '{f.name}' not supported")
8177

82-
from typing import _SpecialForm
83-
if isinstance(f.type, _SpecialForm):
84-
# NOTE: anything except a Union is not allowed
85-
raise TypeError(
86-
f"typing annotation not supported on field '{f.name}': "
87-
f"'{f.type!r}'")
78+
origin = get_origin(f.type)
79+
if origin is Union:
80+
return all(is_array_type(arg) for arg in get_args(f.type))
8881

89-
if not isinstance(f.type, type):
90-
raise TypeError(
91-
f"field '{f.name}' not an instance of 'type': "
92-
f"'{f.type!r}'")
82+
from typing import _GenericAlias, _SpecialForm
83+
if isinstance(f.type, (_GenericAlias, _SpecialForm)):
84+
return False
9385

94-
return f.type is Array or is_array_container_type(f.type)
86+
return is_array_type(f.type)
9587

9688
from pytools import partition
9789
array_fields, non_array_fields = partition(is_array_field, fields(cls))

test/test_utils.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def test_pt_actx_key_stringification_uniqueness():
4747
# {{{ test_dataclass_array_container
4848

4949
def test_dataclass_array_container():
50-
from typing import Optional
5150
from dataclasses import dataclass, field
5251
from arraycontext import dataclass_array_container
5352

@@ -64,19 +63,6 @@ class ArrayContainerWithStringTypes:
6463

6564
# }}}
6665

67-
# {{{ optional fields
68-
69-
@dataclass
70-
class ArrayContainerWithOptional:
71-
x: np.ndarray
72-
y: Optional[np.ndarray]
73-
74-
with pytest.raises(TypeError):
75-
# NOTE: cannot have wrapped annotations (here by `Optional`)
76-
dataclass_array_container(ArrayContainerWithOptional)
77-
78-
# }}}
79-
8066
# {{{ field(init=False)
8167

8268
@dataclass
@@ -106,36 +92,44 @@ class ArrayContainerWithArray:
10692
# }}}
10793

10894

109-
# {{{ test_dataclass_container_unions
95+
# {{{ test_dataclass_container_type_annotations
11096

111-
def test_dataclass_container_unions():
97+
def test_dataclass_container_type_annotations():
11298
from dataclasses import dataclass
11399
from arraycontext import dataclass_array_container
114100

115-
from typing import Union
101+
from typing import Optional, Tuple, Union
116102
from arraycontext import Array
117103

118104
# {{{ union fields
119105

106+
@dataclass_array_container
120107
@dataclass
121108
class ArrayContainerWithUnion:
122109
x: np.ndarray
123110
y: Union[np.ndarray, Array]
124111

125-
dataclass_array_container(ArrayContainerWithUnion)
126-
127112
# }}}
128113

129114
# {{{ non-container union
130115

116+
@dataclass_array_container
131117
@dataclass
132118
class ArrayContainerWithWrongUnion:
133119
x: np.ndarray
134120
y: Union[np.ndarray, float]
135121

136-
with pytest.raises(TypeError):
137-
# NOTE: float is not an ArrayContainer, so y should fail
138-
dataclass_array_container(ArrayContainerWithWrongUnion)
122+
# }}}
123+
124+
# {{{ optional and other fields
125+
126+
@dataclass_array_container
127+
@dataclass
128+
class ArrayContainerWithAnnotations:
129+
x: np.ndarray
130+
y: Tuple[float, float]
131+
z: Optional[np.ndarray]
132+
w: str
139133

140134
# }}}
141135

0 commit comments

Comments
 (0)