Skip to content

Commit 026fac2

Browse files
committed
allow to exclude fields in dataclass_array_container
1 parent 9314073 commit 026fac2

File tree

2 files changed

+138
-67
lines changed

2 files changed

+138
-67
lines changed

arraycontext/container/dataclass.py

Lines changed: 107 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -31,85 +31,125 @@
3131
"""
3232

3333
from dataclasses import fields
34+
from typing import Callable, Optional, Tuple, Union, overload
35+
3436
from arraycontext.container import is_array_container_type
3537

3638

3739
# {{{ dataclass containers
3840

39-
def dataclass_array_container(cls: type) -> type:
40-
"""A class decorator that makes the class to which it is applied a
41+
@overload
42+
def dataclass_array_container(
43+
cls: None = None,
44+
excluded_fields: Optional[Tuple[str, ...]] = None,
45+
) -> Callable[[type], type]:
46+
...
47+
48+
49+
@overload
50+
def dataclass_array_container(
51+
cls: type,
52+
excluded_fields: Optional[Tuple[str, ...]] = None,
53+
) -> type:
54+
...
55+
56+
57+
def dataclass_array_container(
58+
cls: Optional[type] = None,
59+
excluded_fields: Optional[Tuple[str, ...]] = None,
60+
) -> Union[type, Callable[[type], type]]:
61+
"""A class decorator that makes the class to which it is applied an
4162
:class:`ArrayContainer` by registering appropriate implementations of
4263
:func:`serialize_container` and :func:`deserialize_container`.
4364
*cls* must be a :func:`~dataclasses.dataclass`.
4465
4566
Attributes that are not array containers are allowed. In order to decide
4667
whether an attribute is an array container, the declared attribute type
47-
is checked by the criteria from :func:`is_array_container`.
68+
is checked by the criteria from :func:`is_array_container`. Additional
69+
attributes can be excluded manually using *excluded_fields*
4870
"""
49-
from dataclasses import is_dataclass
50-
assert is_dataclass(cls)
51-
52-
array_fields = [
53-
f for f in fields(cls) if is_array_container_type(f.type)]
54-
non_array_fields = [
55-
f for f in fields(cls) if not is_array_container_type(f.type)]
56-
57-
if not array_fields:
58-
raise ValueError(f"'{cls}' must have fields with array container type "
59-
"in order to use the 'dataclass_array_container' decorator")
60-
61-
serialize_expr = ", ".join(
62-
f"({f.name!r}, ary.{f.name})" for f in array_fields)
63-
template_kwargs = ", ".join(
64-
f"{f.name}=template.{f.name}" for f in non_array_fields)
65-
66-
lower_cls_name = cls.__name__.lower()
67-
68-
serialize_init_code = ", ".join(f"{f.name!r}: f'{{instance_name}}.{f.name}'"
69-
for f in array_fields)
70-
deserialize_init_code = ", ".join([
71-
f"{f.name}={{args[{f.name!r}]}}" for f in array_fields
72-
] + [
73-
f"{f.name}={{template_instance_name}}.{f.name}"
74-
for f in non_array_fields
75-
])
76-
77-
from pytools.codegen import remove_common_indentation
78-
serialize_code = remove_common_indentation(f"""
79-
from typing import Any, Iterable, Tuple
80-
from arraycontext import serialize_container, deserialize_container
81-
82-
@serialize_container.register(cls)
83-
def _serialize_{lower_cls_name}(ary: cls) -> Iterable[Tuple[Any, Any]]:
84-
return ({serialize_expr},)
85-
86-
@deserialize_container.register(cls)
87-
def _deserialize_{lower_cls_name}(
88-
template: cls, iterable: Iterable[Tuple[Any, Any]]) -> cls:
89-
return cls(**dict(iterable), {template_kwargs})
90-
91-
# support for with_container_arithmetic
92-
93-
def _serialize_init_arrays_code_{lower_cls_name}(cls, instance_name):
94-
return {{
95-
{serialize_init_code}
96-
}}
97-
cls._serialize_init_arrays_code = classmethod(
98-
_serialize_init_arrays_code_{lower_cls_name})
99-
100-
def _deserialize_init_arrays_code_{lower_cls_name}(
101-
cls, template_instance_name, args):
102-
return f"{deserialize_init_code}"
103-
104-
cls._deserialize_init_arrays_code = classmethod(
105-
_deserialize_init_arrays_code_{lower_cls_name})
106-
""")
107-
108-
exec_dict = {"cls": cls, "_MODULE_SOURCE_CODE": serialize_code}
109-
exec(compile(serialize_code, f"<container serialization for {cls.__name__}>",
110-
"exec"), exec_dict)
111-
112-
return cls
71+
if excluded_fields is None:
72+
excluded_fields = ()
73+
74+
def wrap(cls: type) -> type:
75+
assert excluded_fields is not None
76+
77+
from dataclasses import is_dataclass
78+
assert is_dataclass(cls)
79+
80+
cls_fields = fields(cls)
81+
assert all(any(f == cf.name for cf in cls_fields) for f in excluded_fields)
82+
83+
container_fields = [
84+
f for f in cls_fields
85+
if f.name not in excluded_fields and is_array_container_type(f.type)
86+
]
87+
non_container_fields = [
88+
f for f in fields(cls)
89+
if f.name in excluded_fields or not is_array_container_type(f.type)
90+
]
91+
92+
if not container_fields:
93+
raise ValueError(f"'{cls}' must have fields with array container type "
94+
"in order to use the 'dataclass_array_container' decorator")
95+
96+
serialize_expr = ", ".join(
97+
f"({f.name!r}, ary.{f.name})" for f in container_fields)
98+
template_kwargs = ", ".join(
99+
f"{f.name}=template.{f.name}" for f in non_container_fields)
100+
101+
lower_cls_name = cls.__name__.lower()
102+
103+
serialize_init_code = ", ".join(f"{f.name!r}: f'{{instance_name}}.{f.name}'"
104+
for f in container_fields)
105+
deserialize_init_code = ", ".join([
106+
f"{f.name}={{args[{f.name!r}]}}" for f in container_fields
107+
] + [
108+
f"{f.name}={{template_instance_name}}.{f.name}"
109+
for f in non_container_fields
110+
])
111+
112+
from pytools.codegen import remove_common_indentation
113+
serialize_code = remove_common_indentation(f"""
114+
from typing import Any, Iterable, Tuple
115+
from arraycontext import serialize_container, deserialize_container
116+
117+
@serialize_container.register(cls)
118+
def _serialize_{lower_cls_name}(ary: cls) -> Iterable[Tuple[Any, Any]]:
119+
return ({serialize_expr},)
120+
121+
@deserialize_container.register(cls)
122+
def _deserialize_{lower_cls_name}(
123+
template: cls, iterable: Iterable[Tuple[Any, Any]]) -> cls:
124+
return cls(**dict(iterable), {template_kwargs})
125+
126+
# support for with_container_arithmetic
127+
128+
def _serialize_init_arrays_code_{lower_cls_name}(cls, instance_name):
129+
return {{
130+
{serialize_init_code}
131+
}}
132+
cls._serialize_init_arrays_code = classmethod(
133+
_serialize_init_arrays_code_{lower_cls_name})
134+
135+
def _deserialize_init_arrays_code_{lower_cls_name}(
136+
cls, template_instance_name, args):
137+
return f"{deserialize_init_code}"
138+
139+
cls._deserialize_init_arrays_code = classmethod(
140+
_deserialize_init_arrays_code_{lower_cls_name})
141+
""")
142+
143+
exec_dict = {"cls": cls, "_MODULE_SOURCE_CODE": serialize_code}
144+
exec(compile(serialize_code, f"<container serialization for {cls.__name__}>",
145+
"exec"), exec_dict)
146+
147+
return cls
148+
149+
if cls is not None:
150+
return wrap(cls)
151+
else:
152+
return wrap
113153

114154
# }}}
115155

test/test_arraycontext.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,6 +1093,37 @@ def _actx_allows_scalar_broadcast(actx):
10931093
# }}}
10941094

10951095

1096+
# {{{ test_dataclass_excluded_fields
1097+
1098+
1099+
@with_container_arithmetic(
1100+
bcast_obj_array=True,
1101+
bcast_numpy_array=True,
1102+
rel_comparison=True,
1103+
_cls_has_array_context_attr=True)
1104+
@dataclass_array_container(excluded_fields=("excluded",))
1105+
@dataclass(frozen=True)
1106+
class ExcludedFoo:
1107+
x: DOFArray
1108+
excluded: DOFArray
1109+
1110+
@property
1111+
def array_context(self):
1112+
return self.x.array_context
1113+
1114+
1115+
def test_dataclass_excluded_fields(actx_factory):
1116+
actx = actx_factory()
1117+
1118+
ary = DOFArray(actx, (actx.zeros(3, dtype=np.float64) + 41,))
1119+
x = ExcludedFoo(x=ary, excluded=ary)
1120+
1121+
r = x + 2 * x
1122+
assert actx.to_numpy(actx.np.linalg.norm(r.excluded - x.excluded)) < 1.0e-15
1123+
1124+
# }}}
1125+
1126+
10961127
if __name__ == "__main__":
10971128
import sys
10981129
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)