Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented custom wrapper around sklearn splitters #114

Merged
merged 2 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) (+ the Migration Guide),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.0.0] - [unreleased]

### BREAKING CHANGE

- Instead of the (annoying) `mock_label` and `group_label` arguments, all functions that take a cv-splitter as input,
can now take an instance of the new `DatasetSplitter` class, which elegantly handles grouping and stratification and
also removes the need of forwarding the `mock_label` and `group_label` arguments to the underlying optimizer.
The use of the `mock_label` and `group_label` arguments has been removed without depreciation.
(https://github.com/mad-lab-fau/tpcp/pull/114)

## [0.34.0] - 2024-06-28

### Added
Expand Down
1 change: 1 addition & 0 deletions docs/modules/validate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Classes
:toctree: generated/validate
:template: class_with_private.rst

DatasetSplitter
Scorer
Aggregator
MeanAggregator
Expand Down
13 changes: 13 additions & 0 deletions examples/datasets/_01_datasets_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,19 @@ def create_index(self):
# We only print the train set here
print(final_subset[train], end="\n\n")

# %%
# Instead of doing this manually, we also provide a custom splitter that does this for you.
# It allows us to directly put the dataset into the `split` method of `cross_validate` and use higher level semantics
# to specify the grouping and stratification.
from tpcp.validate import DatasetSplitter

cv = DatasetSplitter(GroupKFold(n_splits=2), groupby=["participant", "recording"])

for train, test in cv.split(final_subset):
# We only print the train set here
print(final_subset[train], end="\n\n")


# %%
# Creating labels also works for datasets that are already grouped.
# But, the columns that should be contained in the label must be a subset of the groupby columns in this case.
Expand Down
171 changes: 171 additions & 0 deletions examples/validation/_04_advanced_cross_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""
Advanced cross-validation
-------------------------
In many real world datasets, a normal k-fold cross-validation might not be ideal, as it assumes that each data point is
fully independent of each other.
This is often not the case, as our dataset might contain multiple data points from the same participant.
Furthermore, we might have multiple "stratification" variables that we want to keep balanced across the folds.
For example, different clinical conditions or different measurement devices.

This two concepts of "grouping" and "stratification" are sometimes complicated to understand and certain (even though
common) cases are not supported by the standard sklearn cross-validation splitters, without "abusing" the API.
For this reason, we create dedicated support for this in tpcp to tackle these cases with a little more confidence.
"""
# %%
# Let's start by re-creating the simple example from the normal cross-validation example.
#
# Dataset
# +++++++
from pathlib import Path

from examples.datasets.datasets_final_ecg import ECGExampleData

try:
HERE = Path(__file__).parent
except NameError:
HERE = Path().resolve()
data_path = HERE.parent.parent / "example_data/ecg_mit_bih_arrhythmia/data"
example_data = ECGExampleData(data_path)

# %%
# Pipeline
# ++++++++
import pandas as pd

from examples.algorithms.algorithms_qrs_detection_final import QRSDetector
from tpcp import Parameter, Pipeline, cf


class MyPipeline(Pipeline):
algorithm: Parameter[QRSDetector]

r_peak_positions_: pd.Series

def __init__(self, algorithm: QRSDetector = cf(QRSDetector())):
self.algorithm = algorithm

def run(self, datapoint: ECGExampleData):
# Note: We need to clone the algorithm instance, to make sure we don't leak any data between runs.
algo = self.algorithm.clone()
algo.detect(datapoint.data, datapoint.sampling_rate_hz)

self.r_peak_positions_ = algo.r_peak_positions_
return self


# %%
# The Scorer
# ++++++++++
from examples.algorithms.algorithms_qrs_detection_final import match_events_with_reference, precision_recall_f1_score


def score(pipeline: MyPipeline, datapoint: ECGExampleData):
# We use the `safe_run` wrapper instead of just run. This is always a good idea.
# We don't need to clone the pipeline here, as GridSearch will already clone the pipeline internally and `run`
# will clone it again.
pipeline = pipeline.safe_run(datapoint)
tolerance_s = 0.02 # We just use 20 ms for this example
matches = match_events_with_reference(
pipeline.r_peak_positions_.to_numpy(),
datapoint.r_peak_positions_.to_numpy(),
tolerance=tolerance_s * datapoint.sampling_rate_hz,
)
precision, recall, f1_score = precision_recall_f1_score(matches)
return {"precision": precision, "recall": recall, "f1_score": f1_score}


# %%
# Stratifcation
# +++++++++++++
# With this setup done, we can have a closer look at the dataset.
example_data

# %%
# The index has two columns, one indicating the participant group and one indicating the participant id.
# In this simple example, all groups appear the same amount of times and the index is ordered in a way that
# each fold will likely get a balanced amount of participants from each group.
#
# To show the impact of grouping and stratification, we take a subset of the data, that removes some participants from
# "group_1" to create an imbalance.
data_imbalanced = example_data.get_subset(index=example_data.index.query("participant not in ['114', '121']"))

# %%
# Running a simple cross-validation with 2 folds, will have all group-1 participants in the test data of the first fold:
#
# Note, that we skip optimization of the pipeline, to keep the example simple and fast.
from sklearn.model_selection import KFold

from tpcp.optimize import DummyOptimize
from tpcp.validate import cross_validate

cv = KFold(n_splits=2)

pipe = MyPipeline()
optimizable_pipe = DummyOptimize(pipe)

results = cross_validate(optimizable_pipe, data_imbalanced, scoring=score, cv=cv)
result_df = pd.DataFrame(results)

# %%
# We can see that the test data of the first fold contains only participants from group 1.
result_df["test_data_labels"].explode()

# %%
# This works fine when the groups are just "additional information", and are unlikely to affect the data within.
# For example, if the groups just reflect in which hospital the data was collected.
# However, when the group reflect information that is likely to affect the data (e.g. a relevant medical indication),
# we need to make sure that the actual group probabilities are remain the same in all folds.
# This can be done through stratification.
#
# .. note:: It is important to understand that "stratification" is not "balancing" the groups.
# Group balancing should never be done during data splitting, as it will change the data distribution in your
# test set, which will no longer reflect the real-world distribution.
#
# To stratify by the "patient group" we can use the `TpcpSplitter` class.
# We will provide it with a base splitter that enables stratification (in this case a `StratifiedKFold` splitter) and
# the column(s) to stratify by.
from sklearn.model_selection import StratifiedKFold

from tpcp.validate import DatasetSplitter

cv = DatasetSplitter(base_splitter=StratifiedKFold(n_splits=2), stratify="patient_group")

results = cross_validate(optimizable_pipe, data_imbalanced, scoring=score, cv=cv)
result_df_stratified = pd.DataFrame(results)
result_df_stratified["test_data_labels"].explode()

# %%
# Now we can see that the groups are balanced in each fold and both folds get one of the remaining group 1 participants.
#
# Grouping
# ++++++++
# Where stratification ensures that the distribution of a specific column is the same in all folds, grouping ensures
# that all data of one group is always either in the train or the test set, but never split across it.
# This is useful, when we have data points that are somehow correlated and the existence of data points from the same
# group in both the train and the test set of the same fold could hence be considered a "leak".
#
# A typical example for this is when we have multiple data points from the same participant.
# In our case here, we will use the "patient_group" as grouping variable for demonstration purposes, as we don't have multiple
# data points per participant.
#
# Note, that we use the "non-subsampled" example data here.
from sklearn.model_selection import GroupKFold

cv = DatasetSplitter(base_splitter=GroupKFold(n_splits=2), groupby="patient_group")

results = cross_validate(optimizable_pipe, example_data, scoring=score, cv=cv)
result_df_grouped = pd.DataFrame(results)
result_df_grouped["test_data_labels"].explode()

# %%
# We can see that this forces the creation of unequal sice splits to ensure that the groups are kept together.
# This is important to keep in mind when using grouping, as it can lead to unequally sized test sets.
#
# Combining Grouping and Stratification
# +++++++++++++++++++++++++++++++++++++
# Of course, we can also combine grouping and stratification.
# A typical example would be to stratify by clinical condition and group by participant.
# This is also easily possible with the `TpcpSplitter` class by providing both arguments.
#
# For the dataset that we have here, this does of course not make much sense, so we are not going to show an example
# here.
144 changes: 144 additions & 0 deletions tests/test_examples/snapshot/test_advanced_cross_validate_0.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
{
"schema":{
"fields":[
{
"name":"index",
"type":"integer"
},
{
"name":"fold_id",
"type":"integer"
},
{
"name":"test_data_labels",
"type":"string"
}
],
"primaryKey":[
"index"
],
"pandas_version":"1.4.0"
},
"data":[
{
"index":0,
"fold_id":0,
"test_data_labels":"group_1"
},
{
"index":1,
"fold_id":0,
"test_data_labels":"100"
},
{
"index":2,
"fold_id":0,
"test_data_labels":"group_3"
},
{
"index":3,
"fold_id":0,
"test_data_labels":"104"
},
{
"index":4,
"fold_id":0,
"test_data_labels":"group_1"
},
{
"index":5,
"fold_id":0,
"test_data_labels":"105"
},
{
"index":6,
"fold_id":0,
"test_data_labels":"group_3"
},
{
"index":7,
"fold_id":0,
"test_data_labels":"108"
},
{
"index":8,
"fold_id":0,
"test_data_labels":"group_1"
},
{
"index":9,
"fold_id":0,
"test_data_labels":"114"
},
{
"index":10,
"fold_id":0,
"test_data_labels":"group_3"
},
{
"index":11,
"fold_id":0,
"test_data_labels":"119"
},
{
"index":12,
"fold_id":0,
"test_data_labels":"group_1"
},
{
"index":13,
"fold_id":0,
"test_data_labels":"121"
},
{
"index":14,
"fold_id":0,
"test_data_labels":"group_3"
},
{
"index":15,
"fold_id":0,
"test_data_labels":"200"
},
{
"index":16,
"fold_id":1,
"test_data_labels":"group_2"
},
{
"index":17,
"fold_id":1,
"test_data_labels":"102"
},
{
"index":18,
"fold_id":1,
"test_data_labels":"group_2"
},
{
"index":19,
"fold_id":1,
"test_data_labels":"106"
},
{
"index":20,
"fold_id":1,
"test_data_labels":"group_2"
},
{
"index":21,
"fold_id":1,
"test_data_labels":"116"
},
{
"index":22,
"fold_id":1,
"test_data_labels":"group_2"
},
{
"index":23,
"fold_id":1,
"test_data_labels":"123"
}
]
}
Loading
Loading