Skip to content

BUG: blackjax and numpyro sampling issue with Truncated Normal #7980

@guillaumeberthon

Description

@guillaumeberthon

Describe the issue:

There is an issue linked to the use the Truncated wrapper on the normal function and how it is handled by JAX (cf example below). So this fails when trying to sample with blackjax or numpyro. On the same model nutpie and pymc samplers work fine.

This issue also appeared in custom logp functions which call on normal_lcdf.
The issue seems to be linked to pt.erfc and pt.erfcx in this function. A manual implementation of the log cdf with just pt.erf poses no issues

@junpenglao and @ericmjl are aware of this issue via private discussions.

Reproduceable code example:

obs = pm.draw(pm.Normal.dist(0, 1.0), draws=100)

with pm.Model():
    mu = pm.Normal("mu", 0, 1)
    sigma = pm.HalfNormal("sigma", 1.0)
    
    dist = pm.Normal.dist(mu,sigma)

    pm.Truncated("obs",
                dist=dist,
                lower=-3.0,
                observed=obs
               )

    pm.sample(
        tune=100,
        draws=500,
        chains=4,
        random_seed=None,
        nuts_sampler="numpyro", # or blackjax
        init="jitter+adapt_diag",
    )

Error message:

Traceback (most recent call last):
  Cell 
marimo:///mnt/c/Users/u601825/PyCharmProjects/pymc-probabilistic-models/notebooks/marimo_notebooks/biologicals_factors_analysis.py#cell=cell-60

, line 17, in <module>
    pm.sample(
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pymc/sampling/mcmc.py", line 802, in sample
    return _sample_external_nuts(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pymc/sampling/mcmc.py", line 391, in _sample_external_nuts
    idata = pymc_jax.sample_jax_nuts(
            ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pymc/sampling/jax.py", line 633, in sample_jax_nuts
    logp_fn = get_jaxified_logp(model, negative_logp=False)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pymc/sampling/jax.py", line 140, in get_jaxified_logp
    logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pymc/sampling/jax.py", line 133, in get_jaxified_graph
    return jax_funcify(fgraph)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/functools.py", line 912, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pytensor/link/jax/dispatch/basic.py", line 56, in jax_funcify_FunctionGraph
    return fgraph_to_python(
           ^^^^^^^^^^^^^^^^^
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pytensor/link/utils.py", line 736, in fgraph_to_python
    compiled_func = op_conversion_fn(
                    ^^^^^^^^^^^^^^^^^
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/functools.py", line 912, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pytensor/link/jax/dispatch/elemwise.py", line 12, in jax_funcify_Elemwise
    base_fn = jax_funcify(scalar_op, node=node, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/functools.py", line 912, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pytensor/link/jax/dispatch/scalar.py", line 275, in jax_funcify_from_tfp
    tfp_jax_op = try_import_tfp_jax_op(op)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/pytensor/link/jax/dispatch/scalar.py", line 43, in try_import_tfp_jax_op
    import tensorflow_probability.substrates.jax.math as tfp_jax_math
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/__init__.py", line 42, in <module>
    from tensorflow_probability.substrates.jax import bijectors
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/__init__.py", line 19, in <module>
    from tensorflow_probability.substrates.jax.bijectors.absolute_value import AbsoluteValue
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/absolute_value.py", line 17, in <module>
    from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/__init__.py", line 19, in <module>
    from tensorflow_probability.python.internal.backend.jax import compat
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/compat.py", line 17, in <module>
    from tensorflow_probability.python.internal.backend.jax import v1
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/v1.py", line 23, in <module>
    from tensorflow_probability.python.internal.backend.jax import linalg_impl
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/linalg_impl.py", line 23, in <module>
    from tensorflow_probability.python.internal.backend.jax import ops
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py", line 681, in <module>
    jax.interpreters.xla.pytype_aval_mappings[onp.ndarray])
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gberthon/mambaforge/envs/PyMC_WSL/lib/python3.12/site-packages/jax/_src/deprecations.py", line 54, in getattr
    raise AttributeError(message)
AttributeError: jax.interpreters.xla.pytype_aval_mappings was deprecated in JAX v0.5.0 and removed in JAX v0.7.0. jax.core.pytype_aval_mappings can be used as a replacement in most cases.

PyMC version information:

PyMC v5.26.1
JAX v0.7.2
python 3.12
pytensor 2.35.1

installation on conda (mamba)
OS: WSL Ubuntu on windows.

Context for the issue:

This is not a critical issue for my work but it would be nice to have it resolved since the Normal dirs

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions