diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 8f81a80959e0d..7d8175a3b5046 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -789,6 +789,12 @@ Changelog unavailable on the basis of state, in a more readable way. :pr:`19948` by `Joel Nothman`_. +_ |Enhancement| :func:`utils.validation.check_is_fitted` now uses + ``__sklearn_is_fitted__`` if available, instead of checking for attributes ending with + an underscore. This also makes :class:`Pipeline` and + :class:`preprocessing.FunctionTransformer` pass + ``check_is_fitted(estimator)``. :pr:`20657` by `Adrin Jalali`_. + - |Fix| Fixed a bug in :func:`utils.sparsefuncs.mean_variance_axis` where the precision of the computed variance was very poor when the real variance is exactly zero. :pr:`19766` by :user:`Jérémie du Boisberranger `. diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 0814632721ba4..35f0fa3768b45 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -26,7 +26,9 @@ from .utils.deprecation import deprecated from .utils._tags import _safe_tags from .utils.validation import check_memory +from .utils.validation import check_is_fitted from .utils.fixes import delayed +from .exceptions import NotFittedError from .utils.metaestimators import _BaseComposition @@ -657,6 +659,18 @@ def n_features_in_(self): # delegate to first step (which will call _check_is_fitted) return self.steps[0][1].n_features_in_ + def __sklearn_is_fitted__(self): + """Indicate whether pipeline has been fit.""" + try: + # check if the last step of the pipeline is fitted + # we only check the last step since if the last step is fit, it + # means the previous steps should also be fit. This is faster than + # checking if every step of the pipeline is fit. + check_is_fitted(self.steps[-1][1]) + return True + except NotFittedError: + return False + def _sk_visual_block_(self): _, estimators = zip(*self.steps) diff --git a/sklearn/preprocessing/_function_transformer.py b/sklearn/preprocessing/_function_transformer.py index 345cc96bb1c2e..202b6ec2f6cdd 100644 --- a/sklearn/preprocessing/_function_transformer.py +++ b/sklearn/preprocessing/_function_transformer.py @@ -176,5 +176,9 @@ def _transform(self, X, func=None, kw_args=None): return func(X, **(kw_args if kw_args else {})) + def __sklearn_is_fitted__(self): + """Return True since FunctionTransfomer is stateless.""" + return True + def _more_tags(self): return {"no_validation": not self.validate, "stateless": True} diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 4176e1a65f4b2..4ec5c7f081a15 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -21,7 +21,8 @@ MinimalRegressor, MinimalTransformer, ) - +from sklearn.exceptions import NotFittedError +from sklearn.utils.validation import check_is_fitted from sklearn.base import clone, is_classifier, BaseEstimator, TransformerMixin from sklearn.pipeline import Pipeline, FeatureUnion, make_pipeline, make_union from sklearn.svm import SVC @@ -1361,3 +1362,16 @@ def test_search_cv_using_minimal_compatible_estimator(Predictor): else: assert_allclose(y_pred, y.mean()) assert model.score(X, y) == pytest.approx(r2_score(y, y_pred)) + + +def test_pipeline_check_if_fitted(): + class Estimator(BaseEstimator): + def fit(self, X, y): + self.fitted_ = True + return self + + pipeline = Pipeline([("clf", Estimator())]) + with pytest.raises(NotFittedError): + check_is_fitted(pipeline) + pipeline.fit(iris.data, iris.target) + check_is_fitted(pipeline) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 8a51d87682dd4..cf88785807d06 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -53,6 +53,7 @@ from ..model_selection import ShuffleSplit from ..model_selection._validation import _safe_split from ..metrics.pairwise import rbf_kernel, linear_kernel, pairwise_distances +from ..utils.validation import check_is_fitted from . import shuffle from ._tags import ( @@ -307,6 +308,7 @@ def _yield_all_checks(estimator): yield check_dict_unchanged yield check_dont_overwrite_parameters yield check_fit_idempotent + yield check_fit_check_is_fitted if not tags["no_validation"]: yield check_n_features_in yield check_fit1d @@ -3501,6 +3503,45 @@ def check_fit_idempotent(name, estimator_orig): ) +def check_fit_check_is_fitted(name, estimator_orig): + # Make sure that estimator doesn't pass check_is_fitted before calling fit + # and that passes check_is_fitted once it's fit. + + rng = np.random.RandomState(42) + + estimator = clone(estimator_orig) + set_random_state(estimator) + if "warm_start" in estimator.get_params(): + estimator.set_params(warm_start=False) + + n_samples = 100 + X = rng.normal(loc=100, size=(n_samples, 2)) + X = _pairwise_estimator_convert_X(X, estimator) + if is_regressor(estimator_orig): + y = rng.normal(size=n_samples) + else: + y = rng.randint(low=0, high=2, size=n_samples) + y = _enforce_estimator_tags_y(estimator, y) + + if not _safe_tags(estimator).get("stateless", False): + # stateless estimators (such as FunctionTransformer) are always "fit"! + try: + check_is_fitted(estimator) + raise AssertionError( + f"{estimator.__class__.__name__} passes check_is_fitted before being" + " fit!" + ) + except NotFittedError: + pass + estimator.fit(X, y) + try: + check_is_fitted(estimator) + except NotFittedError as e: + raise NotFittedError( + "Estimator fails to pass `check_is_fitted` even though it has been fit." + ) from e + + def check_n_features_in(name, estimator_orig): # Make sure that n_features_in_ attribute doesn't exist until fit is # called, and that its value is correct. diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 3d565ca5895ef..b50d263245f32 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -34,6 +34,7 @@ from sklearn.utils.validation import check_array from sklearn.utils import all_estimators from sklearn.exceptions import SkipTestWarning +from sklearn.utils.metaestimators import available_if from sklearn.utils.estimator_checks import ( _NotAnArray, @@ -52,6 +53,7 @@ check_regressor_data_not_an_array, check_outlier_corruption, set_random_state, + check_fit_check_is_fitted, ) @@ -1006,3 +1008,28 @@ def test_minimal_class_implementation_checks(): minimal_estimators = [MinimalTransformer(), MinimalRegressor(), MinimalClassifier()] for estimator in minimal_estimators: check_estimator(estimator) + + +def test_check_fit_check_is_fitted(): + class Estimator(BaseEstimator): + def __init__(self, behavior="attribute"): + self.behavior = behavior + + def fit(self, X, y, **kwargs): + if self.behavior == "attribute": + self.is_fitted_ = True + elif self.behavior == "method": + self._is_fitted = True + return self + + @available_if(lambda self: self.behavior in {"method", "always-true"}) + def __sklearn_is_fitted__(self): + if self.behavior == "always-true": + return True + return hasattr(self, "_is_fitted") + + with raises(Exception, match="passes check_is_fitted before being fit"): + check_fit_check_is_fitted("estimator", Estimator(behavior="always-true")) + + check_fit_check_is_fitted("estimator", Estimator(behavior="method")) + check_fit_check_is_fitted("estimator", Estimator(behavior="attribute")) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 9d88a06149e61..96b9cc32b4909 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -52,7 +52,7 @@ _get_feature_names, ) from sklearn.utils.validation import _check_fit_params - +from sklearn.base import BaseEstimator import sklearn from sklearn.exceptions import NotFittedError, PositiveSpectrumWarning @@ -751,6 +751,20 @@ def test_check_symmetric(): assert_array_equal(output, arr_sym) +def test_check_is_fitted_with_is_fitted(): + class Estimator(BaseEstimator): + def fit(self, **kwargs): + self._is_fitted = True + return self + + def __sklearn_is_fitted__(self): + return hasattr(self, "_is_fitted") and self._is_fitted + + with pytest.raises(NotFittedError): + check_is_fitted(Estimator()) + check_is_fitted(Estimator().fit()) + + def test_check_is_fitted(): # Check is TypeError raised when non estimator instance passed with pytest.raises(TypeError): diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index b5b485c5837ab..a5d2c0e4ec794 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1142,8 +1142,9 @@ def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all): fitted attributes (ending with a trailing underscore) and otherwise raises a NotFittedError with the given message. - This utility is meant to be used internally by estimators themselves, - typically in their own predict / transform methods. + If an estimator does not set any attributes with a trailing underscore, it + can define a ``__sklearn_is_fitted__`` method returning a boolean to specify if the + estimator is fitted or not. Parameters ---------- @@ -1194,13 +1195,15 @@ def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all): if attributes is not None: if not isinstance(attributes, (list, tuple)): attributes = [attributes] - attrs = all_or_any([hasattr(estimator, attr) for attr in attributes]) + fitted = all_or_any([hasattr(estimator, attr) for attr in attributes]) + elif hasattr(estimator, "__sklearn_is_fitted__"): + fitted = estimator.__sklearn_is_fitted__() else: - attrs = [ + fitted = [ v for v in vars(estimator) if v.endswith("_") and not v.startswith("__") ] - if not attrs: + if not fitted: raise NotFittedError(msg % {"name": type(estimator).__name__})