diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 6f7308db..22b74fb6 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -44,7 +44,7 @@ serialize_container, deserialize_container, register_multivector_as_array_container) from .container.arithmetic import with_container_arithmetic -from .container.dataclass import dataclass_array_container +from .container.dataclass import dataclass_array_container, ExcludedField from .container.traversal import ( map_array_container, @@ -85,7 +85,7 @@ "serialize_container", "deserialize_container", "register_multivector_as_array_container", "with_container_arithmetic", - "dataclass_array_container", + "dataclass_array_container", "ExcludedField", "map_array_container", "multimap_array_container", "rec_map_array_container", "rec_multimap_array_container", diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index 203246cd..d28fb566 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -2,6 +2,8 @@ """ .. currentmodule:: arraycontext + +.. autoclass:: ExcludedField .. autofunction:: dataclass_array_container """ @@ -31,11 +33,34 @@ """ from dataclasses import fields + +# NOTE: these are internal attributes and mypy says they do not exist +from typing import _GenericAlias # type: ignore[attr-defined] +try: + from typing import _AnnotatedAlias, get_args # type: ignore[attr-defined] +except ImportError: + from typing_extensions import ( # type: ignore[attr-defined] + _AnnotatedAlias, get_args) + from arraycontext.container import is_array_container_type # {{{ dataclass containers +class ExcludedField: + """Can be used to annotate dataclass fields to be excluded from the container. + + This can be done using :class:`typing.Annotated` as follows + + .. code:: python + + @dataclass + class MyClass: + x: np.ndarray + y: Annotated[np.ndarray, ExcludedField] + """ + + def dataclass_array_container(cls: type) -> type: """A class decorator that makes the class to which it is applied an :class:`ArrayContainer` by registering appropriate implementations of @@ -45,11 +70,21 @@ def dataclass_array_container(cls: type) -> type: Attributes that are not array containers are allowed. In order to decide whether an attribute is an array container, the declared attribute type is checked by the criteria from :func:`is_array_container_type`. + + To explicitly exclude fields from the container serialization (that would + otherwise be recognized as array containers), use :class:`typing.Annotated` + and :class:`ExcludedField`. """ + from dataclasses import is_dataclass, Field assert is_dataclass(cls) def is_array_field(f: Field) -> bool: + # FIXME: is there a nicer way to recognize that we hit Annotated? + if isinstance(f.type, _AnnotatedAlias): + if any(arg is ExcludedField for arg in get_args(f.type)): + return False + if __debug__: if not f.init: raise ValueError( @@ -59,8 +94,7 @@ def is_array_field(f: Field) -> bool: raise TypeError( f"string annotation on field '{f.name}' not supported") - from typing import _SpecialForm - if isinstance(f.type, _SpecialForm): + if isinstance(f.type, _GenericAlias): raise TypeError( f"typing annotation not supported on field '{f.name}': " f"'{f.type!r}'") diff --git a/setup.py b/setup.py index 62ff4a7b..13613afe 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ def main(): "pytest>=2.3", "loopy>=2019.1", "dataclasses; python_version<'3.7'", + "typing_extensions; python_version<'3.9'", "types-dataclasses", ], package_data={"arraycontext": ["py.typed"]}, diff --git a/test/test_utils.py b/test/test_utils.py index 08b6c3a5..ce5bd5e1 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -93,6 +93,40 @@ class ArrayContainerWithInitFalse: # }}} +# {{{ test_dataclass_excluded_fields + +def test_dataclass_excluded_fields(): + from dataclasses import dataclass + from arraycontext import dataclass_array_container, ExcludedField + + try: + from typing import Annotated + except ImportError: + from typing_extensions import Annotated + + @dataclass_array_container + @dataclass(frozen=True) + class ExcludedFoo: + x: np.ndarray + y: np.ndarray + excluded: Annotated[np.ndarray, ExcludedField] + + ary = np.array([42], dtype=object) + c0 = ExcludedFoo(x=ary, y=ary, excluded=ary) + + from arraycontext import serialize_container + iterable = serialize_container(c0) + assert len(iterable) == 2 + + from arraycontext import deserialize_container + c1 = deserialize_container(c0, iterable) + assert np.linalg.norm(c0.x - c1.x) < 1.0e-15 + assert np.linalg.norm(c0.y - c1.y) < 1.0e-15 + + +# }}} + + if __name__ == "__main__": import sys if len(sys.argv) > 1: