From f0d22ca12f574233053da20516997e45d99eb65c Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 13 Feb 2025 09:55:38 +0800 Subject: [PATCH] Fix a bug for D being nullptr in grouped gemm (#1475) * fix a bug for at::from_blob with nullptr Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix a bug for non-TN Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 36 ++++++++++++------- .../pytorch/csrc/extensions/gemm.cpp | 8 +++-- .../pytorch/module/grouped_linear.py | 3 +- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 2401f3ca95..22735c5292 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2131,21 +2131,30 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): if layout == "TN": A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input - out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output + B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input + out = [torch.randn(m, n, dtype=dtype, device="cuda")] # output + out_ref = [o.clone() for o in torch.split(out[0], m_splits)] grad = False + single_output = True elif layout == "NN": A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output - out = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # dgrad + B = list( + torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) + ) # grad_output + out = [torch.randn(m, k, dtype=dtype, device="cuda")] # dgrad + out_ref = [o.clone() for o in torch.split(out[0], m_splits)] grad = True + single_output = True else: # layout == "NT" - A = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input - B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output + A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input + B = list( + torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) + ) # grad_output out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + out_ref = [o.clone() for o in out] grad = True + single_output = False - out_ref = [o.clone() for o in out] for i in range(z): general_gemm( A[i], @@ -2157,17 +2166,20 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): layout=layout, out=out_ref[i], ) + if single_output: + out_ref = [torch.cat(out_ref)] general_grouped_gemm( A, - list(B), - list(out), + B, + out, dtype, get_multi_stream_cublas_workspace(), - m_splits=[k] * n, # TODO, not sure + m_splits=m_splits, grad=grad, accumulate=accumulate, layout=layout, + single_output=single_output, ) # should be bit-wise match @@ -2190,7 +2202,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): pytest.skip(reason_for_no_fp8) z, m, k, n = shape - m_splits = m // z + m_splits = [m // z] * z dtype = torch.bfloat16 A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight @@ -2242,7 +2254,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): out, dtype, get_multi_stream_cublas_workspace(), - m_splits=[k] * m_splits, + m_splits=m_splits, accumulate=accumulate, ) diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index b044c9f604..54bd52f136 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -336,9 +336,13 @@ std::optional> te_general_grouped_gemm( auto dtype = GetATenDType(D_type); auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); if (single_output) { - out_tensor = at::from_blob(output_data_ptr, D_shape, opts); + if (output_data_ptr == nullptr) { + out_tensor = at::empty(D_shape, opts); + } else { + out_tensor = at::from_blob(output_data_ptr, D_shape, opts); + } char* char_ptr = reinterpret_cast(output_data_ptr); - char_ptr += m_splits[i] * te_A.size(0) * (*D)[0].element_size(); + char_ptr += D_shape[0] * D_shape[1] * (*D)[0].element_size(); output_data_ptr = reinterpret_cast(char_ptr); D_vectors.emplace_back(out_tensor); } else { diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 2f9de58984..cab8dff7c2 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -269,9 +269,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], general_grouped_gemm( weights, grad_output, - torch.split(dgrad, ctx.m_splits), + [dgrad], ctx.activation_dtype, get_multi_stream_cublas_workspace(), + single_output=True, layout="NN", m_splits=ctx.m_splits, grad=True,