-
Notifications
You must be signed in to change notification settings - Fork 367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Enabling Per-Tensor Current Scaling Recipe #1471
base: main
Are you sure you want to change the base?
[PyTorch] Enabling Per-Tensor Current Scaling Recipe #1471
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also include FP8 current-scaling in our existing numerics tests. Even if they're not as strict as the golden-value tests, they cover more use-cases. Try adding the current-scaling recipe in test_numerics.py
at
fp8_recipes = [ |
"""Scaling factor to multiply when quantizing to FP8""" | ||
scale: torch.Tensor | ||
"""Max-abs value from last FP8 cast""" | ||
amax: torch.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are not actually necessary, since we can allocate temporary buffers when doing the FP8 cast. That said, I would expect this helps avoid some CPU overhead from dealing with PyTorch's memory pool. This design trades off reduced CPU overhead (maybe) with increased complexity and surface area for bugs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did that in the early stage of development for simplicity. I was also useful for me to debug.
transformer_engine/common/include/transformer_engine/transformer_engine.h
Show resolved
Hide resolved
@@ -21,7 +21,7 @@ using namespace transformer_engine; | |||
|
|||
namespace { | |||
|
|||
template <typename InputType, typename OutputType> | |||
template <typename InputType, typename OutputType, bool UPDATE_AMAX> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The template arg probably isn't giving us much benefit compared to passing in a bool. It doesn't affect anything within the inner loop.
// current tensor scaling | ||
performTestCurrentScaling<InputType, OutputType>(size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to have a separate test suite (perhaps in a different file) to make it easier to debug test failures.
|
||
|
||
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) | ||
class TestCurrentScalingFloat8Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's necessary to duplicate these tests. TestFloat8Tensor
is more testing the basic tensor infrastructure and not really the quantization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the test cases about float8tensor itself (set_data, etc.) in this class and only tested quantize/dequantize to avoid repetitive code.
tests/pytorch/test_recipe.py
Outdated
from recipe_numerics_base import TestFP8RecipeLinearBase | ||
from recipe_numerics_base import GetRecipes | ||
|
||
TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tensor_dumps" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make this configurable so that we can keep a copy of these tensors in our testing systems. Also, we'll need to make corresponding changes in our CI infrastructure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about using an env variable? what name would you like?
float scale = 1.f; | ||
float scale_inv = 1.f; | ||
|
||
if (isinf(clamp_amax) || clamp_amax == 0.f) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have two differing cases:
qscale inf -> qscale = finfo(input_dtype).max
qscale nan or amax == 0 -> qscale = 1.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@timmoon10 What do you think? When we have scale=inf, should we return the max(input_type) or just 1?
Ideally it will be good that every upcoming recipe can share the same compute_scale_from_amax function.
Haven't really triggered this case in test case though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is really helpful to see all the recipe setup in this MR! Thanks Zhongbo.
transformer_engine/pytorch/fp8.py
Outdated
@@ -743,6 +756,8 @@ def create( | |||
cls = DelayedScalingRecipeState | |||
elif recipe.mxfp8(): | |||
cls = MXFP8BlockScalingRecipeState | |||
elif recipe.current_scaled(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif recipe.current_scaled(): | |
elif recipe.current(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should go in the other direction and be more explicit: #1471 (comment)
current
is very unclear to me and makes me think it's a proxy class or something asynchronous.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good, replaced.
fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False | ||
used for calculating output y in forward pass | ||
fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True | ||
use for calculating dgrad in backward pass | ||
fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True | ||
use for calculating dgrad in backward pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not too fond of the name MMParams (assuming it's matmul params).
For Delayed Scaling, it was by choice that these knobs were on the python side for easy toggling but not fully exposed in the recipe APIs (as they are here), because these are low level GEMM details. Is this something that is required as a part of the recipe? If these need to be modified for studies then could using an envvar be a better option?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These gemm configs were indeed hand-picked and tested by the research folks (I set split accumulator to all True here for simplicity to pass the test, but in training we can afford to set it to be False in forward pass). Since research folks want to control the split accumulator config for each gemm, that will add many more env var configs.
I agree that we can later expose some of knobs in python level through some megatron params, but for now I just write
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Researchers want access to control even low level details if they affect numerics substantially. I agree with Zhongbo that environment variables don't scale well versus an abstraction like Recipe that already offers per-layer and per-gemm control granularity.
Description
[WIP] Enable per-tensor current scaling recipe, as an alternative to delayed scaling.
Type of change
Changes
Please list the changes introduced in this PR:
Optional:
Unit Tests
C++ Unit Tests
Python Unit Tests
Checklist: