From 89ecf9e725826917bf6cadaf3c6e2e92a792fd7c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Arne=20K=C3=BCderle?= <a.kuederle@gmail.com>
Date: Mon, 1 Jul 2024 16:33:05 +0200
Subject: [PATCH 1/2] Implemented custom wrapper around sklearn splitters

---
 docs/modules/validate.rst                     |   1 +
 examples/datasets/_01_datasets_basics.py      |  13 ++
 .../_04_advanced_cross_validation.py          | 171 ++++++++++++++++++
 .../test_advanced_cross_validate_0.json       | 144 +++++++++++++++
 .../test_advanced_cross_validate_1.json       | 124 +++++++++++++
 tests/test_examples/test_all_examples.py      |  11 ++
 tests/test_pipelines/test_validate.py         | 102 +++++------
 tpcp/optimize/_optimize.py                    |  40 ++--
 tpcp/validate/__init__.py                     |   3 +-
 tpcp/validate/_cross_val_helper.py            |  73 ++++++++
 tpcp/validate/_validate.py                    |  65 ++-----
 11 files changed, 616 insertions(+), 131 deletions(-)
 create mode 100644 examples/validation/_04_advanced_cross_validation.py
 create mode 100644 tests/test_examples/snapshot/test_advanced_cross_validate_0.json
 create mode 100644 tests/test_examples/snapshot/test_advanced_cross_validate_1.json
 create mode 100644 tpcp/validate/_cross_val_helper.py

diff --git a/docs/modules/validate.rst b/docs/modules/validate.rst
index 09bfbd4..a8cff14 100644
--- a/docs/modules/validate.rst
+++ b/docs/modules/validate.rst
@@ -14,6 +14,7 @@ Classes
    :toctree: generated/validate
    :template: class_with_private.rst
 
+    TpcpSplitter
     Scorer
     Aggregator
     MeanAggregator
diff --git a/examples/datasets/_01_datasets_basics.py b/examples/datasets/_01_datasets_basics.py
index 3aca4dd..e23c482 100644
--- a/examples/datasets/_01_datasets_basics.py
+++ b/examples/datasets/_01_datasets_basics.py
@@ -221,6 +221,19 @@ def create_index(self):
     # We only print the train set here
     print(final_subset[train], end="\n\n")
 
+# %%
+# Instead of doing this manually, we also provide a custom splitter that does this for you.
+# It allows us to directly put the dataset into the `split` method of `cross_validate` and use higher level semantics
+# to specify the grouping and stratification.
+from tpcp.validate import TpcpSplitter
+
+cv = TpcpSplitter(GroupKFold(n_splits=2), groupby=["participant", "recording"])
+
+for train, test in cv.split(final_subset):
+    # We only print the train set here
+    print(final_subset[train], end="\n\n")
+
+
 # %%
 # Creating labels also works for datasets that are already grouped.
 # But, the columns that should be contained in the label must be a subset of the groupby columns in this case.
diff --git a/examples/validation/_04_advanced_cross_validation.py b/examples/validation/_04_advanced_cross_validation.py
new file mode 100644
index 0000000..96a3733
--- /dev/null
+++ b/examples/validation/_04_advanced_cross_validation.py
@@ -0,0 +1,171 @@
+"""
+Advanced cross-validation
+-------------------------
+In many real world datasets, a normal k-fold cross-validation might not be ideal, as it assumes that each data point is
+fully independent of each other.
+This is often not the case, as our dataset might contain multiple data points from the same participant.
+Furthermore, we might have multiple "stratification" variables that we want to keep balanced across the folds.
+For example, different clinical conditions or different measurement devices.
+
+This two concepts of "grouping" and "stratification" are sometimes complicated to understand and certain (even though
+common) cases are not supported by the standard sklearn cross-validation splitters, without "abusing" the API.
+For this reason, we create dedicated support for this in tpcp to tackle these cases with a little more confidence.
+"""
+# %%
+# Let's start by re-creating the simple example from the normal cross-validation example.
+#
+# Dataset
+# +++++++
+from pathlib import Path
+
+from examples.datasets.datasets_final_ecg import ECGExampleData
+
+try:
+    HERE = Path(__file__).parent
+except NameError:
+    HERE = Path().resolve()
+data_path = HERE.parent.parent / "example_data/ecg_mit_bih_arrhythmia/data"
+example_data = ECGExampleData(data_path)
+
+# %%
+# Pipeline
+# ++++++++
+import pandas as pd
+
+from examples.algorithms.algorithms_qrs_detection_final import QRSDetector
+from tpcp import Parameter, Pipeline, cf
+
+
+class MyPipeline(Pipeline):
+    algorithm: Parameter[QRSDetector]
+
+    r_peak_positions_: pd.Series
+
+    def __init__(self, algorithm: QRSDetector = cf(QRSDetector())):
+        self.algorithm = algorithm
+
+    def run(self, datapoint: ECGExampleData):
+        # Note: We need to clone the algorithm instance, to make sure we don't leak any data between runs.
+        algo = self.algorithm.clone()
+        algo.detect(datapoint.data, datapoint.sampling_rate_hz)
+
+        self.r_peak_positions_ = algo.r_peak_positions_
+        return self
+
+
+# %%
+# The Scorer
+# ++++++++++
+from examples.algorithms.algorithms_qrs_detection_final import match_events_with_reference, precision_recall_f1_score
+
+
+def score(pipeline: MyPipeline, datapoint: ECGExampleData):
+    # We use the `safe_run` wrapper instead of just run. This is always a good idea.
+    # We don't need to clone the pipeline here, as GridSearch will already clone the pipeline internally and `run`
+    # will clone it again.
+    pipeline = pipeline.safe_run(datapoint)
+    tolerance_s = 0.02  # We just use 20 ms for this example
+    matches = match_events_with_reference(
+        pipeline.r_peak_positions_.to_numpy(),
+        datapoint.r_peak_positions_.to_numpy(),
+        tolerance=tolerance_s * datapoint.sampling_rate_hz,
+    )
+    precision, recall, f1_score = precision_recall_f1_score(matches)
+    return {"precision": precision, "recall": recall, "f1_score": f1_score}
+
+
+# %%
+# Stratifcation
+# +++++++++++++
+# With this setup done, we can have a closer look at the dataset.
+example_data
+
+# %%
+# The index has two columns, one indicating the participant group and one indicating the participant id.
+# In this simple example, all groups appear the same amount of times and the index is ordered in a way that
+# each fold will likely get a balanced amount of participants from each group.
+#
+# To show the impact of grouping and stratification, we take a subset of the data, that removes some participants from
+# "group_1" to create an imbalance.
+data_imbalanced = example_data.get_subset(index=example_data.index.query("participant not in ['114', '121']"))
+
+# %%
+# Running a simple cross-validation with 2 folds, will have all group-1 participants in the test data of the first fold:
+#
+# Note, that we skip optimization of the pipeline, to keep the example simple and fast.
+from sklearn.model_selection import KFold
+
+from tpcp.optimize import DummyOptimize
+from tpcp.validate import cross_validate
+
+cv = KFold(n_splits=2)
+
+pipe = MyPipeline()
+optimizable_pipe = DummyOptimize(pipe)
+
+results = cross_validate(optimizable_pipe, data_imbalanced, scoring=score, cv=cv)
+result_df = pd.DataFrame(results)
+
+# %%
+# We can see that the test data of the first fold contains only participants from group 1.
+result_df["test_data_labels"].explode()
+
+# %%
+# This works fine when the groups are just "additional information", and are unlikely to affect the data within.
+# For example, if the groups just reflect in which hospital the data was collected.
+# However, when the group reflect information that is likely to affect the data (e.g. a relevant medical indication),
+# we need to make sure that the actual group probabilities are remain the same in all folds.
+# This can be done through stratification.
+#
+# .. note:: It is important to understand that "stratification" is not "balancing" the groups.
+#           Group balancing should never be done during data splitting, as it will change the data distribution in your
+#           test set, which will no longer reflect the real-world distribution.
+#
+# To stratify by the "patient group" we can use the `TpcpSplitter` class.
+# We will provide it with a base splitter that enables stratification (in this case a `StratifiedKFold` splitter) and
+# the column(s) to stratify by.
+from sklearn.model_selection import StratifiedKFold
+
+from tpcp.validate import TpcpSplitter
+
+cv = TpcpSplitter(base_splitter=StratifiedKFold(n_splits=2), stratify="patient_group")
+
+results = cross_validate(optimizable_pipe, data_imbalanced, scoring=score, cv=cv)
+result_df_stratified = pd.DataFrame(results)
+result_df_stratified["test_data_labels"].explode()
+
+# %%
+# Now we can see that the groups are balanced in each fold and both folds get one of the remaining group 1 participants.
+#
+# Grouping
+# ++++++++
+# Where stratification ensures that the distribution of a specific column is the same in all folds, grouping ensures
+# that all data of one group is always either in the train or the test set, but never split across it.
+# This is useful, when we have data points that are somehow correlated and the existence of data points from the same
+# group in both the train and the test set of the same fold could hence be considered a "leak".
+#
+# A typical example for this is when we have multiple data points from the same participant.
+# In our case here, we will use the "patient_group" as grouping variable for demonstration purposes, as we don't have multiple
+# data points per participant.
+#
+# Note, that we use the "non-subsampled" example data here.
+from sklearn.model_selection import GroupKFold
+
+cv = TpcpSplitter(base_splitter=GroupKFold(n_splits=2), groupby="patient_group")
+
+results = cross_validate(optimizable_pipe, example_data, scoring=score, cv=cv)
+result_df_grouped = pd.DataFrame(results)
+result_df_grouped["test_data_labels"].explode()
+
+# %%
+# We can see that this forces the creation of unequal sice splits to ensure that the groups are kept together.
+# This is important to keep in mind when using grouping, as it can lead to unequally sized test sets.
+#
+# Combining Grouping and Stratification
+# +++++++++++++++++++++++++++++++++++++
+# Of course, we can also combine grouping and stratification.
+# A typical example would be to stratify by clinical condition and group by participant.
+# This is also easily possible with the `TpcpSplitter` class by providing both arguments.
+#
+# For the dataset that we have here, this does of course not make much sense, so we are not going to show an example
+# here.
diff --git a/tests/test_examples/snapshot/test_advanced_cross_validate_0.json b/tests/test_examples/snapshot/test_advanced_cross_validate_0.json
new file mode 100644
index 0000000..5992e41
--- /dev/null
+++ b/tests/test_examples/snapshot/test_advanced_cross_validate_0.json
@@ -0,0 +1,144 @@
+{
+    "schema":{
+        "fields":[
+            {
+                "name":"index",
+                "type":"integer"
+            },
+            {
+                "name":"fold_id",
+                "type":"integer"
+            },
+            {
+                "name":"test_data_labels",
+                "type":"string"
+            }
+        ],
+        "primaryKey":[
+            "index"
+        ],
+        "pandas_version":"1.4.0"
+    },
+    "data":[
+        {
+            "index":0,
+            "fold_id":0,
+            "test_data_labels":"group_1"
+        },
+        {
+            "index":1,
+            "fold_id":0,
+            "test_data_labels":"100"
+        },
+        {
+            "index":2,
+            "fold_id":0,
+            "test_data_labels":"group_3"
+        },
+        {
+            "index":3,
+            "fold_id":0,
+            "test_data_labels":"104"
+        },
+        {
+            "index":4,
+            "fold_id":0,
+            "test_data_labels":"group_1"
+        },
+        {
+            "index":5,
+            "fold_id":0,
+            "test_data_labels":"105"
+        },
+        {
+            "index":6,
+            "fold_id":0,
+            "test_data_labels":"group_3"
+        },
+        {
+            "index":7,
+            "fold_id":0,
+            "test_data_labels":"108"
+        },
+        {
+            "index":8,
+            "fold_id":0,
+            "test_data_labels":"group_1"
+        },
+        {
+            "index":9,
+            "fold_id":0,
+            "test_data_labels":"114"
+        },
+        {
+            "index":10,
+            "fold_id":0,
+            "test_data_labels":"group_3"
+        },
+        {
+            "index":11,
+            "fold_id":0,
+            "test_data_labels":"119"
+        },
+        {
+            "index":12,
+            "fold_id":0,
+            "test_data_labels":"group_1"
+        },
+        {
+            "index":13,
+            "fold_id":0,
+            "test_data_labels":"121"
+        },
+        {
+            "index":14,
+            "fold_id":0,
+            "test_data_labels":"group_3"
+        },
+        {
+            "index":15,
+            "fold_id":0,
+            "test_data_labels":"200"
+        },
+        {
+            "index":16,
+            "fold_id":1,
+            "test_data_labels":"group_2"
+        },
+        {
+            "index":17,
+            "fold_id":1,
+            "test_data_labels":"102"
+        },
+        {
+            "index":18,
+            "fold_id":1,
+            "test_data_labels":"group_2"
+        },
+        {
+            "index":19,
+            "fold_id":1,
+            "test_data_labels":"106"
+        },
+        {
+            "index":20,
+            "fold_id":1,
+            "test_data_labels":"group_2"
+        },
+        {
+            "index":21,
+            "fold_id":1,
+            "test_data_labels":"116"
+        },
+        {
+            "index":22,
+            "fold_id":1,
+            "test_data_labels":"group_2"
+        },
+        {
+            "index":23,
+            "fold_id":1,
+            "test_data_labels":"123"
+        }
+    ]
+}
\ No newline at end of file
diff --git a/tests/test_examples/snapshot/test_advanced_cross_validate_1.json b/tests/test_examples/snapshot/test_advanced_cross_validate_1.json
new file mode 100644
index 0000000..ccfed52
--- /dev/null
+++ b/tests/test_examples/snapshot/test_advanced_cross_validate_1.json
@@ -0,0 +1,124 @@
+{
+    "schema":{
+        "fields":[
+            {
+                "name":"index",
+                "type":"integer"
+            },
+            {
+                "name":"fold_id",
+                "type":"integer"
+            },
+            {
+                "name":"test_data_labels",
+                "type":"string"
+            }
+        ],
+        "primaryKey":[
+            "index"
+        ],
+        "pandas_version":"1.4.0"
+    },
+    "data":[
+        {
+            "index":0,
+            "fold_id":0,
+            "test_data_labels":"group_1"
+        },
+        {
+            "index":1,
+            "fold_id":0,
+            "test_data_labels":"100"
+        },
+        {
+            "index":2,
+            "fold_id":0,
+            "test_data_labels":"group_2"
+        },
+        {
+            "index":3,
+            "fold_id":0,
+            "test_data_labels":"102"
+        },
+        {
+            "index":4,
+            "fold_id":0,
+            "test_data_labels":"group_3"
+        },
+        {
+            "index":5,
+            "fold_id":0,
+            "test_data_labels":"104"
+        },
+        {
+            "index":6,
+            "fold_id":0,
+            "test_data_labels":"group_2"
+        },
+        {
+            "index":7,
+            "fold_id":0,
+            "test_data_labels":"106"
+        },
+        {
+            "index":8,
+            "fold_id":0,
+            "test_data_labels":"group_3"
+        },
+        {
+            "index":9,
+            "fold_id":0,
+            "test_data_labels":"108"
+        },
+        {
+            "index":10,
+            "fold_id":1,
+            "test_data_labels":"group_1"
+        },
+        {
+            "index":11,
+            "fold_id":1,
+            "test_data_labels":"105"
+        },
+        {
+            "index":12,
+            "fold_id":1,
+            "test_data_labels":"group_2"
+        },
+        {
+            "index":13,
+            "fold_id":1,
+            "test_data_labels":"116"
+        },
+        {
+            "index":14,
+            "fold_id":1,
+            "test_data_labels":"group_3"
+        },
+        {
+            "index":15,
+            "fold_id":1,
+            "test_data_labels":"119"
+        },
+        {
+            "index":16,
+            "fold_id":1,
+            "test_data_labels":"group_2"
+        },
+        {
+            "index":17,
+            "fold_id":1,
+            "test_data_labels":"123"
+        },
+        {
+            "index":18,
+            "fold_id":1,
+            "test_data_labels":"group_3"
+        },
+        {
+            "index":19,
+            "fold_id":1,
+            "test_data_labels":"200"
+        }
+    ]
+}
\ No newline at end of file
diff --git a/tests/test_examples/test_all_examples.py b/tests/test_examples/test_all_examples.py
index 560558c..0114516 100644
--- a/tests/test_examples/test_all_examples.py
+++ b/tests/test_examples/test_all_examples.py
@@ -69,6 +69,17 @@ def test_cross_validate():
     assert_almost_equal(results["test_f1_score"], [0.9770585, 0.7108303, 0.9250665])
 
 
+def test_advanced_cross_validate(snapshot):
+    from examples.validation._04_advanced_cross_validation import result_df_grouped, result_df_stratified
+
+    snapshot.assert_match(
+        result_df_grouped["test_data_labels"].explode().explode().to_frame().rename_axis("fold_id").reset_index()
+    )
+    snapshot.assert_match(
+        result_df_stratified["test_data_labels"].explode().explode().to_frame().rename_axis("fold_id").reset_index()
+    )
+
+
 def test_optuna():
     from examples.parameter_optimization._04_custom_optuna_optimizer import opti, opti_early_stop
 
diff --git a/tests/test_pipelines/test_validate.py b/tests/test_pipelines/test_validate.py
index d533f22..a0cded5 100644
--- a/tests/test_pipelines/test_validate.py
+++ b/tests/test_pipelines/test_validate.py
@@ -15,7 +15,7 @@
 from tpcp import Dataset, OptimizableParameter, OptimizablePipeline
 from tpcp.exceptions import OptimizationError, TestError
 from tpcp.optimize import DummyOptimize, Optimize
-from tpcp.validate import cross_validate, validate
+from tpcp.validate import TpcpSplitter, cross_validate, validate
 from tpcp.validate._scorer import Scorer, _validate_scorer
 
 
@@ -227,56 +227,6 @@ def test_returned_optimizer_per_fold_independent(self):
         for o in optimizers:
             assert o is not optimizer
 
-    @pytest.mark.parametrize("propagate", (True, False))
-    def test_propagate_groups(self, propagate):
-        pipeline = DummyOptimizablePipeline()
-        dataset = DummyGroupedDataset()
-        groups = dataset.create_string_group_labels("v1")
-        # With 3 splits, each group get its own split -> so basically only "a", only "b", and only "c"
-        cv = GroupKFold(n_splits=3)
-
-        dummy_results = Optimize(pipeline).optimize(dataset)
-        with patch.object(Optimize, "optimize", return_value=dummy_results) as mock:
-            cross_validate(
-                Optimize(pipeline), dataset, cv=cv, scoring=lambda x, y: 1, groups=groups, propagate_groups=propagate
-            )
-
-        assert mock.call_count == 3
-        for call, label in zip(mock.call_args_list, "cba"):
-            train_labels = "abc".replace(label, "")
-            if propagate:
-                assert set(np.unique(call[1]["groups"])) == set(train_labels)
-            else:
-                assert "groups" not in call[1]
-
-    @pytest.mark.parametrize("propagate", (True, False))
-    def test_propagate_mock_labels(self, propagate):
-        pipeline = DummyOptimizablePipeline()
-        dataset = DummyGroupedDataset()
-        groups = dataset.create_string_group_labels("v1")
-        # With 5 folds, we expect exactly on "a", one "b", and one "c" in each fold
-        cv = StratifiedKFold(n_splits=5)
-
-        dummy_results = Optimize(pipeline).optimize(dataset)
-        with patch.object(Optimize, "optimize", return_value=dummy_results) as mock:
-            cross_validate(
-                Optimize(pipeline),
-                dataset,
-                cv=cv,
-                scoring=lambda x, y: 1,
-                mock_labels=groups,
-                propagate_mock_labels=propagate,
-                propagate_groups=False,
-            )
-
-        assert mock.call_count == 5
-        for call in mock.call_args_list:
-            if propagate:
-                assert len(np.unique(call[1]["mock_labels"])) == 3
-                assert set(np.unique(call[1]["mock_labels"])) == set("abc")
-            else:
-                assert "mock_labels" not in call[1]
-
     @pytest.mark.parametrize("error_fold", (0, 2))
     def test_cross_validate_opti_error(self, error_fold):
         with pytest.raises(OptimizationError) as e:
@@ -344,3 +294,53 @@ def test_cross_validate_optimizer_are_cloned(self):
 
         assert len(results["optimizer"]) == 5
         assert len({id(o) for o in results["optimizer"]}) == 5
+
+
+class TestTpcpSplitter:
+    def test_normal_k_fold(self):
+        ds = DummyGroupedDataset()
+        splitter = TpcpSplitter(base_splitter=KFold(n_splits=5))
+        # This should be identical to just calling the splitter directly
+        splits_expected = list(KFold(n_splits=5).split(ds))
+
+        splits = list(splitter.split(ds))
+
+        for (train_expected, test_expected), (train, test) in zip(splits_expected, splits):
+            assert train_expected.tolist() == train.tolist()
+            assert test_expected.tolist() == test.tolist()
+
+    def test_normal_k_fold_with_groupby_ignored(self):
+        ds = DummyGroupedDataset()
+        splitter = TpcpSplitter(base_splitter=KFold(n_splits=5), groupby="v1")
+        # This should be identical to just calling the splitter directly
+        splits_expected = list(KFold(n_splits=5).split(ds))
+
+        splits = list(splitter.split(ds))
+
+        for (train_expected, test_expected), (train, test) in zip(splits_expected, splits):
+            assert train_expected.tolist() == train.tolist()
+            assert test_expected.tolist() == test.tolist()
+
+    def test_normal_group_k_fold(self):
+        ds = DummyGroupedDataset()
+        splitter = TpcpSplitter(base_splitter=GroupKFold(n_splits=3), groupby="v1")
+        # This should be identical to just calling the splitter directly
+        splits_expected = list(GroupKFold(n_splits=3).split(ds, groups=ds.create_string_group_labels("v1")))
+
+        splits = list(splitter.split(ds))
+
+        for (train_expected, test_expected), (train, test) in zip(splits_expected, splits):
+            assert train_expected.tolist() == train.tolist()
+            assert test_expected.tolist() == test.tolist()
+
+    def test_normal_stratified_k_fold(self):
+        ds = DummyGroupedDataset()
+        splitter = TpcpSplitter(base_splitter=StratifiedKFold(n_splits=3), stratify="v1")
+        # This should be identical to just calling the splitter directly
+        splits_expected = list(StratifiedKFold(n_splits=3).split(ds, y=ds.create_string_group_labels("v1")))
+
+        splits = list(splitter.split(ds))
+
+        for (train_expected, test_expected), (train, test) in zip(splits_expected, splits):
+            assert train_expected.tolist() == train.tolist()
+            assert test_expected.tolist() == test.tolist()
diff --git a/tpcp/optimize/_optimize.py b/tpcp/optimize/_optimize.py
index 06ea289..df45c5c 100644
--- a/tpcp/optimize/_optimize.py
+++ b/tpcp/optimize/_optimize.py
@@ -21,7 +21,7 @@
 from joblib import Memory, Parallel
 from numpy.ma import MaskedArray
 from scipy.stats import rankdata
-from sklearn.model_selection import BaseCrossValidator, ParameterGrid, check_cv
+from sklearn.model_selection import BaseCrossValidator, ParameterGrid
 from tqdm.auto import tqdm
 from typing_extensions import Self
 
@@ -45,6 +45,7 @@
 from tpcp._utils._score import _optimize_and_score, _score
 from tpcp.exceptions import PotentialUserErrorWarning
 from tpcp.parallel import delayed
+from tpcp.validate import TpcpSplitter
 from tpcp.validate._scorer import ScorerTypes, _validate_scorer
 
 if TYPE_CHECKING:
@@ -533,10 +534,12 @@ class GridSearchCV(
         with a minus sign, e.g. `-rmse`.
         In case of a single score, use `-score` to select the value with the lowest score.
     cv
-        An integer specifying the number of folds in a K-Fold cross validation or a valid cross validation helper.
-        The default (`None`) will result in a 5-fold cross validation.
-        For further inputs check the `sklearn`
-        `documentation <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html>`_.
+        The cross-validation strategy to use.
+        For simple use-cases the same input as for the sklearn cross-validation function are supported.
+        For further inputs check the `sklearn` `documentation
+        <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_validate.html>`_.
+
+        For more complex usecases like grouping or stratification, the :class:`~tpcp.TpcpSplitter` can be used.
     pure_parameters
         .. warning::
             Do not use this option unless you fully understand it!
@@ -584,10 +587,6 @@ class GridSearchCV(
     ----------------
     dataset
         The dataset instance passed to the optimize method
-    groups
-        The groups passed to the optimize method
-    mock_labels
-        The mock labels passed to the optimize method
 
     Attributes
     ----------
@@ -658,7 +657,7 @@ class GridSearchCV(
     parameter_grid: ParameterGrid
     scoring: ScorerTypes[OptimizablePipelineT, DatasetT, T]
     return_optimized: Union[bool, str]
-    cv: Optional[Union[int, BaseCrossValidator, Iterator]]
+    cv: Optional[Union[TpcpSplitter, int, BaseCrossValidator, Iterator]]
     pure_parameters: Union[bool, list[str]]
     return_train_score: bool
     verbose: int
@@ -668,9 +667,6 @@ class GridSearchCV(
     safe_optimize: bool
     optimize_with_info: bool
 
-    groups: Optional[list[Union[str, tuple[str, ...]]]]
-    mock_labels: Optional[list[Union[str, tuple[str, ...]]]]
-
     cv_results_: dict[str, Any]
     best_params_: dict[str, Any]
     best_index_: int
@@ -709,29 +705,21 @@ def __init__(
         self.safe_optimize = safe_optimize
         self.optimize_with_info = optimize_with_info
 
-    def optimize(self, dataset: DatasetT, *, groups=None, mock_labels=None, **optimize_params) -> Self:
+    def optimize(self, dataset: DatasetT, **optimize_params) -> Self:
         """Run the GridSearchCV on the given dataset.
 
         Parameters
         ----------
         dataset
             The dataset to optimize on.
-        groups
-            An optional set of group labels that are passed to the cross-validation helper.
-        mock_labels
-            An optional set of mocked labels that are passed to the cross-validation helper as the `y` parameter.
-            This can be helpful in combination with the `Stratified*Fold` cross-validation helpers, that use the `y`
-            parameter to stratify the folds.
-
         """
         self.dataset = dataset
-        self.groups = groups
-        self.mock_labels = mock_labels
 
         scoring = _validate_scorer(self.scoring, self.pipeline)
 
-        cv_checked: BaseCrossValidator = check_cv(self.cv, None, classifier=True)
-        n_splits = cv_checked.get_n_splits(dataset, mock_labels, groups=groups)
+        cv = self.cv if isinstance(self.cv, TpcpSplitter) else TpcpSplitter(self.cv)
+
+        n_splits = cv.get_n_splits(dataset)
 
         # We need to wrap our pipeline for a consistent interface.
         # In the future we might be able to allow objects with optimizer Interface as input directly.
@@ -762,7 +750,7 @@ def optimize(self, dataset: DatasetT, *, groups=None, mock_labels=None, **optimi
         combinations = list(
             product(
                 enumerate(split_parameters),
-                enumerate(cv_checked.split(dataset, mock_labels, groups=groups)),
+                enumerate(cv.split(dataset)),
             )
         )
 
diff --git a/tpcp/validate/__init__.py b/tpcp/validate/__init__.py
index b39bfaa..55474bd 100644
--- a/tpcp/validate/__init__.py
+++ b/tpcp/validate/__init__.py
@@ -1,5 +1,6 @@
 """Module for all helper methods to evaluate algorithms."""
+from tpcp.validate._cross_val_helper import TpcpSplitter
 from tpcp.validate._scorer import Aggregator, MeanAggregator, NoAgg, Scorer
 from tpcp.validate._validate import cross_validate, validate
 
-__all__ = ["Scorer", "NoAgg", "Aggregator", "MeanAggregator", "cross_validate", "validate"]
+__all__ = ["Scorer", "NoAgg", "Aggregator", "MeanAggregator", "cross_validate", "validate", "TpcpSplitter"]
diff --git a/tpcp/validate/_cross_val_helper.py b/tpcp/validate/_cross_val_helper.py
new file mode 100644
index 0000000..4393ad6
--- /dev/null
+++ b/tpcp/validate/_cross_val_helper.py
@@ -0,0 +1,73 @@
+from collections.abc import Iterator
+from typing import Optional, Union
+
+from sklearn.model_selection import BaseCrossValidator, check_cv
+
+from tpcp import BaseTpcpObject, Dataset
+
+
+class TpcpSplitter(BaseTpcpObject):
+    """Wrapper around sklearn cross-validation splitters to support grouping and stratification with tpcp-Datasets.
+
+    This wrapper can be used instead of a sklearn-style splitter with all methods that support a ``cv`` parameter.
+    Whenever you want to do complicated cv-logic (like grouping or stratification's), this wrapper is the way to go.
+
+    .. warning:: We don't validate if the selected ``base_splitter`` does anything useful with the provided
+        ``groupby`` and ``stratify`` information.
+        This wrapper just ensures, that the information is correctly extracted from the dataset and passed to the
+        ``split`` method of the ``base_splitter``.
+        So if you are using a normal ``KFold`` splitter, the ``groupby`` and ``stratify`` arguments will have no effect.
+
+    Parameters
+    ----------
+    base_splitter
+        The base splitter to use. Can be an integer (for ``KFold``), an iterator, or any other valid sklearn-splitter.
+        The default is None, which will use the sklearn default ``KFold`` splitter with 5 splits.
+    groupby
+        The column(s) to group by. If None, no grouping is done.
+        Must be a subset of the columns in the dataset.
+
+        This will generate a set of unique string labels with the same shape as the dataset.
+        This will passed to the base splitter as the ``groups`` parameter.
+        It is up to the base splitter to decide what to do with the generated labels.
+    stratify
+        The column(s) to stratify by. If None, no stratification is done.
+        Must be a subset of the columns in the dataset.
+
+        This will generate a set of unique string labels with the same shape as the dataset.
+        This will passed to the base splitter as the ``y`` parameter, acting as "mock" target labels, as sklearn only
+        support stratification on classification outcome targets.
+        It is up to the base splitter to decide what to do with the generated labels.
+
+    """
+
+    def __init__(
+        self,
+        base_splitter: Optional[Union[int, BaseCrossValidator, Iterator]] = None,
+        *,
+        groupby: Optional[Union[str, list[str]]] = None,
+        stratify: Optional[Union[str, list[str]]] = None,
+    ):
+        self.base_splitter = base_splitter
+        self.stratify = stratify
+        self.groupby = groupby
+
+    def _get_splitter(self):
+        return check_cv(self.base_splitter, y=None, classifier=True)
+
+    def _get_labels(self, dataset: Dataset, labels: Union[None, str, list[str]]):
+        if labels:
+            return dataset.create_string_group_labels(labels)
+        return None
+
+    def split(self, dataset: Dataset) -> Iterator[tuple[list[int], list[int]]]:
+        """Split the dataset into train and test sets."""
+        return self._get_splitter().split(
+            dataset, y=self._get_labels(dataset, self.stratify), groups=self._get_labels(dataset, self.groupby)
+        )
+
+    def get_n_splits(self, dataset: Dataset) -> int:
+        """Get the number of splits."""
+        return self._get_splitter().get_n_splits(
+            dataset, y=self._get_labels(dataset, self.stratify), groups=self._get_labels(dataset, self.groupby)
+        )
diff --git a/tpcp/validate/_validate.py b/tpcp/validate/_validate.py
index 91e248d..0c3b26f 100644
--- a/tpcp/validate/_validate.py
+++ b/tpcp/validate/_validate.py
@@ -5,7 +5,7 @@
 
 import numpy as np
 from joblib import Parallel
-from sklearn.model_selection import BaseCrossValidator, check_cv
+from sklearn.model_selection import BaseCrossValidator
 from tqdm.auto import tqdm
 
 from tpcp import Dataset, Pipeline
@@ -14,6 +14,7 @@
 from tpcp._utils._general import _aggregate_final_results, _normalize_score_results, _passthrough
 from tpcp._utils._score import _optimize_and_score, _score
 from tpcp.parallel import delayed
+from tpcp.validate._cross_val_helper import TpcpSplitter
 from tpcp.validate._scorer import Scorer, _validate_scorer
 
 
@@ -21,15 +22,11 @@ def cross_validate(
     optimizable: BaseOptimize,
     dataset: Dataset,
     *,
-    groups: Optional[list[Union[str, tuple[str, ...]]]] = None,
-    mock_labels: Optional[list[Union[str, tuple[str, ...]]]] = None,
     scoring: Optional[Callable] = None,
-    cv: Optional[Union[int, BaseCrossValidator, Iterator]] = None,
+    cv: Optional[Union[TpcpSplitter, int, BaseCrossValidator, Iterator]] = None,
     n_jobs: Optional[int] = None,
     verbose: int = 0,
     optimize_params: Optional[dict[str, Any]] = None,
-    propagate_groups: bool = True,
-    propagate_mock_labels: bool = True,
     pre_dispatch: Union[str, int] = "2*n_jobs",
     return_train_score: bool = False,
     return_optimizer: bool = False,
@@ -47,34 +44,17 @@ def cross_validate(
         :class:`~tpcp.Pipeline` wrapped in an `Optimize` object (:class:`~tpcp.OptimizablePipeline`).
     dataset
         A :class:`~tpcp.Dataset` containing all information.
-    groups
-        Group labels for samples used by the cross validation helper, in case a grouped CV is used (e.g.
-        :class:`~sklearn.model_selection.GroupKFold`).
-        Check the documentation of the :class:`~tpcp.Dataset` class and the respective example for
-        information on how to generate group labels for tpcp datasets.
-
-        The groups will be passed to the optimizers `optimize` method under the same name, if `propagate_groups` is
-        True.
-    mock_labels
-        The value of `mock_labels` is passed as the `y` parameter to the cross-validation helper's `split` method.
-        This can be helpful, if you want to use stratified cross validation.
-        Usually, the stratified CV classes use `y` (i.e. the label) to stratify the data.
-        However, in tpcp, we don't have a dedicated `y` as data and labels are both stored in a single datastructure.
-        If you want to stratify the data (e.g. based on patient cohorts), you can create your own list of labels/groups
-        that should be used for stratification and pass it to `mock_labels` instead.
-
-        The labels will be passed to the optimizers `optimize` method under the same name, if
-        `propagate_mock_labels` is True (similar to how groups are handled).
     scoring
         A callable that can score a single data point given a pipeline.
         This function should return either a single score or a dictionary of scores.
         If scoring is `None` the default `score` method of the optimizable is used instead.
     cv
-        An integer specifying the number of folds in a K-Fold cross validation or a valid cross validation helper.
-        The default (`None`) will result in a 5-fold cross validation.
-        For further inputs check the `sklearn`
-        `documentation
+        The cross-validation strategy to use.
+        For simple use-cases the same input as for the sklearn cross-validation function are supported.
+        For further inputs check the `sklearn` `documentation
         <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_validate.html>`_.
+
+        For more complex usecases like grouping or stratification, the :class:`~tpcp.TpcpSplitter` can be used.
     n_jobs
         Number of jobs to run in parallel.
         One job is created per CV fold.
@@ -84,16 +64,6 @@ def cross_validate(
         At the moment this only effects `Parallel`.
     optimize_params
         Additional parameter that are forwarded to the `optimize` method.
-    propagate_groups
-        In case your optimizable is a cross validation based optimize (e.g. :class:`~tpcp.optimize.GridSearchCv`) and
-        you are using a grouped cross validation, you probably want to use the same grouped CV for the outer and the
-        inner cross validation.
-        If `propagate_groups` is True, the group labels belonging to the training of each fold are passed to the
-        `optimize` method of the optimizable.
-        This only has an effect if `groups` are specified.
-    propagate_mock_labels
-        For the same reason as `propagate_groups`, you might also want to forward the value provided for
-        `mock_labels` to the optimization workflow.
     pre_dispatch
         The number of jobs that should be pre dispatched.
         For an explanation see the documentation of :class:`~joblib.Parallel`.
@@ -143,18 +113,11 @@ def cross_validate(
             instance.
 
     """
-    cv_checked: BaseCrossValidator = check_cv(cv, None, classifier=True)
-
     scoring = _validate_scorer(scoring, optimizable.pipeline)
 
-    optimize_params = optimize_params or {}
-    if propagate_groups is True and "groups" in optimize_params:
-        raise ValueError(
-            "You can not use `propagate_groups` and specify `groups` in `optimize_params`. "
-            "The latter would overwrite the prior. "
-            "Most likely you only want to use `propagate_groups`."
-        )
-    splits = list(cv_checked.split(dataset, mock_labels, groups=groups))
+    cv = cv if isinstance(cv, TpcpSplitter) else TpcpSplitter(base_splitter=cv)
+
+    splits = list(cv.split(dataset))
 
     pbar = partial(tqdm, total=len(splits), desc="CV Folds") if progress_bar else _passthrough
 
@@ -170,11 +133,7 @@ def cross_validate(
                         scoring,
                         dataset[train],
                         dataset[test],
-                        optimize_params={
-                            **_propagate_values("groups", propagate_groups, groups, train),
-                            **_propagate_values("mock_labels", propagate_mock_labels, mock_labels, train),
-                            **optimize_params,
-                        },
+                        optimize_params=optimize_params,
                         hyperparameters=None,
                         pure_parameters=None,
                         return_train_score=return_train_score,

From 6d8c3be84726624c70ce5d5dd350667874812cb5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Arne=20K=C3=BCderle?= <a.kuederle@gmail.com>
Date: Mon, 1 Jul 2024 16:36:59 +0200
Subject: [PATCH 2/2] TpcpSplitter -> DatasetSplitter

---
 CHANGELOG.md                                         | 10 ++++++++++
 docs/modules/validate.rst                            |  2 +-
 examples/datasets/_01_datasets_basics.py             |  4 ++--
 examples/validation/_04_advanced_cross_validation.py |  6 +++---
 tests/test_pipelines/test_validate.py                | 10 +++++-----
 tpcp/optimize/_optimize.py                           |  6 +++---
 tpcp/validate/__init__.py                            |  4 ++--
 tpcp/validate/_cross_val_helper.py                   |  2 +-
 tpcp/validate/_validate.py                           |  6 +++---
 9 files changed, 30 insertions(+), 20 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 0bcf2f9..e9dc005 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -4,6 +4,16 @@ All notable changes to this project will be documented in this file.
 The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) (+ the Migration Guide),
 and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
 
+## [1.0.0] - [unreleased]
+
+### BREAKING CHANGE
+
+- Instead of the (annoying) `mock_label` and `group_label` arguments, all functions that take a cv-splitter as input,
+  can now take an instance of the new `DatasetSplitter` class, which elegantly handles grouping and stratification and
+  also removes the need of forwarding the `mock_label` and `group_label` arguments to the underlying optimizer.
+  The use of the `mock_label` and `group_label` arguments has been removed without depreciation.
+  (https://github.com/mad-lab-fau/tpcp/pull/114)
+
 ## [0.34.0] - 2024-06-28
 
 ### Added
diff --git a/docs/modules/validate.rst b/docs/modules/validate.rst
index a8cff14..60a466a 100644
--- a/docs/modules/validate.rst
+++ b/docs/modules/validate.rst
@@ -14,7 +14,7 @@ Classes
    :toctree: generated/validate
    :template: class_with_private.rst
 
-    TpcpSplitter
+    DatasetSplitter
     Scorer
     Aggregator
     MeanAggregator
diff --git a/examples/datasets/_01_datasets_basics.py b/examples/datasets/_01_datasets_basics.py
index e23c482..6187213 100644
--- a/examples/datasets/_01_datasets_basics.py
+++ b/examples/datasets/_01_datasets_basics.py
@@ -225,9 +225,9 @@ def create_index(self):
 # Instead of doing this manually, we also provide a custom splitter that does this for you.
 # It allows us to directly put the dataset into the `split` method of `cross_validate` and use higher level semantics
 # to specify the grouping and stratification.
-from tpcp.validate import TpcpSplitter
+from tpcp.validate import DatasetSplitter
 
-cv = TpcpSplitter(GroupKFold(n_splits=2), groupby=["participant", "recording"])
+cv = DatasetSplitter(GroupKFold(n_splits=2), groupby=["participant", "recording"])
 
 for train, test in cv.split(final_subset):
     # We only print the train set here
diff --git a/examples/validation/_04_advanced_cross_validation.py b/examples/validation/_04_advanced_cross_validation.py
index 96a3733..4164559 100644
--- a/examples/validation/_04_advanced_cross_validation.py
+++ b/examples/validation/_04_advanced_cross_validation.py
@@ -126,9 +126,9 @@ def score(pipeline: MyPipeline, datapoint: ECGExampleData):
 # the column(s) to stratify by.
 from sklearn.model_selection import StratifiedKFold
 
-from tpcp.validate import TpcpSplitter
+from tpcp.validate import DatasetSplitter
 
-cv = TpcpSplitter(base_splitter=StratifiedKFold(n_splits=2), stratify="patient_group")
+cv = DatasetSplitter(base_splitter=StratifiedKFold(n_splits=2), stratify="patient_group")
 
 results = cross_validate(optimizable_pipe, data_imbalanced, scoring=score, cv=cv)
 result_df_stratified = pd.DataFrame(results)
@@ -151,7 +151,7 @@ def score(pipeline: MyPipeline, datapoint: ECGExampleData):
 # Note, that we use the "non-subsampled" example data here.
 from sklearn.model_selection import GroupKFold
 
-cv = TpcpSplitter(base_splitter=GroupKFold(n_splits=2), groupby="patient_group")
+cv = DatasetSplitter(base_splitter=GroupKFold(n_splits=2), groupby="patient_group")
 
 results = cross_validate(optimizable_pipe, example_data, scoring=score, cv=cv)
 result_df_grouped = pd.DataFrame(results)
diff --git a/tests/test_pipelines/test_validate.py b/tests/test_pipelines/test_validate.py
index a0cded5..cff81b1 100644
--- a/tests/test_pipelines/test_validate.py
+++ b/tests/test_pipelines/test_validate.py
@@ -15,7 +15,7 @@
 from tpcp import Dataset, OptimizableParameter, OptimizablePipeline
 from tpcp.exceptions import OptimizationError, TestError
 from tpcp.optimize import DummyOptimize, Optimize
-from tpcp.validate import TpcpSplitter, cross_validate, validate
+from tpcp.validate import DatasetSplitter, cross_validate, validate
 from tpcp.validate._scorer import Scorer, _validate_scorer
 
 
@@ -299,7 +299,7 @@ def test_cross_validate_optimizer_are_cloned(self):
 class TestTpcpSplitter:
     def test_normal_k_fold(self):
         ds = DummyGroupedDataset()
-        splitter = TpcpSplitter(base_splitter=KFold(n_splits=5))
+        splitter = DatasetSplitter(base_splitter=KFold(n_splits=5))
         # This should be identical to just calling the splitter directly
         splits_expected = list(KFold(n_splits=5).split(ds))
 
@@ -311,7 +311,7 @@ def test_normal_k_fold(self):
 
     def test_normal_k_fold_with_groupby_ignored(self):
         ds = DummyGroupedDataset()
-        splitter = TpcpSplitter(base_splitter=KFold(n_splits=5), groupby="v1")
+        splitter = DatasetSplitter(base_splitter=KFold(n_splits=5), groupby="v1")
         # This should be identical to just calling the splitter directly
         splits_expected = list(KFold(n_splits=5).split(ds))
 
@@ -323,7 +323,7 @@ def test_normal_k_fold_with_groupby_ignored(self):
 
     def test_normal_group_k_fold(self):
         ds = DummyGroupedDataset()
-        splitter = TpcpSplitter(base_splitter=GroupKFold(n_splits=3), groupby="v1")
+        splitter = DatasetSplitter(base_splitter=GroupKFold(n_splits=3), groupby="v1")
         # This should be identical to just calling the splitter directly
         splits_expected = list(GroupKFold(n_splits=3).split(ds, groups=ds.create_string_group_labels("v1")))
 
@@ -335,7 +335,7 @@ def test_normal_group_k_fold(self):
 
     def test_normal_stratified_k_fold(self):
         ds = DummyGroupedDataset()
-        splitter = TpcpSplitter(base_splitter=StratifiedKFold(n_splits=3), stratify="v1")
+        splitter = DatasetSplitter(base_splitter=StratifiedKFold(n_splits=3), stratify="v1")
         # This should be identical to just calling the splitter directly
         splits_expected = list(StratifiedKFold(n_splits=3).split(ds, y=ds.create_string_group_labels("v1")))
 
diff --git a/tpcp/optimize/_optimize.py b/tpcp/optimize/_optimize.py
index df45c5c..c72ad93 100644
--- a/tpcp/optimize/_optimize.py
+++ b/tpcp/optimize/_optimize.py
@@ -45,7 +45,7 @@
 from tpcp._utils._score import _optimize_and_score, _score
 from tpcp.exceptions import PotentialUserErrorWarning
 from tpcp.parallel import delayed
-from tpcp.validate import TpcpSplitter
+from tpcp.validate import DatasetSplitter
 from tpcp.validate._scorer import ScorerTypes, _validate_scorer
 
 if TYPE_CHECKING:
@@ -657,7 +657,7 @@ class GridSearchCV(
     parameter_grid: ParameterGrid
     scoring: ScorerTypes[OptimizablePipelineT, DatasetT, T]
     return_optimized: Union[bool, str]
-    cv: Optional[Union[TpcpSplitter, int, BaseCrossValidator, Iterator]]
+    cv: Optional[Union[DatasetSplitter, int, BaseCrossValidator, Iterator]]
     pure_parameters: Union[bool, list[str]]
     return_train_score: bool
     verbose: int
@@ -717,7 +717,7 @@ def optimize(self, dataset: DatasetT, **optimize_params) -> Self:
 
         scoring = _validate_scorer(self.scoring, self.pipeline)
 
-        cv = self.cv if isinstance(self.cv, TpcpSplitter) else TpcpSplitter(self.cv)
+        cv = self.cv if isinstance(self.cv, DatasetSplitter) else DatasetSplitter(self.cv)
 
         n_splits = cv.get_n_splits(dataset)
 
diff --git a/tpcp/validate/__init__.py b/tpcp/validate/__init__.py
index 55474bd..2097e98 100644
--- a/tpcp/validate/__init__.py
+++ b/tpcp/validate/__init__.py
@@ -1,6 +1,6 @@
 """Module for all helper methods to evaluate algorithms."""
-from tpcp.validate._cross_val_helper import TpcpSplitter
+from tpcp.validate._cross_val_helper import DatasetSplitter
 from tpcp.validate._scorer import Aggregator, MeanAggregator, NoAgg, Scorer
 from tpcp.validate._validate import cross_validate, validate
 
-__all__ = ["Scorer", "NoAgg", "Aggregator", "MeanAggregator", "cross_validate", "validate", "TpcpSplitter"]
+__all__ = ["Scorer", "NoAgg", "Aggregator", "MeanAggregator", "cross_validate", "validate", "DatasetSplitter"]
diff --git a/tpcp/validate/_cross_val_helper.py b/tpcp/validate/_cross_val_helper.py
index 4393ad6..83ab209 100644
--- a/tpcp/validate/_cross_val_helper.py
+++ b/tpcp/validate/_cross_val_helper.py
@@ -6,7 +6,7 @@
 from tpcp import BaseTpcpObject, Dataset
 
 
-class TpcpSplitter(BaseTpcpObject):
+class DatasetSplitter(BaseTpcpObject):
     """Wrapper around sklearn cross-validation splitters to support grouping and stratification with tpcp-Datasets.
 
     This wrapper can be used instead of a sklearn-style splitter with all methods that support a ``cv`` parameter.
diff --git a/tpcp/validate/_validate.py b/tpcp/validate/_validate.py
index 0c3b26f..31c1ebb 100644
--- a/tpcp/validate/_validate.py
+++ b/tpcp/validate/_validate.py
@@ -14,7 +14,7 @@
 from tpcp._utils._general import _aggregate_final_results, _normalize_score_results, _passthrough
 from tpcp._utils._score import _optimize_and_score, _score
 from tpcp.parallel import delayed
-from tpcp.validate._cross_val_helper import TpcpSplitter
+from tpcp.validate._cross_val_helper import DatasetSplitter
 from tpcp.validate._scorer import Scorer, _validate_scorer
 
 
@@ -23,7 +23,7 @@ def cross_validate(
     dataset: Dataset,
     *,
     scoring: Optional[Callable] = None,
-    cv: Optional[Union[TpcpSplitter, int, BaseCrossValidator, Iterator]] = None,
+    cv: Optional[Union[DatasetSplitter, int, BaseCrossValidator, Iterator]] = None,
     n_jobs: Optional[int] = None,
     verbose: int = 0,
     optimize_params: Optional[dict[str, Any]] = None,
@@ -115,7 +115,7 @@ def cross_validate(
     """
     scoring = _validate_scorer(scoring, optimizable.pipeline)
 
-    cv = cv if isinstance(cv, TpcpSplitter) else TpcpSplitter(base_splitter=cv)
+    cv = cv if isinstance(cv, DatasetSplitter) else DatasetSplitter(base_splitter=cv)
 
     splits = list(cv.split(dataset))