From ba6c627e04de10a64ecfe855f5bdf717c7d2ed28 Mon Sep 17 00:00:00 2001 From: a-ws-m Date: Fri, 22 Oct 2021 18:47:30 +0100 Subject: [PATCH] [FIX] amplitude and ls property types (#32) --- tests/test_kernels.py | 4 ++-- unlocknn/kernel_layers.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_kernels.py b/tests/test_kernels.py index 2c90cbc..4d10319 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -58,7 +58,7 @@ def test_reload(tmp_path: Path, kernel_type: Type[AmpAndLengthScaleFn]): loaded_weights = loaded_kernel.get_weights() assert loaded_weights == example_weights - loaded_amp = loaded_kernel.amplitude - loaded_ls = loaded_kernel.length_scale + loaded_amp = loaded_kernel.amplitude.numpy() + loaded_ls = loaded_kernel.length_scale.numpy() assert loaded_amp == pytest.approx(orig_amplitude) assert loaded_ls == pytest.approx(orig_length_scale) diff --git a/unlocknn/kernel_layers.py b/unlocknn/kernel_layers.py index 208fc93..56e3801 100644 --- a/unlocknn/kernel_layers.py +++ b/unlocknn/kernel_layers.py @@ -110,14 +110,14 @@ def __init__(self, **kwargs): ) @property - def amplitude(self) -> float: + def amplitude(self) -> tf.Tensor: """Get the current kernel amplitude.""" - return tf.nn.softplus(0.1 * self._amplitude_basis).numpy().item() + return tf.nn.softplus(0.1 * self._amplitude_basis) @property - def length_scale(self) -> float: + def length_scale(self) -> tf.Tensor: """Get the current kernel length scale.""" - return tf.nn.softplus(5.0 * self._length_scale_basis).numpy().item() + return tf.nn.softplus(5.0 * self._length_scale_basis) class RBFKernelFn(AmpAndLengthScaleFn):