|
14 | 14 |
|
15 | 15 | import io
|
16 | 16 | import operator
|
| 17 | +import warnings |
17 | 18 |
|
18 | 19 | from contextlib import nullcontext
|
19 | 20 |
|
@@ -196,18 +197,26 @@ def test_fit_start(inference_spec, simple_model):
|
196 | 197 |
|
197 | 198 | # Minibatch data can't be extracted into the `observed_data` group in the final InferenceData
|
198 | 199 | [observed_value] = [simple_model.rvs_to_values[obs] for obs in simple_model.observed_RVs]
|
199 |
| - if observed_value.name.startswith("minibatch"): |
200 |
| - warn_ctxt = pytest.warns( |
201 |
| - UserWarning, match="Could not extract data from symbolic observation" |
202 |
| - ) |
203 |
| - else: |
204 |
| - warn_ctxt = nullcontext() |
205 | 200 |
|
206 |
| - try: |
207 |
| - with warn_ctxt: |
| 201 | + # We can`t use pytest.warns here because after version 8.0 it`s still check for warning when |
| 202 | + # exception raised and test failed instead being skipped |
| 203 | + warning_raised = False |
| 204 | + expected_warning = observed_value.name.startswith("minibatch") |
| 205 | + with warnings.catch_warnings(record=True) as record: |
| 206 | + warnings.simplefilter("always") |
| 207 | + try: |
208 | 208 | trace = inference.fit(n=0).sample(10000)
|
209 |
| - except NotImplementedInference as e: |
210 |
| - pytest.skip(str(e)) |
| 209 | + except NotImplementedInference as e: |
| 210 | + pytest.skip(str(e)) |
| 211 | + |
| 212 | + if expected_warning: |
| 213 | + assert len(record) > 0 |
| 214 | + for item in record: |
| 215 | + assert issubclass(item.category, UserWarning) |
| 216 | + assert "Could not extract data from symbolic observation" in str(item.message) |
| 217 | + if not expected_warning: |
| 218 | + assert not record |
| 219 | + |
211 | 220 | np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_init, rtol=0.05)
|
212 | 221 | if has_start_sigma:
|
213 | 222 | np.testing.assert_allclose(np.std(trace.posterior["mu"]), mu_sigma_init, rtol=0.05)
|
|
0 commit comments