Skip to content

Commit 3e7c04f

Browse files
authored
ENH check_is_fitted calls __is_fitted__ if available (scikit-learn#20657)
1 parent dc43867 commit 3e7c04f

File tree

8 files changed

+130
-7
lines changed

8 files changed

+130
-7
lines changed

doc/whats_new/v1.0.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,12 @@ Changelog
789789
unavailable on the basis of state, in a more readable way.
790790
:pr:`19948` by `Joel Nothman`_.
791791

792+
_ |Enhancement| :func:`utils.validation.check_is_fitted` now uses
793+
``__sklearn_is_fitted__`` if available, instead of checking for attributes ending with
794+
an underscore. This also makes :class:`Pipeline` and
795+
:class:`preprocessing.FunctionTransformer` pass
796+
``check_is_fitted(estimator)``. :pr:`20657` by `Adrin Jalali`_.
797+
792798
- |Fix| Fixed a bug in :func:`utils.sparsefuncs.mean_variance_axis` where the
793799
precision of the computed variance was very poor when the real variance is
794800
exactly zero. :pr:`19766` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

sklearn/pipeline.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
from .utils.deprecation import deprecated
2727
from .utils._tags import _safe_tags
2828
from .utils.validation import check_memory
29+
from .utils.validation import check_is_fitted
2930
from .utils.fixes import delayed
31+
from .exceptions import NotFittedError
3032

3133
from .utils.metaestimators import _BaseComposition
3234

@@ -657,6 +659,18 @@ def n_features_in_(self):
657659
# delegate to first step (which will call _check_is_fitted)
658660
return self.steps[0][1].n_features_in_
659661

662+
def __sklearn_is_fitted__(self):
663+
"""Indicate whether pipeline has been fit."""
664+
try:
665+
# check if the last step of the pipeline is fitted
666+
# we only check the last step since if the last step is fit, it
667+
# means the previous steps should also be fit. This is faster than
668+
# checking if every step of the pipeline is fit.
669+
check_is_fitted(self.steps[-1][1])
670+
return True
671+
except NotFittedError:
672+
return False
673+
660674
def _sk_visual_block_(self):
661675
_, estimators = zip(*self.steps)
662676

sklearn/preprocessing/_function_transformer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,5 +176,9 @@ def _transform(self, X, func=None, kw_args=None):
176176

177177
return func(X, **(kw_args if kw_args else {}))
178178

179+
def __sklearn_is_fitted__(self):
180+
"""Return True since FunctionTransfomer is stateless."""
181+
return True
182+
179183
def _more_tags(self):
180184
return {"no_validation": not self.validate, "stateless": True}

sklearn/tests/test_pipeline.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
MinimalRegressor,
2222
MinimalTransformer,
2323
)
24-
24+
from sklearn.exceptions import NotFittedError
25+
from sklearn.utils.validation import check_is_fitted
2526
from sklearn.base import clone, is_classifier, BaseEstimator, TransformerMixin
2627
from sklearn.pipeline import Pipeline, FeatureUnion, make_pipeline, make_union
2728
from sklearn.svm import SVC
@@ -1361,3 +1362,16 @@ def test_search_cv_using_minimal_compatible_estimator(Predictor):
13611362
else:
13621363
assert_allclose(y_pred, y.mean())
13631364
assert model.score(X, y) == pytest.approx(r2_score(y, y_pred))
1365+
1366+
1367+
def test_pipeline_check_if_fitted():
1368+
class Estimator(BaseEstimator):
1369+
def fit(self, X, y):
1370+
self.fitted_ = True
1371+
return self
1372+
1373+
pipeline = Pipeline([("clf", Estimator())])
1374+
with pytest.raises(NotFittedError):
1375+
check_is_fitted(pipeline)
1376+
pipeline.fit(iris.data, iris.target)
1377+
check_is_fitted(pipeline)

sklearn/utils/estimator_checks.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from ..model_selection import ShuffleSplit
5454
from ..model_selection._validation import _safe_split
5555
from ..metrics.pairwise import rbf_kernel, linear_kernel, pairwise_distances
56+
from ..utils.validation import check_is_fitted
5657

5758
from . import shuffle
5859
from ._tags import (
@@ -307,6 +308,7 @@ def _yield_all_checks(estimator):
307308
yield check_dict_unchanged
308309
yield check_dont_overwrite_parameters
309310
yield check_fit_idempotent
311+
yield check_fit_check_is_fitted
310312
if not tags["no_validation"]:
311313
yield check_n_features_in
312314
yield check_fit1d
@@ -3501,6 +3503,45 @@ def check_fit_idempotent(name, estimator_orig):
35013503
)
35023504

35033505

3506+
def check_fit_check_is_fitted(name, estimator_orig):
3507+
# Make sure that estimator doesn't pass check_is_fitted before calling fit
3508+
# and that passes check_is_fitted once it's fit.
3509+
3510+
rng = np.random.RandomState(42)
3511+
3512+
estimator = clone(estimator_orig)
3513+
set_random_state(estimator)
3514+
if "warm_start" in estimator.get_params():
3515+
estimator.set_params(warm_start=False)
3516+
3517+
n_samples = 100
3518+
X = rng.normal(loc=100, size=(n_samples, 2))
3519+
X = _pairwise_estimator_convert_X(X, estimator)
3520+
if is_regressor(estimator_orig):
3521+
y = rng.normal(size=n_samples)
3522+
else:
3523+
y = rng.randint(low=0, high=2, size=n_samples)
3524+
y = _enforce_estimator_tags_y(estimator, y)
3525+
3526+
if not _safe_tags(estimator).get("stateless", False):
3527+
# stateless estimators (such as FunctionTransformer) are always "fit"!
3528+
try:
3529+
check_is_fitted(estimator)
3530+
raise AssertionError(
3531+
f"{estimator.__class__.__name__} passes check_is_fitted before being"
3532+
" fit!"
3533+
)
3534+
except NotFittedError:
3535+
pass
3536+
estimator.fit(X, y)
3537+
try:
3538+
check_is_fitted(estimator)
3539+
except NotFittedError as e:
3540+
raise NotFittedError(
3541+
"Estimator fails to pass `check_is_fitted` even though it has been fit."
3542+
) from e
3543+
3544+
35043545
def check_n_features_in(name, estimator_orig):
35053546
# Make sure that n_features_in_ attribute doesn't exist until fit is
35063547
# called, and that its value is correct.

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sklearn.utils.validation import check_array
3535
from sklearn.utils import all_estimators
3636
from sklearn.exceptions import SkipTestWarning
37+
from sklearn.utils.metaestimators import available_if
3738

3839
from sklearn.utils.estimator_checks import (
3940
_NotAnArray,
@@ -52,6 +53,7 @@
5253
check_regressor_data_not_an_array,
5354
check_outlier_corruption,
5455
set_random_state,
56+
check_fit_check_is_fitted,
5557
)
5658

5759

@@ -1006,3 +1008,28 @@ def test_minimal_class_implementation_checks():
10061008
minimal_estimators = [MinimalTransformer(), MinimalRegressor(), MinimalClassifier()]
10071009
for estimator in minimal_estimators:
10081010
check_estimator(estimator)
1011+
1012+
1013+
def test_check_fit_check_is_fitted():
1014+
class Estimator(BaseEstimator):
1015+
def __init__(self, behavior="attribute"):
1016+
self.behavior = behavior
1017+
1018+
def fit(self, X, y, **kwargs):
1019+
if self.behavior == "attribute":
1020+
self.is_fitted_ = True
1021+
elif self.behavior == "method":
1022+
self._is_fitted = True
1023+
return self
1024+
1025+
@available_if(lambda self: self.behavior in {"method", "always-true"})
1026+
def __sklearn_is_fitted__(self):
1027+
if self.behavior == "always-true":
1028+
return True
1029+
return hasattr(self, "_is_fitted")
1030+
1031+
with raises(Exception, match="passes check_is_fitted before being fit"):
1032+
check_fit_check_is_fitted("estimator", Estimator(behavior="always-true"))
1033+
1034+
check_fit_check_is_fitted("estimator", Estimator(behavior="method"))
1035+
check_fit_check_is_fitted("estimator", Estimator(behavior="attribute"))

sklearn/utils/tests/test_validation.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
_get_feature_names,
5353
)
5454
from sklearn.utils.validation import _check_fit_params
55-
55+
from sklearn.base import BaseEstimator
5656
import sklearn
5757

5858
from sklearn.exceptions import NotFittedError, PositiveSpectrumWarning
@@ -751,6 +751,20 @@ def test_check_symmetric():
751751
assert_array_equal(output, arr_sym)
752752

753753

754+
def test_check_is_fitted_with_is_fitted():
755+
class Estimator(BaseEstimator):
756+
def fit(self, **kwargs):
757+
self._is_fitted = True
758+
return self
759+
760+
def __sklearn_is_fitted__(self):
761+
return hasattr(self, "_is_fitted") and self._is_fitted
762+
763+
with pytest.raises(NotFittedError):
764+
check_is_fitted(Estimator())
765+
check_is_fitted(Estimator().fit())
766+
767+
754768
def test_check_is_fitted():
755769
# Check is TypeError raised when non estimator instance passed
756770
with pytest.raises(TypeError):

sklearn/utils/validation.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,8 +1142,9 @@ def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all):
11421142
fitted attributes (ending with a trailing underscore) and otherwise
11431143
raises a NotFittedError with the given message.
11441144
1145-
This utility is meant to be used internally by estimators themselves,
1146-
typically in their own predict / transform methods.
1145+
If an estimator does not set any attributes with a trailing underscore, it
1146+
can define a ``__sklearn_is_fitted__`` method returning a boolean to specify if the
1147+
estimator is fitted or not.
11471148
11481149
Parameters
11491150
----------
@@ -1194,13 +1195,15 @@ def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all):
11941195
if attributes is not None:
11951196
if not isinstance(attributes, (list, tuple)):
11961197
attributes = [attributes]
1197-
attrs = all_or_any([hasattr(estimator, attr) for attr in attributes])
1198+
fitted = all_or_any([hasattr(estimator, attr) for attr in attributes])
1199+
elif hasattr(estimator, "__sklearn_is_fitted__"):
1200+
fitted = estimator.__sklearn_is_fitted__()
11981201
else:
1199-
attrs = [
1202+
fitted = [
12001203
v for v in vars(estimator) if v.endswith("_") and not v.startswith("__")
12011204
]
12021205

1203-
if not attrs:
1206+
if not fitted:
12041207
raise NotFittedError(msg % {"name": type(estimator).__name__})
12051208

12061209

0 commit comments

Comments
 (0)