Skip to content

Commit 9235e7b

Browse files
committed
feat: expose step_size_jitter option
1 parent eeab711 commit 9235e7b

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

src/wrapper.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,35 @@ impl PyNutsSettings {
784784
};
785785
Ok(())
786786
}
787+
788+
#[getter(step_size_jitter)]
789+
fn step_size_jitter(&self) -> Option<f64> {
790+
match &self.inner {
791+
Settings::LowRank(inner) => inner.adapt_options.step_size_settings.jitter,
792+
Settings::Diag(inner) => inner.adapt_options.step_size_settings.jitter,
793+
Settings::Transforming(inner) => inner.adapt_options.step_size_settings.jitter,
794+
}
795+
}
796+
797+
#[setter(step_size_jitter)]
798+
fn set_step_size_jitter(&mut self, mut val: Option<f64>) -> PyResult<()> {
799+
if let Some(val) = val {
800+
if val < 0.0 {
801+
return Err(PyValueError::new_err("step_size_jitter must be positive"));
802+
}
803+
}
804+
if let Some(jitter) = val {
805+
if jitter == 0.0 {
806+
val = None;
807+
}
808+
}
809+
match &mut self.inner {
810+
Settings::LowRank(inner) => inner.adapt_options.step_size_settings.jitter = val,
811+
Settings::Diag(inner) => inner.adapt_options.step_size_settings.jitter = val,
812+
Settings::Transforming(inner) => inner.adapt_options.step_size_settings.jitter = val,
813+
}
814+
Ok(())
815+
}
787816
}
788817

789818
pub(crate) enum SamplerState {

0 commit comments

Comments
 (0)