diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 72bd0243..55deb90c 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -173,6 +173,8 @@ def is_array_container_type(cls: type) -> bool: function will say that :class:`numpy.ndarray` is an array container type, only object arrays *actually are* array containers. """ + assert isinstance(cls, type), f"must pass a {type!r}, not a '{cls!r}'" + return ( cls is ArrayContainer or (serialize_container.dispatch(cls) diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index 2891f60e..65492076 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -46,13 +46,34 @@ def dataclass_array_container(cls: type) -> type: whether an attribute is an array container, the declared attribute type is checked by the criteria from :func:`is_array_container_type`. """ - from dataclasses import is_dataclass + from dataclasses import is_dataclass, Field assert is_dataclass(cls) - array_fields = [ - f for f in fields(cls) if is_array_container_type(f.type)] - non_array_fields = [ - f for f in fields(cls) if not is_array_container_type(f.type)] + def is_array_field(f: Field) -> bool: + if __debug__: + if not f.init: + raise ValueError( + f"'init=False' field not allowed: '{f.name}'") + + if isinstance(f.type, str): + raise TypeError( + f"string annotation on field '{f.name}' not supported") + + from typing import _SpecialForm + if isinstance(f.type, _SpecialForm): + raise TypeError( + f"typing annotation not supported on field '{f.name}': " + f"'{f.type!r}'") + + if not isinstance(f.type, type): + raise TypeError( + f"field '{f.name}' not an instance of 'type': " + f"'{f.type!r}'") + + return is_array_container_type(f.type) + + from pytools import partition + array_fields, non_array_fields = partition(is_array_field, fields(cls)) if not array_fields: raise ValueError(f"'{cls}' must have fields with array container type " diff --git a/test/test_utils.py b/test/test_utils.py index 2228152f..08b6c3a5 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -22,11 +22,16 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import pytest + +import numpy as np import logging logger = logging.getLogger(__name__) +# {{{ test_pt_actx_key_stringification_uniqueness + def test_pt_actx_key_stringification_uniqueness(): from arraycontext.impl.pytato.compile import _ary_container_key_stringifier @@ -36,13 +41,63 @@ def test_pt_actx_key_stringification_uniqueness(): assert (_ary_container_key_stringifier(("tup", 3, "endtup")) != _ary_container_key_stringifier(((3,),))) +# }}} + + +# {{{ test_dataclass_array_container + +def test_dataclass_array_container(): + from typing import Optional + from dataclasses import dataclass, field + from arraycontext import dataclass_array_container + + # {{{ string fields + + @dataclass + class ArrayContainerWithStringTypes: + x: np.ndarray + y: "np.ndarray" + + with pytest.raises(TypeError): + # NOTE: cannot have string annotations in container + dataclass_array_container(ArrayContainerWithStringTypes) + + # }}} + + # {{{ optional fields + + @dataclass + class ArrayContainerWithOptional: + x: np.ndarray + y: Optional[np.ndarray] + + with pytest.raises(TypeError): + # NOTE: cannot have wrapped annotations (here by `Optional`) + dataclass_array_container(ArrayContainerWithOptional) + + # }}} + + # {{{ field(init=False) + + @dataclass + class ArrayContainerWithInitFalse: + x: np.ndarray + y: np.ndarray = field(default=np.zeros(42), init=False, repr=False) + + with pytest.raises(ValueError): + # NOTE: init=False fields are not allowed + dataclass_array_container(ArrayContainerWithInitFalse) + + # }}} + +# }}} + if __name__ == "__main__": import sys if len(sys.argv) > 1: exec(sys.argv[1]) else: - from pytest import main - main([__file__]) + pytest.main([__file__]) # vim: fdm=marker