Skip to content

[moe training] Cast to mixed precision policy param dtype in fsdp_pre_all_gather hook #2455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 2, 2025

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jun 27, 2025

Stack

Summary

  • After examining this example and discussing with @weifengpy, we think that when implementing fsdp pre/post all gather hooks for a tensor subclass, the developer must handle casting params to the MP policy param dtype in the pre all-gather hook themselves.
  • Update fsdp_post_all_gather_hook to more correctly handle case where out != None (see code comments for details)
  • Remove dtype param from ScaledGroupedMM tensor, as it is no longer needed when doing the casting in pre all gather rather than post all gather.

Test plan

  • ./test/prototype/moe_training/test_fsdp.sh
  • Manual test with torchtitan llama4 debug model: NGPU=2 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.steps=10 --model.converters="float8" --float8.recipe_name="rowwise" --float8.moe_fqns_prototype="experts"

Copy link

pytorch-bot bot commented Jun 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2455

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 27, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft June 27, 2025 19:30
Copy link

pytorch-bot bot commented Jun 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2455

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 5 Pending

As of commit bb9626e with merge base ac14d92 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@danielvegamyhre danielvegamyhre changed the title [WIP] [moe training] ScaledGroupedMMTensor - set dtype [moe training] Cast to MP policy param dtype in fsdp_pre_all_gather hook Jul 1, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review July 1, 2025 20:49
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Jul 1, 2025
@danielvegamyhre danielvegamyhre force-pushed the dtype branch 2 times, most recently from 8df3fbb to fd933ea Compare July 1, 2025 21:28
@danielvegamyhre
Copy link
Contributor Author

cc @drisspg @vkuzo for review

@danielvegamyhre danielvegamyhre requested review from vkuzo and drisspg July 1, 2025 21:35
@danielvegamyhre danielvegamyhre changed the title [moe training] Cast to MP policy param dtype in fsdp_pre_all_gather hook [moe training] Cast to mixed precision policy param dtype in fsdp_pre_all_gather hook Jul 1, 2025
out._data.copy_(data)
return

# For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a test for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a test for float MoE + FSDP training. We don't have a test verifying which code branch is followed in this fsdp_post_all_gather hook at training step 0 vs 1, but I think the FSDP test alone is sufficient. Let me know if you have other thoughts.

@danielvegamyhre
Copy link
Contributor Author

MPS test failures are unrelated to this change

@danielvegamyhre danielvegamyhre merged commit 01f7352 into main Jul 2, 2025
18 of 19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants