diff --git a/tpcp/optimize/_optimize.py b/tpcp/optimize/_optimize.py index c72ad93..b990e83 100644 --- a/tpcp/optimize/_optimize.py +++ b/tpcp/optimize/_optimize.py @@ -721,13 +721,7 @@ def optimize(self, dataset: DatasetT, **optimize_params) -> Self: n_splits = cv.get_n_splits(dataset) - # We need to wrap our pipeline for a consistent interface. - # In the future we might be able to allow objects with optimizer Interface as input directly. - optimizer = Optimize( - self.pipeline, - safe_optimize=self.safe_optimize, - optimize_with_info=self.optimize_with_info, - ) + optimizer = self._wrap_pipeline() # For each para combi, we separate the pure parameters (parameters that do not affect the optimization) and # the hyperparameters. @@ -833,6 +827,16 @@ def optimize(self, dataset: DatasetT, **optimize_params) -> Self: return self + def _wrap_pipeline(self): + # We need to wrap our pipeline for a consistent interface. + # In the future we might be able to allow objects with optimizer Interface as input directly. + optimizer = Optimize( + self.pipeline, + safe_optimize=self.safe_optimize, + optimize_with_info=self.optimize_with_info, + ) + return optimizer + def _format_results(self, candidate_params, n_splits, out, more_results=None): # noqa: C901 """Format the final result dict.