Skip to content

Commit

Permalink
Merge branch 'main' into columnwise-only-tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
timmoon10 authored Feb 25, 2025
2 parents 7f4dfdb + 229dd04 commit 04a067b
Show file tree
Hide file tree
Showing 9 changed files with 273 additions and 213 deletions.
34 changes: 19 additions & 15 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,26 @@
#
# See LICENSE for license information.

set -e

: ${TE_PATH:=/opt/transformerengine}

pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py

FAIL=0

pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || FAIL=1
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1

exit $FAIL
17 changes: 10 additions & 7 deletions qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
#
# See LICENSE for license information.

set -e

: ${TE_PATH:=/opt/transformerengine}

pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py

FAIL=0

pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || FAIL=1
# pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py ### TODO Debug UB support with te.Sequential
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || FAIL=1

exit $FAIL
1 change: 0 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ list(APPEND transformer_engine_SOURCES
activation/swiglu.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/thd_utils.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
normalization/common.cpp
Expand Down
76 changes: 0 additions & 76 deletions transformer_engine/common/fused_attn/thd_utils.cu

This file was deleted.

68 changes: 41 additions & 27 deletions transformer_engine/pytorch/csrc/extensions/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@
*
* See LICENSE for license information.
************************************************************************/

#include "common/common.h"
#include "common/fused_attn/thd_utils.h"
#include "extensions.h"

using namespace transformer_engine::fused_attn;
#include "thd_utils.cuh"

constexpr int block_size = 512;
constexpr int ctas_per_sm = 4;
Expand Down Expand Up @@ -208,28 +204,40 @@ std::vector<py::object> fused_attn_fwd(
std::vector<py::object> output_tensors;
output_tensors.push_back(o_python);
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[i]);
// allocate memory for nvte_aux_tensor_pack.tensors
at::Tensor output_tensor;
if (nvte_aux_tensor_pack.size >= 2) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
if (i < nvte_aux_tensor_pack.size - 2) {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]);
output_tensor = allocateSpace(
nvte_shape_to_vector(temp_shape),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
} else if (i == nvte_aux_tensor_pack.size - 2) {
output_tensor = rng_state;
} else if (i == nvte_aux_tensor_pack.size - 1) {
output_tensor = Bias.value();
}
} else {
output_tensor = (i < nvte_aux_tensor_pack.size - 1)
? allocateSpace(tensor->data.shape, tensor->data.dtype, false)
: rng_state;
NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]);
output_tensor =
(i < nvte_aux_tensor_pack.size - 1)
? allocateSpace(
nvte_shape_to_vector(temp_shape),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false)
: rng_state;
}
} else {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]);
output_tensor = allocateSpace(
nvte_shape_to_vector(temp_shape),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
}
output_tensors.push_back(py::cast(output_tensor));
tensor->data.dptr = output_tensor.data_ptr();
NVTEBasicTensor temp_data = {output_tensor.data_ptr(),
nvte_tensor_type(nvte_aux_tensor_pack.tensors[i]),
nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])};
nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data);
}

// execute the kernel
Expand Down Expand Up @@ -425,11 +433,14 @@ std::vector<py::object> fused_attn_bwd(
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size();
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[i]);
tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr();
std::vector<int64_t> tmp(Aux_CTX_Tensors[i].sizes().vec());
tensor->data.shape = std::vector<size_t>(tmp.begin(), tmp.end());
tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type());
auto temp_vec = std::vector<size_t>(tmp.begin(), tmp.end());
const NVTEShape temp_shape = {temp_vec.data(), temp_vec.size()};
NVTEBasicTensor temp_data = {
Aux_CTX_Tensors[i].data_ptr(),
static_cast<NVTEDType>(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())),
temp_shape};
nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data);
}

// create dBias the same shape as Bias
Expand Down Expand Up @@ -662,8 +673,8 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
grid_y *= tensor.size(i);
}
dim3 grid = {grid_x, grid_y};
thd_read_half_tensor_kernel<<<grid, block, sizeof(int) * (batch + 1),
at::cuda::getCurrentCUDAStream()>>>(
transformer_engine::fused_attn::thd_read_half_tensor_kernel<<<
grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
half.data_ptr(), tensor.data_ptr(), cu_seqlens.data_ptr<int>(), batch, hidden_size_in_bytes,
half_idx, tensor.size(seq_dim));

Expand Down Expand Up @@ -713,13 +724,14 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st
unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};

if (lse_packed) {
thd_lse_kernel<double, true, LseCorrectionFunctor>
transformer_engine::fused_attn::thd_lse_kernel<double, true, LseCorrectionFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<double>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, lse_seqlen, second_half_lse_seqlen);
} else {
thd_lse_kernel<double, false, LseCorrectionFunctor>
transformer_engine::fused_attn::thd_lse_kernel<double, false, LseCorrectionFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<double>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, lse_seqlen, second_half_lse_seqlen);
Expand Down Expand Up @@ -764,13 +776,14 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_
unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};

if (lse_packed) {
thd_lse_kernel<float, true, ReadLseFunctor>
transformer_engine::fused_attn::thd_lse_kernel<float, true, ReadLseFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, lse_seqlen, second_half_lse_seqlen);
} else {
thd_lse_kernel<float, false, ReadLseFunctor>
transformer_engine::fused_attn::thd_lse_kernel<float, false, ReadLseFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, lse_seqlen, second_half_lse_seqlen);
Expand Down Expand Up @@ -829,13 +842,13 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_
dim3 grid = {grid_x, (unsigned int)num_heads};

if (lse_packed) {
thd_out_correction_kernel<dtype, only_second_half, tile, true>
transformer_engine::fused_attn::thd_out_correction_kernel<dtype, only_second_half, tile, true>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), out_per_step.data_ptr<dtype>(), lse.data_ptr<float>(),
lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, num_heads,
dim_per_head, lse_seqlen, lse_per_step_seqlen);
} else {
thd_out_correction_kernel<dtype, only_second_half, tile, false>
transformer_engine::fused_attn::thd_out_correction_kernel<dtype, only_second_half, tile, false>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), out_per_step.data_ptr<dtype>(), lse.data_ptr<float>(),
lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, num_heads,
Expand Down Expand Up @@ -925,7 +938,8 @@ static void thd_grad_correction_helper(at::Tensor grad, const at::Tensor &grad_p
}
dim3 grid = {grid_x, grid_y};

thd_grad_correction_kernel<dtype, Functor_0, Functor_1, functor_idx, 32>
transformer_engine::fused_attn::thd_grad_correction_kernel<dtype, Functor_0, Functor_1,
functor_idx, 32>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
grad.data_ptr<dtype>(), grad_per_step.data_ptr<dtype>(), cu_seqlens.data_ptr<int>(),
batch, hidden_size, total_tokens);
Expand Down Expand Up @@ -992,8 +1006,8 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t

constexpr unsigned int block = 256;
unsigned int grid = (output.size(0) + block - 1) / block;
thd_partition_indices_kernel<<<grid, block, sizeof(int) * (batch + 1),
at::cuda::getCurrentCUDAStream()>>>(
transformer_engine::fused_attn::thd_partition_indices_kernel<<<
grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
output.data_ptr<int>(), cu_seqlens.data_ptr<int>(), batch, total_tokens, world_size, rank);

return output;
Expand Down
Loading

0 comments on commit 04a067b

Please sign in to comment.