diff --git a/sasmodels/kernelcl.py b/sasmodels/kernelcl.py index e69fe398..85458fad 100644 --- a/sasmodels/kernelcl.py +++ b/sasmodels/kernelcl.py @@ -289,8 +289,8 @@ def has_type(self, dtype): """ return self.context.get(dtype, None) is not None - def compile_program(self, name, source, dtype, fast, timestamp): - # type: (str, str, np.dtype, bool, float) -> cl.Program + def compile_program(self, name, source, dtype, fast, timestamp, kernel_names): + # type: (str, str, np.dtype, bool, float, list[str]) -> cl.Program """ Compile the program for the device in the given context. """ @@ -299,8 +299,8 @@ def compile_program(self, name, source, dtype, fast, timestamp): tag = generate.tag_source(source) key = "%s-%s-%s%s"%(name, dtype, tag, ("-fast" if fast else "")) # Check timestamp on program. - program, program_timestamp = self.compiled.get(key, (None, np.inf)) - if program_timestamp < timestamp: + program, compile_timestamp, kernels = self.compiled.get(key, (None, np.inf, [])) + if compile_timestamp < timestamp: del self.compiled[key] if key not in self.compiled: context = self.context[dtype] @@ -308,8 +308,9 @@ def compile_program(self, name, source, dtype, fast, timestamp): context.devices[0].name.strip()) program = compile_model(self.context[dtype], str(source), dtype, fast) - self.compiled[key] = (program, timestamp) - return program + kernels = [getattr(program, k) for k in kernel_names] + self.compiled[key] = (program, timestamp, kernels) + return kernels def _create_some_context(): @@ -457,20 +458,18 @@ def get_function(self, name): def _prepare_program(self): # type: (str) -> None env = environment() + variants = ['Iq', 'Iqxy', 'Imagnetic'] + kernel_names = [generate.kernel_name(self.info, k) for k in variants] timestamp = generate.ocl_timestamp(self.info) - program = env.compile_program( + kernels = env.compile_program( self.info.name, self.source['opencl'], self.dtype, self.fast, - timestamp) - variants = ['Iq', 'Iqxy', 'Imagnetic'] - names = [generate.kernel_name(self.info, k) for k in variants] - functions = [getattr(program, k) for k in names] - self._kernels = {k: v for k, v in zip(variants, functions)} - # Keep a handle to program so GC doesn't collect. - self._program = program - + timestamp, + kernel_names, + ) + self._kernels = {k: v for k, v in zip(variants, kernels)} # TODO: Check that we don't need a destructor for buffers which go out of scope. class GpuInput: diff --git a/sasmodels/kernelcuda.py b/sasmodels/kernelcuda.py index 361d4fdd..cf96de70 100644 --- a/sasmodels/kernelcuda.py +++ b/sasmodels/kernelcuda.py @@ -274,8 +274,8 @@ def has_type(self, dtype): """ return has_type(dtype) - def compile_program(self, name, source, dtype, fast, timestamp): - # type: (str, str, np.dtype, bool, float) -> SourceModule + def compile_program(self, name, source, dtype, fast, timestamp, kernel_names): + # type: (str, str, np.dtype, bool, float, list[str]) -> SourceModule """ Compile the program for the device in the given context. """ @@ -284,14 +284,15 @@ def compile_program(self, name, source, dtype, fast, timestamp): tag = generate.tag_source(source) key = "%s-%s-%s%s"%(name, dtype, tag, ("-fast" if fast else "")) # Check timestamp on program. - program, program_timestamp = self.compiled.get(key, (None, np.inf)) - if program_timestamp < timestamp: + program, compile_timestamp, kernels = self.compiled.get(key, (None, np.inf, [])) + if compile_timestamp < timestamp: del self.compiled[key] if key not in self.compiled: logging.info("building %s for CUDA", key) program = compile_model(str(source), dtype, fast) - self.compiled[key] = (program, timestamp) - return program + kernels = [getattr(program, k) for k in kernel_names] + self.compiled[key] = (program, timestamp, kernels) + return kernels class GpuModel(KernelModel): @@ -349,19 +350,17 @@ def get_function(self, name): def _prepare_program(self): # type: (str) -> None env = environment() + variants = ['Iq', 'Iqxy', 'Imagnetic'] + kernel_names = [generate.kernel_name(self.info, k) for k in variants] timestamp = generate.ocl_timestamp(self.info) - program = env.compile_program( + kernels = env.compile_program( self.info.name, self.source['opencl'], self.dtype, self.fast, - timestamp) - variants = ['Iq', 'Iqxy', 'Imagnetic'] - names = [generate.kernel_name(self.info, k) for k in variants] - functions = [program.get_function(k) for k in names] - self._kernels = {k: v for k, v in zip(variants, functions)} - # Keep a handle to program so GC doesn't collect. - self._program = program + timestamp, + kernel_names) + self._kernels = {k: v for k, v in zip(variants, kernels)} # TODO: Check that we don't need a destructor for buffers which go out of scope.