Skip to content

Allow to manually exclude fields in dataclass_array_container #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
serialize_container, deserialize_container,
register_multivector_as_array_container)
from .container.arithmetic import with_container_arithmetic
from .container.dataclass import dataclass_array_container
from .container.dataclass import dataclass_array_container, ExcludedField

from .container.traversal import (
map_array_container,
Expand Down Expand Up @@ -85,7 +85,7 @@
"serialize_container", "deserialize_container",
"register_multivector_as_array_container",
"with_container_arithmetic",
"dataclass_array_container",
"dataclass_array_container", "ExcludedField",

"map_array_container", "multimap_array_container",
"rec_map_array_container", "rec_multimap_array_container",
Expand Down
38 changes: 36 additions & 2 deletions arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

"""
.. currentmodule:: arraycontext

.. autoclass:: ExcludedField
.. autofunction:: dataclass_array_container
"""

Expand Down Expand Up @@ -31,11 +33,34 @@
"""

from dataclasses import fields

# NOTE: these are internal attributes and mypy says they do not exist
from typing import _GenericAlias # type: ignore[attr-defined]
try:
from typing import _AnnotatedAlias, get_args # type: ignore[attr-defined]
except ImportError:
from typing_extensions import ( # type: ignore[attr-defined]
_AnnotatedAlias, get_args)

from arraycontext.container import is_array_container_type


# {{{ dataclass containers

class ExcludedField:
"""Can be used to annotate dataclass fields to be excluded from the container.

This can be done using :class:`typing.Annotated` as follows

.. code:: python

@dataclass
class MyClass:
x: np.ndarray
y: Annotated[np.ndarray, ExcludedField]
"""


def dataclass_array_container(cls: type) -> type:
"""A class decorator that makes the class to which it is applied an
:class:`ArrayContainer` by registering appropriate implementations of
Expand All @@ -45,11 +70,21 @@ def dataclass_array_container(cls: type) -> type:
Attributes that are not array containers are allowed. In order to decide
whether an attribute is an array container, the declared attribute type
is checked by the criteria from :func:`is_array_container_type`.

To explicitly exclude fields from the container serialization (that would
otherwise be recognized as array containers), use :class:`typing.Annotated`
and :class:`ExcludedField`.
"""

from dataclasses import is_dataclass, Field
assert is_dataclass(cls)

def is_array_field(f: Field) -> bool:
# FIXME: is there a nicer way to recognize that we hit Annotated?
if isinstance(f.type, _AnnotatedAlias):
if any(arg is ExcludedField for arg in get_args(f.type)):
return False

if __debug__:
if not f.init:
raise ValueError(
Expand All @@ -59,8 +94,7 @@ def is_array_field(f: Field) -> bool:
raise TypeError(
f"string annotation on field '{f.name}' not supported")

from typing import _SpecialForm
if isinstance(f.type, _SpecialForm):
if isinstance(f.type, _GenericAlias):
raise TypeError(
f"typing annotation not supported on field '{f.name}': "
f"'{f.type!r}'")
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def main():
"pytest>=2.3",
"loopy>=2019.1",
"dataclasses; python_version<'3.7'",
"typing_extensions; python_version<'3.9'",
"types-dataclasses",
],
package_data={"arraycontext": ["py.typed"]},
Expand Down
34 changes: 34 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,40 @@ class ArrayContainerWithInitFalse:
# }}}


# {{{ test_dataclass_excluded_fields

def test_dataclass_excluded_fields():
from dataclasses import dataclass
from arraycontext import dataclass_array_container, ExcludedField

try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated

@dataclass_array_container
@dataclass(frozen=True)
class ExcludedFoo:
x: np.ndarray
y: np.ndarray
excluded: Annotated[np.ndarray, ExcludedField]

ary = np.array([42], dtype=object)
c0 = ExcludedFoo(x=ary, y=ary, excluded=ary)

from arraycontext import serialize_container
iterable = serialize_container(c0)
assert len(iterable) == 2

from arraycontext import deserialize_container
c1 = deserialize_container(c0, iterable)
assert np.linalg.norm(c0.x - c1.x) < 1.0e-15
assert np.linalg.norm(c0.y - c1.y) < 1.0e-15


# }}}


if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
Expand Down