Skip to content

Commit

Permalink
ENH check_is_fitted calls __is_fitted__ if available (scikit-learn#20657
Browse files Browse the repository at this point in the history
)
  • Loading branch information
adrinjalali authored Aug 20, 2021
1 parent dc43867 commit 3e7c04f
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 7 deletions.
6 changes: 6 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,12 @@ Changelog
unavailable on the basis of state, in a more readable way.
:pr:`19948` by `Joel Nothman`_.

_ |Enhancement| :func:`utils.validation.check_is_fitted` now uses
``__sklearn_is_fitted__`` if available, instead of checking for attributes ending with
an underscore. This also makes :class:`Pipeline` and
:class:`preprocessing.FunctionTransformer` pass
``check_is_fitted(estimator)``. :pr:`20657` by `Adrin Jalali`_.

- |Fix| Fixed a bug in :func:`utils.sparsefuncs.mean_variance_axis` where the
precision of the computed variance was very poor when the real variance is
exactly zero. :pr:`19766` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
Expand Down
14 changes: 14 additions & 0 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from .utils.deprecation import deprecated
from .utils._tags import _safe_tags
from .utils.validation import check_memory
from .utils.validation import check_is_fitted
from .utils.fixes import delayed
from .exceptions import NotFittedError

from .utils.metaestimators import _BaseComposition

Expand Down Expand Up @@ -657,6 +659,18 @@ def n_features_in_(self):
# delegate to first step (which will call _check_is_fitted)
return self.steps[0][1].n_features_in_

def __sklearn_is_fitted__(self):
"""Indicate whether pipeline has been fit."""
try:
# check if the last step of the pipeline is fitted
# we only check the last step since if the last step is fit, it
# means the previous steps should also be fit. This is faster than
# checking if every step of the pipeline is fit.
check_is_fitted(self.steps[-1][1])
return True
except NotFittedError:
return False

def _sk_visual_block_(self):
_, estimators = zip(*self.steps)

Expand Down
4 changes: 4 additions & 0 deletions sklearn/preprocessing/_function_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,5 +176,9 @@ def _transform(self, X, func=None, kw_args=None):

return func(X, **(kw_args if kw_args else {}))

def __sklearn_is_fitted__(self):
"""Return True since FunctionTransfomer is stateless."""
return True

def _more_tags(self):
return {"no_validation": not self.validate, "stateless": True}
16 changes: 15 additions & 1 deletion sklearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
MinimalRegressor,
MinimalTransformer,
)

from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted
from sklearn.base import clone, is_classifier, BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline, FeatureUnion, make_pipeline, make_union
from sklearn.svm import SVC
Expand Down Expand Up @@ -1361,3 +1362,16 @@ def test_search_cv_using_minimal_compatible_estimator(Predictor):
else:
assert_allclose(y_pred, y.mean())
assert model.score(X, y) == pytest.approx(r2_score(y, y_pred))


def test_pipeline_check_if_fitted():
class Estimator(BaseEstimator):
def fit(self, X, y):
self.fitted_ = True
return self

pipeline = Pipeline([("clf", Estimator())])
with pytest.raises(NotFittedError):
check_is_fitted(pipeline)
pipeline.fit(iris.data, iris.target)
check_is_fitted(pipeline)
41 changes: 41 additions & 0 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from ..model_selection import ShuffleSplit
from ..model_selection._validation import _safe_split
from ..metrics.pairwise import rbf_kernel, linear_kernel, pairwise_distances
from ..utils.validation import check_is_fitted

from . import shuffle
from ._tags import (
Expand Down Expand Up @@ -307,6 +308,7 @@ def _yield_all_checks(estimator):
yield check_dict_unchanged
yield check_dont_overwrite_parameters
yield check_fit_idempotent
yield check_fit_check_is_fitted
if not tags["no_validation"]:
yield check_n_features_in
yield check_fit1d
Expand Down Expand Up @@ -3501,6 +3503,45 @@ def check_fit_idempotent(name, estimator_orig):
)


def check_fit_check_is_fitted(name, estimator_orig):
# Make sure that estimator doesn't pass check_is_fitted before calling fit
# and that passes check_is_fitted once it's fit.

rng = np.random.RandomState(42)

estimator = clone(estimator_orig)
set_random_state(estimator)
if "warm_start" in estimator.get_params():
estimator.set_params(warm_start=False)

n_samples = 100
X = rng.normal(loc=100, size=(n_samples, 2))
X = _pairwise_estimator_convert_X(X, estimator)
if is_regressor(estimator_orig):
y = rng.normal(size=n_samples)
else:
y = rng.randint(low=0, high=2, size=n_samples)
y = _enforce_estimator_tags_y(estimator, y)

if not _safe_tags(estimator).get("stateless", False):
# stateless estimators (such as FunctionTransformer) are always "fit"!
try:
check_is_fitted(estimator)
raise AssertionError(
f"{estimator.__class__.__name__} passes check_is_fitted before being"
" fit!"
)
except NotFittedError:
pass
estimator.fit(X, y)
try:
check_is_fitted(estimator)
except NotFittedError as e:
raise NotFittedError(
"Estimator fails to pass `check_is_fitted` even though it has been fit."
) from e


def check_n_features_in(name, estimator_orig):
# Make sure that n_features_in_ attribute doesn't exist until fit is
# called, and that its value is correct.
Expand Down
27 changes: 27 additions & 0 deletions sklearn/utils/tests/test_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sklearn.utils.validation import check_array
from sklearn.utils import all_estimators
from sklearn.exceptions import SkipTestWarning
from sklearn.utils.metaestimators import available_if

from sklearn.utils.estimator_checks import (
_NotAnArray,
Expand All @@ -52,6 +53,7 @@
check_regressor_data_not_an_array,
check_outlier_corruption,
set_random_state,
check_fit_check_is_fitted,
)


Expand Down Expand Up @@ -1006,3 +1008,28 @@ def test_minimal_class_implementation_checks():
minimal_estimators = [MinimalTransformer(), MinimalRegressor(), MinimalClassifier()]
for estimator in minimal_estimators:
check_estimator(estimator)


def test_check_fit_check_is_fitted():
class Estimator(BaseEstimator):
def __init__(self, behavior="attribute"):
self.behavior = behavior

def fit(self, X, y, **kwargs):
if self.behavior == "attribute":
self.is_fitted_ = True
elif self.behavior == "method":
self._is_fitted = True
return self

@available_if(lambda self: self.behavior in {"method", "always-true"})
def __sklearn_is_fitted__(self):
if self.behavior == "always-true":
return True
return hasattr(self, "_is_fitted")

with raises(Exception, match="passes check_is_fitted before being fit"):
check_fit_check_is_fitted("estimator", Estimator(behavior="always-true"))

check_fit_check_is_fitted("estimator", Estimator(behavior="method"))
check_fit_check_is_fitted("estimator", Estimator(behavior="attribute"))
16 changes: 15 additions & 1 deletion sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
_get_feature_names,
)
from sklearn.utils.validation import _check_fit_params

from sklearn.base import BaseEstimator
import sklearn

from sklearn.exceptions import NotFittedError, PositiveSpectrumWarning
Expand Down Expand Up @@ -751,6 +751,20 @@ def test_check_symmetric():
assert_array_equal(output, arr_sym)


def test_check_is_fitted_with_is_fitted():
class Estimator(BaseEstimator):
def fit(self, **kwargs):
self._is_fitted = True
return self

def __sklearn_is_fitted__(self):
return hasattr(self, "_is_fitted") and self._is_fitted

with pytest.raises(NotFittedError):
check_is_fitted(Estimator())
check_is_fitted(Estimator().fit())


def test_check_is_fitted():
# Check is TypeError raised when non estimator instance passed
with pytest.raises(TypeError):
Expand Down
13 changes: 8 additions & 5 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,8 +1142,9 @@ def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all):
fitted attributes (ending with a trailing underscore) and otherwise
raises a NotFittedError with the given message.
This utility is meant to be used internally by estimators themselves,
typically in their own predict / transform methods.
If an estimator does not set any attributes with a trailing underscore, it
can define a ``__sklearn_is_fitted__`` method returning a boolean to specify if the
estimator is fitted or not.
Parameters
----------
Expand Down Expand Up @@ -1194,13 +1195,15 @@ def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all):
if attributes is not None:
if not isinstance(attributes, (list, tuple)):
attributes = [attributes]
attrs = all_or_any([hasattr(estimator, attr) for attr in attributes])
fitted = all_or_any([hasattr(estimator, attr) for attr in attributes])
elif hasattr(estimator, "__sklearn_is_fitted__"):
fitted = estimator.__sklearn_is_fitted__()
else:
attrs = [
fitted = [
v for v in vars(estimator) if v.endswith("_") and not v.startswith("__")
]

if not attrs:
if not fitted:
raise NotFittedError(msg % {"name": type(estimator).__name__})


Expand Down

0 comments on commit 3e7c04f

Please sign in to comment.