From fcfd1182e245e414d3d2df5347c8f0c69656b860 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 21 Feb 2025 23:57:42 +0000 Subject: [PATCH 1/7] Delete row-wise data in single-GPU linear forward Signed-off-by: Tim Moon --- transformer_engine/pytorch/module/linear.py | 1 + .../pytorch/tensor/float8_tensor.py | 47 +++++++++----- .../pytorch/tensor/mxfp8_tensor.py | 61 ++++++++++++------- .../pytorch/tensor/quantized_tensor.py | 6 +- 4 files changed, 78 insertions(+), 37 deletions(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index e51513630f..b3a3138693 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -152,6 +152,7 @@ def forward( ) if not isinstance(inputmat, QuantizedTensor): inputmat = input_quantizer(inputmat) + own_quantized_input = True elif backward_needs_input: inputmat.update_usage(rowwise_usage=True, columnwise_usage=True) inputmat_total = inputmat diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index da788182a0..7210217856 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -267,21 +267,40 @@ def _create_transpose(self): self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose) self._transpose_invalid = False - def update_usage(self, rowwise_usage=True, columnwise_usage=True): - assert rowwise_usage or columnwise_usage, "Could not disable all usages of the tensor" - if rowwise_usage: - assert self._data is not None, "Rowwise usage of the tensor was already disabled" - else: - if not non_tn_fp8_gemm_supported(): - if self._transpose is None or self._transpose_invalid: - self._create_transpose() - self._data = None - if columnwise_usage: - if self._transpose is None or self._transpose_invalid: - assert self._data is not None, "The tensor does not hold any data anymore" - if not non_tn_fp8_gemm_supported(): - self._create_transpose() + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): + # Figure out what data is available and what is required + has_data = self._data is not None + has_data_transpose = self._transpose is not None and not self._transpose_invalid + needs_data = has_data + needs_data_transpose = has_data_transpose + if non_tn_fp8_gemm_supported(): + if rowwise_usage is not None and rowwise_usage: + needs_data = True + if columnwise_usage is not None and columnwise_usage: + needs_data = True + needs_data_transpose = False else: + if rowwise_usage is not None: + needs_data = rowwise_usage + if columnwise_usage is not None: + needs_data_transpose = columnwise_usage + + # Generate data that is required + if needs_data and not has_data: + raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose") + if needs_data_transpose and not has_data_transpose: + if not has_data: + raise RuntimeError("FP8 data is required to generate FP8 data transpose") + self._create_transpose() + + # Delete data that is not required + if not needs_data: + self._data = None + if not needs_data_transpose: self._transpose = None self._transpose_invalid = True diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 86b13415a1..87445edca7 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -207,36 +207,53 @@ def detach(self) -> MXFP8Tensor: # TODO(ksivamani): Fix the detach bug return MXFP8Tensor.make_like(self) - def update_usage(self, rowwise_usage=True, columnwise_usage=True): + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): """ For MXFP8, columnwise scaled output is only produced by x2 scaling kernels, so this function only disables usages. """ - assert rowwise_usage or columnwise_usage, "Could not disable all usages of the tensor." - - if columnwise_usage and rowwise_usage: - assert ( - self._rowwise_data is not None - and self._rowwise_scale_inv is not None - and self._columnwise_data is not None - and self._columnwise_scale_inv is not None - ), "Cannot update to rowwise and columnwise usage." - return + # Default usage is based on available data + if rowwise_usage is None: + rowwise_usage = self._rowwise_data is not None + if columnwise_usage is None: + columnwise_usage = self._columnwise_data is not None + + # Update row-scaled data if rowwise_usage: - assert ( - self._rowwise_data is not None and self._rowwise_scale_inv is not None - ), "Cannot update to rowwise usage." + if self._rowwise_data is None: + raise RuntimeError( + "Requested row-wise usage, " + "but MXFP8Tensor is missing row-scaled FP8 data" + ) + if self._rowwise_scale_inv is None: + raise RuntimeError( + "Requested row-wise usage, " + "but MXFP8Tensor is missing row-scaled scale-inverses" + ) + else: + self._rowwise_data = None + self._rowwise_scale_inv = None + + # Update column-scaled data + if columnwise_usage: + if self._columnwise_data is None: + raise RuntimeError( + "Requested column-wise usage, " + "but MXFP8Tensor is missing column-scaled FP8 data" + ) + if self._columnwise_scale_inv is None: + raise RuntimeError( + "Requested column-wise usage, " + "but MXFP8Tensor is missing column-scaled scale-inverses" + ) + else: self._columnwise_data = None self._columnwise_scale_inv = None - return - - assert ( - self._columnwise_data is not None and self._columnwise_scale_inv is not None - ), "Cannot update to columnwise usage." - self._rowwise_data = None - self._rowwise_scale_inv = None - return def clone(self) -> MXFP8Tensor: # pylint: disable=missing-function-docstring diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index ef21412ca7..2f449a3b09 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -289,7 +289,11 @@ def detach(self) -> QuantizedTensor: f"{self.__class__.__name__} class does not implement detach function" ) - def update_usage(self, rowwise_usage=True, columnwise_usage=True): + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): """Indicate to the tensor how it is going to be used This enables optimizations to memory usage in some cases From 90002d916d1450f46a9c739de2d7475059b64cba Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 22 Feb 2025 01:03:25 +0000 Subject: [PATCH 2/7] Debug Python->C++ parsing of transpose-only Float8Tensors Signed-off-by: Tim Moon --- .../csrc/extensions/type_converters.cpp | 92 ++++++++++--------- 1 file changed, 49 insertions(+), 43 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index d2607e4ed0..792f1e299d 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -4,73 +4,79 @@ * See LICENSE for license information. ************************************************************************/ -#include "common.h" #include "pybind.h" +#include +#include +#include + +#include "common.h" + namespace transformer_engine::pytorch { namespace detail { TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer) { - const at::Tensor &data = tensor.attr("_data").cast(); - const at::Tensor &scale_inv = tensor.attr("_scale_inv").cast(); - float *scale_inv_dptr = reinterpret_cast(scale_inv.data_ptr()); - const DType dtype = tensor.attr("_fp8_dtype").cast(); - - const auto &shape = getTensorShape(data); + auto ret = TensorWrapper(quantizer->get_scaling_mode()); - bool transpose_valid = !tensor.attr("_transpose_invalid").cast(); - std::optional transpose = std::nullopt; - if (transpose_valid) { - transpose = tensor.attr("_transpose").cast>(); + // FP8 data + const DType fp8_dtype = tensor.attr("_fp8_dtype").cast(); + if (!tensor.attr("_data").is_none()) { + const auto &data = tensor.attr("_data").cast(); + ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); } - auto ret = TensorWrapper(quantizer->get_scaling_mode()); + // FP8 data transpose + if (!tensor.attr("_transpose_invalid").cast() + && !tensor.attr("_data").is_none()) { + const auto &data_transpose = tensor.attr("_transpose").cast(); + ret.set_columnwise_data(data_transpose.data_ptr(), + fp8_dtype, + getTensorShape(data_transpose)); + } - ret.set_rowwise_data(data.data_ptr(), dtype, shape); - if (transpose_valid && transpose != std::nullopt) { - const auto &transpose_shape = getTensorShape(*transpose); - ret.set_columnwise_data(transpose->data_ptr(), dtype, transpose_shape); + // Scale-inverse + { + const auto &scale_inv = tensor.attr("_scale_inv").cast(); + float *dptr = reinterpret_cast(scale_inv.data_ptr()); + const auto &dtype = GetTransformerEngineDType(scale_inv.scalar_type()); + const auto &shape = getTensorShape(scale_inv); + ret.set_rowwise_scale_inv(dptr, dtype, shape); + ret.set_columnwise_scale_inv(dptr, dtype, shape); } - const auto scale_inv_dtype = GetTransformerEngineDType(scale_inv.scalar_type()); - const auto scale_inv_shape = getTensorShape(scale_inv); - ret.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - ret.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + // Quantizer state quantizer->set_quantization_params(&ret); + return ret; } TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) { - const DType dtype = tensor.attr("_fp8_dtype").cast(); auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING); - bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); - bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); - - if (rowwise_usage) { - const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast(); - const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); - void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); - const auto &shape = getTensorShape(data_rowwise); - ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, shape); - - const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); - ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat8E8M0, scale_inv_rowwise_shape); + // Row-scaled data + const DType fp8_dtype = tensor.attr("_fp8_dtype").cast(); + if (!tensor.attr("_rowwise_data").is_none()) { + const auto &data = tensor.attr("_rowwise_data").cast(); + const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); + ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), + DType::kFloat8E8M0, + getTensorShape(scale_inv)); } - if (columnwise_usage) { - const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast(); - const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); - void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); - const auto &shape = getTensorShape(data_colwise); - ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); - - const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); - ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat8E8M0, - scale_inv_colwise_shape); + // Column-scaled data + if (!tensor.attr("_columnwise_data").is_none()) { + const auto &data = tensor.attr("_columnwise_data").cast(); + const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); + ret.set_columnwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), + DType::kFloat8E8M0, + getTensorShape(scale_inv)); } + // Quantizer state quantizer->set_quantization_params(&ret); + return ret; } From f9170022fdc9f112bfb10b3e330cf837c842217b Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 22 Feb 2025 03:21:50 +0000 Subject: [PATCH 3/7] Debug tensor shape calculation without row-wise data Signed-off-by: Tim Moon --- .../common/transformer_engine.cpp | 85 +++++++++++-------- transformer_engine/common/util/logging.h | 7 ++ .../csrc/extensions/type_converters.cpp | 2 +- 3 files changed, 59 insertions(+), 35 deletions(-) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index faf6ec990d..3a5080b61d 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -212,37 +212,57 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) { } NVTEShape nvte_tensor_shape(const NVTETensor tensor) { - if (tensor == nullptr) return {nullptr, 0}; - const auto &t = *reinterpret_cast(tensor); - NVTEShape ret; - - // FP8 tensor keeps shape in rowwise data - if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - ret.data = t.data.shape.data(); - ret.ndim = t.data.shape.size(); - return ret; + if (tensor == nullptr) { + NVTE_ERROR("Invalid tensor"); } + NVTEShape ret; - // Get shape based on what data is available - if (t.has_data()) { - ret.data = t.data.shape.data(); - ret.ndim = t.data.shape.size(); - return ret; - } - if (t.has_columnwise_data()) { - ret.data = t.columnwise_data.shape.data(); - ret.ndim = t.columnwise_data.shape.size(); - return ret; + // Determine tensor shape depending on tensor format + const auto &t = *reinterpret_cast(tensor); + switch (t.scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: + { + if (!t.has_data() && t.has_columnwise_data()) { + // We can infer tensor shape if FP8 tensor only has FP8 data + // transpose. However, NVTEShape only contains a pointer and + // cannot store temporary data. We hack around this by caching + // the tensor shape within the empty FP8 data. + auto& shape_cache = const_cast&>(t.data.shape); + shape_cache.clear(); + if (!t.columnwise_data.shape.empty()) { + for (size_t i=1; i < t.columnwise_data.shape.size(); i++) { + shape_cache.push_back(t.columnwise_data.shape[i]); + } + shape_cache.push_back(t.columnwise_data.shape.front()); + } + } + ret.data = t.data.shape.data(); + ret.ndim = t.data.shape.size(); + break; + } + case NVTE_MXFP8_1D_SCALING: + { + if (!t.has_data() && t.has_columnwise_data()) { + ret.data = t.columnwise_data.shape.data(); + ret.ndim = t.columnwise_data.shape.size(); + } else { + ret.data = t.data.shape.data(); + ret.ndim = t.data.shape.size(); + } + break; + } + default: + NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", + transformer_engine::to_string(t.scaling_mode), "\""); } - // Tensor has no data - ret.data = t.data.shape.data(); - ret.ndim = t.data.shape.size(); return ret; } NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { - if (tensor == nullptr) return {nullptr, 0}; + if (tensor == nullptr) { + NVTE_ERROR("Invalid tensor"); + } const auto &t = *reinterpret_cast(tensor); NVTEShape ret; ret.data = t.columnwise_data.shape.data(); @@ -251,24 +271,21 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { } size_t nvte_tensor_ndim(const NVTETensor tensor) { - if (tensor == nullptr) return 0; - const auto &t = *reinterpret_cast(tensor); - return t.data.shape.size(); + return nvte_tensor_shape(tensor).ndim; } size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) { - if (tensor == nullptr) return 0; - const auto &t = *reinterpret_cast(tensor); - NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim); - return t.data.shape[dim]; + const auto &shape = nvte_tensor_shape(tensor); + NVTE_CHECK(0 <= dim && dim < shape.ndim, + "Attempted to access index ", dim, " in a shape array with ", shape.ndim, " entries"); + return shape.data[dim]; } size_t nvte_tensor_numel(const NVTETensor tensor) { - if (tensor == nullptr) return 0; - const auto &t = *reinterpret_cast(tensor); + const auto &shape = nvte_tensor_shape(tensor); size_t numel = 1; - for (auto size : t.data.shape) { - numel *= size; + for (size_t i=0; i < shape.ndim; i++) { + numel *= shape.data[i]; } return numel; } diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 10a4ec28dc..821c0bfb8b 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -16,6 +16,13 @@ #include "../util/string.h" +#define NVTE_WARN(...) \ + do { \ + std::cerr << ::transformer_engine::concat_strings( \ + __FILE__ ":", __LINE__, " in function ", __func__, ": ", \ + ::transformer_engine::concat_strings(__VA_ARGS__), "\n"); \ + } while (false) + #define NVTE_ERROR(...) \ do { \ throw ::std::runtime_error(::transformer_engine::concat_strings( \ diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index 792f1e299d..836e0fc0d9 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -27,7 +27,7 @@ TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer // FP8 data transpose if (!tensor.attr("_transpose_invalid").cast() - && !tensor.attr("_data").is_none()) { + && !tensor.attr("_transpose").is_none()) { const auto &data_transpose = tensor.attr("_transpose").cast(); ret.set_columnwise_data(data_transpose.data_ptr(), fp8_dtype, From 03d95e584fa297f0331f68f112a8d5262941ded7 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 25 Feb 2025 00:51:01 +0000 Subject: [PATCH 4/7] Debug correctness issues with only column-wise data Signed-off-by: Tim Moon --- transformer_engine/common/common.h | 80 ++++++++++++------- .../transformer_engine/transformer_engine.h | 2 +- .../common/transformer_engine.cpp | 11 ++- transformer_engine/common/util/logging.h | 1 + .../pytorch/ops/basic/basic_linear.py | 4 +- 5 files changed, 59 insertions(+), 39 deletions(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index ca9103532d..fa4f24ec99 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -28,6 +28,9 @@ namespace transformer_engine { +std::string to_string(DType type); +std::string to_string(const NVTEScalingMode &mode); + inline size_t product(const std::vector &shape, const size_t begin, const size_t end) { NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", end, " in a vector with ", shape.size(), " entries"); @@ -95,17 +98,8 @@ struct Tensor { scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} int numel() const { - NVTE_CHECK(data.dptr != nullptr || columnwise_data.dptr != nullptr, - "Tensor does not hold any data!"); size_t acc = 1; - if (data.dptr != nullptr) { - for (const auto &dim : data.shape) { - acc *= dim; - } - return acc; - } - // data is empty, use columnwise_data - for (const auto &dim : columnwise_data.shape) { + for (const auto dim: shape()) { acc *= dim; } return acc; @@ -122,24 +116,54 @@ struct Tensor { return data.dtype; } + std::vector shape() const { + /* Note: We sometimes experience spurious compiler errors + * (-Wstringop-overflow) from this function. It appears that GCC + * has some bugs with std::vector (see + * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569). + */ + switch (scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: + if (!has_data() && has_columnwise_data()) { + std::vector ret; + if (!columnwise_data.shape.empty()) { + for (size_t i=1; i < columnwise_data.shape.size(); i++) { + ret.push_back(columnwise_data.shape[i]); + } + ret.push_back(columnwise_data.shape.front()); + } + return ret; + } else { + return data.shape; + } + break; + case NVTE_MXFP8_1D_SCALING: + if (!has_data() && has_columnwise_data()) { + return columnwise_data.shape; + } else { + return data.shape; + } + break; + default: + NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\""); + return {}; + } + } + /*! Matrix height after tensor is flattened to 2D * * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted * as a (D1*D2*...*D(n-1), Dn) matrix. */ size_t flat_first_dim() const { - if (!has_data() && has_columnwise_data()) { - const auto &data_shape = columnwise_data.shape; - if (data_shape.empty()) return 1; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - return product(data_shape, 1, data_shape.size()); - } else { - return product(data_shape, 0, data_shape.size() - 1); + const auto& full_shape = shape(); + size_t ret = 1; + if (!full_shape.empty()) { + for (size_t i=0; i < full_shape.size() - 1; i++) { + ret *= full_shape[i]; } } - const auto &data_shape = data.shape; - if (data_shape.empty()) return 1; - return product(data_shape, 0, data_shape.size() - 1); + return ret; } /*! Matrix width after tensor is flattened to 2D @@ -148,18 +172,12 @@ struct Tensor { * as a (D1*D2*...*D(n-1), Dn) matrix. */ size_t flat_last_dim() const { - if (!has_data() && has_columnwise_data()) { - const auto &data_shape = columnwise_data.shape; - if (data_shape.empty()) return 1; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - return data_shape.front(); - } else { - return data_shape.back(); - } + const auto& full_shape = shape(); + if (full_shape.empty()) { + return 1; + } else { + return full_shape.back(); } - const auto &data_shape = data.shape; - if (data_shape.empty()) return 1; - return data_shape.back(); } }; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index e393dbffc4..3523d960e6 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -501,7 +501,7 @@ class TensorWrapper { * \return Number of elements in the tensor. */ size_t numel() const noexcept { - if (tensor_ == nullptr || this->dptr() == nullptr) return 0; + if (tensor_ == nullptr) return 0; return nvte_tensor_numel(tensor_); } diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 3a5080b61d..edd7a3d3e5 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -227,7 +227,7 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { // transpose. However, NVTEShape only contains a pointer and // cannot store temporary data. We hack around this by caching // the tensor shape within the empty FP8 data. - auto& shape_cache = const_cast&>(t.data.shape); + auto &shape_cache = const_cast&>(t.data.shape); shape_cache.clear(); if (!t.columnwise_data.shape.empty()) { for (size_t i=1; i < t.columnwise_data.shape.size(); i++) { @@ -235,9 +235,12 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { } shape_cache.push_back(t.columnwise_data.shape.front()); } + ret.data = shape_cache.data(); + ret.ndim = shape_cache.size(); + } else { + ret.data = t.data.shape.data(); + ret.ndim = t.data.shape.size(); } - ret.data = t.data.shape.data(); - ret.ndim = t.data.shape.size(); break; } case NVTE_MXFP8_1D_SCALING: @@ -270,7 +273,7 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { return ret; } -size_t nvte_tensor_ndim(const NVTETensor tensor) { +size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; } diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 821c0bfb8b..e0ee89b59b 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -12,6 +12,7 @@ #include #include +#include #include #include "../util/string.h" diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 892e120da1..cb93eb5e6b 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -523,9 +523,7 @@ def _functional_forward( # Configure input tensor for backward pass if own_quantized_x_local: - ### TODO Restore once column-wise usage is supported by itself # pylint: disable=fixme - # x_local.update_usage(rowwise_usage=False) - pass + x_local.update_usage(rowwise_usage=False) # Detach input tensor if needed # Note: PyTorch autograd produces esoteric errors if we save From 20997269d9e9a9edf5f4a977b96b4d8d6f9ce6c4 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 25 Feb 2025 01:44:37 +0000 Subject: [PATCH 5/7] Only cache column-wise input in LayerNormLinear Signed-off-by: Tim Moon --- transformer_engine/pytorch/module/layernorm_linear.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 01bda64101..55f7ba7653 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -310,6 +310,12 @@ def forward( clear_tensor_data(ln_out, ln_out_total) if is_grad_enabled: + + # Input with column-wise usage is needed for dgrad GEMM + if backward_needs_input: + if isinstance(ln_out, QuantizedTensor): + ln_out.update_usage(rowwise_usage=False) + if cpu_offloading: if fp8 and weightmat is not None: set_offloading_param(weightmat, "weight_offloading", True) From 7f4dfdb0503babf08e9503ca5488d7e8fe54a7ff Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 25 Feb 2025 02:02:44 +0000 Subject: [PATCH 6/7] Support MXFP8 all-gather with only column-wise data Signed-off-by: Tim Moon --- transformer_engine/pytorch/distributed.py | 99 ++++++++++++++--------- 1 file changed, 59 insertions(+), 40 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index fe023208d1..e23035ab8a 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -923,20 +923,27 @@ def _all_gather_mxfp8( if out_shape is None: out_shape = [in_shape[0] * world_size] + in_shape[1:] - # Gather MXFP8 data for row-wise usage - if quantizer.rowwise_usage and not quantizer.columnwise_usage: + # Cast input tensor to MXFP8 with required data + if not isinstance(input_, MXFP8TensorBase): + input_ = quantizer(input_) + elif ( + input_.rowwise_data is None and quantizer.rowwise_usage + or input_.columnwise_data is None and quantizer.columnwise_usage + ): + warnings.warn( + "Input and quantizer do not have matching usages. " + "Dequantizing and requantizing to MXFP8." + ) + input_ = quantizer(input_.dequantize()) - # Cast input tensor to MXFP8 if needed - if not isinstance(input_, MXFP8TensorBase): - input_ = quantizer(input_) + # Construct MXFP8 output tensor + out = quantizer.make_empty(out_shape, dtype=input_.dtype, device=input._device) - # Construct MXFP8 output tensor - dtype = torch.float32 - device = "cuda" - if isinstance(input_, MXFP8Tensor): - dtype = input_.dtype - device = input_.device - out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + # Async op handle + handle = None + + # Gather MXFP8 data for row-wise usage + if quantizer.rowwise_usage: # Remove padding from MXFP8 scale-inverses in_scale_inv = input_._rowwise_scale_inv @@ -948,36 +955,48 @@ def _all_gather_mxfp8( out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] # Launch all-gathers - with torch.distributed._coalescing_manager( + if handle is not None: + handle.wait() + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, group=process_group, - device=device, - async_ops=async_op, - ) as coalescing_manager: - torch.distributed.all_gather_into_tensor( - out._rowwise_data, - input_._rowwise_data, - group=process_group, - ) - torch.distributed.all_gather_into_tensor( - out_scale_inv, - in_scale_inv, - group=process_group, - ) - handle = coalescing_manager if async_op else None - return out, handle + ) + handle = torch.distributed.all_gather_into_tensor( + out._rowwise_data, + input_._rowwise_data, + group=process_group, + async_op=async_op, + ) - # Gather in high precision and quantize for column-wise usage - if isinstance(input_, QuantizedTensor): - input_ = input_.dequantize(dtype=torch.bfloat16) - out = torch.empty( - out_shape, - dtype=input_.dtype, - device=input_.device, - memory_format=torch.contiguous_format, - ) - torch.distributed.all_gather_into_tensor(out, input_, group=process_group) - out = quantizer(out) - return out, None + # Gather MXFP8 data for column-wise usage + if quantizer.columnwise_usage: + + # Remove padding from MXFP8 scale-inverses + in_scale_inv = input_._columnwise_scale_inv + out_scale_inv = out._columnwise_scale_inv + flattened_in_shape0 = math.prod(in_shape[:-1]) // 32 + if in_scale_inv.size(0) != flattened_in_shape0: + in_scale_inv = in_scale_inv[:flattened_in_shape0] + out_scale_inv[flattened_in_shape0 * world_size :].zero_() + out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] + + # Launch all-gathers + if handle is not None: + handle.wait() + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, + group=process_group, + ) + handle = torch.distributed.all_gather_into_tensor( + out._columnwise_data, + input_._columnwise_data, + group=process_group, + async_op=async_op, + ) + + return out, handle def gather_along_first_dim( From 351aecea68d7301dd6952d205cdf16b9ad94f001 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Feb 2025 02:15:50 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/common.h | 52 +++++++++---------- .../common/transformer_engine.cpp | 26 ++++------ transformer_engine/common/util/logging.h | 10 ++-- .../csrc/extensions/type_converters.cpp | 17 ++---- transformer_engine/pytorch/distributed.py | 6 ++- .../pytorch/tensor/mxfp8_tensor.py | 9 ++-- 6 files changed, 54 insertions(+), 66 deletions(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index fa4f24ec99..6dad57cdfa 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -99,7 +99,7 @@ struct Tensor { int numel() const { size_t acc = 1; - for (const auto dim: shape()) { + for (const auto dim : shape()) { acc *= dim; } return acc; @@ -123,30 +123,30 @@ struct Tensor { * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569). */ switch (scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: - if (!has_data() && has_columnwise_data()) { - std::vector ret; - if (!columnwise_data.shape.empty()) { - for (size_t i=1; i < columnwise_data.shape.size(); i++) { - ret.push_back(columnwise_data.shape[i]); + case NVTE_DELAYED_TENSOR_SCALING: + if (!has_data() && has_columnwise_data()) { + std::vector ret; + if (!columnwise_data.shape.empty()) { + for (size_t i = 1; i < columnwise_data.shape.size(); i++) { + ret.push_back(columnwise_data.shape[i]); + } + ret.push_back(columnwise_data.shape.front()); } - ret.push_back(columnwise_data.shape.front()); + return ret; + } else { + return data.shape; } - return ret; - } else { - return data.shape; - } - break; - case NVTE_MXFP8_1D_SCALING: - if (!has_data() && has_columnwise_data()) { - return columnwise_data.shape; - } else { - return data.shape; - } - break; - default: - NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\""); - return {}; + break; + case NVTE_MXFP8_1D_SCALING: + if (!has_data() && has_columnwise_data()) { + return columnwise_data.shape; + } else { + return data.shape; + } + break; + default: + NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\""); + return {}; } } @@ -156,10 +156,10 @@ struct Tensor { * as a (D1*D2*...*D(n-1), Dn) matrix. */ size_t flat_first_dim() const { - const auto& full_shape = shape(); + const auto &full_shape = shape(); size_t ret = 1; if (!full_shape.empty()) { - for (size_t i=0; i < full_shape.size() - 1; i++) { + for (size_t i = 0; i < full_shape.size() - 1; i++) { ret *= full_shape[i]; } } @@ -172,7 +172,7 @@ struct Tensor { * as a (D1*D2*...*D(n-1), Dn) matrix. */ size_t flat_last_dim() const { - const auto& full_shape = shape(); + const auto &full_shape = shape(); if (full_shape.empty()) { return 1; } else { diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index edd7a3d3e5..cc935f2d75 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -220,17 +220,16 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { // Determine tensor shape depending on tensor format const auto &t = *reinterpret_cast(tensor); switch (t.scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: - { + case NVTE_DELAYED_TENSOR_SCALING: { if (!t.has_data() && t.has_columnwise_data()) { // We can infer tensor shape if FP8 tensor only has FP8 data // transpose. However, NVTEShape only contains a pointer and // cannot store temporary data. We hack around this by caching // the tensor shape within the empty FP8 data. - auto &shape_cache = const_cast&>(t.data.shape); + auto &shape_cache = const_cast &>(t.data.shape); shape_cache.clear(); if (!t.columnwise_data.shape.empty()) { - for (size_t i=1; i < t.columnwise_data.shape.size(); i++) { + for (size_t i = 1; i < t.columnwise_data.shape.size(); i++) { shape_cache.push_back(t.columnwise_data.shape[i]); } shape_cache.push_back(t.columnwise_data.shape.front()); @@ -243,8 +242,7 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { } break; } - case NVTE_MXFP8_1D_SCALING: - { + case NVTE_MXFP8_1D_SCALING: { if (!t.has_data() && t.has_columnwise_data()) { ret.data = t.columnwise_data.shape.data(); ret.ndim = t.columnwise_data.shape.size(); @@ -254,9 +252,9 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { } break; } - default: - NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", - transformer_engine::to_string(t.scaling_mode), "\""); + default: + NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", + transformer_engine::to_string(t.scaling_mode), "\""); } return ret; @@ -273,21 +271,19 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { return ret; } -size_t nvte_tensor_ndims(const NVTETensor tensor) { - return nvte_tensor_shape(tensor).ndim; -} +size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; } size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) { const auto &shape = nvte_tensor_shape(tensor); - NVTE_CHECK(0 <= dim && dim < shape.ndim, - "Attempted to access index ", dim, " in a shape array with ", shape.ndim, " entries"); + NVTE_CHECK(0 <= dim && dim < shape.ndim, "Attempted to access index ", dim, + " in a shape array with ", shape.ndim, " entries"); return shape.data[dim]; } size_t nvte_tensor_numel(const NVTETensor tensor) { const auto &shape = nvte_tensor_shape(tensor); size_t numel = 1; - for (size_t i=0; i < shape.ndim; i++) { + for (size_t i = 0; i < shape.ndim; i++) { numel *= shape.data[i]; } return numel; diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index e0ee89b59b..173aad52af 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -17,11 +17,11 @@ #include "../util/string.h" -#define NVTE_WARN(...) \ - do { \ - std::cerr << ::transformer_engine::concat_strings( \ - __FILE__ ":", __LINE__, " in function ", __func__, ": ", \ - ::transformer_engine::concat_strings(__VA_ARGS__), "\n"); \ +#define NVTE_WARN(...) \ + do { \ + std::cerr << ::transformer_engine::concat_strings( \ + __FILE__ ":", __LINE__, " in function ", __func__, ": ", \ + ::transformer_engine::concat_strings(__VA_ARGS__), "\n"); \ } while (false) #define NVTE_ERROR(...) \ diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index 836e0fc0d9..ed29aab703 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -4,13 +4,12 @@ * See LICENSE for license information. ************************************************************************/ -#include "pybind.h" - #include #include #include #include "common.h" +#include "pybind.h" namespace transformer_engine::pytorch { namespace detail { @@ -26,12 +25,9 @@ TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer } // FP8 data transpose - if (!tensor.attr("_transpose_invalid").cast() - && !tensor.attr("_transpose").is_none()) { + if (!tensor.attr("_transpose_invalid").cast() && !tensor.attr("_transpose").is_none()) { const auto &data_transpose = tensor.attr("_transpose").cast(); - ret.set_columnwise_data(data_transpose.data_ptr(), - fp8_dtype, - getTensorShape(data_transpose)); + ret.set_columnwise_data(data_transpose.data_ptr(), fp8_dtype, getTensorShape(data_transpose)); } // Scale-inverse @@ -59,9 +55,7 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) const auto &data = tensor.attr("_rowwise_data").cast(); const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); - ret.set_rowwise_scale_inv(scale_inv.data_ptr(), - DType::kFloat8E8M0, - getTensorShape(scale_inv)); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0, getTensorShape(scale_inv)); } // Column-scaled data @@ -69,8 +63,7 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) const auto &data = tensor.attr("_columnwise_data").cast(); const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); ret.set_columnwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); - ret.set_columnwise_scale_inv(scale_inv.data_ptr(), - DType::kFloat8E8M0, + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0, getTensorShape(scale_inv)); } diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e23035ab8a..04fa8aa1ed 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -927,8 +927,10 @@ def _all_gather_mxfp8( if not isinstance(input_, MXFP8TensorBase): input_ = quantizer(input_) elif ( - input_.rowwise_data is None and quantizer.rowwise_usage - or input_.columnwise_data is None and quantizer.columnwise_usage + input_.rowwise_data is None + and quantizer.rowwise_usage + or input_.columnwise_data is None + and quantizer.columnwise_usage ): warnings.warn( "Input and quantizer do not have matching usages. " diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 87445edca7..0db213684c 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -227,13 +227,11 @@ def update_usage( if rowwise_usage: if self._rowwise_data is None: raise RuntimeError( - "Requested row-wise usage, " - "but MXFP8Tensor is missing row-scaled FP8 data" + "Requested row-wise usage, but MXFP8Tensor is missing row-scaled FP8 data" ) if self._rowwise_scale_inv is None: raise RuntimeError( - "Requested row-wise usage, " - "but MXFP8Tensor is missing row-scaled scale-inverses" + "Requested row-wise usage, but MXFP8Tensor is missing row-scaled scale-inverses" ) else: self._rowwise_data = None @@ -243,8 +241,7 @@ def update_usage( if columnwise_usage: if self._columnwise_data is None: raise RuntimeError( - "Requested column-wise usage, " - "but MXFP8Tensor is missing column-scaled FP8 data" + "Requested column-wise usage, but MXFP8Tensor is missing column-scaled FP8 data" ) if self._columnwise_scale_inv is None: raise RuntimeError(