Skip to content

Commit b686676

Browse files
alexfiklinducer
authored andcommitted
split dataclass_array_container for easier modification
1 parent ff1cd0c commit b686676

File tree

1 file changed

+56
-13
lines changed

1 file changed

+56
-13
lines changed

arraycontext/container/dataclass.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,24 @@
3030
THE SOFTWARE.
3131
"""
3232

33-
from typing import Union, get_args
33+
from typing import Tuple, Union, get_args
3434
try:
3535
# NOTE: only available in python >= 3.8
3636
from typing import get_origin
3737
except ImportError:
3838
from typing_extensions import get_origin
3939

40-
from dataclasses import fields
40+
from dataclasses import Field, is_dataclass, fields
4141
from arraycontext.container import is_array_container_type
4242

4343

4444
# {{{ dataclass containers
4545

46+
def is_array_type(tp: type) -> bool:
47+
from arraycontext import Array
48+
return tp is Array or is_array_container_type(tp)
49+
50+
4651
def dataclass_array_container(cls: type) -> type:
4752
"""A class decorator that makes the class to which it is applied an
4853
:class:`ArrayContainer` by registering appropriate implementations of
@@ -51,24 +56,37 @@ def dataclass_array_container(cls: type) -> type:
5156
5257
Attributes that are not array containers are allowed. In order to decide
5358
whether an attribute is an array container, the declared attribute type
54-
is checked by the criteria from :func:`is_array_container_type`.
59+
is checked by the criteria from :func:`is_array_container_type`. This
60+
includes some support for type annotations:
61+
62+
* a :class:`typing.Union` of array containers is considered an array container.
63+
* other type annotations, e.g. :class:`typing.Optional`, are not considered
64+
array containers, even if they wrap one.
5565
"""
56-
from dataclasses import is_dataclass, Field
66+
5767
assert is_dataclass(cls)
5868

5969
def is_array_field(f: Field) -> bool:
60-
from arraycontext import Array
70+
# NOTE: unions of array containers are treated separately to handle
71+
# unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
72+
# they can work seamlessly with arithmetic and traversal.
73+
#
74+
# `Optional[ArrayContainer]` is not allowed, since `None` is not
75+
# handled by `with_container_arithmetic`, which is the common case
76+
# for current container usage. Other type annotations, e.g.
77+
# `Tuple[Container, Container]`, are also not allowed, as they do not
78+
# work with `with_container_arithmetic`.
79+
#
80+
# This is not set in stone, but mostly driven by current usage!
6181

6282
origin = get_origin(f.type)
6383
if origin is Union:
64-
if not all(
65-
arg is Array or is_array_container_type(arg)
66-
for arg in get_args(f.type)):
84+
if all(is_array_type(arg) for arg in get_args(f.type)):
85+
return True
86+
else:
6787
raise TypeError(
6888
f"Field '{f.name}' union contains non-array container "
6989
"arguments. All arguments must be array containers.")
70-
else:
71-
return True
7290

7391
if __debug__:
7492
if not f.init:
@@ -79,8 +97,12 @@ def is_array_field(f: Field) -> bool:
7997
raise TypeError(
8098
f"string annotation on field '{f.name}' not supported")
8199

82-
from typing import _SpecialForm
83-
if isinstance(f.type, _SpecialForm):
100+
# NOTE:
101+
# * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
102+
# * `_SpecialForm` catches `Any`, `Literal`, etc.
103+
from typing import ( # type: ignore[attr-defined]
104+
_BaseGenericAlias, _SpecialForm)
105+
if isinstance(f.type, (_BaseGenericAlias, _SpecialForm)):
84106
# NOTE: anything except a Union is not allowed
85107
raise TypeError(
86108
f"typing annotation not supported on field '{f.name}': "
@@ -91,7 +113,7 @@ def is_array_field(f: Field) -> bool:
91113
f"field '{f.name}' not an instance of 'type': "
92114
f"'{f.type!r}'")
93115

94-
return f.type is Array or is_array_container_type(f.type)
116+
return is_array_type(f.type)
95117

96118
from pytools import partition
97119
array_fields, non_array_fields = partition(is_array_field, fields(cls))
@@ -100,6 +122,27 @@ def is_array_field(f: Field) -> bool:
100122
raise ValueError(f"'{cls}' must have fields with array container type "
101123
"in order to use the 'dataclass_array_container' decorator")
102124

125+
return inject_dataclass_serialization(cls, array_fields, non_array_fields)
126+
127+
128+
def inject_dataclass_serialization(
129+
cls: type,
130+
array_fields: Tuple[Field, ...],
131+
non_array_fields: Tuple[Field, ...]) -> type:
132+
"""Implements :func:`~arraycontext.serialize_container` and
133+
:func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
134+
135+
This function modifies *cls* in place, so the returned value is the same
136+
object with additional functionality.
137+
138+
:arg array_fields: fields of the given dataclass *cls* which are considered
139+
array containers and should be serialized.
140+
:arg non_array_fields: remaining fields of the dataclass *cls* which are
141+
copied over from the template array in deserialization.
142+
"""
143+
144+
assert is_dataclass(cls)
145+
103146
serialize_expr = ", ".join(
104147
f"({f.name!r}, ary.{f.name})" for f in array_fields)
105148
template_kwargs = ", ".join(

0 commit comments

Comments
 (0)