diff --git a/python/cuml/cuml/cluster/kmeans.pyx b/python/cuml/cuml/cluster/kmeans.pyx index 6be09f6912..ed26df5cd6 100644 --- a/python/cuml/cuml/cluster/kmeans.pyx +++ b/python/cuml/cuml/cluster/kmeans.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -284,7 +284,7 @@ class KMeans(UniversalBase, @generate_docstring() @enable_device_interop - def fit(self, X, sample_weight=None, convert_dtype=True) -> "KMeans": + def fit(self, X, y=None, sample_weight=None, convert_dtype=True) -> "KMeans": """ Compute k-means clustering with X. @@ -422,7 +422,7 @@ class KMeans(UniversalBase, 'description': 'Cluster indexes', 'shape': '(n_samples, 1)'}) @enable_device_interop - def fit_predict(self, X, sample_weight=None) -> CumlArray: + def fit_predict(self, X, y=None, sample_weight=None) -> CumlArray: """ Compute cluster centers and predict cluster index for each sample. diff --git a/python/cuml/cuml/feature_extraction/_tfidf.py b/python/cuml/cuml/feature_extraction/_tfidf.py index 2cf5974119..f929ed2dfa 100644 --- a/python/cuml/cuml/feature_extraction/_tfidf.py +++ b/python/cuml/cuml/feature_extraction/_tfidf.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -182,7 +182,7 @@ def _set_idf_diag(self): del self.__df @cuml.internals.api_base_return_any_skipall - def fit(self, X) -> "TfidfTransformer": + def fit(self, X, y=None) -> "TfidfTransformer": """Learn the idf vector (global term weights). Parameters @@ -251,7 +251,7 @@ def transform(self, X, copy=True): return X @cuml.internals.api_base_return_any_skipall - def fit_transform(self, X, copy=True): + def fit_transform(self, X, y=None, copy=True): """ Fit TfidfTransformer to X, then transform X. Equivalent to fit(X).transform(X). diff --git a/python/cuml/cuml/manifold/t_sne.pyx b/python/cuml/cuml/manifold/t_sne.pyx index b984d47818..08ec39913a 100644 --- a/python/cuml/cuml/manifold/t_sne.pyx +++ b/python/cuml/cuml/manifold/t_sne.pyx @@ -413,7 +413,7 @@ class TSNE(UniversalBase, X='dense_sparse', convert_dtype_cast='np.float32') @enable_device_interop - def fit(self, X, convert_dtype=True, knn_graph=None) -> "TSNE": + def fit(self, X, y=None, convert_dtype=True, knn_graph=None) -> "TSNE": """ Fit X into an embedded space. @@ -578,7 +578,7 @@ class TSNE(UniversalBase, 'shape': '(n_samples, n_components)'}) @cuml.internals.api_base_fit_transform() @enable_device_interop - def fit_transform(self, X, convert_dtype=True, + def fit_transform(self, X, y=None, convert_dtype=True, knn_graph=None) -> CumlArray: """ Fit X into an embedded space and return that transformed output. diff --git a/python/cuml/cuml/random_projection/random_projection.pyx b/python/cuml/cuml/random_projection/random_projection.pyx index 81811a4849..48ef013e44 100644 --- a/python/cuml/cuml/random_projection/random_projection.pyx +++ b/python/cuml/cuml/random_projection/random_projection.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2018-2024, NVIDIA CORPORATION. +# Copyright (c) 2018-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -322,7 +322,7 @@ cdef class BaseRandomProjection(): return X_new @cuml.internals.api_base_return_array(get_output_type=False) - def fit_transform(self, X, convert_dtype=True): + def fit_transform(self, X, y=None, convert_dtype=True): return self.fit(X).transform(X, convert_dtype) diff --git a/python/cuml/cuml/tests/test_base.py b/python/cuml/cuml/tests/test_base.py index 0cd01acabb..d258020339 100644 --- a/python/cuml/cuml/tests/test_base.py +++ b/python/cuml/cuml/tests/test_base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -196,3 +196,48 @@ def test_base_children__get_param_names(child_class: str): continue assert name in param_names + + +# We explicitly skip the models in `cuml.tsa` since they match the statsmodels +# interface rather than the sklearn interface (https://github.com/rapidsai/cuml/issues/6258). +# Also skip a few classes that don't match this interface intentionally, since their sklearn +# equivalents are also exceptions. +@pytest.mark.parametrize( + "cls", + [ + cls + for cls in all_base_children.values() + if not cls.__module__.startswith("cuml.tsa.") + and cls + not in { + cuml.preprocessing.LabelBinarizer, + cuml.preprocessing.LabelEncoder, + } + ], +) +def test_sklearn_methods_with_required_y_parameter(cls): + optional_params = { + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.VAR_KEYWORD, + } + for name in [ + "fit", + "partial_fit", + "score", + "fit_transform", + "fit_predict", + ]: + if (method := getattr(cls, name, None)) is None: + # Method not defined, skip + continue + params = list(inspect.signature(method).parameters.values()) + # Assert method has a 2nd parameter named y, which is required by sklearn + assert ( + len(params) > 2 and params[2].name == "y" + ), f"`{name}` requires a `y` parameter, even if it's ignored" + # Check that all remaining parameters are optional + for param in params[3:]: + assert ( + param.kind in optional_params + ), f"`{name}` parameter `{param.name}` must be optional" diff --git a/python/cuml/cuml/tests/test_tsne.py b/python/cuml/cuml/tests/test_tsne.py index fe119eb999..115ade2848 100644 --- a/python/cuml/cuml/tests/test_tsne.py +++ b/python/cuml/cuml/tests/test_tsne.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -75,7 +75,7 @@ def test_tsne_knn_graph_used(test_datasets, type_knn_graph, method): ) # Perform tsne with normal knn_graph - Y = tsne.fit_transform(X, True, knn_graph) + Y = tsne.fit_transform(X, convert_dtype=True, knn_graph=knn_graph) trust_normal = trustworthiness(X, Y, n_neighbors=DEFAULT_N_NEIGHBORS) @@ -97,16 +97,16 @@ def test_tsne_knn_graph_used(test_datasets, type_knn_graph, method): ) # Perform tsne with garbage knn_graph - Y = tsne.fit_transform(X, True, knn_graph_garbage) + Y = tsne.fit_transform(X, convert_dtype=True, knn_graph=knn_graph_garbage) trust_garbage = trustworthiness(X, Y, n_neighbors=DEFAULT_N_NEIGHBORS) assert (trust_normal - trust_garbage) > 0.15 - Y = tsne.fit_transform(X, True, knn_graph_garbage) + Y = tsne.fit_transform(X, convert_dtype=True, knn_graph=knn_graph_garbage) trust_garbage = trustworthiness(X, Y, n_neighbors=DEFAULT_N_NEIGHBORS) assert (trust_normal - trust_garbage) > 0.15 - Y = tsne.fit_transform(X, True, knn_graph_garbage) + Y = tsne.fit_transform(X, convert_dtype=True, knn_graph=knn_graph_garbage) trust_garbage = trustworthiness(X, Y, n_neighbors=DEFAULT_N_NEIGHBORS) assert (trust_normal - trust_garbage) > 0.15 @@ -137,13 +137,17 @@ def test_tsne_knn_parameters(test_datasets, type_knn_graph, method): perplexity=DEFAULT_PERPLEXITY, ) - embed = tsne.fit_transform(X, True, knn_graph) + embed = tsne.fit_transform(X, convert_dtype=True, knn_graph=knn_graph) validate_embedding(X, embed) - embed = tsne.fit_transform(X, True, knn_graph.tocoo()) + embed = tsne.fit_transform( + X, convert_dtype=True, knn_graph=knn_graph.tocoo() + ) validate_embedding(X, embed) - embed = tsne.fit_transform(X, True, knn_graph.tocsc()) + embed = tsne.fit_transform( + X, convert_dtype=True, knn_graph=knn_graph.tocsc() + ) validate_embedding(X, embed) @@ -309,17 +313,21 @@ def test_tsne_knn_parameters_sparse(type_knn_graph, input_type, method): new_data = sp_prefix.csr_matrix(scipy.sparse.csr_matrix(digits)) - Y = tsne.fit_transform(new_data, True, knn_graph) + Y = tsne.fit_transform(new_data, convert_dtype=True, knn_graph=knn_graph) if input_type == "cupy": Y = Y.get() validate_embedding(digits, Y, 0.85) - Y = tsne.fit_transform(new_data, True, knn_graph.tocoo()) + Y = tsne.fit_transform( + new_data, convert_dtype=True, knn_graph=knn_graph.tocoo() + ) if input_type == "cupy": Y = Y.get() validate_embedding(digits, Y, 0.85) - Y = tsne.fit_transform(new_data, True, knn_graph.tocsc()) + Y = tsne.fit_transform( + new_data, convert_dtype=True, knn_graph=knn_graph.tocsc() + ) if input_type == "cupy": Y = Y.get() validate_embedding(digits, Y, 0.85)