Skip to content

fix: Accept X and y as positional argument with as_dict=True in train_test_split #1570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

amltarek
Copy link

@amltarek amltarek commented Apr 20, 2025

closes #1544

I fixed the issue by updating the handling of keyword arguments when as_dict=True is used. Now, if all datasets are passed as keywords, the function directly returns them as a dictionary without extra processing. This makes the behavior more intuitive and avoids redundancy. I also tested the fix through test_train_split_test.py to ensure it works correctly.
Screenshot from 2025-04-20 22-46-40

Copy link
Contributor

@auguste-probabl auguste-probabl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some tests? It would also help to demonstrate what the new behaviour looks like.

@amltarek
Copy link
Author

@auguste-probabl
I added the tests and restored the function's documentation. Could you check it and let me know if there are any updates I can make?

Copy link
Contributor

github-actions bot commented Apr 23, 2025

Documentation preview @ d6f9b44

@glemaitre glemaitre changed the title fix:Passing all datasets by keyword makes it annoying to use as_dict=True #1544 fix: Passing all datasets by keyword makes it annoying to use as_dict=True #1544 Apr 23, 2025
@glemaitre glemaitre changed the title fix: Passing all datasets by keyword makes it annoying to use as_dict=True #1544 fix: Passing all datasets by keyword makes it annoying to use as_dict=True Apr 23, 2025
@glemaitre glemaitre changed the title fix: Passing all datasets by keyword makes it annoying to use as_dict=True fix: Accept X and y as positional argument with as_dict=True in train_test_split Apr 23, 2025
if y is not None:
new_arrays.append(y)
keys += ["y"]

if as_dict and arrays:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you remove this? If both are true, it will cause a conflict.


new_arrays = list(keyword_arrays.values())

if X is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this repeated inside and outside the if? It is redundant.

if X is not None:
new_arrays.append(X)
keys.append("X")
if y is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this repeated inside and outside the if? It is redundant.

@@ -167,21 +177,20 @@ class labels.
stratify=stratify,
)

if X is None:
X = arrays[0] if len(arrays) == 1 else arrays[-2]
if X is None and len(arrays) >= 1:
Copy link
Contributor

@nkapila6 nkapila6 Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect value for case when len(arrays)>1?

@amltarek
Copy link
Author

@auguste-probabl I have implemented the changes as you requested and also tested it. Additionally, I modified another function, test_train_test_split_dict_kwargs(), because it was throwing an error when data was passed without keyword arguments while return_dict=True. This issue has now been fixed. Furthermore, I addressed all the changes requested by @nkapila6.

Copy link
Contributor

@auguste-probabl auguste-probabl Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test with

arr1 = [[1]] * 20
arr2 = [0] * 10 + [1] * 10
train_test_split(arr2, z=arr1, as_dict=True)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is tested through two functions named( test_train_test_split_check_dict()) and test_train_test_split_dict_kwargs().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit different: right now we test either all arguments passed by keyword, or all arguments passed by position. I'd like to also test the combination of both (one array passed by position, one array passed by keyword).

@amltarek
Copy link
Author

I have removed all the duplicate functions. Can you check the code? @auguste-probabl

Comment on lines 244 to 251
result = train_test_split(
X=X,
y=y,
sample_weights=weights,
test_size=0.2,
as_dict=True,
random_state=0,
)
Copy link
Contributor

@auguste-probabl auguste-probabl Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test feels a bit redundant, but

train_test_split(
    X,
    y,
    sample_weights=weights,
    ...
)

(i.e. a mix of positional and keyword arguments) would be interesting. See also my other comment

@auguste-probabl
Copy link
Contributor

@amltarek Can you resolve comments when it's clear that you have addressed them? It helps review your code.

@auguste-probabl
Copy link
Contributor

Also please sign your commits

Copy link
Contributor

Coverage

Coverage Report for backend
FileStmtsMissCoverMissing
venv/lib/python3.12/site-packages/skore
   __init__.py220100% 
   _config.py280100% 
   exceptions.py440%4–23
venv/lib/python3.12/site-packages/skore/persistence
   __init__.py00100% 
venv/lib/python3.12/site-packages/skore/persistence/item
   __init__.py55198%97
   altair_chart_item.py19191%14
   item.py22195%86
   matplotlib_figure_item.py36195%19
   media_item.py220100% 
   numpy_array_item.py27194%16
   pandas_dataframe_item.py29194%14
   pandas_series_item.py29194%14
   pickle_item.py220100% 
   pillow_image_item.py25193%15
   plotly_figure_item.py20192%14
   polars_dataframe_item.py27194%14
   polars_series_item.py22192%14
   primitive_item.py23291%13–15
   sklearn_base_estimator_item.py29194%15
venv/lib/python3.12/site-packages/skore/persistence/repository
   __init__.py20100% 
   item_repository.py59591%15–16, 202–203, 226
venv/lib/python3.12/site-packages/skore/persistence/storage
   __init__.py40100% 
   abstract_storage.py220100% 
   disk_cache_storage.py33195%44
   in_memory_storage.py200100% 
venv/lib/python3.12/site-packages/skore/project
   __init__.py20100% 
   project.py83298%280, 392
venv/lib/python3.12/site-packages/skore/sklearn
   __init__.py60100% 
   _base.py1711492%45, 58, 126, 129, 182–191, 203–>209, 224, 227–228
   find_ml_task.py61099%136–>145
   types.py130100% 
venv/lib/python3.12/site-packages/skore/sklearn/_comparison
   __init__.py50100% 
   metrics_accessor.py165297%163, 164–>166, 1278
   report.py67197%17, 249–>252
venv/lib/python3.12/site-packages/skore/sklearn/_cross_validation
   __init__.py50100% 
   metrics_accessor.py190099%153–>155, 155–>157
   report.py110198%23
venv/lib/python3.12/site-packages/skore/sklearn/_estimator
   __init__.py70100% 
   feature_importance_accessor.py133099%483–>489, 569–>578
   metrics_accessor.py3441096%174–183, 211–>220, 219, 249, 260–>262, 290, 317–321, 336, 371, 372–>374
   report.py148198%24, 253–>255
venv/lib/python3.12/site-packages/skore/sklearn/_plot
   __init__.py20100% 
   base.py60100% 
   style.py280100% 
   utils.py122595%51, 75–77, 81
venv/lib/python3.12/site-packages/skore/sklearn/_plot/metrics
   __init__.py40100% 
   precision_recall_curve.py173199%660
   prediction_error.py1640100% 
   roc_curve.py176199%649
venv/lib/python3.12/site-packages/skore/sklearn/train_test_split
   __init__.py00100% 
   train_test_split.py57393%16, 161, 177
venv/lib/python3.12/site-packages/skore/sklearn/train_test_split/warning
   __init__.py80100% 
   high_class_imbalance_too_few_examples_warning.py17190%79
   high_class_imbalance_warning.py180100% 
   random_state_unset_warning.py12188%15
   shuffle_true_warning.py10183%46
   stratify_is_set_warning.py12188%15
   time_based_column_warning.py23286%17, 73
   train_test_split_warning.py40100% 
venv/lib/python3.12/site-packages/skore/utils
   __init__.py60100% 
   _accessor.py46197%102
   _environment.py27097%30–>35
   _fixes.py80100% 
   _index.py50100% 
   _logger.py22485%15–19
   _measure_time.py100100% 
   _parallel.py38388%23–33, 124
   _patch.py13553%21–37
   _progress_bar.py360100% 
   _show_versions.py330100% 
TOTAL31918496% 

Tests Skipped Failures Errors Time
816 8 💤 0 ❌ 0 🔥 53.579s ⏱️

@amltarek
Copy link
Author

@auguste-probabl Can you check the updates?

Comment on lines +232 to +237
def test_empty_input():
"""Tests that passing empty lists for X and y raises a ValueError."""
X = []
y = []
with pytest.raises(ValueError):
train_test_split(X, y)
Copy link
Contributor

@auguste-probabl auguste-probabl Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this test is needed; this behaviour is not specific to our function, but rather to sklearn's.

Comment on lines +194 to +197
assert "X_train" in result
assert "X_test" in result
assert "y_train" in result
assert "y_test" in result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for this

>>> # When using positional arguments and as_dict=True
>>> # the first argument is assumed to be X, the second y
>>> train_test_split(
... [[1], [2], [3], [4]], [0, 1, 0, 1], as_dict=True, random_state=0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
... [[1], [2], [3], [4]], [0, 1, 0, 1], as_dict=True, random_state=0
... [[1], [2], [3], [4]], [0, 1, 0, 1], as_dict=True

No need for random_state since we don't check the output arrays. You can also remove random_state in the previous doctest

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also test what happens with

train_test_split(X, X=X)

I think there should be an error like

X cannot be passed both by position and by keyword.

Same for y.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Passing all datasets by keyword makes it annoying to use as_dict=True
4 participants