|
| 1 | +__copyright__ = """ |
| 2 | +Copyright (C) 2024 University of Illinois Board of Trustees |
| 3 | +""" |
| 4 | + |
| 5 | +__license__ = """ |
| 6 | +Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | +of this software and associated documentation files (the "Software"), to deal |
| 8 | +in the Software without restriction, including without limitation the rights |
| 9 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 10 | +copies of the Software, and to permit persons to whom the Software is |
| 11 | +furnished to do so, subject to the following conditions: |
| 12 | +
|
| 13 | +The above copyright notice and this permission notice shall be included in |
| 14 | +all copies or substantial portions of the Software. |
| 15 | +
|
| 16 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
| 22 | +THE SOFTWARE. |
| 23 | +""" |
| 24 | +from functools import partial, reduce |
| 25 | + |
| 26 | +import cupy as cp # type: ignore[import-untyped] |
| 27 | + |
| 28 | +from arraycontext.container import is_array_container |
| 29 | +from arraycontext.container.traversal import ( |
| 30 | + multimap_reduce_array_container, rec_map_array_container, |
| 31 | + rec_map_reduce_array_container, rec_multimap_array_container, |
| 32 | + rec_multimap_reduce_array_container) |
| 33 | +from arraycontext.fake_numpy import ( |
| 34 | + BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace) |
| 35 | + |
| 36 | + |
| 37 | +class CupyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): |
| 38 | + # Everything is implemented in the base class for now. |
| 39 | + pass |
| 40 | + |
| 41 | + |
| 42 | +_NUMPY_UFUNCS = {"abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan", |
| 43 | + "sinh", "cosh", "tanh", "exp", "log", "log10", "isnan", |
| 44 | + "sqrt", "concatenate", "transpose", |
| 45 | + "ones_like", "maximum", "minimum", "where", "conj", "arctan2", |
| 46 | + } |
| 47 | + |
| 48 | + |
| 49 | +class CupyFakeNumpyNamespace(BaseFakeNumpyNamespace): |
| 50 | + """ |
| 51 | + A :mod:`numpy` mimic for :class:`CupyArrayContext`. |
| 52 | + """ |
| 53 | + def _get_fake_numpy_linalg_namespace(self): |
| 54 | + return CupyFakeNumpyLinalgNamespace(self._array_context) |
| 55 | + |
| 56 | + def __getattr__(self, name): |
| 57 | + |
| 58 | + if name in _NUMPY_UFUNCS: |
| 59 | + from functools import partial |
| 60 | + return partial(rec_multimap_array_container, |
| 61 | + getattr(cp, name)) |
| 62 | + |
| 63 | + raise NotImplementedError |
| 64 | + |
| 65 | + def sum(self, a, axis=None, dtype=None): |
| 66 | + return rec_map_reduce_array_container(sum, partial(cp.sum, |
| 67 | + axis=axis, |
| 68 | + dtype=dtype), |
| 69 | + a) |
| 70 | + |
| 71 | + def min(self, a, axis=None): |
| 72 | + return rec_map_reduce_array_container( |
| 73 | + partial(reduce, cp.minimum), partial(cp.amin, axis=axis), a) |
| 74 | + |
| 75 | + def max(self, a, axis=None): |
| 76 | + return rec_map_reduce_array_container( |
| 77 | + partial(reduce, cp.maximum), partial(cp.amax, axis=axis), a) |
| 78 | + |
| 79 | + def stack(self, arrays, axis=0): |
| 80 | + return rec_multimap_array_container( |
| 81 | + lambda *args: cp.stack(args, axis=axis), |
| 82 | + *arrays) |
| 83 | + |
| 84 | + def broadcast_to(self, array, shape): |
| 85 | + return rec_map_array_container(partial(cp.broadcast_to, shape=shape), array) |
| 86 | + |
| 87 | + # {{{ relational operators |
| 88 | + |
| 89 | + def equal(self, x, y): |
| 90 | + return rec_multimap_array_container(cp.equal, x, y) |
| 91 | + |
| 92 | + def not_equal(self, x, y): |
| 93 | + return rec_multimap_array_container(cp.not_equal, x, y) |
| 94 | + |
| 95 | + def greater(self, x, y): |
| 96 | + return rec_multimap_array_container(cp.greater, x, y) |
| 97 | + |
| 98 | + def greater_equal(self, x, y): |
| 99 | + return rec_multimap_array_container(cp.greater_equal, x, y) |
| 100 | + |
| 101 | + def less(self, x, y): |
| 102 | + return rec_multimap_array_container(cp.less, x, y) |
| 103 | + |
| 104 | + def less_equal(self, x, y): |
| 105 | + return rec_multimap_array_container(cp.less_equal, x, y) |
| 106 | + |
| 107 | + # }}} |
| 108 | + |
| 109 | + def ravel(self, a, order="C"): |
| 110 | + return rec_map_array_container(partial(cp.ravel, order=order), a) |
| 111 | + |
| 112 | + def vdot(self, x, y, dtype=None): |
| 113 | + if dtype is not None: |
| 114 | + raise NotImplementedError("only 'dtype=None' supported.") |
| 115 | + |
| 116 | + return rec_multimap_reduce_array_container(sum, cp.vdot, x, y) |
| 117 | + |
| 118 | + def any(self, a): |
| 119 | + return rec_map_reduce_array_container(partial(reduce, cp.logical_or), |
| 120 | + lambda subary: cp.any(subary), a) |
| 121 | + |
| 122 | + def all(self, a): |
| 123 | + return rec_map_reduce_array_container(partial(reduce, cp.logical_and), |
| 124 | + lambda subary: cp.all(subary), a) |
| 125 | + |
| 126 | + def array_equal(self, a, b): |
| 127 | + if type(a) is not type(b): |
| 128 | + return False |
| 129 | + elif not is_array_container(a): |
| 130 | + if a.shape != b.shape: |
| 131 | + return False |
| 132 | + else: |
| 133 | + return cp.all(cp.equal(a, b)) |
| 134 | + else: |
| 135 | + try: |
| 136 | + return multimap_reduce_array_container(partial(reduce, |
| 137 | + cp.logical_and), |
| 138 | + self.array_equal, a, b) |
| 139 | + except TypeError: |
| 140 | + return True |
| 141 | + |
| 142 | + def zeros_like(self, ary): |
| 143 | + return rec_multimap_array_container(cp.zeros_like, ary) |
| 144 | + |
| 145 | + def reshape(self, a, newshape, order="C"): |
| 146 | + return rec_map_array_container( |
| 147 | + lambda ary: ary.reshape(newshape, order=order), |
| 148 | + a) |
| 149 | + |
| 150 | + def arange(self, *args, **kwargs): |
| 151 | + return cp.arange(*args, **kwargs) |
| 152 | + |
| 153 | + def linspace(self, *args, **kwargs): |
| 154 | + return cp.linspace(*args, **kwargs) |
| 155 | + |
| 156 | +# vim: fdm=marker |
0 commit comments