Skip to content

Commit 8745974

Browse files
authored
Fix failing VI test due to pytest change (pymc-devs#7144)
1 parent 3693198 commit 8745974

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

tests/variational/test_inference.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import io
1616
import operator
17+
import warnings
1718

1819
from contextlib import nullcontext
1920

@@ -196,18 +197,26 @@ def test_fit_start(inference_spec, simple_model):
196197

197198
# Minibatch data can't be extracted into the `observed_data` group in the final InferenceData
198199
[observed_value] = [simple_model.rvs_to_values[obs] for obs in simple_model.observed_RVs]
199-
if observed_value.name.startswith("minibatch"):
200-
warn_ctxt = pytest.warns(
201-
UserWarning, match="Could not extract data from symbolic observation"
202-
)
203-
else:
204-
warn_ctxt = nullcontext()
205200

206-
try:
207-
with warn_ctxt:
201+
# We can`t use pytest.warns here because after version 8.0 it`s still check for warning when
202+
# exception raised and test failed instead being skipped
203+
warning_raised = False
204+
expected_warning = observed_value.name.startswith("minibatch")
205+
with warnings.catch_warnings(record=True) as record:
206+
warnings.simplefilter("always")
207+
try:
208208
trace = inference.fit(n=0).sample(10000)
209-
except NotImplementedInference as e:
210-
pytest.skip(str(e))
209+
except NotImplementedInference as e:
210+
pytest.skip(str(e))
211+
212+
if expected_warning:
213+
assert len(record) > 0
214+
for item in record:
215+
assert issubclass(item.category, UserWarning)
216+
assert "Could not extract data from symbolic observation" in str(item.message)
217+
if not expected_warning:
218+
assert not record
219+
211220
np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_init, rtol=0.05)
212221
if has_start_sigma:
213222
np.testing.assert_allclose(np.std(trace.posterior["mu"]), mu_sigma_init, rtol=0.05)

0 commit comments

Comments
 (0)