Skip to content

Commit

Permalink
[FIX] amplitude and ls property types (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
a-ws-m committed Oct 22, 2021
1 parent 7072967 commit ba6c627
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions tests/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_reload(tmp_path: Path, kernel_type: Type[AmpAndLengthScaleFn]):
loaded_weights = loaded_kernel.get_weights()
assert loaded_weights == example_weights

loaded_amp = loaded_kernel.amplitude
loaded_ls = loaded_kernel.length_scale
loaded_amp = loaded_kernel.amplitude.numpy()
loaded_ls = loaded_kernel.length_scale.numpy()
assert loaded_amp == pytest.approx(orig_amplitude)
assert loaded_ls == pytest.approx(orig_length_scale)
8 changes: 4 additions & 4 deletions unlocknn/kernel_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,14 @@ def __init__(self, **kwargs):
)

@property
def amplitude(self) -> float:
def amplitude(self) -> tf.Tensor:
"""Get the current kernel amplitude."""
return tf.nn.softplus(0.1 * self._amplitude_basis).numpy().item()
return tf.nn.softplus(0.1 * self._amplitude_basis)

@property
def length_scale(self) -> float:
def length_scale(self) -> tf.Tensor:
"""Get the current kernel length scale."""
return tf.nn.softplus(5.0 * self._length_scale_basis).numpy().item()
return tf.nn.softplus(5.0 * self._length_scale_basis)


class RBFKernelFn(AmpAndLengthScaleFn):
Expand Down

0 comments on commit ba6c627

Please sign in to comment.