From b831dc517e182a623359cc939a22970bbfd79866 Mon Sep 17 00:00:00 2001
From: Jim Crist-Harif <jcristharif@gmail.com>
Date: Fri, 24 Jan 2025 16:19:34 +0000
Subject: [PATCH] Ensure all method signatures are sklearn compatible

`sklearn` requires `fit`/`fit_transform`/... always take a `y`
parameter, even if it's ignored. This adds a test to ensure our
signatures match this rule, and fixes any cases where they didn't. This
makes it easier to include `cuml` estimators within sklearn pipelines.
---
 python/cuml/cuml/cluster/kmeans.pyx           |  6 +--
 python/cuml/cuml/feature_extraction/_tfidf.py |  6 +--
 python/cuml/cuml/manifold/t_sne.pyx           |  4 +-
 .../random_projection/random_projection.pyx   |  4 +-
 python/cuml/cuml/tests/test_base.py           | 47 ++++++++++++++++++-
 python/cuml/cuml/tests/test_tsne.py           | 30 +++++++-----
 6 files changed, 75 insertions(+), 22 deletions(-)

diff --git a/python/cuml/cuml/cluster/kmeans.pyx b/python/cuml/cuml/cluster/kmeans.pyx
index 6be09f6912..ed26df5cd6 100644
--- a/python/cuml/cuml/cluster/kmeans.pyx
+++ b/python/cuml/cuml/cluster/kmeans.pyx
@@ -1,5 +1,5 @@
 #
-# Copyright (c) 2019-2024, NVIDIA CORPORATION.
+# Copyright (c) 2019-2025, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -284,7 +284,7 @@ class KMeans(UniversalBase,
 
     @generate_docstring()
     @enable_device_interop
-    def fit(self, X, sample_weight=None, convert_dtype=True) -> "KMeans":
+    def fit(self, X, y=None, sample_weight=None, convert_dtype=True) -> "KMeans":
         """
         Compute k-means clustering with X.
 
@@ -422,7 +422,7 @@ class KMeans(UniversalBase,
                                        'description': 'Cluster indexes',
                                        'shape': '(n_samples, 1)'})
     @enable_device_interop
-    def fit_predict(self, X, sample_weight=None) -> CumlArray:
+    def fit_predict(self, X, y=None, sample_weight=None) -> CumlArray:
         """
         Compute cluster centers and predict cluster index for each sample.
 
diff --git a/python/cuml/cuml/feature_extraction/_tfidf.py b/python/cuml/cuml/feature_extraction/_tfidf.py
index 2cf5974119..f929ed2dfa 100644
--- a/python/cuml/cuml/feature_extraction/_tfidf.py
+++ b/python/cuml/cuml/feature_extraction/_tfidf.py
@@ -1,5 +1,5 @@
 #
-# Copyright (c) 2019-2024, NVIDIA CORPORATION.
+# Copyright (c) 2019-2025, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -182,7 +182,7 @@ def _set_idf_diag(self):
         del self.__df
 
     @cuml.internals.api_base_return_any_skipall
-    def fit(self, X) -> "TfidfTransformer":
+    def fit(self, X, y=None) -> "TfidfTransformer":
         """Learn the idf vector (global term weights).
 
         Parameters
@@ -251,7 +251,7 @@ def transform(self, X, copy=True):
         return X
 
     @cuml.internals.api_base_return_any_skipall
-    def fit_transform(self, X, copy=True):
+    def fit_transform(self, X, y=None, copy=True):
         """
         Fit TfidfTransformer to X, then transform X.
         Equivalent to fit(X).transform(X).
diff --git a/python/cuml/cuml/manifold/t_sne.pyx b/python/cuml/cuml/manifold/t_sne.pyx
index b984d47818..08ec39913a 100644
--- a/python/cuml/cuml/manifold/t_sne.pyx
+++ b/python/cuml/cuml/manifold/t_sne.pyx
@@ -413,7 +413,7 @@ class TSNE(UniversalBase,
                         X='dense_sparse',
                         convert_dtype_cast='np.float32')
     @enable_device_interop
-    def fit(self, X, convert_dtype=True, knn_graph=None) -> "TSNE":
+    def fit(self, X, y=None, convert_dtype=True, knn_graph=None) -> "TSNE":
         """
         Fit X into an embedded space.
 
@@ -578,7 +578,7 @@ class TSNE(UniversalBase,
                                        'shape': '(n_samples, n_components)'})
     @cuml.internals.api_base_fit_transform()
     @enable_device_interop
-    def fit_transform(self, X, convert_dtype=True,
+    def fit_transform(self, X, y=None, convert_dtype=True,
                       knn_graph=None) -> CumlArray:
         """
         Fit X into an embedded space and return that transformed output.
diff --git a/python/cuml/cuml/random_projection/random_projection.pyx b/python/cuml/cuml/random_projection/random_projection.pyx
index 81811a4849..48ef013e44 100644
--- a/python/cuml/cuml/random_projection/random_projection.pyx
+++ b/python/cuml/cuml/random_projection/random_projection.pyx
@@ -1,5 +1,5 @@
 #
-# Copyright (c) 2018-2024, NVIDIA CORPORATION.
+# Copyright (c) 2018-2025, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -322,7 +322,7 @@ cdef class BaseRandomProjection():
         return X_new
 
     @cuml.internals.api_base_return_array(get_output_type=False)
-    def fit_transform(self, X, convert_dtype=True):
+    def fit_transform(self, X, y=None, convert_dtype=True):
         return self.fit(X).transform(X, convert_dtype)
 
 
diff --git a/python/cuml/cuml/tests/test_base.py b/python/cuml/cuml/tests/test_base.py
index 0cd01acabb..d258020339 100644
--- a/python/cuml/cuml/tests/test_base.py
+++ b/python/cuml/cuml/tests/test_base.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2019-2024, NVIDIA CORPORATION.
+# Copyright (c) 2019-2025, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -196,3 +196,48 @@ def test_base_children__get_param_names(child_class: str):
                 continue
 
             assert name in param_names
+
+
+# We explicitly skip the models in `cuml.tsa` since they match the statsmodels
+# interface rather than the sklearn interface (https://github.com/rapidsai/cuml/issues/6258).
+# Also skip a few classes that don't match this interface intentionally, since their sklearn
+# equivalents are also exceptions.
+@pytest.mark.parametrize(
+    "cls",
+    [
+        cls
+        for cls in all_base_children.values()
+        if not cls.__module__.startswith("cuml.tsa.")
+        and cls
+        not in {
+            cuml.preprocessing.LabelBinarizer,
+            cuml.preprocessing.LabelEncoder,
+        }
+    ],
+)
+def test_sklearn_methods_with_required_y_parameter(cls):
+    optional_params = {
+        inspect.Parameter.KEYWORD_ONLY,
+        inspect.Parameter.POSITIONAL_OR_KEYWORD,
+        inspect.Parameter.VAR_KEYWORD,
+    }
+    for name in [
+        "fit",
+        "partial_fit",
+        "score",
+        "fit_transform",
+        "fit_predict",
+    ]:
+        if (method := getattr(cls, name, None)) is None:
+            # Method not defined, skip
+            continue
+        params = list(inspect.signature(method).parameters.values())
+        # Assert method has a 2nd parameter named y, which is required by sklearn
+        assert (
+            len(params) > 2 and params[2].name == "y"
+        ), f"`{name}` requires a `y` parameter, even if it's ignored"
+        # Check that all remaining parameters are optional
+        for param in params[3:]:
+            assert (
+                param.kind in optional_params
+            ), f"`{name}` parameter `{param.name}` must be optional"
diff --git a/python/cuml/cuml/tests/test_tsne.py b/python/cuml/cuml/tests/test_tsne.py
index fe119eb999..115ade2848 100644
--- a/python/cuml/cuml/tests/test_tsne.py
+++ b/python/cuml/cuml/tests/test_tsne.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2019-2024, NVIDIA CORPORATION.
+# Copyright (c) 2019-2025, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -75,7 +75,7 @@ def test_tsne_knn_graph_used(test_datasets, type_knn_graph, method):
     )
 
     # Perform tsne with normal knn_graph
-    Y = tsne.fit_transform(X, True, knn_graph)
+    Y = tsne.fit_transform(X, convert_dtype=True, knn_graph=knn_graph)
 
     trust_normal = trustworthiness(X, Y, n_neighbors=DEFAULT_N_NEIGHBORS)
 
@@ -97,16 +97,16 @@ def test_tsne_knn_graph_used(test_datasets, type_knn_graph, method):
     )
 
     # Perform tsne with garbage knn_graph
-    Y = tsne.fit_transform(X, True, knn_graph_garbage)
+    Y = tsne.fit_transform(X, convert_dtype=True, knn_graph=knn_graph_garbage)
 
     trust_garbage = trustworthiness(X, Y, n_neighbors=DEFAULT_N_NEIGHBORS)
     assert (trust_normal - trust_garbage) > 0.15
 
-    Y = tsne.fit_transform(X, True, knn_graph_garbage)
+    Y = tsne.fit_transform(X, convert_dtype=True, knn_graph=knn_graph_garbage)
     trust_garbage = trustworthiness(X, Y, n_neighbors=DEFAULT_N_NEIGHBORS)
     assert (trust_normal - trust_garbage) > 0.15
 
-    Y = tsne.fit_transform(X, True, knn_graph_garbage)
+    Y = tsne.fit_transform(X, convert_dtype=True, knn_graph=knn_graph_garbage)
     trust_garbage = trustworthiness(X, Y, n_neighbors=DEFAULT_N_NEIGHBORS)
     assert (trust_normal - trust_garbage) > 0.15
 
@@ -137,13 +137,17 @@ def test_tsne_knn_parameters(test_datasets, type_knn_graph, method):
         perplexity=DEFAULT_PERPLEXITY,
     )
 
-    embed = tsne.fit_transform(X, True, knn_graph)
+    embed = tsne.fit_transform(X, convert_dtype=True, knn_graph=knn_graph)
     validate_embedding(X, embed)
 
-    embed = tsne.fit_transform(X, True, knn_graph.tocoo())
+    embed = tsne.fit_transform(
+        X, convert_dtype=True, knn_graph=knn_graph.tocoo()
+    )
     validate_embedding(X, embed)
 
-    embed = tsne.fit_transform(X, True, knn_graph.tocsc())
+    embed = tsne.fit_transform(
+        X, convert_dtype=True, knn_graph=knn_graph.tocsc()
+    )
     validate_embedding(X, embed)
 
 
@@ -309,17 +313,21 @@ def test_tsne_knn_parameters_sparse(type_knn_graph, input_type, method):
 
     new_data = sp_prefix.csr_matrix(scipy.sparse.csr_matrix(digits))
 
-    Y = tsne.fit_transform(new_data, True, knn_graph)
+    Y = tsne.fit_transform(new_data, convert_dtype=True, knn_graph=knn_graph)
     if input_type == "cupy":
         Y = Y.get()
     validate_embedding(digits, Y, 0.85)
 
-    Y = tsne.fit_transform(new_data, True, knn_graph.tocoo())
+    Y = tsne.fit_transform(
+        new_data, convert_dtype=True, knn_graph=knn_graph.tocoo()
+    )
     if input_type == "cupy":
         Y = Y.get()
     validate_embedding(digits, Y, 0.85)
 
-    Y = tsne.fit_transform(new_data, True, knn_graph.tocsc())
+    Y = tsne.fit_transform(
+        new_data, convert_dtype=True, knn_graph=knn_graph.tocsc()
+    )
     if input_type == "cupy":
         Y = Y.get()
     validate_embedding(digits, Y, 0.85)