-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Open
Labels
Description
Describe the issue:
There is an issue linked to the use the Truncated wrapper on the normal function and how it is handled by JAX (cf example below). So this fails when trying to sample with blackjax or numpyro. On the same model nutpie and pymc samplers work fine.
This issue also appeared in custom logp functions which call on normal_lcdf.
The issue seems to be linked to pt.erfc and pt.erfcx in this function. A manual implementation of the log cdf with just pt.erf poses no issues
@junpenglao and @ericmjl are aware of this issue via private discussions.
Reproduceable code example:
obs = pm.draw(pm.Normal.dist(0, 1.0), draws=100)
with pm.Model():
mu = pm.Normal("mu", 0, 1)
sigma = pm.HalfNormal("sigma", 1.0)
dist = pm.Normal.dist(mu,sigma)
pm.Truncated("obs",
dist=dist,
lower=-3.0,
observed=obs
)
pm.sample(
tune=100,
draws=500,
chains=4,
random_seed=None,
nuts_sampler="numpyro", # or blackjax
init="jitter+adapt_diag",
)Error message:
Traceback (most recent call last):
Cell
marimo:///mnt/c/Users/u601825/PyCharmProjects/pymc-probabilistic-models/notebooks/marimo_notebooks/biologicals_factors_analysis.py#cell=cell-60
, line 17, in <module>
pm.sample(
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pymc/sampling/mcmc.py", line 802, in sample
return _sample_external_nuts(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pymc/sampling/mcmc.py", line 391, in _sample_external_nuts
idata = pymc_jax.sample_jax_nuts(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pymc/sampling/jax.py", line 633, in sample_jax_nuts
logp_fn = get_jaxified_logp(model, negative_logp=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pymc/sampling/jax.py", line 140, in get_jaxified_logp
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pymc/sampling/jax.py", line 133, in get_jaxified_graph
return jax_funcify(fgraph)
^^^^^^^^^^^^^^^^^^^
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/functools.py", line 912, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pytensor/link/jax/dispatch/basic.py", line 56, in jax_funcify_FunctionGraph
return fgraph_to_python(
^^^^^^^^^^^^^^^^^
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pytensor/link/utils.py", line 736, in fgraph_to_python
compiled_func = op_conversion_fn(
^^^^^^^^^^^^^^^^^
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/functools.py", line 912, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pytensor/link/jax/dispatch/elemwise.py", line 12, in jax_funcify_Elemwise
base_fn = jax_funcify(scalar_op, node=node, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/functools.py", line 912, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pytensor/link/jax/dispatch/scalar.py", line 275, in jax_funcify_from_tfp
tfp_jax_op = try_import_tfp_jax_op(op)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pytensor/link/jax/dispatch/scalar.py", line 43, in try_import_tfp_jax_op
import tensorflow_probability.substrates.jax.math as tfp_jax_math
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/__init__.py", line 42, in <module>
from tensorflow_probability.substrates.jax import bijectors
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/__init__.py", line 19, in <module>
from tensorflow_probability.substrates.jax.bijectors.absolute_value import AbsoluteValue
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/absolute_value.py", line 17, in <module>
from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/__init__.py", line 19, in <module>
from tensorflow_probability.python.internal.backend.jax import compat
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/compat.py", line 17, in <module>
from tensorflow_probability.python.internal.backend.jax import v1
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/v1.py", line 23, in <module>
from tensorflow_probability.python.internal.backend.jax import linalg_impl
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/linalg_impl.py", line 23, in <module>
from tensorflow_probability.python.internal.backend.jax import ops
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py", line 681, in <module>
jax.interpreters.xla.pytype_aval_mappings[onp.ndarray])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/jax/_src/deprecations.py", line 54, in getattr
raise AttributeError(message)
AttributeError: jax.interpreters.xla.pytype_aval_mappings was deprecated in JAX v0.5.0 and removed in JAX v0.7.0. jax.core.pytype_aval_mappings can be used as a replacement in most cases.PyMC version information:
PyMC v5.26.1
JAX v0.7.2
python 3.12
pytensor 2.35.1
installation on conda (mamba)
OS: WSL Ubuntu on windows.
Context for the issue:
This is not a critical issue for my work but it would be nice to have it resolved since the Normal dirs