Skip to content

Conversation

liangel-02
Copy link
Contributor

@liangel-02 liangel-02 commented Aug 22, 2025

addressing #2833

updating test to include mxfp8: torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_fsdp.py , all tests pass

Copy link

pytorch-bot bot commented Aug 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2849

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7b0add0 with merge base df7bf37 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 22, 2025
@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Aug 22, 2025

@liangel-02 we need to use this torchtitan API:

set_token_group_alignment_size_m(16) # fp8

or

set_token_group_alignment_size_m(32) # mxfp8

This is because TMA (some background here) requires the slowest moving dim (stride of 1) to be 128 bit (16 byte) aligned.

In the backward pass, when grad_weight = grad_output_t @ input, the “M” dimension (flattened token groups) become this “stride 1” dim. Therefore, each token group must be 16 byte aligned for this grouped gemm.

  • For bf16, this means groups divisible by 8 elements (16 bytes / 2 bytes per elem).
  • For fp8 groups divisible by 16 elements (16 bytes / 1 byte per elem).
  • For mxfp8, I enforce they must be divisible by block_size (32) not because of TMA, but to ensure when we quantize along that dim, we have (1) no scaling groups crossing logically distinct tensor boundaries, and (2) no “partial” scaling group at the the end (e..g, 1x17 chunk because the tensor ends).

This might all be a bit confusing without background knowledge on GPU architecture and MoE models, but don't worry! Plenty of time to learn.

For now, just implement this setting in some clean/interpretable way in the test code and it should (hopefully) work.

@liangel-02 liangel-02 added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Aug 25, 2025
@liangel-02 liangel-02 changed the title [mxfp8 moe training] Add mxfp8 to FSDP and TP tests [mxfp8 moe training] Add mxfp8 to FSDP tests Aug 25, 2025
@liangel-02 liangel-02 marked this pull request as ready for review August 25, 2025 16:53
Copy link
Contributor

@danielvegamyhre danielvegamyhre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! 1 minor comment before landing, thanks

@@ -83,7 +95,8 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
return False

# quantize test model
config = MoETrainingConfig()
config = MoETrainingConfig(recipe)
# config = MoETrainingConfig()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented code before landing

@liangel-02 liangel-02 merged commit 8537883 into main Aug 25, 2025
18 checks passed
@liangel-02 liangel-02 deleted the mxfp8_tests branch August 25, 2025 19:46
@liangel-02 liangel-02 restored the mxfp8_tests branch August 25, 2025 22:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants