Skip to content

Commit d5e0617

Browse files
authored
Add training infos getters at Python level (#196)
1 parent 970de65 commit d5e0617

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

python/egobox/tests/test_gpmix.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ def griewank(x):
2222

2323
class TestGpMix(unittest.TestCase):
2424
def setUp(self):
25-
xt = np.array([[0.0, 1.0, 2.0, 3.0, 4.0]]).T
26-
yt = np.array([[0.0, 1.0, 1.5, 0.9, 1.0]]).T
25+
self.xt = np.array([[0.0, 1.0, 2.0, 3.0, 4.0]]).T
26+
self.yt = np.array([[0.0, 1.0, 1.5, 0.9, 1.0]]).T
2727

2828
gpmix = egx.GpMix() # or egx.Gpx.builder()
29-
self.gpx = gpmix.fit(xt, yt)
29+
self.gpx = gpmix.fit(self.xt, self.yt)
3030

3131
def test_gpx_kriging(self):
3232
gpx = self.gpx
@@ -76,6 +76,12 @@ def test_gpx_save_load(self):
7676
0.0, gpx2.predict_var(np.array([[1.1]])).item(), delta=1e-3
7777
)
7878

79+
def test_training_params(self):
80+
self.assertEquals(self.gpx.dims(), (1, 1))
81+
(xdata, ydata) = self.gpx.training_data()
82+
np.testing.assert_array_equal(xdata, self.xt)
83+
np.testing.assert_array_equal(ydata, self.yt)
84+
7985
def test_kpls_griewank(self):
8086
lb = -600
8187
ub = 600
@@ -106,6 +112,7 @@ def test_kpls_griewank(self):
106112
for builder in builders:
107113
gpx = builder.fit(x_train, y_train)
108114
y_pred = gpx.predict(x_test)
115+
self.assertEqual(100, gpx.dims()[0])
109116
error = np.linalg.norm(y_pred - y_test) / np.linalg.norm(y_test)
110117
print(" RMS error: " + str(error))
111118

src/gp_mix.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
//! See the [tutorial notebook](https://github.com/relf/egobox/doc/Gpx_Tutorial.ipynb) for usage.
1111
//!
1212
use crate::types::*;
13+
use egobox_gp::metrics::CrossValScore;
1314
use egobox_moe::{Clustered, MixtureGpSurrogate, ThetaTuning};
1415
#[allow(unused_imports)] // Avoid linting problem
1516
use egobox_moe::{GpMixture, GpSurrogate, GpSurrogateExt};
@@ -356,6 +357,31 @@ impl Gpx {
356357
.into_pyarray_bound(py)
357358
}
358359

360+
/// Get the input and output dimensions of the surrogate
361+
///
362+
/// Returns
363+
/// the couple (nx, ny)
364+
///
365+
fn dims(&self) -> (usize, usize) {
366+
self.0.dims()
367+
}
368+
369+
/// Get the nt training data points used to fit the surrogate
370+
///
371+
/// Returns
372+
/// the couple (ndarray[nt, nx], ndarray[nt, ny])
373+
///
374+
fn training_data<'py>(
375+
&self,
376+
py: Python<'py>,
377+
) -> (Bound<'py, PyArray2<f64>>, Bound<'py, PyArray2<f64>>) {
378+
let (xdata, ydata) = self.0.training_data();
379+
(
380+
xdata.to_owned().into_pyarray_bound(py),
381+
ydata.to_owned().into_pyarray_bound(py),
382+
)
383+
}
384+
359385
/// Get optimized thetas hyperparameters (ie once GP experts are fitted)
360386
///
361387
/// Returns

0 commit comments

Comments
 (0)