-
Notifications
You must be signed in to change notification settings - Fork 546
[JAX] Quickstart documentation #2310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Quickstart documentation #2310
Conversation
Signed-off-by: tdophung <[email protected]>
2796e91 to
733d61b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds a comprehensive JAX-specific quickstart tutorial notebook that demonstrates four progressive implementations of a transformer layer: pure JAX/Flax, basic Transformer Engine (TE) modules, fused TE modules, and the full TE TransformerLayer with optional FP8 support. The notebook includes performance benchmarking utilities and parameter-sharing helpers to enable fair comparisons between implementations. A minor build infrastructure change updates the PyTorch dependency installation in the GitHub Actions workflow to explicitly use CUDA 13.0 wheels, ensuring compatibility with the JAX container environment. This documentation effort mirrors the structure of the existing PyTorch quickstart guide, providing JAX users with a dedicated migration and optimization path.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
.github/workflows/build.yml |
5/5 | Split torch installation into separate command with explicit CUDA 13.0 index URL |
docs/examples/quickstart_jax.ipynb |
3/5 | Added comprehensive JAX quickstart notebook with 4 progressive transformer implementations |
docs/examples/quickstart_jax_utils.py |
2/5 | Added utility functions for benchmarking and parameter sharing between JAX and TE models |
Confidence score: 2/5
- This PR contains critical runtime errors that will prevent the tutorial from executing successfully
- The main issue is function signature mismatch in
quickstart_jax_utils.py:speedometer()callstrain_step_fnwith 5 arguments butcreate_train_step_fn()returns a function expecting only 4 parameters, causing a guaranteed TypeError - Additional concerns include variable scope issues in the notebook (line 795 references
te_transformer_params_templatewhich may not be in scope after FP8 initialization) and unused parameters indicating incomplete implementation - Pay close attention to
docs/examples/quickstart_jax_utils.pylines 37-46 (function call mismatch) and the notebook cells around lines 787-801 (variable scoping)
Sequence Diagram
sequenceDiagram
participant User
participant Notebook as quickstart_jax.ipynb
participant Utils as quickstart_jax_utils.py
participant JAX as JAX/Flax
participant TE as Transformer Engine
User->>Notebook: Execute notebook cells
Notebook->>JAX: Import jax, jax.numpy, flax.linen
Notebook->>TE: Import transformer_engine.jax
Notebook->>Utils: Import quickstart_jax_utils
Note over Notebook,JAX: 1. Build BasicTransformerLayer (pure JAX/Flax)
Notebook->>JAX: Initialize BasicTransformerLayer
JAX-->>Notebook: Return initialized model
Notebook->>JAX: init(key, x) to create params
JAX-->>Notebook: Return params
Notebook->>Utils: speedometer(model_apply_fn, variables, input, output_grad)
Utils->>Utils: create_train_step_fn() - JIT compile fwd/bwd
Utils->>JAX: Run warmup iterations
Utils->>JAX: Run timing iterations (forward + backward)
JAX-->>Utils: Return loss and gradients
Utils-->>Notebook: Print mean time
Note over Notebook,TE: 2. Build BasicTETransformerLayer (TE modules)
Notebook->>TE: Initialize BasicTETransformerLayer
TE-->>Notebook: Return TE model
Notebook->>TE: init(key, x) to create TE params
TE-->>Notebook: Return TE params template
Notebook->>Utils: share_parameters_with_basic_te_model()
Utils-->>Notebook: Return shared params
Notebook->>Utils: speedometer() with TE model
Utils->>TE: Run timing iterations with TE layers
TE-->>Utils: Return loss and gradients
Utils-->>Notebook: Print improved mean time
Note over Notebook,TE: 3. Build FusedTETransformerLayer (fused ops)
Notebook->>TE: Initialize FusedTETransformerLayer
TE-->>Notebook: Return fused TE model
Notebook->>TE: init() with LayerNormDenseGeneral, LayerNormMLP
TE-->>Notebook: Return fused params
Notebook->>Utils: share_fused_parameters_with_basic_te_model()
Utils-->>Notebook: Return shared fused params
Notebook->>Utils: speedometer() with fused model
Utils->>TE: Run timing with fused kernels
TE-->>Utils: Return loss and gradients
Utils-->>Notebook: Print further improved time
Note over Notebook,TE: 4. Use TE TransformerLayer with FP8
Notebook->>TE: Create fp8_recipe (DelayedScaling)
Notebook->>TE: fp8_autocast(enabled=True, fp8_recipe)
Notebook->>TE: Initialize TransformerLayer within fp8_autocast
TE-->>Notebook: Return params with fp8_metas collection
Notebook->>Utils: share_parameters_with_transformerlayer_te_model()
Utils-->>Notebook: Return shared params with fp8_metas
Notebook->>Utils: speedometer() with fp8_autocast_kwargs
Utils->>TE: Run timing with FP8 precision
TE-->>Utils: Return loss and gradients (FP8)
Utils-->>Notebook: Print fastest mean time
Notebook-->>User: Display performance comparison
3 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review. The developer has cleaned up the JAX quickstart notebook by removing execution artifacts (output cells, execution counts), fixing typos and type annotations, and correcting markdown link syntax. The quickstart_jax_utils.py file was updated to remove an unused key parameter from train_step_fn calls and add a 2025 copyright header. However, a critical logic error remains unfixed in the notebook at line 772: the code references te_transformer_params_template which is undefined in the FP8 context (it was only defined outside the FP8 block in line 671). This variable should be te_transformer_params instead.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| docs/examples/quickstart_jax.ipynb | 2/5 | Cleaned notebook execution artifacts and fixed typos/types, but introduced critical undefined variable bug in FP8 initialization block (line 772) |
| docs/examples/quickstart_jax_utils.py | 3/5 | Removed unused key parameter from train step functions and added2025copyright; shallow copy operations may allow unintended parameter mutation |
Confidence score: 2/5
- This PR cannot be merged safely due to a critical logic error that will cause runtime failure.
- Score reflects one critical bug (undefined variable
te_transformer_params_templateat line772 in the notebook) that will break FP8 initialization when users try to execute the notebook. The developer needs to replacete_transformer_params_templatewithte_transformer_paramsin the FP8context. Additionally, the shallowcopy.deepcopyoperations in the utils file may cause subtle parameter-sharing bugs if nested dictionaries are mutated. - The notebook file (quickstart_jax.ipynb) requires immediate attention - specifically line 772 in the FP8 initialization block where an undefined variable is referenced.
2 files reviewed, 2 comments
d9859d1 to
b742244
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the latest changes to two documentation files: docs/examples/quickstart_jax.ipynb and docs/examples/quickstart_jax_utils.py. Recent commits have addressed previous formatting and documentation issues (markdown links, type annotations, typos), but critical runtime bugs remain unresolved. The quickstart notebook demonstrates progressive optimization of a JAX/Flax transformer using Transformer Engine, but the utility file contains a function signature mismatch that will cause immediate failures when users attempt to run the notebook.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| docs/examples/quickstart_jax.ipynb | 2/5 | Cleared execution artifacts and fixed minor issues, but contains a critical NameError on line 772 where undefined variable te_transformer_params_template is referenced in FP8 initialization block |
| docs/examples/quickstart_jax_utils.py | 1/5 | Added benchmarking utilities for JAX/Flax models with critical signature mismatch—train_step_fn called with 4 arguments (lines 41, 47) but function signatures expect 4 parameters (dropout_key required); will fail immediately at runtime |
Confidence score: 1/5
- This PR will cause immediate runtime failures when users attempt to follow the quickstart guide due to function signature mismatches and undefined variable references
- Score reflects two critical blocking issues: (1) the
train_step_fnsignature mismatch inquickstart_jax_utils.pywhere calls on lines 41 and 47 provide 4 arguments but the underlying functions requiredropout_keyas a parameter, and (2) the undefinedte_transformer_params_templatereference on line 772 of the notebook that will raise NameError - Both utility file functions (
create_train_step_fnandcreate_train_step_fn_vjp) and the FP8 initialization block in the notebook require immediate correction before this documentation can be published
2 files reviewed, no comments
docs/examples/quickstart_jax.ipynb
Outdated
| " )\n", | ||
| " x = attention(q, k, v, attention_mask, deterministic=deterministic)\n", | ||
| " \n", | ||
| " # Attention built-in. Comment out if not used\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a bit more detail to this comment, something like: If you would like to use Flax's built-in attention implementation, you can use this instead
Want to clarify to the users' that this is an alternative option and not some debug code we left by mistake
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this comment really got resolved. Let me make a suggestion how I think it could look like instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR introduces a comprehensive JAX/Flax quickstart tutorial that mirrors the PyTorch quickstart guide. The tutorial demonstrates four progressive implementations of a Transformer layer, showing performance improvements from pure JAX/Flax (~28ms) to fully optimized TE with FP8 (~8ms).
Key Changes:
- New
quickstart_jax.ipynbnotebook with 4 transformer implementations (pure JAX, basic TE, fused TE, full TE with FP8) - Supporting utility file
quickstart_jax_utils.pywith benchmarking functions and custom attention/MLP modules - Workflow update to use specific PyTorch CUDA index URL for build consistency
Issues Found:
- Several typos in documentation (already flagged in previous comments)
- Malformed markdown link in notebook cell
- Unused
datasetsimport - Minor documentation link format issues
Positive Aspects:
- Tutorial provides clear progressive optimization path
- Benchmarking utilities are well-structured and correct
- Good pedagogical approach with 4 iterations showing incremental improvements
- Code is functionally correct despite some previous incorrect review comments
Confidence Score: 4/5
- This PR is safe to merge after addressing minor documentation issues
- Score reflects that this is a documentation-only PR with correct functional code. The issues found are primarily cosmetic (typos, formatting) rather than logical errors. Several previous comments incorrectly flagged non-issues (e.g., VJP signature, missing imports that don't exist). The actual issues are: unused import, malformed markdown link, and typos - all easily fixable and non-blocking for a draft PR.
- docs/examples/quickstart_jax.ipynb needs attention for typos and formatting issues. The utility file and workflow changes are solid.
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| .github/workflows/build.yml | 5/5 | Workflow updated to install torch from specific CUDA index URL - improves build consistency |
| docs/examples/quickstart_jax.ipynb | 3/5 | New JAX quickstart tutorial with 4 transformer implementations; contains typos and malformed markdown link already flagged |
| docs/examples/quickstart_jax_utils.py | 4/5 | Utility functions for speedometer benchmarking and custom attention/MLP modules; code is correct despite some previous incorrect comments |
Sequence Diagram
sequenceDiagram
participant User as User/Tutorial
participant Basic as BasicTransformerLayer
participant BasicTE as BasicTETransformerLayer
participant Fused as FusedTETransformerLayer
participant Full as TransformerLayer (TE)
participant Utils as quickstart_jax_utils
User->>Basic: Initialize with pure JAX/Flax
Basic->>Basic: nn.LayerNorm + nn.Dense (QKV)
Basic->>Utils: DotProductAttention
Utils-->>Basic: attention output
Basic->>Basic: nn.Dense (projection) + nn.Dropout
Basic->>Utils: BasicMLP
Utils-->>Basic: MLP output
Basic-->>User: ~28ms training step
User->>BasicTE: Replace with TE modules
BasicTE->>BasicTE: te_flax.LayerNorm + te_flax.DenseGeneral
BasicTE->>Utils: DotProductAttention
Utils-->>BasicTE: attention output
BasicTE->>BasicTE: te_flax.DenseGeneral + nn.Dropout
BasicTE->>BasicTE: BasicTEMLP (te_flax.DenseGeneral)
BasicTE-->>User: ~17ms training step
User->>Fused: Use fused TE operations
Fused->>Fused: te_flax.LayerNormDenseGeneral (fused)
Fused->>Utils: DotProductAttention
Utils-->>Fused: attention output
Fused->>Fused: te_flax.DenseGeneral + nn.Dropout
Fused->>Fused: te_flax.LayerNormMLP (fused)
Fused-->>User: ~18ms training step
User->>Full: Use full TE TransformerLayer
Full->>Full: Complete optimized implementation
Full-->>User: ~12ms training step
User->>Full: Enable FP8 with fp8_autocast
Full->>Full: FP8 precision compute
Full-->>User: ~8ms training step
3 files reviewed, 1 comment
|
There seems to be some issue with the commits mixed up. Could you rebase on top of the current main @tdophung ? |
docs/examples/quickstart_jax.ipynb
Outdated
| "- `DotProductAttention`: `DotProductAttention` from [quickstart_jax_utils.py](quickstart_jax_utils.py)\n", | ||
| "- `Projection`: `nn.Dense` (JAX/Flax)\n", | ||
| "- `Dropout`: `nn.Dropout` (JAX/Flax)\n", | ||
| "- `MLP`: `BasicMLP` from [quickstart_jax_utils.py](quickstart_jax_utils.py)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ksivaman putting on your radar that this will require more substitutions when we create the docs package.
docs/examples/quickstart_jax.ipynb
Outdated
| "- `nn.Dropout`: JAX/Flax Dropout\n", | ||
| "- `BasicMLP`: Custom MLP from [quickstart_jax_utils.py](quickstart_jax_utils.py)\n", | ||
| "\n", | ||
| "<small> (**) _The code below also shows how to use the built-in attention sub-layer from either pure Flax or TE Flax in commented code if you wish to use those instead of the custom attention in [quickstart_jax_utils.py]. The implementation is there for your reference of how attention is roughly implemented in our source_</small>\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like this sentence - it sounds too informal. A quick pass with copilot gave this instead:
The code snippet below also includes commented examples demonstrating how to utilize the built-in attention sub-layer from either pure Flax or TE Flax. These alternatives are provided should you prefer them over the custom attention implementation found in quickstart_jax_utils.py. This reference is intended to offer insight into the general structure and approach used for implementing attention mechanisms in our source code.
I don't really like that one either (although it sounds better), but maybe we could use it to get something inbetween.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of having both the flax.linen.MultiHeadDotProductAttention and the self-built attention? If the self-built attention is implemented correctly, they should have exactly the same ops as the linen one, right?
I think it will be less confusing to simply start with the linen attention module directly, then swap it with the TE one. In that way, we can expose those two modules in this tutorial directly rather than hiding their details in the util file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it is helpful to have the self-built attention since it shows how attention works and also stay close to the Pytorch tutorial
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't have commented code anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Adds comprehensive JAX quickstart documentation mirroring the PyTorch tutorial structure, demonstrating progressive performance optimization through four iterations of a Transformer layer implementation.
Key Changes:
- New Jupyter notebook
quickstart_jax.ipynb(699 lines) with hands-on tutorial - Supporting utility file
quickstart_jax_utils.py(267 lines) containing reusable components - CI workflow update to separate PyTorch installation with explicit CUDA 13.0 index URL
Tutorial Progression:
- Pure JAX/Flax baseline (~26ms) - Custom attention and MLP implementations
- Basic TE modules (~16ms) - Using
te_flax.LayerNormandDenseGeneral - Fused TE modules (~16ms) - Using
LayerNormDenseGeneralandLayerNormMLP - Full TE TransformerLayer (~11ms) - Complete optimized implementation
- FP8 enabled (~7ms) - 73% speedup over baseline with FP8 precision
Issues Found:
- Multiple typos in documentation ('similicity', 'attetntion', 'inistead', 'linnen', 'buiil-in')
- Hardcoded absolute path
os.chdir('/workspace/docs/examples')will break for users - Malformed markdown link syntax
[quickstart_jax_utils.py] (**)(quickstart_jax_utils.py) - Extra quote in string
'no_mask''should be'no_mask' - Unused
datasetsimport in notebook layernorm_epsdefined asintinstead offloat(style issue, doesn't affect functionality)
Confidence Score: 3/5
- Safe to merge after fixing documentation issues - no functional bugs found, only typos and hardcoded paths that will impact user experience
- Score of 3 reflects that while the code is functionally correct and the tutorial successfully demonstrates TE capabilities, there are several documentation quality issues that should be fixed: hardcoded paths will break for users running the notebook outside the expected environment, multiple typos reduce professionalism, and malformed markdown links affect navigation. The CI change is safe and improves dependency management.
- docs/examples/quickstart_jax.ipynb requires fixing hardcoded paths (line 59) and typos throughout; docs/examples/quickstart_jax_utils.py needs typo corrections
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| .github/workflows/build.yml | 5/5 | Splits PyTorch installation into separate command with explicit CUDA index URL for better dependency management |
| docs/examples/quickstart_jax_utils.py | 3/5 | New utility file with attention implementations and speedometer function; contains several typos and documentation formatting issues flagged in previous comments |
| docs/examples/quickstart_jax.ipynb | 3/5 | New JAX quickstart tutorial notebook demonstrating progressive optimization with TE; contains hardcoded paths, typos, and malformed markdown links flagged in previous comments |
Sequence Diagram
sequenceDiagram
participant User
participant Notebook as quickstart_jax.ipynb
participant Utils as quickstart_jax_utils.py
participant JAX as JAX/Flax
participant TE as TransformerEngine
User->>Notebook: Run tutorial cells
Notebook->>JAX: Import jax, jax.numpy, flax.linen
Notebook->>Utils: Import quickstart_jax_utils
Notebook->>TE: Import transformer_engine.jax
Note over Notebook: Iteration 1: Pure JAX/Flax
Notebook->>JAX: Define BasicTransformerLayer
Notebook->>Utils: Use AttentionWrapper(CUSTOM_DOT_PRODUCT)
Notebook->>Utils: Use BasicMLP
Notebook->>Utils: speedometer() - measure baseline
Note over Notebook: Iteration 2: Basic TE
Notebook->>TE: Use te_flax.LayerNorm, DenseGeneral
Notebook->>Utils: Use AttentionWrapper(TE_FLAX_MULTIHEAD)
Notebook->>Utils: speedometer() - faster performance
Note over Notebook: Iteration 3: Fused TE
Notebook->>TE: Use LayerNormDenseGeneral, LayerNormMLP
Notebook->>Utils: speedometer() - similar to iteration 2
Note over Notebook: Iteration 4: Full TE TransformerLayer
Notebook->>TE: Use te_flax.TransformerLayer
Notebook->>Utils: speedometer() - best performance
Note over Notebook: Enable FP8
Notebook->>TE: fp8_autocast context manager
Notebook->>TE: Initialize with FP8 recipe
Notebook->>Utils: speedometer(fp8_autocast_kwargs) - fastest
Utils->>JAX: create_train_step_fn()
Utils->>JAX: jax.value_and_grad() for gradients
Utils->>JAX: jax.jit() for compilation
Utils-->>Notebook: Return timing results
3 files reviewed, no comments
|
|
Wonder why I cannot reply to Pawel's suggestion to do tabbing in html. But I intend for this MR to just be the notebook first and work on a separate PR for displaying on html since that will also touch the Python quickstart guide. This PR is just for a functional ipynb |
b70bddf to
c6b6ca3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Adds comprehensive JAX quickstart documentation demonstrating progressive optimization of Transformer layers using TransformerEngine. The tutorial walks through 4 implementations showing increasing performance gains (26ms → 7ms).
Key Changes:
- New Jupyter notebook (
quickstart_jax.ipynb) with 4 progressive transformer implementations - Supporting utility module (
quickstart_jax_utils.py) with attention wrappers and performance benchmarking - CI workflow update for PyTorch installation with explicit CUDA index URL
Implementation Progression:
- BasicTransformerLayer: Pure JAX/Flax baseline (~26ms per step)
- BasicTETransformerLayer: Using TE's
DenseGeneralandLayerNormmodules (~16ms) - FusedTETransformerLayer: Leveraging fused kernels like
LayerNormDenseGeneral(~16ms) - Full TransformerLayer with FP8: Complete TE layer with FP8 precision (~7ms)
Strengths:
- Well-structured tutorial mirroring PyTorch quickstart for consistency
- Clear performance comparisons with timing measurements
- Good documentation of FP8 initialization requirements
- Comprehensive utility code with multiple attention implementations
Previous Review Feedback Addressed:
Most syntax issues (typos, formatting) and several logic concerns from previous reviews appear to have been addressed in recent commits, including fixing attention implementation inconsistencies and adding proper enum for attention types.
Confidence Score: 4/5
- This documentation PR is safe to merge with minor issues already flagged in previous reviews
- Score reflects that this is a documentation-only PR with comprehensive tutorial content. The code examples are well-structured and functional. Previous review comments have identified syntax issues (typos, formatting) that should be addressed, but these are non-blocking for a documentation PR. The utility code is correct and the notebook provides educational value. The CI workflow change is sensible and low-risk.
- All files are in good shape. The notebook may have minor typos/formatting issues noted in previous comments that should be cleaned up before merge.
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| docs/examples/quickstart_jax.ipynb | 4/5 | Comprehensive JAX quickstart notebook with 4 progressive transformer implementations. Most previous syntax/typo issues appear addressed. Minor formatting improvements needed. |
| docs/examples/quickstart_jax_utils.py | 5/5 | Well-structured utility module with attention implementations and performance testing. Code is correct and follows best practices. |
Sequence Diagram
sequenceDiagram
participant User as User/Notebook
participant Utils as quickstart_jax_utils
participant JAX as JAX/Flax
participant TE as TransformerEngine
Note over User: 1. BasicTransformerLayer (Pure JAX/Flax)
User->>JAX: Initialize BasicTransformerLayer
JAX-->>User: params
User->>Utils: speedometer(model, params, data)
Utils->>Utils: create_train_step_fn()
Utils->>JAX: Forward/Backward Pass
JAX-->>Utils: loss, gradients
Utils-->>User: Mean time: ~26ms
Note over User: 2. BasicTETransformerLayer (TE modules)
User->>TE: Initialize with DenseGeneral, LayerNorm
TE-->>User: te_params
User->>Utils: speedometer(model, params, data)
Utils->>TE: Forward/Backward with TE modules
TE-->>Utils: loss, gradients
Utils-->>User: Mean time: ~16ms
Note over User: 3. FusedTETransformerLayer (Fused TE)
User->>TE: Initialize with LayerNormDenseGeneral, LayerNormMLP
TE-->>User: fused_params
User->>Utils: speedometer(model, params, data)
Utils->>TE: Forward/Backward with fused kernels
TE-->>Utils: loss, gradients
Utils-->>User: Mean time: ~16ms
Note over User: 4. TransformerLayer with FP8
User->>TE: fp8_autocast context
User->>TE: Initialize TransformerLayer
TE-->>User: params + fp8_metas
User->>Utils: speedometer(model, params, data, fp8_enabled=True)
Utils->>TE: Forward/Backward with FP8 precision
TE-->>Utils: loss, gradients
Utils-->>User: Mean time: ~7ms
2 files reviewed, no comments
…comaptibility with speedometer Signed-off-by: tdophung <[email protected]>
Signed-off-by: tdophung <[email protected]>
Signed-off-by: tdophung <[email protected]>
…Layer Signed-off-by: tdophung <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: tdophung <[email protected]>
…e and unfused TE impls to achieve same performance (removing extra dropout layer in fused layers. Also some minor wording changes Signed-off-by: tdophung <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: tdophung <[email protected]>
for more information, see https://pre-commit.ci
…ch, ...] instead of [batch, sequence,...] Signed-off-by: tdophung <[email protected]>
…or fuse to take effect because quantization exist as suggested. Also make TransformerLayer perf get closer to Fused by setting hidden_dropout=0 Signed-off-by: tdophung <[email protected]>
…ll of BasicTETransformerLayer and demonstrated difference in runtime between using flax and using te's attetion implementation Signed-off-by: tdophung <[email protected]>
…ayer Signed-off-by: tdophung <[email protected]>
…mpl only, removing last mention of Pytorch Signed-off-by: tdophung <[email protected]>
Signed-off-by: tdophung <[email protected]>
fda70e7 to
816b379
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds a comprehensive JAX quickstart guide that mirrors the PyTorch tutorial structure. The notebook demonstrates a progressive approach to optimizing Transformer layers, showing 4 iterations with decreasing training times:
- Pure Flax baseline (~17.67ms) - Basic implementation using standard Flax modules
- Basic TE with Flax attention (~16.39ms) - Replaces Dense/LayerNorm with TE equivalents
- Full basic TE (~13.01ms) - Adds TE's DotProductAttention
- TE with FP8 (~10.07ms) - Enables FP8 quantization
- Fused TE layers (~9.79ms) - Uses LayerNormDenseGeneral and LayerNormMLP
- Built-in TransformerLayer (~9.71ms) - Uses TE's complete TransformerLayer module
The utility file (quickstart_jax_utils.py) provides clean helper functions for benchmarking with proper JIT compilation and gradient computation.
Most issues flagged in previous comments are valid (typos, documentation links, unused imports). The code logic appears sound - the jax.value_and_grad usage with argnums=(0,1) correctly returns gradients for both variables and inp parameters.
Confidence Score: 4/5
- This documentation PR is safe to merge after addressing typos and minor issues flagged in previous comments
- Score reflects that this is a documentation-only change with good educational content. The code examples are functional and demonstrate proper usage patterns. Previous reviewers identified legitimate issues (typos, unused imports, minor logic concerns) that should be addressed, but none are critical blockers. The utility functions are correctly implemented with proper JAX patterns.
- Focus on
docs/examples/quickstart_jax.ipynbto fix the typos, hardcoded paths, and documentation link formats identified in previous comments
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| docs/examples/quickstart_jax.ipynb | 4/5 | Comprehensive JAX quickstart tutorial with good progression from basic Flax to TE modules. Contains some typos and minor documentation issues already flagged in previous comments. |
| docs/examples/quickstart_jax_utils.py | 5/5 | Helper utilities for benchmarking JAX models. Clean implementation of training step functions with proper JIT compilation and gradient computation. |
Sequence Diagram
sequenceDiagram
participant User
participant Notebook as quickstart_jax.ipynb
participant Utils as quickstart_jax_utils.py
participant JAX
participant TE as TransformerEngine
User->>Notebook: Run notebook cells
Notebook->>JAX: Define BasicTransformerLayer (Flax)
Notebook->>JAX: Initialize & benchmark baseline
Notebook->>Utils: Call speedometer()
Utils->>Utils: create_train_step_fn()
Utils->>JAX: jax.jit(fwd_bwd_fn)
Utils->>JAX: Execute warmup iterations
Utils->>JAX: Execute timing iterations
Utils-->>Notebook: Return mean time
Notebook->>TE: Import transformer_engine.jax
Notebook->>TE: Define BasicTETransformerLayer (TE modules)
Notebook->>Utils: Call speedometer()
Utils-->>Notebook: Return improved time
Notebook->>TE: Enable FP8 with fp8_autocast
Notebook->>TE: Initialize model in FP8 context
Notebook->>Utils: Call speedometer() with FP8
Utils->>TE: Wrap training in fp8_autocast
Utils-->>Notebook: Return FP8 time
Notebook->>TE: Define FusedTETransformerLayer
Notebook->>TE: Use LayerNormDenseGeneral & LayerNormMLP
Notebook->>Utils: Call speedometer()
Utils-->>Notebook: Return fused time
Notebook->>TE: Use te_flax.TransformerLayer (built-in)
Notebook->>Utils: Call speedometer()
Utils-->>Notebook: Return best time
2 files reviewed, no comments
Signed-off-by: tdophung <[email protected]>
Signed-off-by: tdophung <[email protected]>
816b379 to
35df868
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
…hanges Signed-off-by: tdophung <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
|
/te-ci jax |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
* jax quickstart guide first commit Signed-off-by: tdophung <[email protected]> * edit the syntax errors and remove unnecessary comments in utils. Add some footnotes in the quick start notebook Signed-off-by: tdophung <[email protected]> * Fix greptiles comments on spelling, deepcopy, vjp function signature comaptibility with speedometer Signed-off-by: tdophung <[email protected]> * Add Copyright to utils and fix some more greptiles complaints Signed-off-by: tdophung <[email protected]> * Add comments to alternative of layers Signed-off-by: tdophung <[email protected]> * Remove weight sharing between different iterations of the transformerLayer Signed-off-by: tdophung <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: tdophung <[email protected]> * Add enum for attention implementations. Fix inconsistency between fuse and unfused TE impls to achieve same performance (removing extra dropout layer in fused layers. Also some minor wording changes Signed-off-by: tdophung <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: tdophung <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug in TransformerLayer expected input shape being [sequence, batch, ...] instead of [batch, sequence,...] Signed-off-by: tdophung <[email protected]> * Changing structure of notebook to bring fp8 ahead of fuse, to allow for fuse to take effect because quantization exist as suggested. Also make TransformerLayer perf get closer to Fused by setting hidden_dropout=0 Signed-off-by: tdophung <[email protected]> * add option to choose between different attention implementation in call of BasicTETransformerLayer and demonstrated difference in runtime between using flax and using te's attetion implementation Signed-off-by: tdophung <[email protected]> * Fix mistake in lacking attention_implementation in FuseTETransformerLayer Signed-off-by: tdophung <[email protected]> * Removing AttentionWrapper and custom built DPA, using flax and TE's impl only, removing last mention of Pytorch Signed-off-by: tdophung <[email protected]> * More changing to markdowns to remove pytorch Signed-off-by: tdophung <[email protected]> * cosmetics fixes Signed-off-by: tdophung <[email protected]> * changing names of all implementations Signed-off-by: tdophung <[email protected]> * change fp8_autocast to autocast, make causal mask, and some wording changes Signed-off-by: tdophung <[email protected]> --------- Signed-off-by: tdophung <[email protected]> Co-authored-by: tdophung <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: jberchtold-nvidia <[email protected]>

Description
This change adds the new quick start notebook for Jax, mirroring the same Transformer architecture as the PyTorch guide, for users familiar with the PyTorch guide can also easily follow. Contains an 4 iterations of the layer with different training step time durations:
This might also include changes to how sphinx display this content on the HTML docs page in later commits
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: