diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index fa4f24ec99..6dad57cdfa 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -99,7 +99,7 @@ struct Tensor { int numel() const { size_t acc = 1; - for (const auto dim: shape()) { + for (const auto dim : shape()) { acc *= dim; } return acc; @@ -123,30 +123,30 @@ struct Tensor { * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569). */ switch (scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: - if (!has_data() && has_columnwise_data()) { - std::vector ret; - if (!columnwise_data.shape.empty()) { - for (size_t i=1; i < columnwise_data.shape.size(); i++) { - ret.push_back(columnwise_data.shape[i]); + case NVTE_DELAYED_TENSOR_SCALING: + if (!has_data() && has_columnwise_data()) { + std::vector ret; + if (!columnwise_data.shape.empty()) { + for (size_t i = 1; i < columnwise_data.shape.size(); i++) { + ret.push_back(columnwise_data.shape[i]); + } + ret.push_back(columnwise_data.shape.front()); } - ret.push_back(columnwise_data.shape.front()); + return ret; + } else { + return data.shape; } - return ret; - } else { - return data.shape; - } - break; - case NVTE_MXFP8_1D_SCALING: - if (!has_data() && has_columnwise_data()) { - return columnwise_data.shape; - } else { - return data.shape; - } - break; - default: - NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\""); - return {}; + break; + case NVTE_MXFP8_1D_SCALING: + if (!has_data() && has_columnwise_data()) { + return columnwise_data.shape; + } else { + return data.shape; + } + break; + default: + NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\""); + return {}; } } @@ -156,10 +156,10 @@ struct Tensor { * as a (D1*D2*...*D(n-1), Dn) matrix. */ size_t flat_first_dim() const { - const auto& full_shape = shape(); + const auto &full_shape = shape(); size_t ret = 1; if (!full_shape.empty()) { - for (size_t i=0; i < full_shape.size() - 1; i++) { + for (size_t i = 0; i < full_shape.size() - 1; i++) { ret *= full_shape[i]; } } @@ -172,7 +172,7 @@ struct Tensor { * as a (D1*D2*...*D(n-1), Dn) matrix. */ size_t flat_last_dim() const { - const auto& full_shape = shape(); + const auto &full_shape = shape(); if (full_shape.empty()) { return 1; } else { diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index edd7a3d3e5..cc935f2d75 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -220,17 +220,16 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { // Determine tensor shape depending on tensor format const auto &t = *reinterpret_cast(tensor); switch (t.scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: - { + case NVTE_DELAYED_TENSOR_SCALING: { if (!t.has_data() && t.has_columnwise_data()) { // We can infer tensor shape if FP8 tensor only has FP8 data // transpose. However, NVTEShape only contains a pointer and // cannot store temporary data. We hack around this by caching // the tensor shape within the empty FP8 data. - auto &shape_cache = const_cast&>(t.data.shape); + auto &shape_cache = const_cast &>(t.data.shape); shape_cache.clear(); if (!t.columnwise_data.shape.empty()) { - for (size_t i=1; i < t.columnwise_data.shape.size(); i++) { + for (size_t i = 1; i < t.columnwise_data.shape.size(); i++) { shape_cache.push_back(t.columnwise_data.shape[i]); } shape_cache.push_back(t.columnwise_data.shape.front()); @@ -243,8 +242,7 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { } break; } - case NVTE_MXFP8_1D_SCALING: - { + case NVTE_MXFP8_1D_SCALING: { if (!t.has_data() && t.has_columnwise_data()) { ret.data = t.columnwise_data.shape.data(); ret.ndim = t.columnwise_data.shape.size(); @@ -254,9 +252,9 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { } break; } - default: - NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", - transformer_engine::to_string(t.scaling_mode), "\""); + default: + NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", + transformer_engine::to_string(t.scaling_mode), "\""); } return ret; @@ -273,21 +271,19 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { return ret; } -size_t nvte_tensor_ndims(const NVTETensor tensor) { - return nvte_tensor_shape(tensor).ndim; -} +size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; } size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) { const auto &shape = nvte_tensor_shape(tensor); - NVTE_CHECK(0 <= dim && dim < shape.ndim, - "Attempted to access index ", dim, " in a shape array with ", shape.ndim, " entries"); + NVTE_CHECK(0 <= dim && dim < shape.ndim, "Attempted to access index ", dim, + " in a shape array with ", shape.ndim, " entries"); return shape.data[dim]; } size_t nvte_tensor_numel(const NVTETensor tensor) { const auto &shape = nvte_tensor_shape(tensor); size_t numel = 1; - for (size_t i=0; i < shape.ndim; i++) { + for (size_t i = 0; i < shape.ndim; i++) { numel *= shape.data[i]; } return numel; diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index e0ee89b59b..173aad52af 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -17,11 +17,11 @@ #include "../util/string.h" -#define NVTE_WARN(...) \ - do { \ - std::cerr << ::transformer_engine::concat_strings( \ - __FILE__ ":", __LINE__, " in function ", __func__, ": ", \ - ::transformer_engine::concat_strings(__VA_ARGS__), "\n"); \ +#define NVTE_WARN(...) \ + do { \ + std::cerr << ::transformer_engine::concat_strings( \ + __FILE__ ":", __LINE__, " in function ", __func__, ": ", \ + ::transformer_engine::concat_strings(__VA_ARGS__), "\n"); \ } while (false) #define NVTE_ERROR(...) \ diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index 836e0fc0d9..ed29aab703 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -4,13 +4,12 @@ * See LICENSE for license information. ************************************************************************/ -#include "pybind.h" - #include #include #include #include "common.h" +#include "pybind.h" namespace transformer_engine::pytorch { namespace detail { @@ -26,12 +25,9 @@ TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer } // FP8 data transpose - if (!tensor.attr("_transpose_invalid").cast() - && !tensor.attr("_transpose").is_none()) { + if (!tensor.attr("_transpose_invalid").cast() && !tensor.attr("_transpose").is_none()) { const auto &data_transpose = tensor.attr("_transpose").cast(); - ret.set_columnwise_data(data_transpose.data_ptr(), - fp8_dtype, - getTensorShape(data_transpose)); + ret.set_columnwise_data(data_transpose.data_ptr(), fp8_dtype, getTensorShape(data_transpose)); } // Scale-inverse @@ -59,9 +55,7 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) const auto &data = tensor.attr("_rowwise_data").cast(); const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); - ret.set_rowwise_scale_inv(scale_inv.data_ptr(), - DType::kFloat8E8M0, - getTensorShape(scale_inv)); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0, getTensorShape(scale_inv)); } // Column-scaled data @@ -69,8 +63,7 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) const auto &data = tensor.attr("_columnwise_data").cast(); const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); ret.set_columnwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); - ret.set_columnwise_scale_inv(scale_inv.data_ptr(), - DType::kFloat8E8M0, + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0, getTensorShape(scale_inv)); } diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e23035ab8a..04fa8aa1ed 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -927,8 +927,10 @@ def _all_gather_mxfp8( if not isinstance(input_, MXFP8TensorBase): input_ = quantizer(input_) elif ( - input_.rowwise_data is None and quantizer.rowwise_usage - or input_.columnwise_data is None and quantizer.columnwise_usage + input_.rowwise_data is None + and quantizer.rowwise_usage + or input_.columnwise_data is None + and quantizer.columnwise_usage ): warnings.warn( "Input and quantizer do not have matching usages. " diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 87445edca7..0db213684c 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -227,13 +227,11 @@ def update_usage( if rowwise_usage: if self._rowwise_data is None: raise RuntimeError( - "Requested row-wise usage, " - "but MXFP8Tensor is missing row-scaled FP8 data" + "Requested row-wise usage, but MXFP8Tensor is missing row-scaled FP8 data" ) if self._rowwise_scale_inv is None: raise RuntimeError( - "Requested row-wise usage, " - "but MXFP8Tensor is missing row-scaled scale-inverses" + "Requested row-wise usage, but MXFP8Tensor is missing row-scaled scale-inverses" ) else: self._rowwise_data = None @@ -243,8 +241,7 @@ def update_usage( if columnwise_usage: if self._columnwise_data is None: raise RuntimeError( - "Requested column-wise usage, " - "but MXFP8Tensor is missing column-scaled FP8 data" + "Requested column-wise usage, but MXFP8Tensor is missing column-scaled FP8 data" ) if self._columnwise_scale_inv is None: raise RuntimeError(