Skip to content

Standardize XLA loop APIs #8918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
rpsilva-aws opened this issue Apr 1, 2025 · 0 comments
Open

Standardize XLA loop APIs #8918

rpsilva-aws opened this issue Apr 1, 2025 · 0 comments
Labels
enhancement New feature or request

Comments

@rpsilva-aws
Copy link
Collaborator

rpsilva-aws commented Apr 1, 2025

Robustness enhancements

In this item, we standardize the XLA loop APIs that are currently supported with JAX (scan, fori-loop, while-loop), reaching parity with the added UX benefit associated with users wishing to yield the respective benefits.

Scan RFC: #8620
While-loop/fori-loop RFC: #6941

There have been diverging variants that can largely benefit in having a single robust implementation, namely:

Given the increasing robustness improvements to scan from @tengyifei, we should consider building upon it, and unifying most of the core implementation. A subitem includes replacing grad acc UX with scan or while-loop entirely, so that customers can use the existing UX to achieve the same benefits as JAX (AXLearn's gradient accumulation, https://github.com/apple/axlearn/pull/614/files).

Motivation

Currently, the while-loop and fori-loop experimental APIs (#6941) are limited (issues: #7839, #8566), in robustness, e.g. unable to handle special scalar iteration values (0, 1), preserving the static XLA loop argument requirements (body, condition and init) needed for https://openxla.org/xla/operation_semantics#while, device RNG initialization, among others. A separate subitem (for Neuron) is that we want to reach parity between scan's work and the gradient accumulation experimental API, to ensure that we benefit from the same enhancements for scan_layers and gradient accumulation XLA loops, helping make it less brittle especially when nesting the two features. As it stands, the scan API is increasingly more robust, so it would help to bridge the implementations.

cc: @tengyifei @ManfeiBai

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants