Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pad MXFP8 scale inverses at the time of creation #1431

Merged
merged 6 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/pytorch/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None)
if failed_tensors:
print(failure_message)
t1, t2 = failed_tensors[0]
torch.testing.assert_close(t1, t2, rtol=0, atol=0)
torch.testing.assert_close(t1, t2, rtol=0, atol=0, equal_nan=True)
ksivaman marked this conversation as resolved.
Show resolved Hide resolved


def generate_data(
Expand Down
6 changes: 2 additions & 4 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,10 +423,8 @@ struct is_fp8<fp8e5m2> : std::true_type {};
size_t typeToSize(const DType type);

void CheckNoopTensor(const Tensor &t, const std::string &name);
void CheckInputTensor(const Tensor &t, const std::string &name,
bool check_scale_inv_alignment = false);
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false,
bool check_scale_inv_alignment = false);
void CheckInputTensor(const Tensor &t, const std::string &name);
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false);

bool is_fp8_dtype(const DType t);

Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/common/swizzle/swizzle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
return;
}

CheckInputTensor(*input, "scaling_factor_input", true);
CheckInputTensor(*output, "scaling_factor_output", true);
CheckInputTensor(*input, "scaling_factor_input");
CheckInputTensor(*output, "scaling_factor_output");

auto& scaling_mode = input->scaling_mode;

Expand Down
12 changes: 5 additions & 7 deletions transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void CheckNoopTensor(const Tensor &t, const std::string &name) {
}
}

void CheckScaleTensorShape(const Tensor &t, bool check_scale_inv_alignment) {
void CheckScaleTensorShape(const Tensor &t) {
NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!");
if (is_tensor_scaling(t.scaling_mode)) {
// per-tensor scaling
Expand All @@ -80,7 +80,6 @@ void CheckScaleTensorShape(const Tensor &t, bool check_scale_inv_alignment) {
}
} else {
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) {
if (!check_scale_inv_alignment) return;
// Need (4, 128) alignment even for e8 scaling factor
auto block_alignment = std::vector<size_t>{128ul / typeToSize(t.scale_inv.dtype),
4ul / typeToSize(t.scale_inv.dtype)};
Expand Down Expand Up @@ -111,7 +110,7 @@ void CheckScaleTensorShape(const Tensor &t, bool check_scale_inv_alignment) {
}
}

void CheckInputTensor(const Tensor &t, const std::string &name, bool check_scale_inv_alignment) {
void CheckInputTensor(const Tensor &t, const std::string &name) {
const DType type = t.dtype();
if (is_fp8_dtype(type)) {
// FP8 input needs to have scale_inv
Expand Down Expand Up @@ -143,11 +142,10 @@ void CheckInputTensor(const Tensor &t, const std::string &name, bool check_scale
}
NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!");

CheckScaleTensorShape(t, check_scale_inv_alignment);
CheckScaleTensorShape(t);
}

void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty,
bool check_scale_inv_alignment) {
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
const DType type = t.dtype();
if (is_fp8_dtype(type)) {
// FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
Expand Down Expand Up @@ -189,7 +187,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output ", name, " is not allocated!");
}

CheckScaleTensorShape(t, check_scale_inv_alignment);
CheckScaleTensorShape(t);
}

} // namespace transformer_engine
Expand Down
5 changes: 5 additions & 0 deletions transformer_engine/pytorch/csrc/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,4 +223,9 @@ std::vector<size_t> convertShape(const NVTEShape& shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}

int roundup(const int value, const int multiple) {
assert(multiple > 0);
return ((value + multiple - 1) / multiple) * multiple;
}

} // namespace transformer_engine::pytorch
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <cassert>
#include <cmath>
ksivaman marked this conversation as resolved.
Show resolved Hide resolved
#include <cstring>
#include <iostream>
#include <memory>
Expand Down Expand Up @@ -252,6 +253,8 @@ void* getDataPtr(at::Tensor tensor, int offset = 0);

std::vector<size_t> convertShape(const NVTEShape& shape);

int roundup(const int value, const int multiple);

} // namespace transformer_engine::pytorch

namespace std {
Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,6 @@ at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv);

at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv);

at::Tensor pad_scale_inv(at::Tensor scale_inv, bool rowwise);

/***************************************************************************************************
* Comm+GEMM Overlap Wrappers
**************************************************************************************************/
Expand Down
20 changes: 11 additions & 9 deletions transformer_engine/pytorch/csrc/extensions/quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,21 +181,23 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
} else {
data = at::empty(torch_shape, opts);
}
rowwise_scale_inv = at::empty({numel / last_dim, last_dim / MXFP8_BLOCK_SIZE}, opts);
auto sinv0 = roundup(numel / last_dim, 128);
auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4);
ksivaman marked this conversation as resolved.
Show resolved Hide resolved
rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts);
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(
rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{numel / last_dim, last_dim / MXFP8_BLOCK_SIZE});
} else {
tensor.set_rowwise_scale_inv(rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{sinv0, sinv1});
}

if (columnwise_usage) {
auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4);
auto sinv1 = roundup(last_dim, 128);
columnwise_data = at::empty(torch_shape, opts);
columnwise_scale_inv = at::empty({numel / (last_dim * MXFP8_BLOCK_SIZE), last_dim}, opts);
columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts);

tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape);
tensor.set_columnwise_scale_inv(
columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{numel / (last_dim * MXFP8_BLOCK_SIZE), last_dim});
tensor.set_columnwise_scale_inv(columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{sinv0, sinv1});
}
this->set_quantization_params(&tensor);

Expand Down
19 changes: 2 additions & 17 deletions transformer_engine/pytorch/csrc/extensions/swizzle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,24 +65,11 @@ void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool roww
}
}

at::Tensor pad_scale_inv(at::Tensor scale_inv, bool rowwise) {
size_t dim_1_mod = (rowwise) ? 128 : 4;
size_t dim_2_mod = (rowwise) ? 4 : 128;
size_t dim_1_pad = (dim_1_mod - scale_inv.sizes()[0] % dim_1_mod) % dim_1_mod;
size_t dim_2_pad = (dim_2_mod - scale_inv.sizes()[1] % dim_2_mod) % dim_2_mod;
if (dim_1_pad == 0 && dim_2_pad == 0) {
return scale_inv;
}
return at::constant_pad_nd(scale_inv, {0, dim_2_pad, 0, dim_1_pad}, 0.0);
}

at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor _scale_inv) {
at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv) {
using namespace transformer_engine::pytorch;

NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors.");

auto scale_inv = pad_scale_inv(_scale_inv, true);

auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA);
auto swizzled_scale_inv = at::empty_like(scale_inv, options);

Expand All @@ -102,13 +89,11 @@ at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor _scale_inv) {
return swizzled_scale_inv;
}

at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor _scale_inv) {
at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv) {
using namespace transformer_engine::pytorch;

NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors.");

auto scale_inv = pad_scale_inv(_scale_inv, false);

auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA);
auto swizzled_scale_inv = at::empty_like(scale_inv, options);

Expand Down
14 changes: 7 additions & 7 deletions transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from transformer_engine_torch import DType as TE_DType
from ..constants import MXFP8_BLOCK_SCALING_SIZE
from ..utils import devices_match
from ..utils import devices_match, round_up_to_nearest_multiple

from ._internal.mxfp8_tensor_base import MXFP8TensorBase, _FromMXFP8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
Expand Down Expand Up @@ -81,9 +81,9 @@ def make_empty(

# Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device)
scale_inv = torch.empty(
math.prod(shape[:-1]),
shape[-1] // MXFP8_BLOCK_SCALING_SIZE,
scale_inv = torch.zeros(
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
ksivaman marked this conversation as resolved.
Show resolved Hide resolved
dtype=torch.uint8,
device=device,
)
Expand All @@ -93,9 +93,9 @@ def make_empty(
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty_like(data)
columnwise_scale_inv = torch.empty(
math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE,
shape[-1],
columnwise_scale_inv = torch.zeros(
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128),
dtype=torch.uint8,
device=device,
)
Expand Down
7 changes: 7 additions & 0 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,10 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool:
def get_sm_count() -> int:
"""Returns the number of streaming multiprocessors in the current device."""
return torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count


def round_up_to_nearest_multiple(value, multiple):
"""Round up `value` to the next mutiple of `multiple`"""
if multiple == 0:
raise ValueError("multiple cannot be zero.")
return ((value + multiple - 1) // multiple) * multiple
Loading