|
30 | 30 | THE SOFTWARE.
|
31 | 31 | """
|
32 | 32 |
|
33 |
| -from typing import Union, get_args |
| 33 | +from typing import Tuple, Union, get_args |
34 | 34 | try:
|
35 | 35 | # NOTE: only available in python >= 3.8
|
36 | 36 | from typing import get_origin
|
37 | 37 | except ImportError:
|
38 | 38 | from typing_extensions import get_origin
|
39 | 39 |
|
40 |
| -from dataclasses import fields |
| 40 | +from dataclasses import is_dataclass, fields, Field |
41 | 41 | from arraycontext.container import is_array_container_type
|
42 | 42 |
|
43 | 43 |
|
44 | 44 | # {{{ dataclass containers
|
45 | 45 |
|
46 |
| -def dataclass_array_container(cls: type) -> type: |
47 |
| - """A class decorator that makes the class to which it is applied an |
48 |
| - :class:`ArrayContainer` by registering appropriate implementations of |
49 |
| - :func:`serialize_container` and :func:`deserialize_container`. |
50 |
| - *cls* must be a :func:`~dataclasses.dataclass`. |
| 46 | +def is_array_type(tp: type) -> bool: |
| 47 | + from arraycontext import Array |
| 48 | + return tp is Array or is_array_container_type(tp) |
51 | 49 |
|
52 |
| - Attributes that are not array containers are allowed. In order to decide |
53 |
| - whether an attribute is an array container, the declared attribute type |
54 |
| - is checked by the criteria from :func:`is_array_container_type`. |
55 |
| - """ |
56 |
| - from dataclasses import is_dataclass, Field |
57 |
| - assert is_dataclass(cls) |
58 | 50 |
|
59 |
| - def is_array_field(f: Field) -> bool: |
60 |
| - from arraycontext import Array |
| 51 | +def inject_container_serialization( |
| 52 | + cls: type, |
| 53 | + array_fields: Tuple[Field, ...], |
| 54 | + non_array_fields: Tuple[Field, ...], |
| 55 | + ) -> type: |
| 56 | + """Implements :func:`~arraycontext.serialize_container` and |
| 57 | + :func:`~arraycontext.deserialize_container` for the given class *cls*. |
61 | 58 |
|
62 |
| - origin = get_origin(f.type) |
63 |
| - if origin is Union: |
64 |
| - if not all( |
65 |
| - arg is Array or is_array_container_type(arg) |
66 |
| - for arg in get_args(f.type)): |
67 |
| - raise TypeError( |
68 |
| - f"Field '{f.name}' union contains non-array container " |
69 |
| - "arguments. All arguments must be array containers.") |
70 |
| - else: |
71 |
| - return True |
| 59 | + This function modifies *cls* in place, so the returned value is the same |
| 60 | + object with additional functionality. |
72 | 61 |
|
73 |
| - if __debug__: |
74 |
| - if not f.init: |
75 |
| - raise ValueError( |
76 |
| - f"'init=False' field not allowed: '{f.name}'") |
| 62 | + :arg array_fields: fields of the given dataclass *cls* which are considered |
| 63 | + array containers and should be serialized. |
| 64 | + :arg non_array_fields: remaining fields of the dataclass *cls* which are |
| 65 | + copied over from the template array in deserialization. |
77 | 66 |
|
78 |
| - if isinstance(f.type, str): |
79 |
| - raise TypeError( |
80 |
| - f"string annotation on field '{f.name}' not supported") |
81 |
| - |
82 |
| - from typing import _SpecialForm |
83 |
| - if isinstance(f.type, _SpecialForm): |
84 |
| - # NOTE: anything except a Union is not allowed |
85 |
| - raise TypeError( |
86 |
| - f"typing annotation not supported on field '{f.name}': " |
87 |
| - f"'{f.type!r}'") |
88 |
| - |
89 |
| - if not isinstance(f.type, type): |
90 |
| - raise TypeError( |
91 |
| - f"field '{f.name}' not an instance of 'type': " |
92 |
| - f"'{f.type!r}'") |
93 |
| - |
94 |
| - return f.type is Array or is_array_container_type(f.type) |
95 |
| - |
96 |
| - from pytools import partition |
97 |
| - array_fields, non_array_fields = partition(is_array_field, fields(cls)) |
| 67 | + :returns: the input class *cls*. |
| 68 | + """ |
98 | 69 |
|
99 |
| - if not array_fields: |
100 |
| - raise ValueError(f"'{cls}' must have fields with array container type " |
101 |
| - "in order to use the 'dataclass_array_container' decorator") |
| 70 | + assert is_dataclass(cls) |
102 | 71 |
|
103 | 72 | serialize_expr = ", ".join(
|
104 | 73 | f"({f.name!r}, ary.{f.name})" for f in array_fields)
|
@@ -153,6 +122,66 @@ def _deserialize_init_arrays_code_{lower_cls_name}(
|
153 | 122 |
|
154 | 123 | return cls
|
155 | 124 |
|
| 125 | + |
| 126 | +def dataclass_array_container(cls: type) -> type: |
| 127 | + """A class decorator that makes the class to which it is applied an |
| 128 | + :class:`ArrayContainer` by registering appropriate implementations of |
| 129 | + :func:`serialize_container` and :func:`deserialize_container`. |
| 130 | + *cls* must be a :func:`~dataclasses.dataclass`. |
| 131 | +
|
| 132 | + Attributes that are not array containers are allowed. In order to decide |
| 133 | + whether an attribute is an array container, the declared attribute type |
| 134 | + is checked by the criteria from :func:`is_array_container_type`. This |
| 135 | + includes some support for type annotations: |
| 136 | +
|
| 137 | + * a :class:`typing.Union` of array containers is considered an array container. |
| 138 | + * other type annotations, e.g. :class:`typing.Optional`, are not considered |
| 139 | + array containers, even if they wrap one. |
| 140 | + """ |
| 141 | + assert is_dataclass(cls) |
| 142 | + |
| 143 | + def is_array_field(f: Field) -> bool: |
| 144 | + if __debug__: |
| 145 | + if not f.init: |
| 146 | + raise ValueError( |
| 147 | + f"Fields with 'init=False' not allowed: '{f.name}'") |
| 148 | + |
| 149 | + if isinstance(f.type, str): |
| 150 | + raise TypeError( |
| 151 | + f"String annotation on field '{f.name}' not supported") |
| 152 | + |
| 153 | + # NOTE: unions of array containers are treated seprately to allow |
| 154 | + # * unions of only array containers, e.g. Union[np.ndarray, Array], as |
| 155 | + # they can work seamlessly with arithmetic and traversal. |
| 156 | + # * `Optional[ArrayContainer]` is not allowed, since `None` is not |
| 157 | + # handled by `with_container_arithmetic`, which is the common case |
| 158 | + # for current container usage. |
| 159 | + # |
| 160 | + # Other type annotations, e.g. `Tuple[Container, Container]`, are also |
| 161 | + # not allowed, as they do not work with `with_container_arithmetic`. |
| 162 | + # |
| 163 | + # This is not set in stone, but mostly driven by current usage! |
| 164 | + |
| 165 | + origin = get_origin(f.type) |
| 166 | + if origin is Union: |
| 167 | + # NOTE: `Optional` is caught in here as an alias for `Union[Anon, type]` |
| 168 | + return all(is_array_type(arg) for arg in get_args(f.type)) |
| 169 | + |
| 170 | + from typing import _GenericAlias, _SpecialForm # type: ignore[attr-defined] |
| 171 | + if isinstance(f.type, (_GenericAlias, _SpecialForm)): |
| 172 | + return False |
| 173 | + |
| 174 | + return is_array_type(f.type) |
| 175 | + |
| 176 | + from pytools import partition |
| 177 | + array_fields, non_array_fields = partition(is_array_field, fields(cls)) |
| 178 | + |
| 179 | + if not array_fields: |
| 180 | + raise ValueError(f"'{cls}' must have fields with array container type " |
| 181 | + "in order to use the 'dataclass_array_container' decorator") |
| 182 | + |
| 183 | + return inject_container_serialization(cls, array_fields, non_array_fields) |
| 184 | + |
156 | 185 | # }}}
|
157 | 186 |
|
158 | 187 | # vim: foldmethod=marker
|
0 commit comments