Skip to content

Commit 0083db2

Browse files
committed
allow to exclude fields in dataclass_array_container
1 parent f6400a4 commit 0083db2

File tree

4 files changed

+74
-5
lines changed

4 files changed

+74
-5
lines changed

arraycontext/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
serialize_container, deserialize_container,
4545
register_multivector_as_array_container)
4646
from .container.arithmetic import with_container_arithmetic
47-
from .container.dataclass import dataclass_array_container
47+
from .container.dataclass import dataclass_array_container, ExcludedField
4848

4949
from .container.traversal import (
5050
map_array_container,
@@ -85,7 +85,7 @@
8585
"serialize_container", "deserialize_container",
8686
"register_multivector_as_array_container",
8787
"with_container_arithmetic",
88-
"dataclass_array_container",
88+
"dataclass_array_container", "ExcludedField",
8989

9090
"map_array_container", "multimap_array_container",
9191
"rec_map_array_container", "rec_multimap_array_container",

arraycontext/container/dataclass.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
"""
44
.. currentmodule:: arraycontext
5+
6+
.. autoclass:: ExcludedField
57
.. autofunction:: dataclass_array_container
68
"""
79

@@ -31,25 +33,58 @@
3133
"""
3234

3335
from dataclasses import fields
36+
37+
try:
38+
# NOTE: mypy fails with `Module "typing" has no attribute "get_args"`
39+
from typing import ( # type: ignore[attr-defined]
40+
_AnnotatedAlias, _GenericAlias, get_args)
41+
except ImportError:
42+
from typing_extensions import ( # type: ignore[attr-defined]
43+
_AnnotatedAlias, _GenericAlias, get_args)
44+
3445
from arraycontext.container import is_array_container_type
3546

3647

3748
# {{{ dataclass containers
3849

50+
class ExcludedField:
51+
"""Can be used to annotate dataclass fields to be excluded from the container.
52+
53+
This can be done using :class:`typing.Annotated` as follows
54+
55+
.. code:: python
56+
57+
@dataclass
58+
class MyClass:
59+
x: np.ndarray
60+
y: Annotated[np.ndarray, ExcludedField]
61+
"""
62+
63+
3964
def dataclass_array_container(cls: type) -> type:
40-
"""A class decorator that makes the class to which it is applied a
65+
"""A class decorator that makes the class to which it is applied an
4166
:class:`ArrayContainer` by registering appropriate implementations of
4267
:func:`serialize_container` and :func:`deserialize_container`.
4368
*cls* must be a :func:`~dataclasses.dataclass`.
4469
4570
Attributes that are not array containers are allowed. In order to decide
4671
whether an attribute is an array container, the declared attribute type
4772
is checked by the criteria from :func:`is_array_container_type`.
73+
74+
To explicitly exclude fields from the container serialization (that would
75+
otherwise be recognized as array containers), use :class:`typing.Annotated`
76+
and :class:`ExcludedField`.
4877
"""
78+
4979
from dataclasses import is_dataclass, Field
5080
assert is_dataclass(cls)
5181

5282
def is_array_field(f: Field) -> bool:
83+
# FIXME: is there a nicer way to recognize that we hit Annotated?
84+
if isinstance(f.type, _AnnotatedAlias):
85+
if any(arg is ExcludedField for arg in get_args(f.type)):
86+
return False
87+
5388
if __debug__:
5489
if not f.init:
5590
raise ValueError(
@@ -59,8 +94,7 @@ def is_array_field(f: Field) -> bool:
5994
raise TypeError(
6095
f"string annotation on field '{f.name}' not supported")
6196

62-
from typing import _SpecialForm
63-
if isinstance(f.type, _SpecialForm):
97+
if isinstance(f.type, _GenericAlias):
6498
raise TypeError(
6599
f"typing annotation not supported on field '{f.name}': "
66100
f"'{f.type!r}'")

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def main():
4343
"pytest>=2.3",
4444
"loopy>=2019.1",
4545
"dataclasses; python_version<'3.7'",
46+
"typing_extensions; python_version<'3.9'",
4647
"types-dataclasses",
4748
],
4849
package_data={"arraycontext": ["py.typed"]},

test/test_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,40 @@ class ArrayContainerWithInitFalse:
9393
# }}}
9494

9595

96+
# {{{ test_dataclass_excluded_fields
97+
98+
def test_dataclass_excluded_fields():
99+
from dataclasses import dataclass
100+
from arraycontext import dataclass_array_container, ExcludedField
101+
102+
try:
103+
from typing import Annotated
104+
except ImportError:
105+
from typing_extensions import Annotated
106+
107+
@dataclass_array_container
108+
@dataclass(frozen=True)
109+
class ExcludedFoo:
110+
x: np.ndarray
111+
y: np.ndarray
112+
excluded: Annotated[np.ndarray, ExcludedField]
113+
114+
ary = np.array([42], dtype=object)
115+
c0 = ExcludedFoo(x=ary, y=ary, excluded=ary)
116+
117+
from arraycontext import serialize_container
118+
iterable = serialize_container(c0)
119+
assert len(iterable) == 2
120+
121+
from arraycontext import deserialize_container
122+
c1 = deserialize_container(c0, iterable)
123+
assert np.linalg.norm(c0.x - c1.x) < 1.0e-15
124+
assert np.linalg.norm(c0.y - c1.y) < 1.0e-15
125+
126+
127+
# }}}
128+
129+
96130
if __name__ == "__main__":
97131
import sys
98132
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)