-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support store_param_remainders
feature from Apex in TE Fused Adam
#1408
Conversation
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
…ransformerEngine into param_remainder
for more information, see https://pre-commit.ci
@@ -243,13 +256,14 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): | |||
unscaled_state.mul_(rscale) | |||
scaled_state.copy_(unscaled_state) | |||
|
|||
def get_unscaled_state(self, param, state_name): | |||
def get_unscaled_state(self, param, state_name, store_param_remainders=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value of store_param_remainders
is False
here, but it's True
by default in the constructor. I think it's misleading, why not just set it to True
here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to store param remainders for state_name other than master_params, that's why it's defaulted to false.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer if this function didn't expose this kwarg since it makes its behavior less obvious. get_unscaled_state
implies that it produces an FP32 value that is ready to use, so it would be better if step
called a different function to access the BF16 remainder. If we want to keep this overall logic, we should change the function name to something more accurate (although a vague name like get_state_for_adam_kernel
is a code smell).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worked around it. Resolving conversation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better, although we still have the problem that state scaling and BF16 remainders are both using this function in different ways. It's troubling that get_unscaled_state
might not get the unscaled state.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree your point. But using this function both with/without feature makes it look very efficient. Writing a separate function needs new function usage across step function, checkpointing, etc.
I've also tried to add assert checks inside the function to tighten the understanding/correctness. Hope you are fine with it.
I'm getting NaNs when using this feature. You can reproduce it by running Still, with all these changes, the tests fail at |
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu
Outdated
Show resolved
Hide resolved
@@ -243,13 +256,14 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): | |||
unscaled_state.mul_(rscale) | |||
scaled_state.copy_(unscaled_state) | |||
|
|||
def get_unscaled_state(self, param, state_name): | |||
def get_unscaled_state(self, param, state_name, store_param_remainders=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer if this function didn't expose this kwarg since it makes its behavior less obvious. get_unscaled_state
implies that it produces an FP32 value that is ready to use, so it would be better if step
called a different function to access the BF16 remainder. If we want to keep this overall logic, we should change the function name to something more accurate (although a vague name like get_state_for_adam_kernel
is a code smell).
Signed-off-by: Selvaraj Anandaraj <[email protected]>
…ransformerEngine into param_remainder
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Selvaraj Anandaraj <[email protected]>
…ransformerEngine into param_remainder
for more information, see https://pre-commit.ci
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
for more information, see https://pre-commit.ci
Another thing - this code is failing with CUDA memory access errors when we pass both |
Signed-off-by: Selvaraj Anandaraj <[email protected]>
…ransformerEngine into param_remainder
for more information, see https://pre-commit.ci
@MaciejBalaNV I added a failure guard with capturable mode. We don't have a plan to use CUDA graphs with optimizers, so not worth having this support. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making this feature opt-in makes me feel a lot less worried. Users should know what they are doing before enabling this optimization.
I'll change this PR to merge into release_v2.0
and start a CI pipeline.
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( | ||
p_in_type, 0, "adam", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've already confirmed that p_in_type
is BF16, so dispatching for FP16 and FP32 is unnecessary. See #1408 (comment).
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( | |
p_in_type, 0, "adam", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, but it shouldn't break the code I guess.
@@ -243,13 +256,14 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): | |||
unscaled_state.mul_(rscale) | |||
scaled_state.copy_(unscaled_state) | |||
|
|||
def get_unscaled_state(self, param, state_name): | |||
def get_unscaled_state(self, param, state_name, store_param_remainders=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better, although we still have the problem that state scaling and BF16 remainders are both using this function in different ways. It's troubling that get_unscaled_state
might not get the unscaled state.
/te-ci pytorch |
Signed-off-by: Selvaraj Anandaraj <[email protected]>
…ransformerEngine into param_remainder
/te-ci pytorch |
#1443 is identical to this PR, but rebased on the |
Description
When the master parameter is in FP32 and the model parameters are in BF16, we can store the trailing 16 remainder bits and reconstruct the master FP32 param from (BF16 model param + the remainder).
This helps us half the master parameter memory usage.