Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pytorch] Decoupling framework extensions from common module #1498

Conversation

KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Feb 20, 2025

Description

Motivation: Currently, framework extensions has a dependency on transformer_engine/common (libtransformer_engine.so) which is met via the symbols exposed by libtransformer_engine.version. However, this could result in two problems:

  1. Over exposure of internal symbols from transformer_engine to the user
  2. Possible compatibility issues if a different compiler / compiler version is used for building the .so and building the framework extensions code

This PR is the first in a series of decoupling PRs to follow

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Files in the framework extensions identified to have this dependency in the first pass are -

In pytorch/csrc/extensions:
attention.cu, gemm.cpp, multi_tensor_adam.cu

In jax/csrc/extensions:
attention.cpp, utils.cu

This PR only decouples pytorch/csrc/extensions/attention.cu by :

  1. Using nvte_* calls to replace the usage of transformer_engine::Tensor
  2. Moving the structs in t_e/common/fused_attn/thd_util.h to a new file t_e/pytorch/csrc/thd_util.cuh
  3. Templatizing thd_partition_indices_kernel and thd_read_half_tensor_kernel so that they are "re-compiled" on instantiation in the framework extensions

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@KshitijLakhani KshitijLakhani self-assigned this Feb 20, 2025
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/decouple-te-in-fwexten branch from 99871f3 to cd20471 Compare February 20, 2025 04:19
@KshitijLakhani
Copy link
Collaborator Author

/te-ci pytorch

@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/decouple-te-in-fwexten branch from 7750ea4 to c62be40 Compare February 20, 2025 10:40
#include "common/fused_attn/thd_utils.h"
#include "extensions.h"

using namespace transformer_engine::fused_attn;
#include "thd_util.cuh"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the new header file created in the framework extensions for the structs that were being used in attention.cu from transformer_engine/common/fused_attn/thd_utils.h.

@@ -0,0 +1,59 @@
/*************************************************************************
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is transformer_engine/common/include/transformer_engine a better place for this ?

@cyanguwa
Copy link
Collaborator

/te-ci pytorch

@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/decouple-te-in-fwexten branch from adf5d8a to 56a8500 Compare February 21, 2025 04:38
KshitijLakhani and others added 6 commits February 20, 2025 20:40
…el kernels ONLY for invoking recompilation and not directly using the pre-compiled symbols in libtransformer.so

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
…de header

Code cleanup

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/decouple-te-in-fwexten branch from 40c3120 to bd6182e Compare February 21, 2025 04:40
@KshitijLakhani KshitijLakhani changed the title Decoupling framework extensions from common module [Pytorch]Decoupling framework extensions from common module Feb 21, 2025
@KshitijLakhani KshitijLakhani changed the title [Pytorch]Decoupling framework extensions from common module [Pytorch] Decoupling framework extensions from common module Feb 21, 2025
@ksivaman
Copy link
Member

/te-ci pytorch

@@ -33,39 +86,78 @@ __forceinline__ __device__ int binary_search(int target, int *array, int len) {
/***************************************************************************************************
* Support THD format for Context Parallel: Generate partitioned indices for input tokens
**************************************************************************************************/

// Templatizing this kernel ONLY so that when it is used in framework
// extensions, it is dynamically compiled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably remove these comments now.

#include <cuda.h>
#include <cuda_bf16.h>
/***************************************************************************************************
* Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe have a cleanup of these comments. They seem unrelated to the functions below. Same for the Gradients ones.

…d_read_half_tensor_kernel

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>

Code clean up

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/decouple-te-in-fwexten branch from d32096e to 08b82a0 Compare February 21, 2025 18:40
@KshitijLakhani KshitijLakhani marked this pull request as ready for review February 21, 2025 18:41
@KshitijLakhani KshitijLakhani merged commit 7f2dcf9 into NVIDIA:main Feb 22, 2025
11 of 12 checks passed
@KshitijLakhani KshitijLakhani deleted the klakhani/maint/decouple-te-in-fwexten branch February 22, 2025 05:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants