Skip to content

Commit

Permalink
Fix a bug for D being nullptr in grouped gemm (#1475)
Browse files Browse the repository at this point in the history
* fix a bug for at::from_blob with nullptr

Signed-off-by: Xin Yao <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix a bug for non-TN

Signed-off-by: Xin Yao <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Xin Yao <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>
  • Loading branch information
3 people authored Feb 13, 2025
1 parent ee4a17d commit f0d22ca
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 15 deletions.
36 changes: 24 additions & 12 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2131,21 +2131,30 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):

if layout == "TN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input
out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output
B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input
out = [torch.randn(m, n, dtype=dtype, device="cuda")] # output
out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
grad = False
single_output = True
elif layout == "NN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output
out = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # dgrad
B = list(
torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
) # grad_output
out = [torch.randn(m, k, dtype=dtype, device="cuda")] # dgrad
out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
grad = True
single_output = True
else: # layout == "NT"
A = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input
B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output
A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input
B = list(
torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
) # grad_output
out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad
out_ref = [o.clone() for o in out]
grad = True
single_output = False

out_ref = [o.clone() for o in out]
for i in range(z):
general_gemm(
A[i],
Expand All @@ -2157,17 +2166,20 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
layout=layout,
out=out_ref[i],
)
if single_output:
out_ref = [torch.cat(out_ref)]

general_grouped_gemm(
A,
list(B),
list(out),
B,
out,
dtype,
get_multi_stream_cublas_workspace(),
m_splits=[k] * n, # TODO, not sure
m_splits=m_splits,
grad=grad,
accumulate=accumulate,
layout=layout,
single_output=single_output,
)

# should be bit-wise match
Expand All @@ -2190,7 +2202,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
pytest.skip(reason_for_no_fp8)

z, m, k, n = shape
m_splits = m // z
m_splits = [m // z] * z

dtype = torch.bfloat16
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
Expand Down Expand Up @@ -2242,7 +2254,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
out,
dtype,
get_multi_stream_cublas_workspace(),
m_splits=[k] * m_splits,
m_splits=m_splits,
accumulate=accumulate,
)

Expand Down
8 changes: 6 additions & 2 deletions transformer_engine/pytorch/csrc/extensions/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,13 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
auto dtype = GetATenDType(D_type);
auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA);
if (single_output) {
out_tensor = at::from_blob(output_data_ptr, D_shape, opts);
if (output_data_ptr == nullptr) {
out_tensor = at::empty(D_shape, opts);
} else {
out_tensor = at::from_blob(output_data_ptr, D_shape, opts);
}
char* char_ptr = reinterpret_cast<char*>(output_data_ptr);
char_ptr += m_splits[i] * te_A.size(0) * (*D)[0].element_size();
char_ptr += D_shape[0] * D_shape[1] * (*D)[0].element_size();
output_data_ptr = reinterpret_cast<void*>(char_ptr);
D_vectors.emplace_back(out_tensor);
} else {
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
general_grouped_gemm(
weights,
grad_output,
torch.split(dgrad, ctx.m_splits),
[dgrad],
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
single_output=True,
layout="NN",
m_splits=ctx.m_splits,
grad=True,
Expand Down

0 comments on commit f0d22ca

Please sign in to comment.