Replies: 5 comments 4 replies
-
Hi, Not an answer but I have a similar issue. I found (and was hoping) jaxopt.bisection solved my need for a way to find a slightly complicated root that is compatible with jax.vmap and jax.grad and next jax.hessian. I am told:, " Most solvers in JAXopt don't work out of the box with scalars (maybe they should?) So typically, ones needs to use an array of size 1 instead. " and I need to find a work around or a bit of clarity. Indeed jaxopt.bisection produced the correct root individually for multiple tests but later I tried to vmap a a broader code and it turns out my problem lies with data types and jaxopt.bisection. At this stage of my code mostly everything is a jax device, dtype = jnp.float64 array including single valued jax arrays. What was thought to work did amazingly well on a single datapoint basis:
Turns out .item() is not compatible with jax tracing and applying .vmap or .grad. Using .item() succeeds in a single case basis. Alternatively, not using .item() and sending as jnp.array to Bisection causes boolean errors. For single use cases, all of the following are jax device arrays with a single float 64 entry. When applying vmap, I am expanding over a jax array of regime values to return an array of answers. Struggling to find a way past this apparent data type problem in using jaxopt.bisection. Thanks, |
Beta Was this translation helpful? Give feedback.
-
Half my problem is fixed. For the single case this works with removing the d-type on t_initial which now permits alt_pytree to accept 0-dim jax_arrays (scalars).
|
Beta Was this translation helpful? Give feedback.
-
assigning vmap returns error.
state = self.init_state(init_params, *args, **kwargs) File "/opt/anaconda3/lib/python3.8/site-packages/jaxopt/_src/bisection.py", line 105, in init_state File "/opt/anaconda3/lib/python3.8/site-packages/jax/core.py", line 634, in bool File "/opt/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1267, in error ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<BatchTrace(level=1/1)> with |
Beta Was this translation helpful? Give feedback.
-
Restructured and rewrote the entire code block to apply jax.vmap to a code block containing jaxopt.bisection isn a more direct way. I still get the same concretization boolean error from within jaxopt.bisection. The Bisection solver works for a single input set and appears to not work with vmap. |
Beta Was this translation helpful? Give feedback.
-
@yyang97 Indeed, second-order derivatives via implicit diff are not supported yet in JAXopt :( I'm curious, what is your use case? |
Beta Was this translation helpful? Give feedback.
-
When my project is involved with calculating the Hessian of the root finding, I just found that
jaxopt.Bisection()
does not support that...Do I implement it correctly or it just does not support Hessian?
Beta Was this translation helpful? Give feedback.
All reactions