@@ -156,7 +156,7 @@ def _compute(self, part, fun, *arglist):
156
156
fun (part .offset , part .offset + part .size , * arglist )
157
157
158
158
159
- def generate_cell_wrapper (itspace , args , kernel_name = None , wrapper_name = None ):
159
+ def generate_cell_wrapper (itspace , args , forward_args = (), kernel_name = None , wrapper_name = None ):
160
160
direct = all (a .map is None for a in args )
161
161
snippets = host .wrapper_snippets (itspace , args , kernel_name = kernel_name , wrapper_name = wrapper_name )
162
162
@@ -170,7 +170,10 @@ def generate_cell_wrapper(itspace, args, kernel_name=None, wrapper_name=None):
170
170
snippets ['nlayers_arg' ] = ""
171
171
snippets ['extr_pos_loop' ] = ""
172
172
173
- template = """static inline void %(wrapper_name)s(%(wrapper_args)s%(const_args)s%(nlayers_arg)s, int cell)
173
+ snippets ['wrapper_fargs' ] = "" .join ("{1} farg{0}, " .format (i , arg ) for i , arg in enumerate (forward_args ))
174
+ snippets ['kernel_fargs' ] = "" .join ("farg{0}, " .format (i ) for i in xrange (len (forward_args )))
175
+
176
+ template = """static inline void %(wrapper_name)s(%(wrapper_fargs)s%(wrapper_args)s%(const_args)s%(nlayers_arg)s, int cell)
174
177
{
175
178
%(user_code)s
176
179
%(wrapper_decs)s;
@@ -186,7 +189,7 @@ def generate_cell_wrapper(itspace, args, kernel_name=None, wrapper_name=None):
186
189
%(map_bcs_m)s;
187
190
%(buffer_decl)s;
188
191
%(buffer_gather)s
189
- %(kernel_name)s(%(kernel_args)s);
192
+ %(kernel_name)s(%(kernel_fargs)s%( kernel_args)s);
190
193
%(itset_loop_body)s
191
194
%(map_bcs_p)s;
192
195
}
0 commit comments