|
53 | 53 | from ..model_selection import ShuffleSplit
|
54 | 54 | from ..model_selection._validation import _safe_split
|
55 | 55 | from ..metrics.pairwise import rbf_kernel, linear_kernel, pairwise_distances
|
| 56 | +from ..utils.validation import check_is_fitted |
56 | 57 |
|
57 | 58 | from . import shuffle
|
58 | 59 | from ._tags import (
|
@@ -307,6 +308,7 @@ def _yield_all_checks(estimator):
|
307 | 308 | yield check_dict_unchanged
|
308 | 309 | yield check_dont_overwrite_parameters
|
309 | 310 | yield check_fit_idempotent
|
| 311 | + yield check_fit_check_is_fitted |
310 | 312 | if not tags["no_validation"]:
|
311 | 313 | yield check_n_features_in
|
312 | 314 | yield check_fit1d
|
@@ -3501,6 +3503,45 @@ def check_fit_idempotent(name, estimator_orig):
|
3501 | 3503 | )
|
3502 | 3504 |
|
3503 | 3505 |
|
| 3506 | +def check_fit_check_is_fitted(name, estimator_orig): |
| 3507 | + # Make sure that estimator doesn't pass check_is_fitted before calling fit |
| 3508 | + # and that passes check_is_fitted once it's fit. |
| 3509 | + |
| 3510 | + rng = np.random.RandomState(42) |
| 3511 | + |
| 3512 | + estimator = clone(estimator_orig) |
| 3513 | + set_random_state(estimator) |
| 3514 | + if "warm_start" in estimator.get_params(): |
| 3515 | + estimator.set_params(warm_start=False) |
| 3516 | + |
| 3517 | + n_samples = 100 |
| 3518 | + X = rng.normal(loc=100, size=(n_samples, 2)) |
| 3519 | + X = _pairwise_estimator_convert_X(X, estimator) |
| 3520 | + if is_regressor(estimator_orig): |
| 3521 | + y = rng.normal(size=n_samples) |
| 3522 | + else: |
| 3523 | + y = rng.randint(low=0, high=2, size=n_samples) |
| 3524 | + y = _enforce_estimator_tags_y(estimator, y) |
| 3525 | + |
| 3526 | + if not _safe_tags(estimator).get("stateless", False): |
| 3527 | + # stateless estimators (such as FunctionTransformer) are always "fit"! |
| 3528 | + try: |
| 3529 | + check_is_fitted(estimator) |
| 3530 | + raise AssertionError( |
| 3531 | + f"{estimator.__class__.__name__} passes check_is_fitted before being" |
| 3532 | + " fit!" |
| 3533 | + ) |
| 3534 | + except NotFittedError: |
| 3535 | + pass |
| 3536 | + estimator.fit(X, y) |
| 3537 | + try: |
| 3538 | + check_is_fitted(estimator) |
| 3539 | + except NotFittedError as e: |
| 3540 | + raise NotFittedError( |
| 3541 | + "Estimator fails to pass `check_is_fitted` even though it has been fit." |
| 3542 | + ) from e |
| 3543 | + |
| 3544 | + |
3504 | 3545 | def check_n_features_in(name, estimator_orig):
|
3505 | 3546 | # Make sure that n_features_in_ attribute doesn't exist until fit is
|
3506 | 3547 | # called, and that its value is correct.
|
|
0 commit comments