-
Notifications
You must be signed in to change notification settings - Fork 367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parallel Cross Entropy using online softmax #1456
Parallel Cross Entropy using online softmax #1456
Conversation
Signed-off-by: Selvaraj Anandaraj <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test in tests/pytorch
?
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
@timmoon10 Added tests. |
for more information, see https://pre-commit.ci
Signed-off-by: Selvaraj Anandaraj <[email protected]>
…j5597/TransformerEngine into parallel_cross_entropy
for more information, see https://pre-commit.ci
Signed-off-by: Selvaraj Anandaraj <[email protected]>
…j5597/TransformerEngine into parallel_cross_entropy
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sanandaraj5597 Could you fix the linting errors? You could run it locally using bash qa/L0_pytorch_lint/test.sh
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Fixed lint errors. |
for more information, see https://pre-commit.ci
Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Selvaraj Anandaraj <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, pending CI
/te-ci pytorch |
Signed-off-by: Selvaraj Anandaraj <[email protected]>
…j5597/TransformerEngine into parallel_cross_entropy
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Signed-off-by: Selvaraj Anandaraj <[email protected]>
…j5597/TransformerEngine into parallel_cross_entropy
/te-ci pytorch |
Signed-off-by: Selvaraj Anandaraj <[email protected]>
…j5597/TransformerEngine into parallel_cross_entropy
for more information, see https://pre-commit.ci
/te-ci pytorch |
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
It seems that the Triton dependency is messing up other tests. The latest upstream Triton (3.2.0) doesn't support Blackwell, so the NVIDIA PyTorch container is using a custom internal build. I've removed Triton as a formal dependency, but we should put it back once Blackwell support is upstreamed. |
/te-ci pytorch |
Description
This PR implements a parallel cross entropy function using the online technique to calculate softmax. This feature has multiple aspects:
[Thanks to Liger kernel implementation for providing the idea about online softmax and in-place gradient calculation.]