Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support store_param_remainders feature from Apex in TE Fused Adam #1408

Merged
merged 29 commits into from
Jan 31, 2025

Conversation

sanandaraj5597
Copy link
Contributor

Description

When the master parameter is in FP32 and the model parameters are in BF16, we can store the trailing 16 remainder bits and reconstruct the master FP32 param from (BF16 model param + the remainder).

This helps us half the master parameter memory usage.

@@ -243,13 +256,14 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale):
unscaled_state.mul_(rscale)
scaled_state.copy_(unscaled_state)

def get_unscaled_state(self, param, state_name):
def get_unscaled_state(self, param, state_name, store_param_remainders=False):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value of store_param_remainders is False here, but it's True by default in the constructor. I think it's misleading, why not just set it to True here as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to store param remainders for state_name other than master_params, that's why it's defaulted to false.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer if this function didn't expose this kwarg since it makes its behavior less obvious. get_unscaled_state implies that it produces an FP32 value that is ready to use, so it would be better if step called a different function to access the BF16 remainder. If we want to keep this overall logic, we should change the function name to something more accurate (although a vague name like get_state_for_adam_kernel is a code smell).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worked around it. Resolving conversation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better, although we still have the problem that state scaling and BF16 remainders are both using this function in different ways. It's troubling that get_unscaled_state might not get the unscaled state.

Copy link
Contributor Author

@sanandaraj5597 sanandaraj5597 Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree your point. But using this function both with/without feature makes it look very efficient. Writing a separate function needs new function usage across step function, checkpointing, etc.

I've also tried to add assert checks inside the function to tighten the understanding/correctness. Hope you are fine with it.

@MaciejBalaNV
Copy link

I'm getting NaNs when using this feature. You can reproduce it by running test_fused_optimizer tests, after setting store_param_remainders=True in _initialize_state method (otherwise it fails earlier) and by commenting out torch.testing.assert_close(ref_params, master_params) check (this is expected to fail, since we now keep master_params as int16).

Still, with all these changes, the tests fail at torch.testing.assert_close(ref_params, model_params_to_fp32, rtol=1e-2, atol=1e-2, equal_nan=True) with an error message that weights are NaN.

transformer_engine/pytorch/optimizers/fused_adam.py Outdated Show resolved Hide resolved
@@ -243,13 +256,14 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale):
unscaled_state.mul_(rscale)
scaled_state.copy_(unscaled_state)

def get_unscaled_state(self, param, state_name):
def get_unscaled_state(self, param, state_name, store_param_remainders=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer if this function didn't expose this kwarg since it makes its behavior less obvious. get_unscaled_state implies that it produces an FP32 value that is ready to use, so it would be better if step called a different function to access the BF16 remainder. If we want to keep this overall logic, we should change the function name to something more accurate (although a vague name like get_state_for_adam_kernel is a code smell).

Selvaraj Anandaraj and others added 17 commits January 29, 2025 21:28
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
@MaciejBalaNV
Copy link

MaciejBalaNV commented Jan 30, 2025

Another thing - this code is failing with CUDA memory access errors when we pass both capturable=True and store_param_remainders=True. I think the tests didn't pick this up, because the only test with capturable=True and master_weights=True uses FP16, not BF16, so param_remainders are silently not used. At the very least we should have an assert to make sure capturable=False is passed. Of course making it work with capturable=True would be even better.

@sanandaraj5597
Copy link
Contributor Author

@MaciejBalaNV I added a failure guard with capturable mode. We don't have a plan to use CUDA graphs with optimizers, so not worth having this support.

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making this feature opt-in makes me feel a lot less worried. Users should know what they are doing before enabling this optimization.

I'll change this PR to merge into release_v2.0 and start a CI pipeline.

Comment on lines +691 to +692
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've already confirmed that p_in_type is BF16, so dispatching for FP16 and FP32 is unnecessary. See #1408 (comment).

Suggested change
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but it shouldn't break the code I guess.

@@ -243,13 +256,14 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale):
unscaled_state.mul_(rscale)
scaled_state.copy_(unscaled_state)

def get_unscaled_state(self, param, state_name):
def get_unscaled_state(self, param, state_name, store_param_remainders=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better, although we still have the problem that state scaling and BF16 remainders are both using this function in different ways. It's troubling that get_unscaled_state might not get the unscaled state.

@timmoon10 timmoon10 changed the base branch from main to release_v2.0 January 30, 2025 22:34
@timmoon10 timmoon10 changed the base branch from release_v2.0 to main January 30, 2025 22:35
@timmoon10
Copy link
Collaborator

/te-ci pytorch

Selvaraj Anandaraj added 2 commits January 30, 2025 17:03
@sanandaraj5597
Copy link
Contributor Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator

#1443 is identical to this PR, but rebased on the release_v2.0 branch.

@timmoon10 timmoon10 merged commit e536954 into NVIDIA:main Jan 31, 2025
14 checks passed
@sanandaraj5597 sanandaraj5597 deleted the param_remainder branch January 31, 2025 01:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants