Skip to content

Commit 55b4cbf

Browse files
Implement CupyArrayContext
Co-authored-by: Kaushik Kulkarni <[email protected]>
1 parent e53fa90 commit 55b4cbf

File tree

6 files changed

+354
-11
lines changed

6 files changed

+354
-11
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ jobs:
4949
. ./ci-support-v0
5050
build_py_project_in_conda_env
5151
python -m pip install mypy pytest
52+
conda install cupy
5253
./run-mypy.sh
5354
5455
pytest3_pocl:

arraycontext/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
Array, ArrayContext, ArrayOrContainer, ArrayOrContainerOrScalar,
5151
ArrayOrContainerOrScalarT, ArrayOrContainerT, ArrayT, Scalar, ScalarLike,
5252
tag_axes)
53+
from .impl.cupy import CupyArrayContext
5354
from .impl.jax import EagerJAXArrayContext
5455
from .impl.pyopencl import PyOpenCLArrayContext
5556
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
@@ -106,7 +107,9 @@
106107
"PytestArrayContextFactory",
107108
"PytestPyOpenCLArrayContextFactory",
108109
"pytest_generate_tests_for_array_contexts",
109-
"pytest_generate_tests_for_pyopencl_array_context"
110+
"pytest_generate_tests_for_pyopencl_array_context",
111+
112+
"CupyArrayContext",
110113
)
111114

112115

arraycontext/impl/cupy/__init__.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""
2+
.. currentmodule:: arraycontext
3+
4+
5+
A mod :`cupy`-based array context.
6+
7+
.. autoclass:: CupyArrayContext
8+
"""
9+
__copyright__ = """
10+
Copyright (C) 2024 University of Illinois Board of Trustees
11+
"""
12+
13+
__license__ = """
14+
Permission is hereby granted, free of charge, to any person obtaining a copy
15+
of this software and associated documentation files (the "Software"), to deal
16+
in the Software without restriction, including without limitation the rights
17+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18+
copies of the Software, and to permit persons to whom the Software is
19+
furnished to do so, subject to the following conditions:
20+
21+
The above copyright notice and this permission notice shall be included in
22+
all copies or substantial portions of the Software.
23+
24+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
30+
THE SOFTWARE.
31+
"""
32+
33+
from collections.abc import Mapping
34+
35+
36+
try:
37+
import cupy as cp # type: ignore[import-untyped]
38+
except ModuleNotFoundError:
39+
pass
40+
41+
import loopy as lp
42+
43+
from arraycontext.container.traversal import (
44+
rec_map_array_container, with_array_context)
45+
from arraycontext.context import ArrayContext
46+
47+
48+
class CupyArrayContext(ArrayContext):
49+
"""
50+
A :class:`ArrayContext` that uses :mod:`cupy.ndarray` to represent arrays
51+
52+
53+
.. automethod:: __init__
54+
"""
55+
def __init__(self):
56+
super().__init__()
57+
self._loopy_transform_cache: \
58+
Mapping["lp.TranslationUnit", "lp.TranslationUnit"] = {}
59+
60+
self.array_types = (cp.ndarray,)
61+
62+
def _get_fake_numpy_namespace(self):
63+
from .fake_numpy import CupyFakeNumpyNamespace
64+
return CupyFakeNumpyNamespace(self)
65+
66+
# {{{ ArrayContext interface
67+
68+
def clone(self):
69+
return type(self)()
70+
71+
def empty(self, shape, dtype):
72+
return cp.empty(shape, dtype=dtype)
73+
74+
def zeros(self, shape, dtype):
75+
return cp.zeros(shape, dtype)
76+
77+
def from_numpy(self, np_array):
78+
return cp.array(np_array)
79+
80+
def to_numpy(self, array):
81+
return cp.asnumpy(array)
82+
83+
def call_loopy(self, t_unit, **kwargs):
84+
t_unit = t_unit.copy(target=lp.ExecutableCTarget())
85+
try:
86+
t_unit = self._loopy_transform_cache[t_unit]
87+
except KeyError:
88+
orig_t_unit = t_unit
89+
t_unit = self.transform_loopy_program(t_unit)
90+
self._loopy_transform_cache[orig_t_unit] = t_unit
91+
del orig_t_unit
92+
93+
_, result = t_unit(**kwargs)
94+
95+
return result
96+
97+
def freeze(self, array):
98+
def _freeze(ary):
99+
return cp.asnumpy(ary)
100+
101+
return with_array_context(rec_map_array_container(_freeze, array), actx=None)
102+
103+
def thaw(self, array):
104+
def _thaw(ary):
105+
return cp.array(ary)
106+
107+
return with_array_context(rec_map_array_container(_thaw, array), actx=self)
108+
109+
# }}}
110+
111+
def transform_loopy_program(self, t_unit):
112+
raise ValueError("CupyArrayContext does not implement "
113+
"transform_loopy_program. Sub-classes are supposed "
114+
"to implement it.")
115+
116+
def tag(self, tags, array):
117+
# No tagging support in CupyArrayContext
118+
return array
119+
120+
def tag_axis(self, iaxis, tags, array):
121+
return array
122+
123+
def einsum(self, spec, *args, arg_names=None, tagged=()):
124+
return cp.einsum(spec, *args)
125+
126+
@property
127+
def permits_inplace_modification(self):
128+
return True
129+
130+
@property
131+
def supports_nonscalar_broadcasting(self):
132+
return True
133+
134+
@property
135+
def permits_advanced_indexing(self):
136+
return True

arraycontext/impl/cupy/fake_numpy.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
__copyright__ = """
2+
Copyright (C) 2024 University of Illinois Board of Trustees
3+
"""
4+
5+
__license__ = """
6+
Permission is hereby granted, free of charge, to any person obtaining a copy
7+
of this software and associated documentation files (the "Software"), to deal
8+
in the Software without restriction, including without limitation the rights
9+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
copies of the Software, and to permit persons to whom the Software is
11+
furnished to do so, subject to the following conditions:
12+
13+
The above copyright notice and this permission notice shall be included in
14+
all copies or substantial portions of the Software.
15+
16+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22+
THE SOFTWARE.
23+
"""
24+
from functools import partial, reduce
25+
26+
import cupy as cp # type: ignore[import-untyped]
27+
28+
from arraycontext.container import is_array_container
29+
from arraycontext.container.traversal import (
30+
multimap_reduce_array_container, rec_map_array_container,
31+
rec_map_reduce_array_container, rec_multimap_array_container,
32+
rec_multimap_reduce_array_container)
33+
from arraycontext.fake_numpy import (
34+
BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace)
35+
36+
37+
class CupyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
38+
# Everything is implemented in the base class for now.
39+
pass
40+
41+
42+
_NUMPY_UFUNCS = {"abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan",
43+
"sinh", "cosh", "tanh", "exp", "log", "log10", "isnan",
44+
"sqrt", "concatenate", "transpose",
45+
"ones_like", "maximum", "minimum", "where", "conj", "arctan2",
46+
}
47+
48+
49+
class CupyFakeNumpyNamespace(BaseFakeNumpyNamespace):
50+
"""
51+
A :mod:`numpy` mimic for :class:`CupyArrayContext`.
52+
"""
53+
def _get_fake_numpy_linalg_namespace(self):
54+
return CupyFakeNumpyLinalgNamespace(self._array_context)
55+
56+
def __getattr__(self, name):
57+
58+
if name in _NUMPY_UFUNCS:
59+
from functools import partial
60+
return partial(rec_multimap_array_container,
61+
getattr(cp, name))
62+
63+
raise NotImplementedError
64+
65+
def sum(self, a, axis=None, dtype=None):
66+
return rec_map_reduce_array_container(sum, partial(cp.sum,
67+
axis=axis,
68+
dtype=dtype),
69+
a)
70+
71+
def min(self, a, axis=None):
72+
return rec_map_reduce_array_container(
73+
partial(reduce, cp.minimum), partial(cp.amin, axis=axis), a)
74+
75+
def max(self, a, axis=None):
76+
return rec_map_reduce_array_container(
77+
partial(reduce, cp.maximum), partial(cp.amax, axis=axis), a)
78+
79+
def stack(self, arrays, axis=0):
80+
return rec_multimap_array_container(
81+
lambda *args: cp.stack(args, axis=axis),
82+
*arrays)
83+
84+
def broadcast_to(self, array, shape):
85+
return rec_map_array_container(partial(cp.broadcast_to, shape=shape), array)
86+
87+
# {{{ relational operators
88+
89+
def equal(self, x, y):
90+
return rec_multimap_array_container(cp.equal, x, y)
91+
92+
def not_equal(self, x, y):
93+
return rec_multimap_array_container(cp.not_equal, x, y)
94+
95+
def greater(self, x, y):
96+
return rec_multimap_array_container(cp.greater, x, y)
97+
98+
def greater_equal(self, x, y):
99+
return rec_multimap_array_container(cp.greater_equal, x, y)
100+
101+
def less(self, x, y):
102+
return rec_multimap_array_container(cp.less, x, y)
103+
104+
def less_equal(self, x, y):
105+
return rec_multimap_array_container(cp.less_equal, x, y)
106+
107+
# }}}
108+
109+
def ravel(self, a, order="C"):
110+
return rec_map_array_container(partial(cp.ravel, order=order), a)
111+
112+
def vdot(self, x, y, dtype=None):
113+
if dtype is not None:
114+
raise NotImplementedError("only 'dtype=None' supported.")
115+
116+
return rec_multimap_reduce_array_container(sum, cp.vdot, x, y)
117+
118+
def any(self, a):
119+
return rec_map_reduce_array_container(partial(reduce, cp.logical_or),
120+
lambda subary: cp.any(subary), a)
121+
122+
def all(self, a):
123+
return rec_map_reduce_array_container(partial(reduce, cp.logical_and),
124+
lambda subary: cp.all(subary), a)
125+
126+
def array_equal(self, a, b):
127+
if type(a) is not type(b):
128+
return False
129+
elif not is_array_container(a):
130+
if a.shape != b.shape:
131+
return False
132+
else:
133+
return cp.all(cp.equal(a, b))
134+
else:
135+
try:
136+
return multimap_reduce_array_container(partial(reduce,
137+
cp.logical_and),
138+
self.array_equal, a, b)
139+
except TypeError:
140+
return True
141+
142+
def zeros_like(self, ary):
143+
return rec_multimap_array_container(cp.zeros_like, ary)
144+
145+
def reshape(self, a, newshape, order="C"):
146+
return rec_map_array_container(
147+
lambda ary: ary.reshape(newshape, order=order),
148+
a)
149+
150+
def arange(self, *args, **kwargs):
151+
return cp.arange(*args, **kwargs)
152+
153+
def linspace(self, *args, **kwargs):
154+
return cp.linspace(*args, **kwargs)
155+
156+
# vim: fdm=marker

arraycontext/pytest.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,23 @@ def __str__(self):
224224
return "<PytatoJAXArrayContext>"
225225

226226

227+
class _PytestCupyArrayContextFactory(PytestArrayContextFactory):
228+
@classmethod
229+
def is_available(cls) -> bool:
230+
try:
231+
import cupy # type: ignore[import-untyped] # noqa: F401
232+
return True
233+
except ImportError:
234+
return False
235+
236+
def __call__(self):
237+
from arraycontext import CupyArrayContext
238+
return CupyArrayContext()
239+
240+
def __str__(self):
241+
return "<CupyArrayContext>"
242+
243+
227244
_ARRAY_CONTEXT_FACTORY_REGISTRY: \
228245
Dict[str, Type[PytestArrayContextFactory]] = {
229246
"pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass,
@@ -232,6 +249,7 @@ def __str__(self):
232249
"pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory,
233250
"pytato:jax": _PytestPytatoJaxArrayContextFactory,
234251
"eagerjax": _PytestEagerJaxArrayContextFactory,
252+
"cupy": _PytestCupyArrayContextFactory,
235253
}
236254

237255

0 commit comments

Comments
 (0)