You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi! I know in jax, we can "wrap" jax device arrays to make them hashable. We do this to allow for them to be treated as static_args.
My question is: how can I retain this "hack" when calling a jaxopt optimizer? Is this possible?
For context: imagine we are optimizing a large binomial likelihood with a lot of data for trials and successes.
We would want to pass in those as constants, but treat them as arrays (for linear algebra operations).
Is there a best practice to do this?
Beta Was this translation helpful? Give feedback.
All reactions