Skip to content

Commit

Permalink
Pad MXFP8 scale inverses at the time of creation (#1431)
Browse files Browse the repository at this point in the history
* Create scale_inv for block scaling already padded

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fix

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Remove old file, fix CG test

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Change default value of env

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman authored Jan 29, 2025
1 parent fb1a241 commit cffec60
Show file tree
Hide file tree
Showing 13 changed files with 57 additions and 85 deletions.
2 changes: 1 addition & 1 deletion qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
Expand Down
35 changes: 0 additions & 35 deletions tests/cpp/run_norm_tests.sh

This file was deleted.

6 changes: 2 additions & 4 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,10 +423,8 @@ struct is_fp8<fp8e5m2> : std::true_type {};
size_t typeToSize(const DType type);

void CheckNoopTensor(const Tensor &t, const std::string &name);
void CheckInputTensor(const Tensor &t, const std::string &name,
bool check_scale_inv_alignment = false);
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false,
bool check_scale_inv_alignment = false);
void CheckInputTensor(const Tensor &t, const std::string &name);
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false);

bool is_fp8_dtype(const DType t);

Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/common/swizzle/swizzle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
return;
}

CheckInputTensor(*input, "scaling_factor_input", true);
CheckInputTensor(*output, "scaling_factor_output", true);
CheckInputTensor(*input, "scaling_factor_input");
CheckInputTensor(*output, "scaling_factor_output");

auto& scaling_mode = input->scaling_mode;

Expand Down
12 changes: 5 additions & 7 deletions transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void CheckNoopTensor(const Tensor &t, const std::string &name) {
}
}

void CheckScaleTensorShape(const Tensor &t, bool check_scale_inv_alignment) {
void CheckScaleTensorShape(const Tensor &t) {
NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!");
if (is_tensor_scaling(t.scaling_mode)) {
// per-tensor scaling
Expand All @@ -80,7 +80,6 @@ void CheckScaleTensorShape(const Tensor &t, bool check_scale_inv_alignment) {
}
} else {
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) {
if (!check_scale_inv_alignment) return;
// Need (4, 128) alignment even for e8 scaling factor
auto block_alignment = std::vector<size_t>{128ul / typeToSize(t.scale_inv.dtype),
4ul / typeToSize(t.scale_inv.dtype)};
Expand Down Expand Up @@ -111,7 +110,7 @@ void CheckScaleTensorShape(const Tensor &t, bool check_scale_inv_alignment) {
}
}

void CheckInputTensor(const Tensor &t, const std::string &name, bool check_scale_inv_alignment) {
void CheckInputTensor(const Tensor &t, const std::string &name) {
const DType type = t.dtype();
if (is_fp8_dtype(type)) {
// FP8 input needs to have scale_inv
Expand Down Expand Up @@ -143,11 +142,10 @@ void CheckInputTensor(const Tensor &t, const std::string &name, bool check_scale
}
NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!");

CheckScaleTensorShape(t, check_scale_inv_alignment);
CheckScaleTensorShape(t);
}

void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty,
bool check_scale_inv_alignment) {
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
const DType type = t.dtype();
if (is_fp8_dtype(type)) {
// FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
Expand Down Expand Up @@ -189,7 +187,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output ", name, " is not allocated!");
}

CheckScaleTensorShape(t, check_scale_inv_alignment);
CheckScaleTensorShape(t);
}

} // namespace transformer_engine
Expand Down
5 changes: 5 additions & 0 deletions transformer_engine/pytorch/csrc/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,4 +223,9 @@ std::vector<size_t> convertShape(const NVTEShape& shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}

int roundup(const int value, const int multiple) {
assert(multiple > 0);
return ((value + multiple - 1) / multiple) * multiple;
}

} // namespace transformer_engine::pytorch
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ void* getDataPtr(at::Tensor tensor, int offset = 0);

std::vector<size_t> convertShape(const NVTEShape& shape);

int roundup(const int value, const int multiple);

} // namespace transformer_engine::pytorch

namespace std {
Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,6 @@ at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv);

at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv);

at::Tensor pad_scale_inv(at::Tensor scale_inv, bool rowwise);

/***************************************************************************************************
* Comm+GEMM Overlap Wrappers
**************************************************************************************************/
Expand Down
24 changes: 15 additions & 9 deletions transformer_engine/pytorch/csrc/extensions/quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,28 +174,34 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
auto last_dim = torch_shape.back();

NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0,
"MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE,
" (got shape=", torch_shape, ")");

at::Tensor data;
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data = std::move(*rowwise_data);
} else {
data = at::empty(torch_shape, opts);
}
rowwise_scale_inv = at::empty({numel / last_dim, last_dim / MXFP8_BLOCK_SIZE}, opts);
auto sinv0 = roundup(numel / last_dim, 128);
auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4);
rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts);
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(
rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{numel / last_dim, last_dim / MXFP8_BLOCK_SIZE});
} else {
tensor.set_rowwise_scale_inv(rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{sinv0, sinv1});
}

if (columnwise_usage) {
auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4);
auto sinv1 = roundup(last_dim, 128);
columnwise_data = at::empty(torch_shape, opts);
columnwise_scale_inv = at::empty({numel / (last_dim * MXFP8_BLOCK_SIZE), last_dim}, opts);
columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts);

tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape);
tensor.set_columnwise_scale_inv(
columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{numel / (last_dim * MXFP8_BLOCK_SIZE), last_dim});
tensor.set_columnwise_scale_inv(columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{sinv0, sinv1});
}
this->set_quantization_params(&tensor);

Expand Down
19 changes: 2 additions & 17 deletions transformer_engine/pytorch/csrc/extensions/swizzle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,24 +65,11 @@ void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool roww
}
}

at::Tensor pad_scale_inv(at::Tensor scale_inv, bool rowwise) {
size_t dim_1_mod = (rowwise) ? 128 : 4;
size_t dim_2_mod = (rowwise) ? 4 : 128;
size_t dim_1_pad = (dim_1_mod - scale_inv.sizes()[0] % dim_1_mod) % dim_1_mod;
size_t dim_2_pad = (dim_2_mod - scale_inv.sizes()[1] % dim_2_mod) % dim_2_mod;
if (dim_1_pad == 0 && dim_2_pad == 0) {
return scale_inv;
}
return at::constant_pad_nd(scale_inv, {0, dim_2_pad, 0, dim_1_pad}, 0.0);
}

at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor _scale_inv) {
at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv) {
using namespace transformer_engine::pytorch;

NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors.");

auto scale_inv = pad_scale_inv(_scale_inv, true);

auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA);
auto swizzled_scale_inv = at::empty_like(scale_inv, options);

Expand All @@ -102,13 +89,11 @@ at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor _scale_inv) {
return swizzled_scale_inv;
}

at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor _scale_inv) {
at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv) {
using namespace transformer_engine::pytorch;

NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors.");

auto scale_inv = pad_scale_inv(_scale_inv, false);

auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA);
auto swizzled_scale_inv = at::empty_like(scale_inv, options);

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..tensor.mxfp8_tensor import MXFP8Quantizer


_use_cudnn_mxfp8_norm = bool(int(os.getenv("NVTE_CUDNN_MXFP8_NORM", "1")))
_use_cudnn_mxfp8_norm = bool(int(os.getenv("NVTE_CUDNN_MXFP8_NORM", "0")))


def _get_normalization_func(normalization: str, forward: bool):
Expand Down
22 changes: 15 additions & 7 deletions transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from transformer_engine_torch import DType as TE_DType
from ..constants import MXFP8_BLOCK_SCALING_SIZE
from ..utils import devices_match
from ..utils import devices_match, round_up_to_nearest_multiple

from ._internal.mxfp8_tensor_base import MXFP8TensorBase, _FromMXFP8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
Expand Down Expand Up @@ -79,11 +79,19 @@ def make_empty(
if device is None:
device = torch.device("cuda")

assert (
shape[-1] % MXFP8_BLOCK_SCALING_SIZE == 0
and math.prod(shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE == 0
), (
f"Incorrect shape {shape} for MXFP8. Tensor dims must be divisible by"
f" {MXFP8_BLOCK_SCALING_SIZE}"
)

# Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device)
scale_inv = torch.empty(
math.prod(shape[:-1]),
shape[-1] // MXFP8_BLOCK_SCALING_SIZE,
scale_inv = torch.zeros(
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
dtype=torch.uint8,
device=device,
)
Expand All @@ -93,9 +101,9 @@ def make_empty(
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty_like(data)
columnwise_scale_inv = torch.empty(
math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE,
shape[-1],
columnwise_scale_inv = torch.zeros(
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128),
dtype=torch.uint8,
device=device,
)
Expand Down
7 changes: 7 additions & 0 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,10 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool:
def get_sm_count() -> int:
"""Returns the number of streaming multiprocessors in the current device."""
return torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count


def round_up_to_nearest_multiple(value, multiple):
"""Round up `value` to the next mutiple of `multiple`"""
if multiple == 0:
raise ValueError("multiple cannot be zero.")
return ((value + multiple - 1) // multiple) * multiple

0 comments on commit cffec60

Please sign in to comment.