-
Notifications
You must be signed in to change notification settings - Fork 367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Pytorch] Decoupling framework extensions from common module #1498
[Pytorch] Decoupling framework extensions from common module #1498
Conversation
99871f3
to
cd20471
Compare
/te-ci pytorch |
7750ea4
to
c62be40
Compare
#include "common/fused_attn/thd_utils.h" | ||
#include "extensions.h" | ||
|
||
using namespace transformer_engine::fused_attn; | ||
#include "thd_util.cuh" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the new header file created in the framework extensions for the structs that were being used in attention.cu
from transformer_engine/common/fused_attn/thd_utils.h
.
@@ -0,0 +1,59 @@ | |||
/************************************************************************* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is transformer_engine/common/include/transformer_engine
a better place for this ?
/te-ci pytorch |
adf5d8a
to
56a8500
Compare
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
…el kernels ONLY for invoking recompilation and not directly using the pre-compiled symbols in libtransformer.so Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
… common.h Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
…de header Code cleanup Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
40c3120
to
bd6182e
Compare
/te-ci pytorch |
@@ -33,39 +86,78 @@ __forceinline__ __device__ int binary_search(int target, int *array, int len) { | |||
/*************************************************************************************************** | |||
* Support THD format for Context Parallel: Generate partitioned indices for input tokens | |||
**************************************************************************************************/ | |||
|
|||
// Templatizing this kernel ONLY so that when it is used in framework | |||
// extensions, it is dynamically compiled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could probably remove these comments now.
#include <cuda.h> | ||
#include <cuda_bf16.h> | ||
/*************************************************************************************************** | ||
* Support THD format for Context Parallel: softmax_lse related operations | ||
**************************************************************************************************/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe have a cleanup of these comments. They seem unrelated to the functions below. Same for the Gradients ones.
…d_read_half_tensor_kernel Signed-off-by: Kshitij Janardan Lakhani <[email protected]> Code clean up Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
d32096e
to
08b82a0
Compare
for more information, see https://pre-commit.ci
Description
Motivation: Currently, framework extensions has a dependency on
transformer_engine/common
(libtransformer_engine.so
) which is met via the symbols exposed bylibtransformer_engine.version
. However, this could result in two problems:This PR is the first in a series of decoupling PRs to follow
Type of change
Changes
Files in the framework extensions identified to have this dependency in the first pass are -
In
pytorch/csrc/extensions
:attention.cu
,gemm.cpp
,multi_tensor_adam.cu
In
jax/csrc/extensions
:attention.cpp
,utils.cu
This PR only decouples
pytorch/csrc/extensions/attention.cu
by :nvte_*
calls to replace the usage oftransformer_engine::Tensor
t_e/common/fused_attn/thd_util.h
to a new filet_e/pytorch/csrc/thd_util.cuh
thd_partition_indices_kernel
andthd_read_half_tensor_kernel
so that they are "re-compiled" on instantiation in the framework extensionsChecklist: