Skip to content

Commit 0163f2f

Browse files
committed
add split mxfp8 quantized tensor and enable mxfp8 input for grouped linear
Signed-off-by: Xin Yao <[email protected]>
1 parent f1b18ed commit 0163f2f

File tree

8 files changed

+316
-59
lines changed

8 files changed

+316
-59
lines changed

tests/pytorch/test_numerics.py

Lines changed: 140 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
import torch.nn as nn
1313
from torch.nn import Parameter
1414

15+
import transformer_engine.pytorch as te
1516
from transformer_engine.pytorch.fp8 import (
1617
FP8GlobalStateManager,
1718
fp8_autocast,
1819
fp8_model_init,
1920
)
21+
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
2022
from transformer_engine.pytorch.utils import (
2123
init_method_normal,
2224
scaled_init_method_normal,
@@ -1697,13 +1699,16 @@ def _test_grouped_linear_accuracy(
16971699
fp8,
16981700
fuse_wgrad_accumulation,
16991701
delay_wgrad_compute=False,
1702+
activation_func=None, # assume gated activation function
17001703
):
17011704
reset_rng_states()
17021705
if fp8:
17031706
FP8GlobalStateManager.reset()
17041707

1708+
# assume gated activation function
1709+
hidden_size = config.hidden_size if activation_func is None else 2 * config.hidden_size
17051710
inp_hidden_states = torch.randn(
1706-
(config.max_seqlen_q, bs, config.hidden_size),
1711+
(config.max_seqlen_q, bs, hidden_size),
17071712
dtype=dtype,
17081713
device="cuda",
17091714
requires_grad=True,
@@ -1728,11 +1733,11 @@ def _test_grouped_linear_accuracy(
17281733
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
17291734
if isinstance(block, GroupedLinear):
17301735
m_splits = m_splits * bs
1731-
out = block(inp_hidden_states, m_splits.tolist())
1736+
out = block(activation_func(inp_hidden_states), m_splits.tolist())
17321737
else:
17331738
out = torch.cat(
17341739
[
1735-
block[i](inp)
1740+
block[i](activation_func(inp))
17361741
for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist()))
17371742
]
17381743
)
@@ -1967,6 +1972,92 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
19671972
)
19681973

19691974

1975+
@pytest.mark.skipif(not mxfp8_available, reason="MXFP8 is not available")
1976+
@pytest.mark.parametrize("dtype", param_types, ids=str)
1977+
@pytest.mark.parametrize("num_gemms", [3, 6])
1978+
@pytest.mark.parametrize("bs", batch_sizes)
1979+
@pytest.mark.parametrize("model", ["126m"])
1980+
@pytest.mark.parametrize("recipe", [recipe.MXFP8BlockScaling()])
1981+
@pytest.mark.parametrize("fp8_model_params", all_boolean)
1982+
def test_grouped_linear_fp8_input(
1983+
dtype,
1984+
num_gemms,
1985+
bs,
1986+
model,
1987+
recipe,
1988+
fp8_model_params,
1989+
):
1990+
config = model_configs[model]
1991+
if config.max_seqlen_q % 32 != 0:
1992+
pytest.skip("MXFP8 requires sequence length to be divisible by 32.")
1993+
1994+
with fp8_model_init(enabled=fp8_model_params, recipe=recipe):
1995+
grouped_linear_bf16_input = GroupedLinear(
1996+
num_gemms,
1997+
config.hidden_size,
1998+
4 * config.hidden_size,
1999+
bias=False,
2000+
params_dtype=dtype,
2001+
device="cuda",
2002+
fuse_wgrad_accumulation=True,
2003+
).eval()
2004+
2005+
grouped_linear_fp8_input = GroupedLinear(
2006+
num_gemms,
2007+
config.hidden_size,
2008+
4 * config.hidden_size,
2009+
bias=False,
2010+
params_dtype=dtype,
2011+
device="cuda",
2012+
fuse_wgrad_accumulation=True,
2013+
).eval()
2014+
2015+
# Share params
2016+
with torch.no_grad():
2017+
for i in range(num_gemms):
2018+
setattr(
2019+
grouped_linear_fp8_input,
2020+
f"weight{i}",
2021+
Parameter(getattr(grouped_linear_bf16_input, f"weight{i}").clone()),
2022+
)
2023+
weight_i = getattr(grouped_linear_bf16_input, f"weight{i}")
2024+
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
2025+
weight_i_copy = getattr(grouped_linear_fp8_input, f"weight{i}")
2026+
weight_i_copy.main_grad = weight_i.main_grad.clone()
2027+
2028+
bf16_activation = te.ops.SwiGLU()
2029+
fp8_activation = te.ops.Sequential(
2030+
te.ops.SwiGLU(),
2031+
te.ops.Quantize(forward=True, backward=False), # Output QuantizedTensor in forward
2032+
)
2033+
2034+
outputs_ref = _test_grouped_linear_accuracy(
2035+
grouped_linear_bf16_input,
2036+
num_gemms,
2037+
bs,
2038+
dtype,
2039+
config,
2040+
recipe,
2041+
fp8=True,
2042+
fuse_wgrad_accumulation=True,
2043+
activation_func=bf16_activation,
2044+
)
2045+
outputs = _test_grouped_linear_accuracy(
2046+
grouped_linear_fp8_input,
2047+
num_gemms,
2048+
bs,
2049+
dtype,
2050+
config,
2051+
recipe,
2052+
fp8=True,
2053+
fuse_wgrad_accumulation=True,
2054+
activation_func=fp8_activation,
2055+
)
2056+
# Shoule be bit-wise match
2057+
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
2058+
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
2059+
2060+
19702061
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False):
19712062

19722063
def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
@@ -2706,3 +2797,49 @@ def _run_module(m, inp):
27062797
out = _run_module(g2, b)
27072798

27082799
assert_allclose(out, outT, 1e-7)
2800+
2801+
2802+
@pytest.mark.skipif(not mxfp8_available, reason="MXFP8 is not available")
2803+
@pytest.mark.parametrize("dtype", param_types, ids=str)
2804+
@pytest.mark.parametrize("num_experts", [8])
2805+
@pytest.mark.parametrize("m", [64, 128, 256])
2806+
@pytest.mark.parametrize("k", [64, 128, 256])
2807+
def test_split_quantized_tensor(dtype, num_experts, m, k):
2808+
2809+
tensor = torch.randn((m * num_experts, k), dtype=dtype, device="cuda")
2810+
m_splits = [m] * num_experts
2811+
2812+
quantizer = MXFP8Quantizer(
2813+
fp8_dtype=tex.DType.kFloat8E4M3,
2814+
rowwise=True,
2815+
columnwise=True,
2816+
)
2817+
2818+
# Split and quantize one by one
2819+
ref_mxfp8 = tex.split_quantize(tensor, m_splits, [quantizer] * num_experts)
2820+
2821+
# Quantize as a whole and then split
2822+
out_mxfp8 = tex.split_quantized_tensor(quantizer.quantize(tensor), m_splits)
2823+
2824+
for ref, out in zip(ref_mxfp8, out_mxfp8):
2825+
assert ref._quantizer.rowwise_usage == out._quantizer.rowwise_usage
2826+
assert ref._quantizer.columnwise_usage == out._quantizer.columnwise_usage
2827+
assert ref._quantizer.dtype == out._quantizer.dtype
2828+
assert ref._quantizer.internal == out._quantizer.internal
2829+
2830+
torch.testing.assert_close(ref._rowwise_data, out._rowwise_data, rtol=0, atol=0)
2831+
# Padded area are random filled.
2832+
torch.testing.assert_close(
2833+
ref._rowwise_scale_inv[:m, : k // 32],
2834+
out._rowwise_scale_inv[:m, : k // 32],
2835+
rtol=0,
2836+
atol=0,
2837+
)
2838+
torch.testing.assert_close(ref._columnwise_data, out._columnwise_data, rtol=0, atol=0)
2839+
# Padded area are random filled.
2840+
torch.testing.assert_close(
2841+
ref._columnwise_scale_inv[: m // 32, :k],
2842+
out._columnwise_scale_inv[: m // 32, :k],
2843+
rtol=0,
2844+
atol=0,
2845+
)

transformer_engine/pytorch/csrc/common.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,4 +291,20 @@ size_t roundup(const size_t value, const size_t multiple) {
291291
return ((value + multiple - 1) / multiple) * multiple;
292292
}
293293

294+
at::Tensor make_torch_view(std::shared_ptr<at::Tensor>& buffer, const std::vector<size_t>& shape,
295+
size_t offset, at::ScalarType dtype) {
296+
std::vector<int64_t> shape_int64(shape.begin(), shape.end());
297+
// in the case where full buffer is empty because local rank receives no tokens for all the experts
298+
// then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob
299+
// but in the case where some experts receive tokens, some not, we want to leverage from_blob
300+
// as much as possible to avoid CPU overhead
301+
if (buffer->data_ptr<uint8_t>() == nullptr) {
302+
return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype));
303+
}
304+
return at::from_blob(
305+
buffer->data_ptr<uint8_t>() + offset, shape_int64,
306+
[buffer](void*) {}, // deleter holds shared_ptr
307+
at::device(at::kCUDA).dtype(dtype));
308+
}
309+
294310
} // namespace transformer_engine::pytorch

transformer_engine/pytorch/csrc/common.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,14 @@ std::vector<size_t> convertShape(const NVTEShape& shape);
420420
size_t roundup(const size_t value, const size_t multiple);
421421

422422
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
423+
424+
/*! @brief Helper function to construct tensor view
425+
*
426+
* Note: Deleter holds a shared_ptr for the buffer, so the buffer
427+
* will survive until all views are deleted.
428+
*/
429+
at::Tensor make_torch_view(std::shared_ptr<at::Tensor>& buffer, const std::vector<size_t>& shape,
430+
size_t offset, at::ScalarType dtype);
423431
} // namespace transformer_engine::pytorch
424432

425433
namespace std {

transformer_engine/pytorch/csrc/extensions/cast.cpp

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -199,25 +199,6 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
199199
constexpr size_t fp8_elem_size = 1;
200200
constexpr size_t scale_elem_size = 4;
201201

202-
// Helper function to construct tensor view
203-
// Note: Deleter holds a shared_ptr for the buffer, so the buffer
204-
// will survive until all views are deleted.
205-
auto make_torch_view = [](std::shared_ptr<at::Tensor> &buffer, const std::vector<size_t> &shape,
206-
size_t offset, at::ScalarType dtype) -> at::Tensor {
207-
std::vector<int64_t> shape_int64(shape.begin(), shape.end());
208-
// in the case where full buffer is empty because local rank receives no tokens for all the experts
209-
// then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob
210-
// but in the case where some experts receive tokens, some not, we want to leverage from_blob
211-
// as much as possible to avoid CPU overhead
212-
if (buffer->data_ptr<uint8_t>() == nullptr) {
213-
return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype));
214-
}
215-
return at::from_blob(
216-
buffer->data_ptr<uint8_t>() + offset, shape_int64,
217-
[buffer](void *) {}, // deleter holds shared_ptr
218-
at::device(at::kCUDA).dtype(dtype));
219-
};
220-
221202
// Allocate row-wise data
222203
std::vector<at::Tensor> rowwise_data_list, rowwise_scale_list;
223204
std::vector<std::vector<size_t>> rowwise_data_shapes, rowwise_scale_shapes;
@@ -353,25 +334,6 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
353334
constexpr size_t fp8_elem_size = 1;
354335
constexpr size_t scale_elem_size = 1;
355336

356-
// Helper function to construct tensor view
357-
// Note: Deleter holds a shared_ptr for the buffer, so the buffer
358-
// will survive until all views are deleted.
359-
auto make_torch_view = [](std::shared_ptr<at::Tensor> &buffer, const std::vector<size_t> &shape,
360-
size_t offset, at::ScalarType dtype) -> at::Tensor {
361-
std::vector<int64_t> shape_int64(shape.begin(), shape.end());
362-
// in the case where full buffer is empty because local rank receives no tokens for all the experts
363-
// then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob
364-
// but in the case where some experts receive tokens, some not, we want to leverage from_blob
365-
// as much as possible to avoid CPU overhead
366-
if (buffer->data_ptr<uint8_t>() == nullptr) {
367-
return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype));
368-
}
369-
return at::from_blob(
370-
buffer->data_ptr<uint8_t>() + offset, shape_int64,
371-
[buffer](void *) {}, // deleter holds shared_ptr
372-
at::device(at::kCUDA).dtype(dtype));
373-
};
374-
375337
// Allocate row-wise data
376338
std::vector<at::Tensor> rowwise_data_list, rowwise_scale_list;
377339
std::vector<std::vector<size_t>> rowwise_data_shapes, rowwise_scale_shapes;

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include "../common.h"
2020
#include "../extensions.h"
21+
#include "../util.h"
2122
#include "common.h"
2223

2324
namespace transformer_engine::pytorch {
@@ -240,6 +241,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
240241
"Fused Multi-tensor padding", py::call_guard<py::gil_scoped_release>());
241242
m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding,
242243
"Fused Multi-tensor unpadding", py::call_guard<py::gil_scoped_release>());
244+
m.def("split_quantized_tensor", &transformer_engine::pytorch::split_quantized_tensor,
245+
"Split quantized tensor");
243246

244247
// attention kernels
245248
m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd,

0 commit comments

Comments
 (0)