|
31 | 31 | """
|
32 | 32 |
|
33 | 33 | from dataclasses import fields
|
| 34 | +from typing import Callable, Optional, Tuple, Union, overload |
| 35 | + |
34 | 36 | from arraycontext.container import is_array_container_type
|
35 | 37 |
|
36 | 38 |
|
37 | 39 | # {{{ dataclass containers
|
38 | 40 |
|
39 |
| -def dataclass_array_container(cls: type) -> type: |
40 |
| - """A class decorator that makes the class to which it is applied a |
| 41 | +@overload |
| 42 | +def dataclass_array_container( |
| 43 | + cls: None = None, |
| 44 | + excluded_fields: Optional[Tuple[str, ...]] = None, |
| 45 | + ) -> Callable[[type], type]: |
| 46 | + ... |
| 47 | + |
| 48 | + |
| 49 | +@overload |
| 50 | +def dataclass_array_container( |
| 51 | + cls: type, |
| 52 | + excluded_fields: Optional[Tuple[str, ...]] = None, |
| 53 | + ) -> type: |
| 54 | + ... |
| 55 | + |
| 56 | + |
| 57 | +def dataclass_array_container( |
| 58 | + cls: Optional[type] = None, |
| 59 | + excluded_fields: Optional[Tuple[str, ...]] = None, |
| 60 | + ) -> Union[type, Callable[[type], type]]: |
| 61 | + """A class decorator that makes the class to which it is applied an |
41 | 62 | :class:`ArrayContainer` by registering appropriate implementations of
|
42 | 63 | :func:`serialize_container` and :func:`deserialize_container`.
|
43 | 64 | *cls* must be a :func:`~dataclasses.dataclass`.
|
44 | 65 |
|
45 | 66 | Attributes that are not array containers are allowed. In order to decide
|
46 | 67 | whether an attribute is an array container, the declared attribute type
|
47 |
| - is checked by the criteria from :func:`is_array_container`. |
| 68 | + is checked by the criteria from :func:`is_array_container`. Additional |
| 69 | + attributes can be excluded manually using *excluded_fields* |
48 | 70 | """
|
49 |
| - from dataclasses import is_dataclass |
50 |
| - assert is_dataclass(cls) |
51 |
| - |
52 |
| - array_fields = [ |
53 |
| - f for f in fields(cls) if is_array_container_type(f.type)] |
54 |
| - non_array_fields = [ |
55 |
| - f for f in fields(cls) if not is_array_container_type(f.type)] |
56 |
| - |
57 |
| - if not array_fields: |
58 |
| - raise ValueError(f"'{cls}' must have fields with array container type " |
59 |
| - "in order to use the 'dataclass_array_container' decorator") |
60 |
| - |
61 |
| - serialize_expr = ", ".join( |
62 |
| - f"({f.name!r}, ary.{f.name})" for f in array_fields) |
63 |
| - template_kwargs = ", ".join( |
64 |
| - f"{f.name}=template.{f.name}" for f in non_array_fields) |
65 |
| - |
66 |
| - lower_cls_name = cls.__name__.lower() |
67 |
| - |
68 |
| - serialize_init_code = ", ".join(f"{f.name!r}: f'{{instance_name}}.{f.name}'" |
69 |
| - for f in array_fields) |
70 |
| - deserialize_init_code = ", ".join([ |
71 |
| - f"{f.name}={{args[{f.name!r}]}}" for f in array_fields |
72 |
| - ] + [ |
73 |
| - f"{f.name}={{template_instance_name}}.{f.name}" |
74 |
| - for f in non_array_fields |
75 |
| - ]) |
76 |
| - |
77 |
| - from pytools.codegen import remove_common_indentation |
78 |
| - serialize_code = remove_common_indentation(f""" |
79 |
| - from typing import Any, Iterable, Tuple |
80 |
| - from arraycontext import serialize_container, deserialize_container |
81 |
| -
|
82 |
| - @serialize_container.register(cls) |
83 |
| - def _serialize_{lower_cls_name}(ary: cls) -> Iterable[Tuple[Any, Any]]: |
84 |
| - return ({serialize_expr},) |
85 |
| -
|
86 |
| - @deserialize_container.register(cls) |
87 |
| - def _deserialize_{lower_cls_name}( |
88 |
| - template: cls, iterable: Iterable[Tuple[Any, Any]]) -> cls: |
89 |
| - return cls(**dict(iterable), {template_kwargs}) |
90 |
| -
|
91 |
| - # support for with_container_arithmetic |
92 |
| -
|
93 |
| - def _serialize_init_arrays_code_{lower_cls_name}(cls, instance_name): |
94 |
| - return {{ |
95 |
| - {serialize_init_code} |
96 |
| - }} |
97 |
| - cls._serialize_init_arrays_code = classmethod( |
98 |
| - _serialize_init_arrays_code_{lower_cls_name}) |
99 |
| -
|
100 |
| - def _deserialize_init_arrays_code_{lower_cls_name}( |
101 |
| - cls, template_instance_name, args): |
102 |
| - return f"{deserialize_init_code}" |
103 |
| -
|
104 |
| - cls._deserialize_init_arrays_code = classmethod( |
105 |
| - _deserialize_init_arrays_code_{lower_cls_name}) |
106 |
| - """) |
107 |
| - |
108 |
| - exec_dict = {"cls": cls, "_MODULE_SOURCE_CODE": serialize_code} |
109 |
| - exec(compile(serialize_code, f"<container serialization for {cls.__name__}>", |
110 |
| - "exec"), exec_dict) |
111 |
| - |
112 |
| - return cls |
| 71 | + if excluded_fields is None: |
| 72 | + excluded_fields = () |
| 73 | + |
| 74 | + def wrap(cls: type) -> type: |
| 75 | + assert excluded_fields is not None |
| 76 | + |
| 77 | + from dataclasses import is_dataclass |
| 78 | + assert is_dataclass(cls) |
| 79 | + |
| 80 | + cls_fields = fields(cls) |
| 81 | + assert all(any(f == cf.name for cf in cls_fields) for f in excluded_fields) |
| 82 | + |
| 83 | + container_fields = [ |
| 84 | + f for f in cls_fields |
| 85 | + if f.name not in excluded_fields and is_array_container_type(f.type) |
| 86 | + ] |
| 87 | + non_container_fields = [ |
| 88 | + f for f in fields(cls) |
| 89 | + if f.name in excluded_fields or not is_array_container_type(f.type) |
| 90 | + ] |
| 91 | + |
| 92 | + if not container_fields: |
| 93 | + raise ValueError(f"'{cls}' must have fields with array container type " |
| 94 | + "in order to use the 'dataclass_array_container' decorator") |
| 95 | + |
| 96 | + serialize_expr = ", ".join( |
| 97 | + f"({f.name!r}, ary.{f.name})" for f in container_fields) |
| 98 | + template_kwargs = ", ".join( |
| 99 | + f"{f.name}=template.{f.name}" for f in non_container_fields) |
| 100 | + |
| 101 | + lower_cls_name = cls.__name__.lower() |
| 102 | + |
| 103 | + serialize_init_code = ", ".join(f"{f.name!r}: f'{{instance_name}}.{f.name}'" |
| 104 | + for f in container_fields) |
| 105 | + deserialize_init_code = ", ".join([ |
| 106 | + f"{f.name}={{args[{f.name!r}]}}" for f in container_fields |
| 107 | + ] + [ |
| 108 | + f"{f.name}={{template_instance_name}}.{f.name}" |
| 109 | + for f in non_container_fields |
| 110 | + ]) |
| 111 | + |
| 112 | + from pytools.codegen import remove_common_indentation |
| 113 | + serialize_code = remove_common_indentation(f""" |
| 114 | + from typing import Any, Iterable, Tuple |
| 115 | + from arraycontext import serialize_container, deserialize_container |
| 116 | +
|
| 117 | + @serialize_container.register(cls) |
| 118 | + def _serialize_{lower_cls_name}(ary: cls) -> Iterable[Tuple[Any, Any]]: |
| 119 | + return ({serialize_expr},) |
| 120 | +
|
| 121 | + @deserialize_container.register(cls) |
| 122 | + def _deserialize_{lower_cls_name}( |
| 123 | + template: cls, iterable: Iterable[Tuple[Any, Any]]) -> cls: |
| 124 | + return cls(**dict(iterable), {template_kwargs}) |
| 125 | +
|
| 126 | + # support for with_container_arithmetic |
| 127 | +
|
| 128 | + def _serialize_init_arrays_code_{lower_cls_name}(cls, instance_name): |
| 129 | + return {{ |
| 130 | + {serialize_init_code} |
| 131 | + }} |
| 132 | + cls._serialize_init_arrays_code = classmethod( |
| 133 | + _serialize_init_arrays_code_{lower_cls_name}) |
| 134 | +
|
| 135 | + def _deserialize_init_arrays_code_{lower_cls_name}( |
| 136 | + cls, template_instance_name, args): |
| 137 | + return f"{deserialize_init_code}" |
| 138 | +
|
| 139 | + cls._deserialize_init_arrays_code = classmethod( |
| 140 | + _deserialize_init_arrays_code_{lower_cls_name}) |
| 141 | + """) |
| 142 | + |
| 143 | + exec_dict = {"cls": cls, "_MODULE_SOURCE_CODE": serialize_code} |
| 144 | + exec(compile(serialize_code, f"<container serialization for {cls.__name__}>", |
| 145 | + "exec"), exec_dict) |
| 146 | + |
| 147 | + return cls |
| 148 | + |
| 149 | + if cls is not None: |
| 150 | + return wrap(cls) |
| 151 | + else: |
| 152 | + return wrap |
113 | 153 |
|
114 | 154 | # }}}
|
115 | 155 |
|
|
0 commit comments