Skip to content

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Oct 28, 2025

Description

This change adds the new quick start notebook for Jax, mirroring the same Transformer architecture as the PyTorch guide, for users familiar with the PyTorch guide can also easily follow. Contains an 4 iterations of the layer with different training step time durations:

  1. pure JAX/Flax implemtation
  2. Basic TE implementation without any fused layer
  3. TE implementation with mixed fused and unfused layers
  4. Full TransformerLayer from TE

This might also include changes to how sphinx display this content on the HTML docs page in later commits

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added 1 quickstart guide written in Jax

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • [ DOES NOT APPLY] I have added tests that prove my fix is effective or that my feature works
  • [ DOES NOT APPLY] New and existing unit tests pass locally with my changes

greptile-apps[bot]

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

@tdophung tdophung force-pushed the tdophung/jax_quickstart_documentation branch from 2796e91 to 733d61b Compare October 28, 2025 18:39
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds a comprehensive JAX-specific quickstart tutorial notebook that demonstrates four progressive implementations of a transformer layer: pure JAX/Flax, basic Transformer Engine (TE) modules, fused TE modules, and the full TE TransformerLayer with optional FP8 support. The notebook includes performance benchmarking utilities and parameter-sharing helpers to enable fair comparisons between implementations. A minor build infrastructure change updates the PyTorch dependency installation in the GitHub Actions workflow to explicitly use CUDA 13.0 wheels, ensuring compatibility with the JAX container environment. This documentation effort mirrors the structure of the existing PyTorch quickstart guide, providing JAX users with a dedicated migration and optimization path.

Important Files Changed

Filename Score Overview
.github/workflows/build.yml 5/5 Split torch installation into separate command with explicit CUDA 13.0 index URL
docs/examples/quickstart_jax.ipynb 3/5 Added comprehensive JAX quickstart notebook with 4 progressive transformer implementations
docs/examples/quickstart_jax_utils.py 2/5 Added utility functions for benchmarking and parameter sharing between JAX and TE models

Confidence score: 2/5

  • This PR contains critical runtime errors that will prevent the tutorial from executing successfully
  • The main issue is function signature mismatch in quickstart_jax_utils.py: speedometer() calls train_step_fn with 5 arguments but create_train_step_fn() returns a function expecting only 4 parameters, causing a guaranteed TypeError
  • Additional concerns include variable scope issues in the notebook (line 795 references te_transformer_params_template which may not be in scope after FP8 initialization) and unused parameters indicating incomplete implementation
  • Pay close attention to docs/examples/quickstart_jax_utils.py lines 37-46 (function call mismatch) and the notebook cells around lines 787-801 (variable scoping)

Sequence Diagram

sequenceDiagram
    participant User
    participant Notebook as quickstart_jax.ipynb
    participant Utils as quickstart_jax_utils.py
    participant JAX as JAX/Flax
    participant TE as Transformer Engine

    User->>Notebook: Execute notebook cells
    Notebook->>JAX: Import jax, jax.numpy, flax.linen
    Notebook->>TE: Import transformer_engine.jax
    Notebook->>Utils: Import quickstart_jax_utils
    
    Note over Notebook,JAX: 1. Build BasicTransformerLayer (pure JAX/Flax)
    Notebook->>JAX: Initialize BasicTransformerLayer
    JAX-->>Notebook: Return initialized model
    Notebook->>JAX: init(key, x) to create params
    JAX-->>Notebook: Return params
    Notebook->>Utils: speedometer(model_apply_fn, variables, input, output_grad)
    Utils->>Utils: create_train_step_fn() - JIT compile fwd/bwd
    Utils->>JAX: Run warmup iterations
    Utils->>JAX: Run timing iterations (forward + backward)
    JAX-->>Utils: Return loss and gradients
    Utils-->>Notebook: Print mean time
    
    Note over Notebook,TE: 2. Build BasicTETransformerLayer (TE modules)
    Notebook->>TE: Initialize BasicTETransformerLayer
    TE-->>Notebook: Return TE model
    Notebook->>TE: init(key, x) to create TE params
    TE-->>Notebook: Return TE params template
    Notebook->>Utils: share_parameters_with_basic_te_model()
    Utils-->>Notebook: Return shared params
    Notebook->>Utils: speedometer() with TE model
    Utils->>TE: Run timing iterations with TE layers
    TE-->>Utils: Return loss and gradients
    Utils-->>Notebook: Print improved mean time
    
    Note over Notebook,TE: 3. Build FusedTETransformerLayer (fused ops)
    Notebook->>TE: Initialize FusedTETransformerLayer
    TE-->>Notebook: Return fused TE model
    Notebook->>TE: init() with LayerNormDenseGeneral, LayerNormMLP
    TE-->>Notebook: Return fused params
    Notebook->>Utils: share_fused_parameters_with_basic_te_model()
    Utils-->>Notebook: Return shared fused params
    Notebook->>Utils: speedometer() with fused model
    Utils->>TE: Run timing with fused kernels
    TE-->>Utils: Return loss and gradients
    Utils-->>Notebook: Print further improved time
    
    Note over Notebook,TE: 4. Use TE TransformerLayer with FP8
    Notebook->>TE: Create fp8_recipe (DelayedScaling)
    Notebook->>TE: fp8_autocast(enabled=True, fp8_recipe)
    Notebook->>TE: Initialize TransformerLayer within fp8_autocast
    TE-->>Notebook: Return params with fp8_metas collection
    Notebook->>Utils: share_parameters_with_transformerlayer_te_model()
    Utils-->>Notebook: Return shared params with fp8_metas
    Notebook->>Utils: speedometer() with fp8_autocast_kwargs
    Utils->>TE: Run timing with FP8 precision
    TE-->>Utils: Return loss and gradients (FP8)
    Utils-->>Notebook: Print fastest mean time
    
    Notebook-->>User: Display performance comparison
Loading

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review. The developer has cleaned up the JAX quickstart notebook by removing execution artifacts (output cells, execution counts), fixing typos and type annotations, and correcting markdown link syntax. The quickstart_jax_utils.py file was updated to remove an unused key parameter from train_step_fn calls and add a 2025 copyright header. However, a critical logic error remains unfixed in the notebook at line 772: the code references te_transformer_params_template which is undefined in the FP8 context (it was only defined outside the FP8 block in line 671). This variable should be te_transformer_params instead.

Important Files Changed

Filename Score Overview
docs/examples/quickstart_jax.ipynb 2/5 Cleaned notebook execution artifacts and fixed typos/types, but introduced critical undefined variable bug in FP8 initialization block (line 772)
docs/examples/quickstart_jax_utils.py 3/5 Removed unused key parameter from train step functions and added2025copyright; shallow copy operations may allow unintended parameter mutation

Confidence score: 2/5

  • This PR cannot be merged safely due to a critical logic error that will cause runtime failure.
  • Score reflects one critical bug (undefined variable te_transformer_params_template at line772 in the notebook) that will break FP8 initialization when users try to execute the notebook. The developer needs to replace te_transformer_params_template with te_transformer_params in the FP8context. Additionally, the shallow copy.deepcopy operations in the utils file may cause subtle parameter-sharing bugs if nested dictionaries are mutated.
  • The notebook file (quickstart_jax.ipynb) requires immediate attention - specifically line 772 in the FP8 initialization block where an undefined variable is referenced.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@tdophung tdophung force-pushed the tdophung/jax_quickstart_documentation branch from d9859d1 to b742244 Compare October 28, 2025 21:19
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the latest changes to two documentation files: docs/examples/quickstart_jax.ipynb and docs/examples/quickstart_jax_utils.py. Recent commits have addressed previous formatting and documentation issues (markdown links, type annotations, typos), but critical runtime bugs remain unresolved. The quickstart notebook demonstrates progressive optimization of a JAX/Flax transformer using Transformer Engine, but the utility file contains a function signature mismatch that will cause immediate failures when users attempt to run the notebook.

Important Files Changed

Filename Score Overview
docs/examples/quickstart_jax.ipynb 2/5 Cleared execution artifacts and fixed minor issues, but contains a critical NameError on line 772 where undefined variable te_transformer_params_template is referenced in FP8 initialization block
docs/examples/quickstart_jax_utils.py 1/5 Added benchmarking utilities for JAX/Flax models with critical signature mismatch—train_step_fn called with 4 arguments (lines 41, 47) but function signatures expect 4 parameters (dropout_key required); will fail immediately at runtime

Confidence score: 1/5

  • This PR will cause immediate runtime failures when users attempt to follow the quickstart guide due to function signature mismatches and undefined variable references
  • Score reflects two critical blocking issues: (1) the train_step_fn signature mismatch in quickstart_jax_utils.py where calls on lines 41 and 47 provide 4 arguments but the underlying functions require dropout_key as a parameter, and (2) the undefined te_transformer_params_template reference on line 772 of the notebook that will raise NameError
  • Both utility file functions (create_train_step_fn and create_train_step_fn_vjp) and the FP8 initialization block in the notebook require immediate correction before this documentation can be published

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

" )\n",
" x = attention(q, k, v, attention_mask, deterministic=deterministic)\n",
" \n",
" # Attention built-in. Comment out if not used\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a bit more detail to this comment, something like: If you would like to use Flax's built-in attention implementation, you can use this instead

Want to clarify to the users' that this is an alternative option and not some debug code we left by mistake

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this comment really got resolved. Let me make a suggestion how I think it could look like instead.

@phu0ngng phu0ngng requested review from phu0ngng and ptrendx October 30, 2025 00:17
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR introduces a comprehensive JAX/Flax quickstart tutorial that mirrors the PyTorch quickstart guide. The tutorial demonstrates four progressive implementations of a Transformer layer, showing performance improvements from pure JAX/Flax (~28ms) to fully optimized TE with FP8 (~8ms).

Key Changes:

  • New quickstart_jax.ipynb notebook with 4 transformer implementations (pure JAX, basic TE, fused TE, full TE with FP8)
  • Supporting utility file quickstart_jax_utils.py with benchmarking functions and custom attention/MLP modules
  • Workflow update to use specific PyTorch CUDA index URL for build consistency

Issues Found:

  • Several typos in documentation (already flagged in previous comments)
  • Malformed markdown link in notebook cell
  • Unused datasets import
  • Minor documentation link format issues

Positive Aspects:

  • Tutorial provides clear progressive optimization path
  • Benchmarking utilities are well-structured and correct
  • Good pedagogical approach with 4 iterations showing incremental improvements
  • Code is functionally correct despite some previous incorrect review comments

Confidence Score: 4/5

  • This PR is safe to merge after addressing minor documentation issues
  • Score reflects that this is a documentation-only PR with correct functional code. The issues found are primarily cosmetic (typos, formatting) rather than logical errors. Several previous comments incorrectly flagged non-issues (e.g., VJP signature, missing imports that don't exist). The actual issues are: unused import, malformed markdown link, and typos - all easily fixable and non-blocking for a draft PR.
  • docs/examples/quickstart_jax.ipynb needs attention for typos and formatting issues. The utility file and workflow changes are solid.

Important Files Changed

File Analysis

Filename Score Overview
.github/workflows/build.yml 5/5 Workflow updated to install torch from specific CUDA index URL - improves build consistency
docs/examples/quickstart_jax.ipynb 3/5 New JAX quickstart tutorial with 4 transformer implementations; contains typos and malformed markdown link already flagged
docs/examples/quickstart_jax_utils.py 4/5 Utility functions for speedometer benchmarking and custom attention/MLP modules; code is correct despite some previous incorrect comments

Sequence Diagram

sequenceDiagram
    participant User as User/Tutorial
    participant Basic as BasicTransformerLayer
    participant BasicTE as BasicTETransformerLayer
    participant Fused as FusedTETransformerLayer
    participant Full as TransformerLayer (TE)
    participant Utils as quickstart_jax_utils
    
    User->>Basic: Initialize with pure JAX/Flax
    Basic->>Basic: nn.LayerNorm + nn.Dense (QKV)
    Basic->>Utils: DotProductAttention
    Utils-->>Basic: attention output
    Basic->>Basic: nn.Dense (projection) + nn.Dropout
    Basic->>Utils: BasicMLP
    Utils-->>Basic: MLP output
    Basic-->>User: ~28ms training step
    
    User->>BasicTE: Replace with TE modules
    BasicTE->>BasicTE: te_flax.LayerNorm + te_flax.DenseGeneral
    BasicTE->>Utils: DotProductAttention
    Utils-->>BasicTE: attention output
    BasicTE->>BasicTE: te_flax.DenseGeneral + nn.Dropout
    BasicTE->>BasicTE: BasicTEMLP (te_flax.DenseGeneral)
    BasicTE-->>User: ~17ms training step
    
    User->>Fused: Use fused TE operations
    Fused->>Fused: te_flax.LayerNormDenseGeneral (fused)
    Fused->>Utils: DotProductAttention
    Utils-->>Fused: attention output
    Fused->>Fused: te_flax.DenseGeneral + nn.Dropout
    Fused->>Fused: te_flax.LayerNormMLP (fused)
    Fused-->>User: ~18ms training step
    
    User->>Full: Use full TE TransformerLayer
    Full->>Full: Complete optimized implementation
    Full-->>User: ~12ms training step
    
    User->>Full: Enable FP8 with fp8_autocast
    Full->>Full: FP8 precision compute
    Full-->>User: ~8ms training step
Loading

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@phu0ngng phu0ngng changed the title [DRAFT] Jax (separate from PyTorch) Quickstart documentation [JAX] Quickstart documentation Oct 31, 2025
@ptrendx
Copy link
Member

ptrendx commented Oct 31, 2025

There seems to be some issue with the commits mixed up. Could you rebase on top of the current main @tdophung ?

"- `DotProductAttention`: `DotProductAttention` from [quickstart_jax_utils.py](quickstart_jax_utils.py)\n",
"- `Projection`: `nn.Dense` (JAX/Flax)\n",
"- `Dropout`: `nn.Dropout` (JAX/Flax)\n",
"- `MLP`: `BasicMLP` from [quickstart_jax_utils.py](quickstart_jax_utils.py)\n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ksivaman putting on your radar that this will require more substitutions when we create the docs package.

"- `nn.Dropout`: JAX/Flax Dropout\n",
"- `BasicMLP`: Custom MLP from [quickstart_jax_utils.py](quickstart_jax_utils.py)\n",
"\n",
"<small> (**) _The code below also shows how to use the built-in attention sub-layer from either pure Flax or TE Flax in commented code if you wish to use those instead of the custom attention in [quickstart_jax_utils.py]. The implementation is there for your reference of how attention is roughly implemented in our source_</small>\n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this sentence - it sounds too informal. A quick pass with copilot gave this instead:

The code snippet below also includes commented examples demonstrating how to utilize the built-in attention sub-layer from either pure Flax or TE Flax. These alternatives are provided should you prefer them over the custom attention implementation found in quickstart_jax_utils.py. This reference is intended to offer insight into the general structure and approach used for implementing attention mechanisms in our source code.

I don't really like that one either (although it sounds better), but maybe we could use it to get something inbetween.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of having both the flax.linen.MultiHeadDotProductAttention and the self-built attention? If the self-built attention is implemented correctly, they should have exactly the same ops as the linen one, right?

I think it will be less confusing to simply start with the linen attention module directly, then swap it with the TE one. In that way, we can expose those two modules in this tutorial directly rather than hiding their details in the util file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it is helpful to have the self-built attention since it shows how attention works and also stay close to the Pytorch tutorial

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't have commented code anymore?

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Adds comprehensive JAX quickstart documentation mirroring the PyTorch tutorial structure, demonstrating progressive performance optimization through four iterations of a Transformer layer implementation.

Key Changes:

  • New Jupyter notebook quickstart_jax.ipynb (699 lines) with hands-on tutorial
  • Supporting utility file quickstart_jax_utils.py (267 lines) containing reusable components
  • CI workflow update to separate PyTorch installation with explicit CUDA 13.0 index URL

Tutorial Progression:

  1. Pure JAX/Flax baseline (~26ms) - Custom attention and MLP implementations
  2. Basic TE modules (~16ms) - Using te_flax.LayerNorm and DenseGeneral
  3. Fused TE modules (~16ms) - Using LayerNormDenseGeneral and LayerNormMLP
  4. Full TE TransformerLayer (~11ms) - Complete optimized implementation
  5. FP8 enabled (~7ms) - 73% speedup over baseline with FP8 precision

Issues Found:

  • Multiple typos in documentation ('similicity', 'attetntion', 'inistead', 'linnen', 'buiil-in')
  • Hardcoded absolute path os.chdir('/workspace/docs/examples') will break for users
  • Malformed markdown link syntax [quickstart_jax_utils.py] (**)(quickstart_jax_utils.py)
  • Extra quote in string 'no_mask'' should be 'no_mask'
  • Unused datasets import in notebook
  • layernorm_eps defined as int instead of float (style issue, doesn't affect functionality)

Confidence Score: 3/5

  • Safe to merge after fixing documentation issues - no functional bugs found, only typos and hardcoded paths that will impact user experience
  • Score of 3 reflects that while the code is functionally correct and the tutorial successfully demonstrates TE capabilities, there are several documentation quality issues that should be fixed: hardcoded paths will break for users running the notebook outside the expected environment, multiple typos reduce professionalism, and malformed markdown links affect navigation. The CI change is safe and improves dependency management.
  • docs/examples/quickstart_jax.ipynb requires fixing hardcoded paths (line 59) and typos throughout; docs/examples/quickstart_jax_utils.py needs typo corrections

Important Files Changed

File Analysis

Filename Score Overview
.github/workflows/build.yml 5/5 Splits PyTorch installation into separate command with explicit CUDA index URL for better dependency management
docs/examples/quickstart_jax_utils.py 3/5 New utility file with attention implementations and speedometer function; contains several typos and documentation formatting issues flagged in previous comments
docs/examples/quickstart_jax.ipynb 3/5 New JAX quickstart tutorial notebook demonstrating progressive optimization with TE; contains hardcoded paths, typos, and malformed markdown links flagged in previous comments

Sequence Diagram

sequenceDiagram
    participant User
    participant Notebook as quickstart_jax.ipynb
    participant Utils as quickstart_jax_utils.py
    participant JAX as JAX/Flax
    participant TE as TransformerEngine

    User->>Notebook: Run tutorial cells
    Notebook->>JAX: Import jax, jax.numpy, flax.linen
    Notebook->>Utils: Import quickstart_jax_utils
    Notebook->>TE: Import transformer_engine.jax
    
    Note over Notebook: Iteration 1: Pure JAX/Flax
    Notebook->>JAX: Define BasicTransformerLayer
    Notebook->>Utils: Use AttentionWrapper(CUSTOM_DOT_PRODUCT)
    Notebook->>Utils: Use BasicMLP
    Notebook->>Utils: speedometer() - measure baseline
    
    Note over Notebook: Iteration 2: Basic TE
    Notebook->>TE: Use te_flax.LayerNorm, DenseGeneral
    Notebook->>Utils: Use AttentionWrapper(TE_FLAX_MULTIHEAD)
    Notebook->>Utils: speedometer() - faster performance
    
    Note over Notebook: Iteration 3: Fused TE
    Notebook->>TE: Use LayerNormDenseGeneral, LayerNormMLP
    Notebook->>Utils: speedometer() - similar to iteration 2
    
    Note over Notebook: Iteration 4: Full TE TransformerLayer
    Notebook->>TE: Use te_flax.TransformerLayer
    Notebook->>Utils: speedometer() - best performance
    
    Note over Notebook: Enable FP8
    Notebook->>TE: fp8_autocast context manager
    Notebook->>TE: Initialize with FP8 recipe
    Notebook->>Utils: speedometer(fp8_autocast_kwargs) - fastest
    
    Utils->>JAX: create_train_step_fn()
    Utils->>JAX: jax.value_and_grad() for gradients
    Utils->>JAX: jax.jit() for compilation
    Utils-->>Notebook: Return timing results
Loading

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@pggPL
Copy link
Collaborator

pggPL commented Nov 4, 2025

  1. It's not being rendered in the docs now, so I cannot see it - you need to put it into some proper .rst file.
  2. I had a discussion with @ptrendx yesterday and he said that you considered using tabs like in d2l.ai. I did something like this here using sphinx-tabs package. You can download the docs from here and go to section features -> low precision training -> custom recipe to see it. Tbh it is 100% cursor generated, so the code may not make sense, but it looks nice. Idk if you want to use it, just letting you know.
Zrzut ekranu 2025-11-4 o 17 57 00

@tdophung
Copy link
Collaborator Author

tdophung commented Nov 5, 2025

Wonder why I cannot reply to Pawel's suggestion to do tabbing in html. But I intend for this MR to just be the notebook first and work on a separate PR for displaying on html since that will also touch the Python quickstart guide. This PR is just for a functional ipynb

@tdophung tdophung force-pushed the tdophung/jax_quickstart_documentation branch from b70bddf to c6b6ca3 Compare November 5, 2025 00:47
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Adds comprehensive JAX quickstart documentation demonstrating progressive optimization of Transformer layers using TransformerEngine. The tutorial walks through 4 implementations showing increasing performance gains (26ms → 7ms).

Key Changes:

  • New Jupyter notebook (quickstart_jax.ipynb) with 4 progressive transformer implementations
  • Supporting utility module (quickstart_jax_utils.py) with attention wrappers and performance benchmarking
  • CI workflow update for PyTorch installation with explicit CUDA index URL

Implementation Progression:

  1. BasicTransformerLayer: Pure JAX/Flax baseline (~26ms per step)
  2. BasicTETransformerLayer: Using TE's DenseGeneral and LayerNorm modules (~16ms)
  3. FusedTETransformerLayer: Leveraging fused kernels like LayerNormDenseGeneral (~16ms)
  4. Full TransformerLayer with FP8: Complete TE layer with FP8 precision (~7ms)

Strengths:

  • Well-structured tutorial mirroring PyTorch quickstart for consistency
  • Clear performance comparisons with timing measurements
  • Good documentation of FP8 initialization requirements
  • Comprehensive utility code with multiple attention implementations

Previous Review Feedback Addressed:
Most syntax issues (typos, formatting) and several logic concerns from previous reviews appear to have been addressed in recent commits, including fixing attention implementation inconsistencies and adding proper enum for attention types.

Confidence Score: 4/5

  • This documentation PR is safe to merge with minor issues already flagged in previous reviews
  • Score reflects that this is a documentation-only PR with comprehensive tutorial content. The code examples are well-structured and functional. Previous review comments have identified syntax issues (typos, formatting) that should be addressed, but these are non-blocking for a documentation PR. The utility code is correct and the notebook provides educational value. The CI workflow change is sensible and low-risk.
  • All files are in good shape. The notebook may have minor typos/formatting issues noted in previous comments that should be cleaned up before merge.

Important Files Changed

File Analysis

Filename Score Overview
docs/examples/quickstart_jax.ipynb 4/5 Comprehensive JAX quickstart notebook with 4 progressive transformer implementations. Most previous syntax/typo issues appear addressed. Minor formatting improvements needed.
docs/examples/quickstart_jax_utils.py 5/5 Well-structured utility module with attention implementations and performance testing. Code is correct and follows best practices.

Sequence Diagram

sequenceDiagram
    participant User as User/Notebook
    participant Utils as quickstart_jax_utils
    participant JAX as JAX/Flax
    participant TE as TransformerEngine
    
    Note over User: 1. BasicTransformerLayer (Pure JAX/Flax)
    User->>JAX: Initialize BasicTransformerLayer
    JAX-->>User: params
    User->>Utils: speedometer(model, params, data)
    Utils->>Utils: create_train_step_fn()
    Utils->>JAX: Forward/Backward Pass
    JAX-->>Utils: loss, gradients
    Utils-->>User: Mean time: ~26ms
    
    Note over User: 2. BasicTETransformerLayer (TE modules)
    User->>TE: Initialize with DenseGeneral, LayerNorm
    TE-->>User: te_params
    User->>Utils: speedometer(model, params, data)
    Utils->>TE: Forward/Backward with TE modules
    TE-->>Utils: loss, gradients
    Utils-->>User: Mean time: ~16ms
    
    Note over User: 3. FusedTETransformerLayer (Fused TE)
    User->>TE: Initialize with LayerNormDenseGeneral, LayerNormMLP
    TE-->>User: fused_params
    User->>Utils: speedometer(model, params, data)
    Utils->>TE: Forward/Backward with fused kernels
    TE-->>Utils: loss, gradients
    Utils-->>User: Mean time: ~16ms
    
    Note over User: 4. TransformerLayer with FP8
    User->>TE: fp8_autocast context
    User->>TE: Initialize TransformerLayer
    TE-->>User: params + fp8_metas
    User->>Utils: speedometer(model, params, data, fp8_enabled=True)
    Utils->>TE: Forward/Backward with FP8 precision
    TE-->>Utils: loss, gradients
    Utils-->>User: Mean time: ~7ms
Loading

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

tdophung and others added 12 commits November 10, 2025 10:14
…comaptibility with speedometer

Signed-off-by: tdophung <[email protected]>
…Layer

Signed-off-by: tdophung <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: tdophung <[email protected]>
…e and unfused TE impls to achieve same performance (removing extra dropout layer in fused layers. Also some minor wording changes

Signed-off-by: tdophung <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: tdophung <[email protected]>
…ch, ...] instead of [batch, sequence,...]

Signed-off-by: tdophung <[email protected]>
…or fuse to take effect because quantization exist as suggested. Also make TransformerLayer perf get closer to Fused by setting hidden_dropout=0

Signed-off-by: tdophung <[email protected]>
…ll of BasicTETransformerLayer and demonstrated difference in runtime between using flax and using te's attetion implementation

Signed-off-by: tdophung <[email protected]>
…mpl only, removing last mention of Pytorch

Signed-off-by: tdophung <[email protected]>
@tdophung tdophung force-pushed the tdophung/jax_quickstart_documentation branch from fda70e7 to 816b379 Compare November 10, 2025 18:14
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds a comprehensive JAX quickstart guide that mirrors the PyTorch tutorial structure. The notebook demonstrates a progressive approach to optimizing Transformer layers, showing 4 iterations with decreasing training times:

  • Pure Flax baseline (~17.67ms) - Basic implementation using standard Flax modules
  • Basic TE with Flax attention (~16.39ms) - Replaces Dense/LayerNorm with TE equivalents
  • Full basic TE (~13.01ms) - Adds TE's DotProductAttention
  • TE with FP8 (~10.07ms) - Enables FP8 quantization
  • Fused TE layers (~9.79ms) - Uses LayerNormDenseGeneral and LayerNormMLP
  • Built-in TransformerLayer (~9.71ms) - Uses TE's complete TransformerLayer module

The utility file (quickstart_jax_utils.py) provides clean helper functions for benchmarking with proper JIT compilation and gradient computation.

Most issues flagged in previous comments are valid (typos, documentation links, unused imports). The code logic appears sound - the jax.value_and_grad usage with argnums=(0,1) correctly returns gradients for both variables and inp parameters.

Confidence Score: 4/5

  • This documentation PR is safe to merge after addressing typos and minor issues flagged in previous comments
  • Score reflects that this is a documentation-only change with good educational content. The code examples are functional and demonstrate proper usage patterns. Previous reviewers identified legitimate issues (typos, unused imports, minor logic concerns) that should be addressed, but none are critical blockers. The utility functions are correctly implemented with proper JAX patterns.
  • Focus on docs/examples/quickstart_jax.ipynb to fix the typos, hardcoded paths, and documentation link formats identified in previous comments

Important Files Changed

File Analysis

Filename Score Overview
docs/examples/quickstart_jax.ipynb 4/5 Comprehensive JAX quickstart tutorial with good progression from basic Flax to TE modules. Contains some typos and minor documentation issues already flagged in previous comments.
docs/examples/quickstart_jax_utils.py 5/5 Helper utilities for benchmarking JAX models. Clean implementation of training step functions with proper JIT compilation and gradient computation.

Sequence Diagram

sequenceDiagram
    participant User
    participant Notebook as quickstart_jax.ipynb
    participant Utils as quickstart_jax_utils.py
    participant JAX
    participant TE as TransformerEngine

    User->>Notebook: Run notebook cells
    Notebook->>JAX: Define BasicTransformerLayer (Flax)
    Notebook->>JAX: Initialize & benchmark baseline
    Notebook->>Utils: Call speedometer()
    Utils->>Utils: create_train_step_fn()
    Utils->>JAX: jax.jit(fwd_bwd_fn)
    Utils->>JAX: Execute warmup iterations
    Utils->>JAX: Execute timing iterations
    Utils-->>Notebook: Return mean time
    
    Notebook->>TE: Import transformer_engine.jax
    Notebook->>TE: Define BasicTETransformerLayer (TE modules)
    Notebook->>Utils: Call speedometer()
    Utils-->>Notebook: Return improved time
    
    Notebook->>TE: Enable FP8 with fp8_autocast
    Notebook->>TE: Initialize model in FP8 context
    Notebook->>Utils: Call speedometer() with FP8
    Utils->>TE: Wrap training in fp8_autocast
    Utils-->>Notebook: Return FP8 time
    
    Notebook->>TE: Define FusedTETransformerLayer
    Notebook->>TE: Use LayerNormDenseGeneral & LayerNormMLP
    Notebook->>Utils: Call speedometer()
    Utils-->>Notebook: Return fused time
    
    Notebook->>TE: Use te_flax.TransformerLayer (built-in)
    Notebook->>Utils: Call speedometer()
    Utils-->>Notebook: Return best time
Loading

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@tdophung tdophung force-pushed the tdophung/jax_quickstart_documentation branch from 816b379 to 35df868 Compare November 13, 2025 22:41
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx added the 2.10.0 label Nov 14, 2025
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@jberchtold-nvidia
Copy link
Collaborator

/te-ci jax

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx merged commit 42d2274 into NVIDIA:main Nov 15, 2025
10 of 13 checks passed
ptrendx pushed a commit that referenced this pull request Nov 15, 2025
* jax quickstart guide first commit

Signed-off-by: tdophung <[email protected]>

* edit the syntax errors and remove unnecessary comments in utils. Add some footnotes in the quick start notebook

Signed-off-by: tdophung <[email protected]>

* Fix greptiles comments on spelling, deepcopy, vjp function signature comaptibility with speedometer

Signed-off-by: tdophung <[email protected]>

* Add Copyright to utils and fix some more greptiles complaints

Signed-off-by: tdophung <[email protected]>

* Add comments to alternative of layers

Signed-off-by: tdophung <[email protected]>

* Remove weight sharing between different iterations of the transformerLayer

Signed-off-by: tdophung <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: tdophung <[email protected]>

* Add enum for attention implementations. Fix inconsistency between fuse and unfused TE impls to achieve same performance (removing extra dropout layer in fused layers. Also some minor wording changes

Signed-off-by: tdophung <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: tdophung <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix bug in TransformerLayer expected input shape being [sequence, batch, ...] instead of [batch, sequence,...]

Signed-off-by: tdophung <[email protected]>

* Changing structure of notebook to  bring fp8 ahead of fuse, to allow for fuse to take effect because quantization exist as suggested. Also make TransformerLayer perf get closer to Fused by setting hidden_dropout=0

Signed-off-by: tdophung <[email protected]>

* add option to choose between different attention implementation in call of BasicTETransformerLayer and demonstrated difference in runtime between using flax and using te's attetion implementation

Signed-off-by: tdophung <[email protected]>

* Fix mistake in lacking attention_implementation in FuseTETransformerLayer

Signed-off-by: tdophung <[email protected]>

* Removing AttentionWrapper and custom built DPA, using flax and TE's impl only, removing last mention of Pytorch

Signed-off-by: tdophung <[email protected]>

* More changing to markdowns to remove pytorch

Signed-off-by: tdophung <[email protected]>

* cosmetics fixes

Signed-off-by: tdophung <[email protected]>

* changing names of all implementations

Signed-off-by: tdophung <[email protected]>

* change fp8_autocast to autocast, make causal mask, and some wording changes

Signed-off-by: tdophung <[email protected]>

---------

Signed-off-by: tdophung <[email protected]>
Co-authored-by: tdophung <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: jberchtold-nvidia <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants