Skip to content

Commit 741a21b

Browse files
authored
Merge pull request #520 from OP2/mesh-level-wrapper
Allow passing layer argument to kernel
2 parents f09cefb + d3ab20a commit 741a21b

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

pyop2/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4040,6 +4040,14 @@ def __init__(self, kernel, iterset, *args, **kwargs):
40404040
self._kernel = kernel
40414041
self._is_layered = iterset._extruded
40424042
self._iteration_region = kwargs.get("iterate", None)
4043+
self._pass_layer_arg = kwargs.get("pass_layer_arg", False)
4044+
4045+
if self._pass_layer_arg:
4046+
if self.is_direct:
4047+
raise ValueError("Can't request layer arg for direct iteration")
4048+
if not self._is_layered:
4049+
raise ValueError("Can't request layer arg for non-extruded iteration")
4050+
40434051
# Are we only computing over owned set entities?
40444052
self._only_local = isinstance(iterset, LocalSet)
40454053

@@ -4389,6 +4397,10 @@ def par_loop(kernel, it_space, *args, **kwargs):
43894397
except the top layer, accessing data two adjacent (in
43904398
the extruded direction) cells at a time.
43914399
4400+
:kwarg pass_layer_arg: Should the wrapper pass the current layer
4401+
into the kernel (as an ``int``). Only makes sense for
4402+
indirect extruded iteration.
4403+
43924404
.. warning ::
43934405
It is the caller's responsibility that the number and type of all
43944406
:class:`base.Arg`\s passed to the :func:`par_loop` match those expected

pyop2/sequential.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,7 @@ def __init__(self, kernel, itspace, *args, **kwargs):
609609
self._args = args
610610
self._direct = kwargs.get('direct', False)
611611
self._iteration_region = kwargs.get('iterate', ALL)
612+
self._pass_layer_arg = kwargs.get('pass_layer_arg', False)
612613
# Copy the class variables, so we don't overwrite them
613614
self._cppargs = dcopy(type(self)._cppargs)
614615
self._libraries = dcopy(type(self)._libraries)
@@ -709,7 +710,8 @@ def generate_code(self):
709710
kernel_name=self._kernel._name,
710711
user_code=self._kernel._user_code,
711712
wrapper_name=self._wrapper_name,
712-
iteration_region=self._iteration_region)
713+
iteration_region=self._iteration_region,
714+
pass_layer_arg=self._pass_layer_arg)
713715
return self._code_dict
714716

715717
def set_argtypes(self, iterset, *args):
@@ -781,7 +783,8 @@ def prepare_arglist(self, iterset, *args):
781783
@cached_property
782784
def _jitmodule(self):
783785
return JITModule(self.kernel, self.it_space, *self.args,
784-
direct=self.is_direct, iterate=self.iteration_region)
786+
direct=self.is_direct, iterate=self.iteration_region,
787+
pass_layer_arg=self._pass_layer_arg)
785788

786789
@collective
787790
def _compute(self, part, fun, *arglist):
@@ -792,7 +795,7 @@ def _compute(self, part, fun, *arglist):
792795

793796
def wrapper_snippets(itspace, args,
794797
kernel_name=None, wrapper_name=None, user_code=None,
795-
iteration_region=ALL):
798+
iteration_region=ALL, pass_layer_arg=False):
796799
"""Generates code snippets for the wrapper,
797800
ready to be into a template.
798801
@@ -914,6 +917,10 @@ def extrusion_loop():
914917
_buf_gather[arg] = "\n".join([_itspace_loops, _buf_gather[arg], _itspace_loop_close])
915918
_kernel_args = ', '.join([arg.c_kernel_arg(count) if not arg._uses_itspace else _buf_name[arg]
916919
for count, arg in enumerate(args)])
920+
921+
if pass_layer_arg:
922+
_kernel_args += ", j_0"
923+
917924
_buf_gather = ";\n".join(_buf_gather.values())
918925
_buf_decl = ";\n".join(_buf_decl.values())
919926

test/unit/test_extrusion.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,23 @@ def test_direct_loop_inc(self, xtr_nodes):
377377
dat.dataset.set, dat(op2.INC))
378378
assert numpy.allclose(dat.data[:], 1.0)
379379

380+
def test_extruded_layer_arg(self, elements, field_map, dat_f):
381+
"""Tests that the layer argument is being passed when prompted
382+
to in the parloop."""
383+
384+
kernel_blah = """void kernel_blah(double* x[], int layer_arg){
385+
x[0][0] = layer_arg;
386+
}\n"""
387+
388+
op2.par_loop(op2.Kernel(kernel_blah, "kernel_blah"),
389+
elements, dat_f(op2.WRITE, field_map),
390+
pass_layer_arg=True)
391+
end = layers - 1
392+
start = 0
393+
ref = np.arange(start, end)
394+
assert [dat_f.data[end*n:end*(n+1)] == ref
395+
for n in range(len(dat_f.data) + 1)]
396+
380397
def test_write_data_field(self, elements, dat_coords, dat_field, coords_map, field_map, dat_f):
381398
kernel_wo = "void kernel_wo(double* x[]) { x[0][0] = 42.0; }\n"
382399

0 commit comments

Comments
 (0)