Skip to content

Improve torch_xla.compile documentation #9194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 101 additions & 65 deletions docs/source/learn/eager.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Eager Mode + Compile API
# PyTorch/XLA Compile API and it's interaction with Eager mode.

In this doc we will go over how to use PyTorch/XLA's new experimental
`eager` mode with the `compile` API. The goal is to make PyTorch/XLA
experience more aligned with the native PyTorch and make development
process easier.
## Overview

Currently PyTorch/XLA runs on the LazyTensor tracing mode by default. In
the following code
PyTorch/XLA integrates PyTorch with the XLA compiler to optimize deep learning
workloads across various hardware accelerators. Currently PyTorch/XLA uses the
LazyTensor tracing mode by default where operations are recorded into a
computation graph for deferred compilation and execution (triggered by
`torch_xla.sync()`), as shown in the following code:

``` python
```python
import torch
import torch_xla
import torchvision
Expand All @@ -24,24 +24,46 @@ res = model(input)
torch_xla.sync()
```

The actual model compilation and device execution happens when
`torch_xla.sync` is called. There are multiple drawback of this
approach.
While this approach enables performance optimizations, it introduces significant
usability challenges.

1. Users are often confused about when the framework is tracing and
when the framework is executing.
2. Non-core model code(data preprocessing for example) often generates
some small pending execution that gets leaked into the main
graph(step function) and causes recompilation. The recompilation of
the whole graph is usually very expensive.
3. It is hard to debug when/why recompilation happens.
## Challenges with LazyTensor Mode

To mitigate above issues we want to introduce the new UX with eager and
compile.
- **Ambiguity**: Developers struggle to distinguish between tracing and
execution phases, complicating development and debugging.

## Basic Usage
- **Recompilation Overhead**: Whenever any part of the captured graph changes,
`torch_xla.sync()` will recompile the whole graph. Changes in non-core
operations (e.g., data preprocessing) thus trigger expensive recompilations.

``` python
- **Debugging Difficulty**: Identifying the cause of recompilations is
challenging due to the opaque nature of graph-building processes.

## Eager Mode and `torch_xla.compile`

To address these issues, PyTorch/XLA introduces an experimental eager mode
(enabled via `torch_xla.experimental.eager_mode(True)`) and the
`torch_xla.compile` API. This shift aligns PyTorch/XLA more closely with
native PyTorch, prioritizing developer experience while preserving
performance. Eager mode is likely to become the default in future releases.

- **Eager Mode**: Executes operations immediately, enhancing flexibility and
debugging but at a performance cost.

- **torch_xla.compile**: A decorator or wrapper that explicitly marks code
(e.g., a model or function) for XLA compilation within an eager context,
providing clear boundaries and immediate feedback.

**Note that `torch_xla.compile` is independently useful, even outside of eager
mode, providing benefits such as preventing dataloading operations from leaking
into the training loop graph by capturing them into a separate graph, and
catching accidental graph breaks when `full_graph=True` is specified.**

## How `torch_xla.compile` works

Let's have a look at a basic usage of `torch_xla.compile`:

```python
import torch
import torch_xla
import torchvision
Expand All @@ -60,31 +82,33 @@ input = torch.randn(64, 3, 224, 224).to(device)
res = compiled_model(input)
```

Note that
where the implementation of `torch_xla.compile` can be summarized as follows:

1. Currently user has to manually enable the eager mode by
`torch_xla.experimental.eager_mode(True)`.
2. The region of the code that wants to be compiled should be wrapped
by `torch_xla.compile`.
1. **Disables Eager Mode**: Temporarily switches to tracing to build a
computation graph.

The implementation of the `torch_xla.compile` is actually pretty
straight forward, it disables the eager mode when entering the target
function and start tracing. It will call the `torch_xla.sync()` when
target function returns and reenable the eager mode. You can expect the
same perfomrance by using the `eager` + `compile` API compared to the
existing `mark_step/sync` approach.
2. **Traces Operations**: Records operations for XLA optimization.

### Inference
3. **Compiles and Executes**: Triggers compilation and execution via an
internal `torch_xla.sync()` call.

``` python
torch_xla.experimental.eager_mode(True)
compiled_model = torch.compile(model, backend="openxla")
```
4. **Re-enables Eager Mode**: Resumes eager execution after compilation.

It is recommened to use the `torch.compile` instead of
`torch_xla.compile` for inference to reduce the tracing overhad.
This "eager-to-lazy-to-eager" transition abstracts synchronization complexity,
balancing flexibility and performance.

### Training
## `torch_xla.compile` vs. `torch.compile`

The PyTorch ecosystem offers multiple compilation APIs, and understanding their
distinct roles, especially within PyTorch/XLA, is crucial for optimal
performance and development.

- `torch_xla.compile` is optimized for PyTorch/XLA training workflows. Designed
to work efficiently with the XLA backend for iterative training, it's the
recommended API for compiling training loops due to its observed performance
advantages. The best practice is to enclose the complete training step, e.g.
forward pass, loss calculation, backward pass, and optimizer step, within a
`step_fn` and then compiling this function.

``` python
torch_xla.experimental.eager_mode(True)
Expand All @@ -100,33 +124,45 @@ def step_fn(model, data, target, loss_fn, optimizer):
step_fn = torch_xla.compile(step_fn)
```

In training we asked user to refactor the `step_fn` out because it is
usually better to compile the model's forward, backward and optimizer
together. The long term goal is to also use `torch.compile` for training
but right now we recommend user to use `torch_xla.compile`(for
perfomrance reason).
- `torch.compile` is PyTorch's general-purpose compilation API designed to
accelerate PyTorch models across various backends. For PyTorch/XLA, it uses the
`openxla` backend. We recommend `torch.compile` for PyTorch/XLA inference
because it lowers tracing overhead, leading to more efficient static inference
graphs. To use it with XLA, simply specify `backend="openxla"`.

``` python
torch_xla.experimental.eager_mode(True)
compiled_model = torch.compile(model, backend="openxla")
```

The long-term aim is for `torch.compile` to be the single compilation API for
both training and inference on XLA.

## Benchmark
## Performance Benchmarks

I run a 2 layer decoder only model training(it is pretty much just a
llama2) with fake data on a single chip of v4-8 for 300 steps. Below is
the number I observed.
To quantify the performance impact of torch_xla.compile and eager mode,
benchmarks were conducted under specific conditions. The benchmarks utilized a
2-layer decoder-only model, similar to Llama2, trained with fake data. The
training process spanned 300 steps on a single chip of a v4-8 TPU. The observed
performance, measured in tokens per second, clearly illustrates the impact of
different execution modes:

Mode token/s
--------------------------- ---------
Tracing mode (base line) 147
Eager mode 65
Eager + torch_xla compile 147
| Mode | token/s |
|-----------------------------|---------|
| Tracing mode (base line) | 147 |
| Eager mode | 65 |
| Eager + torch_xla compile | 147 |

: Eager mode benchmarks
Eager mode with `torch_xla.compile` matches the performance of traditional
LazyTensor tracing mode at `147` tokens/s, demonstrating a better user
experience without performance loss.

Eager mode can achieve ~45% performance of the fully compiled model for
the decoder only model. For more information, see
[train_decoder_only_base.py](https://github.com/pytorch/xla/blob/master/examples/train_decoder_only_base.py)
Pure eager mode's performance is model-dependent; it achieves ~45% of the fully
compiled model's performance for decoder-only models. However, for ResNet50,
pure eager mode was significantly slower (about 1% of compiled mode). For more
information, see [train_decoder_only_base.py](https://github.com/pytorch/xla/blob/master/examples/train_decoder_only_base.py)
and [eager example](https://github.com/pytorch/xla/tree/master/examples/eager).
Note that perfomrane of the eager mode is very model dependent. When I
tried to run the resnet50, the eager mode perfomrance is \~1% of the
compiled mode. We don't exepct user to use eager mode to execute the
main training loop. Eager mode is meant to be used to handle non-core
part of the training/inference logic(Data preprocessing, random number
generations etc) or debug.
This varying overhead means pure eager mode is not intended for main training or
inference loops. Its utility lies in non-core tasks like data preprocessing,
random number generation, custom utilities, or debugging, where immediate
execution is prioritized over throughput.