From 89ecf9e725826917bf6cadaf3c6e2e92a792fd7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arne=20K=C3=BCderle?= <a.kuederle@gmail.com> Date: Mon, 1 Jul 2024 16:33:05 +0200 Subject: [PATCH 1/2] Implemented custom wrapper around sklearn splitters --- docs/modules/validate.rst | 1 + examples/datasets/_01_datasets_basics.py | 13 ++ .../_04_advanced_cross_validation.py | 171 ++++++++++++++++++ .../test_advanced_cross_validate_0.json | 144 +++++++++++++++ .../test_advanced_cross_validate_1.json | 124 +++++++++++++ tests/test_examples/test_all_examples.py | 11 ++ tests/test_pipelines/test_validate.py | 102 +++++------ tpcp/optimize/_optimize.py | 40 ++-- tpcp/validate/__init__.py | 3 +- tpcp/validate/_cross_val_helper.py | 73 ++++++++ tpcp/validate/_validate.py | 65 ++----- 11 files changed, 616 insertions(+), 131 deletions(-) create mode 100644 examples/validation/_04_advanced_cross_validation.py create mode 100644 tests/test_examples/snapshot/test_advanced_cross_validate_0.json create mode 100644 tests/test_examples/snapshot/test_advanced_cross_validate_1.json create mode 100644 tpcp/validate/_cross_val_helper.py diff --git a/docs/modules/validate.rst b/docs/modules/validate.rst index 09bfbd4..a8cff14 100644 --- a/docs/modules/validate.rst +++ b/docs/modules/validate.rst @@ -14,6 +14,7 @@ Classes :toctree: generated/validate :template: class_with_private.rst + TpcpSplitter Scorer Aggregator MeanAggregator diff --git a/examples/datasets/_01_datasets_basics.py b/examples/datasets/_01_datasets_basics.py index 3aca4dd..e23c482 100644 --- a/examples/datasets/_01_datasets_basics.py +++ b/examples/datasets/_01_datasets_basics.py @@ -221,6 +221,19 @@ def create_index(self): # We only print the train set here print(final_subset[train], end="\n\n") +# %% +# Instead of doing this manually, we also provide a custom splitter that does this for you. +# It allows us to directly put the dataset into the `split` method of `cross_validate` and use higher level semantics +# to specify the grouping and stratification. +from tpcp.validate import TpcpSplitter + +cv = TpcpSplitter(GroupKFold(n_splits=2), groupby=["participant", "recording"]) + +for train, test in cv.split(final_subset): + # We only print the train set here + print(final_subset[train], end="\n\n") + + # %% # Creating labels also works for datasets that are already grouped. # But, the columns that should be contained in the label must be a subset of the groupby columns in this case. diff --git a/examples/validation/_04_advanced_cross_validation.py b/examples/validation/_04_advanced_cross_validation.py new file mode 100644 index 0000000..96a3733 --- /dev/null +++ b/examples/validation/_04_advanced_cross_validation.py @@ -0,0 +1,171 @@ +""" +Advanced cross-validation +------------------------- +In many real world datasets, a normal k-fold cross-validation might not be ideal, as it assumes that each data point is +fully independent of each other. +This is often not the case, as our dataset might contain multiple data points from the same participant. +Furthermore, we might have multiple "stratification" variables that we want to keep balanced across the folds. +For example, different clinical conditions or different measurement devices. + +This two concepts of "grouping" and "stratification" are sometimes complicated to understand and certain (even though +common) cases are not supported by the standard sklearn cross-validation splitters, without "abusing" the API. +For this reason, we create dedicated support for this in tpcp to tackle these cases with a little more confidence. +""" +# %% +# Let's start by re-creating the simple example from the normal cross-validation example. +# +# Dataset +# +++++++ +from pathlib import Path + +from examples.datasets.datasets_final_ecg import ECGExampleData + +try: + HERE = Path(__file__).parent +except NameError: + HERE = Path().resolve() +data_path = HERE.parent.parent / "example_data/ecg_mit_bih_arrhythmia/data" +example_data = ECGExampleData(data_path) + +# %% +# Pipeline +# ++++++++ +import pandas as pd + +from examples.algorithms.algorithms_qrs_detection_final import QRSDetector +from tpcp import Parameter, Pipeline, cf + + +class MyPipeline(Pipeline): + algorithm: Parameter[QRSDetector] + + r_peak_positions_: pd.Series + + def __init__(self, algorithm: QRSDetector = cf(QRSDetector())): + self.algorithm = algorithm + + def run(self, datapoint: ECGExampleData): + # Note: We need to clone the algorithm instance, to make sure we don't leak any data between runs. + algo = self.algorithm.clone() + algo.detect(datapoint.data, datapoint.sampling_rate_hz) + + self.r_peak_positions_ = algo.r_peak_positions_ + return self + + +# %% +# The Scorer +# ++++++++++ +from examples.algorithms.algorithms_qrs_detection_final import match_events_with_reference, precision_recall_f1_score + + +def score(pipeline: MyPipeline, datapoint: ECGExampleData): + # We use the `safe_run` wrapper instead of just run. This is always a good idea. + # We don't need to clone the pipeline here, as GridSearch will already clone the pipeline internally and `run` + # will clone it again. + pipeline = pipeline.safe_run(datapoint) + tolerance_s = 0.02 # We just use 20 ms for this example + matches = match_events_with_reference( + pipeline.r_peak_positions_.to_numpy(), + datapoint.r_peak_positions_.to_numpy(), + tolerance=tolerance_s * datapoint.sampling_rate_hz, + ) + precision, recall, f1_score = precision_recall_f1_score(matches) + return {"precision": precision, "recall": recall, "f1_score": f1_score} + + +# %% +# Stratifcation +# +++++++++++++ +# With this setup done, we can have a closer look at the dataset. +example_data + +# %% +# The index has two columns, one indicating the participant group and one indicating the participant id. +# In this simple example, all groups appear the same amount of times and the index is ordered in a way that +# each fold will likely get a balanced amount of participants from each group. +# +# To show the impact of grouping and stratification, we take a subset of the data, that removes some participants from +# "group_1" to create an imbalance. +data_imbalanced = example_data.get_subset(index=example_data.index.query("participant not in ['114', '121']")) + +# %% +# Running a simple cross-validation with 2 folds, will have all group-1 participants in the test data of the first fold: +# +# Note, that we skip optimization of the pipeline, to keep the example simple and fast. +from sklearn.model_selection import KFold + +from tpcp.optimize import DummyOptimize +from tpcp.validate import cross_validate + +cv = KFold(n_splits=2) + +pipe = MyPipeline() +optimizable_pipe = DummyOptimize(pipe) + +results = cross_validate(optimizable_pipe, data_imbalanced, scoring=score, cv=cv) +result_df = pd.DataFrame(results) + +# %% +# We can see that the test data of the first fold contains only participants from group 1. +result_df["test_data_labels"].explode() + +# %% +# This works fine when the groups are just "additional information", and are unlikely to affect the data within. +# For example, if the groups just reflect in which hospital the data was collected. +# However, when the group reflect information that is likely to affect the data (e.g. a relevant medical indication), +# we need to make sure that the actual group probabilities are remain the same in all folds. +# This can be done through stratification. +# +# .. note:: It is important to understand that "stratification" is not "balancing" the groups. +# Group balancing should never be done during data splitting, as it will change the data distribution in your +# test set, which will no longer reflect the real-world distribution. +# +# To stratify by the "patient group" we can use the `TpcpSplitter` class. +# We will provide it with a base splitter that enables stratification (in this case a `StratifiedKFold` splitter) and +# the column(s) to stratify by. +from sklearn.model_selection import StratifiedKFold + +from tpcp.validate import TpcpSplitter + +cv = TpcpSplitter(base_splitter=StratifiedKFold(n_splits=2), stratify="patient_group") + +results = cross_validate(optimizable_pipe, data_imbalanced, scoring=score, cv=cv) +result_df_stratified = pd.DataFrame(results) +result_df_stratified["test_data_labels"].explode() + +# %% +# Now we can see that the groups are balanced in each fold and both folds get one of the remaining group 1 participants. +# +# Grouping +# ++++++++ +# Where stratification ensures that the distribution of a specific column is the same in all folds, grouping ensures +# that all data of one group is always either in the train or the test set, but never split across it. +# This is useful, when we have data points that are somehow correlated and the existence of data points from the same +# group in both the train and the test set of the same fold could hence be considered a "leak". +# +# A typical example for this is when we have multiple data points from the same participant. +# In our case here, we will use the "patient_group" as grouping variable for demonstration purposes, as we don't have multiple +# data points per participant. +# +# Note, that we use the "non-subsampled" example data here. +from sklearn.model_selection import GroupKFold + +cv = TpcpSplitter(base_splitter=GroupKFold(n_splits=2), groupby="patient_group") + +results = cross_validate(optimizable_pipe, example_data, scoring=score, cv=cv) +result_df_grouped = pd.DataFrame(results) +result_df_grouped["test_data_labels"].explode() + +# %% +# We can see that this forces the creation of unequal sice splits to ensure that the groups are kept together. +# This is important to keep in mind when using grouping, as it can lead to unequally sized test sets. +# +# Combining Grouping and Stratification +# +++++++++++++++++++++++++++++++++++++ +# Of course, we can also combine grouping and stratification. +# A typical example would be to stratify by clinical condition and group by participant. +# This is also easily possible with the `TpcpSplitter` class by providing both arguments. +# +# For the dataset that we have here, this does of course not make much sense, so we are not going to show an example +# here. diff --git a/tests/test_examples/snapshot/test_advanced_cross_validate_0.json b/tests/test_examples/snapshot/test_advanced_cross_validate_0.json new file mode 100644 index 0000000..5992e41 --- /dev/null +++ b/tests/test_examples/snapshot/test_advanced_cross_validate_0.json @@ -0,0 +1,144 @@ +{ + "schema":{ + "fields":[ + { + "name":"index", + "type":"integer" + }, + { + "name":"fold_id", + "type":"integer" + }, + { + "name":"test_data_labels", + "type":"string" + } + ], + "primaryKey":[ + "index" + ], + "pandas_version":"1.4.0" + }, + "data":[ + { + "index":0, + "fold_id":0, + "test_data_labels":"group_1" + }, + { + "index":1, + "fold_id":0, + "test_data_labels":"100" + }, + { + "index":2, + "fold_id":0, + "test_data_labels":"group_3" + }, + { + "index":3, + "fold_id":0, + "test_data_labels":"104" + }, + { + "index":4, + "fold_id":0, + "test_data_labels":"group_1" + }, + { + "index":5, + "fold_id":0, + "test_data_labels":"105" + }, + { + "index":6, + "fold_id":0, + "test_data_labels":"group_3" + }, + { + "index":7, + "fold_id":0, + "test_data_labels":"108" + }, + { + "index":8, + "fold_id":0, + "test_data_labels":"group_1" + }, + { + "index":9, + "fold_id":0, + "test_data_labels":"114" + }, + { + "index":10, + "fold_id":0, + "test_data_labels":"group_3" + }, + { + "index":11, + "fold_id":0, + "test_data_labels":"119" + }, + { + "index":12, + "fold_id":0, + "test_data_labels":"group_1" + }, + { + "index":13, + "fold_id":0, + "test_data_labels":"121" + }, + { + "index":14, + "fold_id":0, + "test_data_labels":"group_3" + }, + { + "index":15, + "fold_id":0, + "test_data_labels":"200" + }, + { + "index":16, + "fold_id":1, + "test_data_labels":"group_2" + }, + { + "index":17, + "fold_id":1, + "test_data_labels":"102" + }, + { + "index":18, + "fold_id":1, + "test_data_labels":"group_2" + }, + { + "index":19, + "fold_id":1, + "test_data_labels":"106" + }, + { + "index":20, + "fold_id":1, + "test_data_labels":"group_2" + }, + { + "index":21, + "fold_id":1, + "test_data_labels":"116" + }, + { + "index":22, + "fold_id":1, + "test_data_labels":"group_2" + }, + { + "index":23, + "fold_id":1, + "test_data_labels":"123" + } + ] +} \ No newline at end of file diff --git a/tests/test_examples/snapshot/test_advanced_cross_validate_1.json b/tests/test_examples/snapshot/test_advanced_cross_validate_1.json new file mode 100644 index 0000000..ccfed52 --- /dev/null +++ b/tests/test_examples/snapshot/test_advanced_cross_validate_1.json @@ -0,0 +1,124 @@ +{ + "schema":{ + "fields":[ + { + "name":"index", + "type":"integer" + }, + { + "name":"fold_id", + "type":"integer" + }, + { + "name":"test_data_labels", + "type":"string" + } + ], + "primaryKey":[ + "index" + ], + "pandas_version":"1.4.0" + }, + "data":[ + { + "index":0, + "fold_id":0, + "test_data_labels":"group_1" + }, + { + "index":1, + "fold_id":0, + "test_data_labels":"100" + }, + { + "index":2, + "fold_id":0, + "test_data_labels":"group_2" + }, + { + "index":3, + "fold_id":0, + "test_data_labels":"102" + }, + { + "index":4, + "fold_id":0, + "test_data_labels":"group_3" + }, + { + "index":5, + "fold_id":0, + "test_data_labels":"104" + }, + { + "index":6, + "fold_id":0, + "test_data_labels":"group_2" + }, + { + "index":7, + "fold_id":0, + "test_data_labels":"106" + }, + { + "index":8, + "fold_id":0, + "test_data_labels":"group_3" + }, + { + "index":9, + "fold_id":0, + "test_data_labels":"108" + }, + { + "index":10, + "fold_id":1, + "test_data_labels":"group_1" + }, + { + "index":11, + "fold_id":1, + "test_data_labels":"105" + }, + { + "index":12, + "fold_id":1, + "test_data_labels":"group_2" + }, + { + "index":13, + "fold_id":1, + "test_data_labels":"116" + }, + { + "index":14, + "fold_id":1, + "test_data_labels":"group_3" + }, + { + "index":15, + "fold_id":1, + "test_data_labels":"119" + }, + { + "index":16, + "fold_id":1, + "test_data_labels":"group_2" + }, + { + "index":17, + "fold_id":1, + "test_data_labels":"123" + }, + { + "index":18, + "fold_id":1, + "test_data_labels":"group_3" + }, + { + "index":19, + "fold_id":1, + "test_data_labels":"200" + } + ] +} \ No newline at end of file diff --git a/tests/test_examples/test_all_examples.py b/tests/test_examples/test_all_examples.py index 560558c..0114516 100644 --- a/tests/test_examples/test_all_examples.py +++ b/tests/test_examples/test_all_examples.py @@ -69,6 +69,17 @@ def test_cross_validate(): assert_almost_equal(results["test_f1_score"], [0.9770585, 0.7108303, 0.9250665]) +def test_advanced_cross_validate(snapshot): + from examples.validation._04_advanced_cross_validation import result_df_grouped, result_df_stratified + + snapshot.assert_match( + result_df_grouped["test_data_labels"].explode().explode().to_frame().rename_axis("fold_id").reset_index() + ) + snapshot.assert_match( + result_df_stratified["test_data_labels"].explode().explode().to_frame().rename_axis("fold_id").reset_index() + ) + + def test_optuna(): from examples.parameter_optimization._04_custom_optuna_optimizer import opti, opti_early_stop diff --git a/tests/test_pipelines/test_validate.py b/tests/test_pipelines/test_validate.py index d533f22..a0cded5 100644 --- a/tests/test_pipelines/test_validate.py +++ b/tests/test_pipelines/test_validate.py @@ -15,7 +15,7 @@ from tpcp import Dataset, OptimizableParameter, OptimizablePipeline from tpcp.exceptions import OptimizationError, TestError from tpcp.optimize import DummyOptimize, Optimize -from tpcp.validate import cross_validate, validate +from tpcp.validate import TpcpSplitter, cross_validate, validate from tpcp.validate._scorer import Scorer, _validate_scorer @@ -227,56 +227,6 @@ def test_returned_optimizer_per_fold_independent(self): for o in optimizers: assert o is not optimizer - @pytest.mark.parametrize("propagate", (True, False)) - def test_propagate_groups(self, propagate): - pipeline = DummyOptimizablePipeline() - dataset = DummyGroupedDataset() - groups = dataset.create_string_group_labels("v1") - # With 3 splits, each group get its own split -> so basically only "a", only "b", and only "c" - cv = GroupKFold(n_splits=3) - - dummy_results = Optimize(pipeline).optimize(dataset) - with patch.object(Optimize, "optimize", return_value=dummy_results) as mock: - cross_validate( - Optimize(pipeline), dataset, cv=cv, scoring=lambda x, y: 1, groups=groups, propagate_groups=propagate - ) - - assert mock.call_count == 3 - for call, label in zip(mock.call_args_list, "cba"): - train_labels = "abc".replace(label, "") - if propagate: - assert set(np.unique(call[1]["groups"])) == set(train_labels) - else: - assert "groups" not in call[1] - - @pytest.mark.parametrize("propagate", (True, False)) - def test_propagate_mock_labels(self, propagate): - pipeline = DummyOptimizablePipeline() - dataset = DummyGroupedDataset() - groups = dataset.create_string_group_labels("v1") - # With 5 folds, we expect exactly on "a", one "b", and one "c" in each fold - cv = StratifiedKFold(n_splits=5) - - dummy_results = Optimize(pipeline).optimize(dataset) - with patch.object(Optimize, "optimize", return_value=dummy_results) as mock: - cross_validate( - Optimize(pipeline), - dataset, - cv=cv, - scoring=lambda x, y: 1, - mock_labels=groups, - propagate_mock_labels=propagate, - propagate_groups=False, - ) - - assert mock.call_count == 5 - for call in mock.call_args_list: - if propagate: - assert len(np.unique(call[1]["mock_labels"])) == 3 - assert set(np.unique(call[1]["mock_labels"])) == set("abc") - else: - assert "mock_labels" not in call[1] - @pytest.mark.parametrize("error_fold", (0, 2)) def test_cross_validate_opti_error(self, error_fold): with pytest.raises(OptimizationError) as e: @@ -344,3 +294,53 @@ def test_cross_validate_optimizer_are_cloned(self): assert len(results["optimizer"]) == 5 assert len({id(o) for o in results["optimizer"]}) == 5 + + +class TestTpcpSplitter: + def test_normal_k_fold(self): + ds = DummyGroupedDataset() + splitter = TpcpSplitter(base_splitter=KFold(n_splits=5)) + # This should be identical to just calling the splitter directly + splits_expected = list(KFold(n_splits=5).split(ds)) + + splits = list(splitter.split(ds)) + + for (train_expected, test_expected), (train, test) in zip(splits_expected, splits): + assert train_expected.tolist() == train.tolist() + assert test_expected.tolist() == test.tolist() + + def test_normal_k_fold_with_groupby_ignored(self): + ds = DummyGroupedDataset() + splitter = TpcpSplitter(base_splitter=KFold(n_splits=5), groupby="v1") + # This should be identical to just calling the splitter directly + splits_expected = list(KFold(n_splits=5).split(ds)) + + splits = list(splitter.split(ds)) + + for (train_expected, test_expected), (train, test) in zip(splits_expected, splits): + assert train_expected.tolist() == train.tolist() + assert test_expected.tolist() == test.tolist() + + def test_normal_group_k_fold(self): + ds = DummyGroupedDataset() + splitter = TpcpSplitter(base_splitter=GroupKFold(n_splits=3), groupby="v1") + # This should be identical to just calling the splitter directly + splits_expected = list(GroupKFold(n_splits=3).split(ds, groups=ds.create_string_group_labels("v1"))) + + splits = list(splitter.split(ds)) + + for (train_expected, test_expected), (train, test) in zip(splits_expected, splits): + assert train_expected.tolist() == train.tolist() + assert test_expected.tolist() == test.tolist() + + def test_normal_stratified_k_fold(self): + ds = DummyGroupedDataset() + splitter = TpcpSplitter(base_splitter=StratifiedKFold(n_splits=3), stratify="v1") + # This should be identical to just calling the splitter directly + splits_expected = list(StratifiedKFold(n_splits=3).split(ds, y=ds.create_string_group_labels("v1"))) + + splits = list(splitter.split(ds)) + + for (train_expected, test_expected), (train, test) in zip(splits_expected, splits): + assert train_expected.tolist() == train.tolist() + assert test_expected.tolist() == test.tolist() diff --git a/tpcp/optimize/_optimize.py b/tpcp/optimize/_optimize.py index 06ea289..df45c5c 100644 --- a/tpcp/optimize/_optimize.py +++ b/tpcp/optimize/_optimize.py @@ -21,7 +21,7 @@ from joblib import Memory, Parallel from numpy.ma import MaskedArray from scipy.stats import rankdata -from sklearn.model_selection import BaseCrossValidator, ParameterGrid, check_cv +from sklearn.model_selection import BaseCrossValidator, ParameterGrid from tqdm.auto import tqdm from typing_extensions import Self @@ -45,6 +45,7 @@ from tpcp._utils._score import _optimize_and_score, _score from tpcp.exceptions import PotentialUserErrorWarning from tpcp.parallel import delayed +from tpcp.validate import TpcpSplitter from tpcp.validate._scorer import ScorerTypes, _validate_scorer if TYPE_CHECKING: @@ -533,10 +534,12 @@ class GridSearchCV( with a minus sign, e.g. `-rmse`. In case of a single score, use `-score` to select the value with the lowest score. cv - An integer specifying the number of folds in a K-Fold cross validation or a valid cross validation helper. - The default (`None`) will result in a 5-fold cross validation. - For further inputs check the `sklearn` - `documentation <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html>`_. + The cross-validation strategy to use. + For simple use-cases the same input as for the sklearn cross-validation function are supported. + For further inputs check the `sklearn` `documentation + <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_validate.html>`_. + + For more complex usecases like grouping or stratification, the :class:`~tpcp.TpcpSplitter` can be used. pure_parameters .. warning:: Do not use this option unless you fully understand it! @@ -584,10 +587,6 @@ class GridSearchCV( ---------------- dataset The dataset instance passed to the optimize method - groups - The groups passed to the optimize method - mock_labels - The mock labels passed to the optimize method Attributes ---------- @@ -658,7 +657,7 @@ class GridSearchCV( parameter_grid: ParameterGrid scoring: ScorerTypes[OptimizablePipelineT, DatasetT, T] return_optimized: Union[bool, str] - cv: Optional[Union[int, BaseCrossValidator, Iterator]] + cv: Optional[Union[TpcpSplitter, int, BaseCrossValidator, Iterator]] pure_parameters: Union[bool, list[str]] return_train_score: bool verbose: int @@ -668,9 +667,6 @@ class GridSearchCV( safe_optimize: bool optimize_with_info: bool - groups: Optional[list[Union[str, tuple[str, ...]]]] - mock_labels: Optional[list[Union[str, tuple[str, ...]]]] - cv_results_: dict[str, Any] best_params_: dict[str, Any] best_index_: int @@ -709,29 +705,21 @@ def __init__( self.safe_optimize = safe_optimize self.optimize_with_info = optimize_with_info - def optimize(self, dataset: DatasetT, *, groups=None, mock_labels=None, **optimize_params) -> Self: + def optimize(self, dataset: DatasetT, **optimize_params) -> Self: """Run the GridSearchCV on the given dataset. Parameters ---------- dataset The dataset to optimize on. - groups - An optional set of group labels that are passed to the cross-validation helper. - mock_labels - An optional set of mocked labels that are passed to the cross-validation helper as the `y` parameter. - This can be helpful in combination with the `Stratified*Fold` cross-validation helpers, that use the `y` - parameter to stratify the folds. - """ self.dataset = dataset - self.groups = groups - self.mock_labels = mock_labels scoring = _validate_scorer(self.scoring, self.pipeline) - cv_checked: BaseCrossValidator = check_cv(self.cv, None, classifier=True) - n_splits = cv_checked.get_n_splits(dataset, mock_labels, groups=groups) + cv = self.cv if isinstance(self.cv, TpcpSplitter) else TpcpSplitter(self.cv) + + n_splits = cv.get_n_splits(dataset) # We need to wrap our pipeline for a consistent interface. # In the future we might be able to allow objects with optimizer Interface as input directly. @@ -762,7 +750,7 @@ def optimize(self, dataset: DatasetT, *, groups=None, mock_labels=None, **optimi combinations = list( product( enumerate(split_parameters), - enumerate(cv_checked.split(dataset, mock_labels, groups=groups)), + enumerate(cv.split(dataset)), ) ) diff --git a/tpcp/validate/__init__.py b/tpcp/validate/__init__.py index b39bfaa..55474bd 100644 --- a/tpcp/validate/__init__.py +++ b/tpcp/validate/__init__.py @@ -1,5 +1,6 @@ """Module for all helper methods to evaluate algorithms.""" +from tpcp.validate._cross_val_helper import TpcpSplitter from tpcp.validate._scorer import Aggregator, MeanAggregator, NoAgg, Scorer from tpcp.validate._validate import cross_validate, validate -__all__ = ["Scorer", "NoAgg", "Aggregator", "MeanAggregator", "cross_validate", "validate"] +__all__ = ["Scorer", "NoAgg", "Aggregator", "MeanAggregator", "cross_validate", "validate", "TpcpSplitter"] diff --git a/tpcp/validate/_cross_val_helper.py b/tpcp/validate/_cross_val_helper.py new file mode 100644 index 0000000..4393ad6 --- /dev/null +++ b/tpcp/validate/_cross_val_helper.py @@ -0,0 +1,73 @@ +from collections.abc import Iterator +from typing import Optional, Union + +from sklearn.model_selection import BaseCrossValidator, check_cv + +from tpcp import BaseTpcpObject, Dataset + + +class TpcpSplitter(BaseTpcpObject): + """Wrapper around sklearn cross-validation splitters to support grouping and stratification with tpcp-Datasets. + + This wrapper can be used instead of a sklearn-style splitter with all methods that support a ``cv`` parameter. + Whenever you want to do complicated cv-logic (like grouping or stratification's), this wrapper is the way to go. + + .. warning:: We don't validate if the selected ``base_splitter`` does anything useful with the provided + ``groupby`` and ``stratify`` information. + This wrapper just ensures, that the information is correctly extracted from the dataset and passed to the + ``split`` method of the ``base_splitter``. + So if you are using a normal ``KFold`` splitter, the ``groupby`` and ``stratify`` arguments will have no effect. + + Parameters + ---------- + base_splitter + The base splitter to use. Can be an integer (for ``KFold``), an iterator, or any other valid sklearn-splitter. + The default is None, which will use the sklearn default ``KFold`` splitter with 5 splits. + groupby + The column(s) to group by. If None, no grouping is done. + Must be a subset of the columns in the dataset. + + This will generate a set of unique string labels with the same shape as the dataset. + This will passed to the base splitter as the ``groups`` parameter. + It is up to the base splitter to decide what to do with the generated labels. + stratify + The column(s) to stratify by. If None, no stratification is done. + Must be a subset of the columns in the dataset. + + This will generate a set of unique string labels with the same shape as the dataset. + This will passed to the base splitter as the ``y`` parameter, acting as "mock" target labels, as sklearn only + support stratification on classification outcome targets. + It is up to the base splitter to decide what to do with the generated labels. + + """ + + def __init__( + self, + base_splitter: Optional[Union[int, BaseCrossValidator, Iterator]] = None, + *, + groupby: Optional[Union[str, list[str]]] = None, + stratify: Optional[Union[str, list[str]]] = None, + ): + self.base_splitter = base_splitter + self.stratify = stratify + self.groupby = groupby + + def _get_splitter(self): + return check_cv(self.base_splitter, y=None, classifier=True) + + def _get_labels(self, dataset: Dataset, labels: Union[None, str, list[str]]): + if labels: + return dataset.create_string_group_labels(labels) + return None + + def split(self, dataset: Dataset) -> Iterator[tuple[list[int], list[int]]]: + """Split the dataset into train and test sets.""" + return self._get_splitter().split( + dataset, y=self._get_labels(dataset, self.stratify), groups=self._get_labels(dataset, self.groupby) + ) + + def get_n_splits(self, dataset: Dataset) -> int: + """Get the number of splits.""" + return self._get_splitter().get_n_splits( + dataset, y=self._get_labels(dataset, self.stratify), groups=self._get_labels(dataset, self.groupby) + ) diff --git a/tpcp/validate/_validate.py b/tpcp/validate/_validate.py index 91e248d..0c3b26f 100644 --- a/tpcp/validate/_validate.py +++ b/tpcp/validate/_validate.py @@ -5,7 +5,7 @@ import numpy as np from joblib import Parallel -from sklearn.model_selection import BaseCrossValidator, check_cv +from sklearn.model_selection import BaseCrossValidator from tqdm.auto import tqdm from tpcp import Dataset, Pipeline @@ -14,6 +14,7 @@ from tpcp._utils._general import _aggregate_final_results, _normalize_score_results, _passthrough from tpcp._utils._score import _optimize_and_score, _score from tpcp.parallel import delayed +from tpcp.validate._cross_val_helper import TpcpSplitter from tpcp.validate._scorer import Scorer, _validate_scorer @@ -21,15 +22,11 @@ def cross_validate( optimizable: BaseOptimize, dataset: Dataset, *, - groups: Optional[list[Union[str, tuple[str, ...]]]] = None, - mock_labels: Optional[list[Union[str, tuple[str, ...]]]] = None, scoring: Optional[Callable] = None, - cv: Optional[Union[int, BaseCrossValidator, Iterator]] = None, + cv: Optional[Union[TpcpSplitter, int, BaseCrossValidator, Iterator]] = None, n_jobs: Optional[int] = None, verbose: int = 0, optimize_params: Optional[dict[str, Any]] = None, - propagate_groups: bool = True, - propagate_mock_labels: bool = True, pre_dispatch: Union[str, int] = "2*n_jobs", return_train_score: bool = False, return_optimizer: bool = False, @@ -47,34 +44,17 @@ def cross_validate( :class:`~tpcp.Pipeline` wrapped in an `Optimize` object (:class:`~tpcp.OptimizablePipeline`). dataset A :class:`~tpcp.Dataset` containing all information. - groups - Group labels for samples used by the cross validation helper, in case a grouped CV is used (e.g. - :class:`~sklearn.model_selection.GroupKFold`). - Check the documentation of the :class:`~tpcp.Dataset` class and the respective example for - information on how to generate group labels for tpcp datasets. - - The groups will be passed to the optimizers `optimize` method under the same name, if `propagate_groups` is - True. - mock_labels - The value of `mock_labels` is passed as the `y` parameter to the cross-validation helper's `split` method. - This can be helpful, if you want to use stratified cross validation. - Usually, the stratified CV classes use `y` (i.e. the label) to stratify the data. - However, in tpcp, we don't have a dedicated `y` as data and labels are both stored in a single datastructure. - If you want to stratify the data (e.g. based on patient cohorts), you can create your own list of labels/groups - that should be used for stratification and pass it to `mock_labels` instead. - - The labels will be passed to the optimizers `optimize` method under the same name, if - `propagate_mock_labels` is True (similar to how groups are handled). scoring A callable that can score a single data point given a pipeline. This function should return either a single score or a dictionary of scores. If scoring is `None` the default `score` method of the optimizable is used instead. cv - An integer specifying the number of folds in a K-Fold cross validation or a valid cross validation helper. - The default (`None`) will result in a 5-fold cross validation. - For further inputs check the `sklearn` - `documentation + The cross-validation strategy to use. + For simple use-cases the same input as for the sklearn cross-validation function are supported. + For further inputs check the `sklearn` `documentation <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_validate.html>`_. + + For more complex usecases like grouping or stratification, the :class:`~tpcp.TpcpSplitter` can be used. n_jobs Number of jobs to run in parallel. One job is created per CV fold. @@ -84,16 +64,6 @@ def cross_validate( At the moment this only effects `Parallel`. optimize_params Additional parameter that are forwarded to the `optimize` method. - propagate_groups - In case your optimizable is a cross validation based optimize (e.g. :class:`~tpcp.optimize.GridSearchCv`) and - you are using a grouped cross validation, you probably want to use the same grouped CV for the outer and the - inner cross validation. - If `propagate_groups` is True, the group labels belonging to the training of each fold are passed to the - `optimize` method of the optimizable. - This only has an effect if `groups` are specified. - propagate_mock_labels - For the same reason as `propagate_groups`, you might also want to forward the value provided for - `mock_labels` to the optimization workflow. pre_dispatch The number of jobs that should be pre dispatched. For an explanation see the documentation of :class:`~joblib.Parallel`. @@ -143,18 +113,11 @@ def cross_validate( instance. """ - cv_checked: BaseCrossValidator = check_cv(cv, None, classifier=True) - scoring = _validate_scorer(scoring, optimizable.pipeline) - optimize_params = optimize_params or {} - if propagate_groups is True and "groups" in optimize_params: - raise ValueError( - "You can not use `propagate_groups` and specify `groups` in `optimize_params`. " - "The latter would overwrite the prior. " - "Most likely you only want to use `propagate_groups`." - ) - splits = list(cv_checked.split(dataset, mock_labels, groups=groups)) + cv = cv if isinstance(cv, TpcpSplitter) else TpcpSplitter(base_splitter=cv) + + splits = list(cv.split(dataset)) pbar = partial(tqdm, total=len(splits), desc="CV Folds") if progress_bar else _passthrough @@ -170,11 +133,7 @@ def cross_validate( scoring, dataset[train], dataset[test], - optimize_params={ - **_propagate_values("groups", propagate_groups, groups, train), - **_propagate_values("mock_labels", propagate_mock_labels, mock_labels, train), - **optimize_params, - }, + optimize_params=optimize_params, hyperparameters=None, pure_parameters=None, return_train_score=return_train_score, From 6d8c3be84726624c70ce5d5dd350667874812cb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arne=20K=C3=BCderle?= <a.kuederle@gmail.com> Date: Mon, 1 Jul 2024 16:36:59 +0200 Subject: [PATCH 2/2] TpcpSplitter -> DatasetSplitter --- CHANGELOG.md | 10 ++++++++++ docs/modules/validate.rst | 2 +- examples/datasets/_01_datasets_basics.py | 4 ++-- examples/validation/_04_advanced_cross_validation.py | 6 +++--- tests/test_pipelines/test_validate.py | 10 +++++----- tpcp/optimize/_optimize.py | 6 +++--- tpcp/validate/__init__.py | 4 ++-- tpcp/validate/_cross_val_helper.py | 2 +- tpcp/validate/_validate.py | 6 +++--- 9 files changed, 30 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0bcf2f9..e9dc005 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,16 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) (+ the Migration Guide), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.0.0] - [unreleased] + +### BREAKING CHANGE + +- Instead of the (annoying) `mock_label` and `group_label` arguments, all functions that take a cv-splitter as input, + can now take an instance of the new `DatasetSplitter` class, which elegantly handles grouping and stratification and + also removes the need of forwarding the `mock_label` and `group_label` arguments to the underlying optimizer. + The use of the `mock_label` and `group_label` arguments has been removed without depreciation. + (https://github.com/mad-lab-fau/tpcp/pull/114) + ## [0.34.0] - 2024-06-28 ### Added diff --git a/docs/modules/validate.rst b/docs/modules/validate.rst index a8cff14..60a466a 100644 --- a/docs/modules/validate.rst +++ b/docs/modules/validate.rst @@ -14,7 +14,7 @@ Classes :toctree: generated/validate :template: class_with_private.rst - TpcpSplitter + DatasetSplitter Scorer Aggregator MeanAggregator diff --git a/examples/datasets/_01_datasets_basics.py b/examples/datasets/_01_datasets_basics.py index e23c482..6187213 100644 --- a/examples/datasets/_01_datasets_basics.py +++ b/examples/datasets/_01_datasets_basics.py @@ -225,9 +225,9 @@ def create_index(self): # Instead of doing this manually, we also provide a custom splitter that does this for you. # It allows us to directly put the dataset into the `split` method of `cross_validate` and use higher level semantics # to specify the grouping and stratification. -from tpcp.validate import TpcpSplitter +from tpcp.validate import DatasetSplitter -cv = TpcpSplitter(GroupKFold(n_splits=2), groupby=["participant", "recording"]) +cv = DatasetSplitter(GroupKFold(n_splits=2), groupby=["participant", "recording"]) for train, test in cv.split(final_subset): # We only print the train set here diff --git a/examples/validation/_04_advanced_cross_validation.py b/examples/validation/_04_advanced_cross_validation.py index 96a3733..4164559 100644 --- a/examples/validation/_04_advanced_cross_validation.py +++ b/examples/validation/_04_advanced_cross_validation.py @@ -126,9 +126,9 @@ def score(pipeline: MyPipeline, datapoint: ECGExampleData): # the column(s) to stratify by. from sklearn.model_selection import StratifiedKFold -from tpcp.validate import TpcpSplitter +from tpcp.validate import DatasetSplitter -cv = TpcpSplitter(base_splitter=StratifiedKFold(n_splits=2), stratify="patient_group") +cv = DatasetSplitter(base_splitter=StratifiedKFold(n_splits=2), stratify="patient_group") results = cross_validate(optimizable_pipe, data_imbalanced, scoring=score, cv=cv) result_df_stratified = pd.DataFrame(results) @@ -151,7 +151,7 @@ def score(pipeline: MyPipeline, datapoint: ECGExampleData): # Note, that we use the "non-subsampled" example data here. from sklearn.model_selection import GroupKFold -cv = TpcpSplitter(base_splitter=GroupKFold(n_splits=2), groupby="patient_group") +cv = DatasetSplitter(base_splitter=GroupKFold(n_splits=2), groupby="patient_group") results = cross_validate(optimizable_pipe, example_data, scoring=score, cv=cv) result_df_grouped = pd.DataFrame(results) diff --git a/tests/test_pipelines/test_validate.py b/tests/test_pipelines/test_validate.py index a0cded5..cff81b1 100644 --- a/tests/test_pipelines/test_validate.py +++ b/tests/test_pipelines/test_validate.py @@ -15,7 +15,7 @@ from tpcp import Dataset, OptimizableParameter, OptimizablePipeline from tpcp.exceptions import OptimizationError, TestError from tpcp.optimize import DummyOptimize, Optimize -from tpcp.validate import TpcpSplitter, cross_validate, validate +from tpcp.validate import DatasetSplitter, cross_validate, validate from tpcp.validate._scorer import Scorer, _validate_scorer @@ -299,7 +299,7 @@ def test_cross_validate_optimizer_are_cloned(self): class TestTpcpSplitter: def test_normal_k_fold(self): ds = DummyGroupedDataset() - splitter = TpcpSplitter(base_splitter=KFold(n_splits=5)) + splitter = DatasetSplitter(base_splitter=KFold(n_splits=5)) # This should be identical to just calling the splitter directly splits_expected = list(KFold(n_splits=5).split(ds)) @@ -311,7 +311,7 @@ def test_normal_k_fold(self): def test_normal_k_fold_with_groupby_ignored(self): ds = DummyGroupedDataset() - splitter = TpcpSplitter(base_splitter=KFold(n_splits=5), groupby="v1") + splitter = DatasetSplitter(base_splitter=KFold(n_splits=5), groupby="v1") # This should be identical to just calling the splitter directly splits_expected = list(KFold(n_splits=5).split(ds)) @@ -323,7 +323,7 @@ def test_normal_k_fold_with_groupby_ignored(self): def test_normal_group_k_fold(self): ds = DummyGroupedDataset() - splitter = TpcpSplitter(base_splitter=GroupKFold(n_splits=3), groupby="v1") + splitter = DatasetSplitter(base_splitter=GroupKFold(n_splits=3), groupby="v1") # This should be identical to just calling the splitter directly splits_expected = list(GroupKFold(n_splits=3).split(ds, groups=ds.create_string_group_labels("v1"))) @@ -335,7 +335,7 @@ def test_normal_group_k_fold(self): def test_normal_stratified_k_fold(self): ds = DummyGroupedDataset() - splitter = TpcpSplitter(base_splitter=StratifiedKFold(n_splits=3), stratify="v1") + splitter = DatasetSplitter(base_splitter=StratifiedKFold(n_splits=3), stratify="v1") # This should be identical to just calling the splitter directly splits_expected = list(StratifiedKFold(n_splits=3).split(ds, y=ds.create_string_group_labels("v1"))) diff --git a/tpcp/optimize/_optimize.py b/tpcp/optimize/_optimize.py index df45c5c..c72ad93 100644 --- a/tpcp/optimize/_optimize.py +++ b/tpcp/optimize/_optimize.py @@ -45,7 +45,7 @@ from tpcp._utils._score import _optimize_and_score, _score from tpcp.exceptions import PotentialUserErrorWarning from tpcp.parallel import delayed -from tpcp.validate import TpcpSplitter +from tpcp.validate import DatasetSplitter from tpcp.validate._scorer import ScorerTypes, _validate_scorer if TYPE_CHECKING: @@ -657,7 +657,7 @@ class GridSearchCV( parameter_grid: ParameterGrid scoring: ScorerTypes[OptimizablePipelineT, DatasetT, T] return_optimized: Union[bool, str] - cv: Optional[Union[TpcpSplitter, int, BaseCrossValidator, Iterator]] + cv: Optional[Union[DatasetSplitter, int, BaseCrossValidator, Iterator]] pure_parameters: Union[bool, list[str]] return_train_score: bool verbose: int @@ -717,7 +717,7 @@ def optimize(self, dataset: DatasetT, **optimize_params) -> Self: scoring = _validate_scorer(self.scoring, self.pipeline) - cv = self.cv if isinstance(self.cv, TpcpSplitter) else TpcpSplitter(self.cv) + cv = self.cv if isinstance(self.cv, DatasetSplitter) else DatasetSplitter(self.cv) n_splits = cv.get_n_splits(dataset) diff --git a/tpcp/validate/__init__.py b/tpcp/validate/__init__.py index 55474bd..2097e98 100644 --- a/tpcp/validate/__init__.py +++ b/tpcp/validate/__init__.py @@ -1,6 +1,6 @@ """Module for all helper methods to evaluate algorithms.""" -from tpcp.validate._cross_val_helper import TpcpSplitter +from tpcp.validate._cross_val_helper import DatasetSplitter from tpcp.validate._scorer import Aggregator, MeanAggregator, NoAgg, Scorer from tpcp.validate._validate import cross_validate, validate -__all__ = ["Scorer", "NoAgg", "Aggregator", "MeanAggregator", "cross_validate", "validate", "TpcpSplitter"] +__all__ = ["Scorer", "NoAgg", "Aggregator", "MeanAggregator", "cross_validate", "validate", "DatasetSplitter"] diff --git a/tpcp/validate/_cross_val_helper.py b/tpcp/validate/_cross_val_helper.py index 4393ad6..83ab209 100644 --- a/tpcp/validate/_cross_val_helper.py +++ b/tpcp/validate/_cross_val_helper.py @@ -6,7 +6,7 @@ from tpcp import BaseTpcpObject, Dataset -class TpcpSplitter(BaseTpcpObject): +class DatasetSplitter(BaseTpcpObject): """Wrapper around sklearn cross-validation splitters to support grouping and stratification with tpcp-Datasets. This wrapper can be used instead of a sklearn-style splitter with all methods that support a ``cv`` parameter. diff --git a/tpcp/validate/_validate.py b/tpcp/validate/_validate.py index 0c3b26f..31c1ebb 100644 --- a/tpcp/validate/_validate.py +++ b/tpcp/validate/_validate.py @@ -14,7 +14,7 @@ from tpcp._utils._general import _aggregate_final_results, _normalize_score_results, _passthrough from tpcp._utils._score import _optimize_and_score, _score from tpcp.parallel import delayed -from tpcp.validate._cross_val_helper import TpcpSplitter +from tpcp.validate._cross_val_helper import DatasetSplitter from tpcp.validate._scorer import Scorer, _validate_scorer @@ -23,7 +23,7 @@ def cross_validate( dataset: Dataset, *, scoring: Optional[Callable] = None, - cv: Optional[Union[TpcpSplitter, int, BaseCrossValidator, Iterator]] = None, + cv: Optional[Union[DatasetSplitter, int, BaseCrossValidator, Iterator]] = None, n_jobs: Optional[int] = None, verbose: int = 0, optimize_params: Optional[dict[str, Any]] = None, @@ -115,7 +115,7 @@ def cross_validate( """ scoring = _validate_scorer(scoring, optimizable.pipeline) - cv = cv if isinstance(cv, TpcpSplitter) else TpcpSplitter(base_splitter=cv) + cv = cv if isinstance(cv, DatasetSplitter) else DatasetSplitter(base_splitter=cv) splits = list(cv.split(dataset))