You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In this item, we standardize the XLA loop APIs that are currently supported with JAX (scan, fori-loop, while-loop), reaching parity with the added UX benefit associated with users wishing to yield the respective benefits.
Given the increasing robustness improvements to scan from @tengyifei, we should consider building upon it, and unifying most of the core implementation. A subitem includes replacing grad acc UX with scan or while-loop entirely, so that customers can use the existing UX to achieve the same benefits as JAX (AXLearn's gradient accumulation, https://github.com/apple/axlearn/pull/614/files).
Motivation
Currently, the while-loop and fori-loop experimental APIs (#6941) are limited (issues: #7839, #8566), in robustness, e.g. unable to handle special scalar iteration values (0, 1), preserving the static XLA loop argument requirements (body, condition and init) needed for https://openxla.org/xla/operation_semantics#while, device RNG initialization, among others. A separate subitem (for Neuron) is that we want to reach parity between scan's work and the gradient accumulation experimental API, to ensure that we benefit from the same enhancements for scan_layers and gradient accumulation XLA loops, helping make it less brittle especially when nesting the two features. As it stands, the scan API is increasingly more robust, so it would help to bridge the implementations.
Robustness enhancements
In this item, we standardize the XLA loop APIs that are currently supported with JAX (
scan
,fori-loop
,while-loop
), reaching parity with the added UX benefit associated with users wishing to yield the respective benefits.Scan RFC: #8620
While-loop/fori-loop RFC: #6941
There have been diverging variants that can largely benefit in having a single robust implementation, namely:
Given the increasing robustness improvements to
scan
from @tengyifei, we should consider building upon it, and unifying most of the core implementation. A subitem includes replacing grad acc UX withscan
orwhile-loop
entirely, so that customers can use the existing UX to achieve the same benefits as JAX (AXLearn's gradient accumulation, https://github.com/apple/axlearn/pull/614/files).Motivation
Currently, the while-loop and fori-loop experimental APIs (#6941) are limited (issues: #7839, #8566), in robustness, e.g. unable to handle special scalar iteration values (0, 1), preserving the static XLA loop argument requirements (body, condition and init) needed for https://openxla.org/xla/operation_semantics#while, device RNG initialization, among others. A separate subitem (for Neuron) is that we want to reach parity between scan's work and the gradient accumulation experimental API, to ensure that we benefit from the same enhancements for scan_layers and gradient accumulation XLA loops, helping make it less brittle especially when nesting the two features. As it stands, the
scan
API is increasingly more robust, so it would help to bridge the implementations.cc: @tengyifei @ManfeiBai
The text was updated successfully, but these errors were encountered: