Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 25, 2025
1 parent 04a067b commit 351aece
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 66 deletions.
52 changes: 26 additions & 26 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ struct Tensor {

int numel() const {
size_t acc = 1;
for (const auto dim: shape()) {
for (const auto dim : shape()) {
acc *= dim;
}
return acc;
Expand All @@ -123,30 +123,30 @@ struct Tensor {
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569).
*/
switch (scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING:
if (!has_data() && has_columnwise_data()) {
std::vector<size_t> ret;
if (!columnwise_data.shape.empty()) {
for (size_t i=1; i < columnwise_data.shape.size(); i++) {
ret.push_back(columnwise_data.shape[i]);
case NVTE_DELAYED_TENSOR_SCALING:
if (!has_data() && has_columnwise_data()) {
std::vector<size_t> ret;
if (!columnwise_data.shape.empty()) {
for (size_t i = 1; i < columnwise_data.shape.size(); i++) {
ret.push_back(columnwise_data.shape[i]);
}
ret.push_back(columnwise_data.shape.front());
}
ret.push_back(columnwise_data.shape.front());
return ret;
} else {
return data.shape;
}
return ret;
} else {
return data.shape;
}
break;
case NVTE_MXFP8_1D_SCALING:
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape;
} else {
return data.shape;
}
break;
default:
NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\"");
return {};
break;
case NVTE_MXFP8_1D_SCALING:
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape;
} else {
return data.shape;
}
break;
default:
NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\"");
return {};
}
}

Expand All @@ -156,10 +156,10 @@ struct Tensor {
* as a (D1*D2*...*D(n-1), Dn) matrix.
*/
size_t flat_first_dim() const {
const auto& full_shape = shape();
const auto &full_shape = shape();
size_t ret = 1;
if (!full_shape.empty()) {
for (size_t i=0; i < full_shape.size() - 1; i++) {
for (size_t i = 0; i < full_shape.size() - 1; i++) {
ret *= full_shape[i];
}
}
Expand All @@ -172,7 +172,7 @@ struct Tensor {
* as a (D1*D2*...*D(n-1), Dn) matrix.
*/
size_t flat_last_dim() const {
const auto& full_shape = shape();
const auto &full_shape = shape();
if (full_shape.empty()) {
return 1;
} else {
Expand Down
26 changes: 11 additions & 15 deletions transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,17 +220,16 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
// Determine tensor shape depending on tensor format
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
switch (t.scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING:
{
case NVTE_DELAYED_TENSOR_SCALING: {
if (!t.has_data() && t.has_columnwise_data()) {
// We can infer tensor shape if FP8 tensor only has FP8 data
// transpose. However, NVTEShape only contains a pointer and
// cannot store temporary data. We hack around this by caching
// the tensor shape within the empty FP8 data.
auto &shape_cache = const_cast<std::vector<size_t>&>(t.data.shape);
auto &shape_cache = const_cast<std::vector<size_t> &>(t.data.shape);
shape_cache.clear();
if (!t.columnwise_data.shape.empty()) {
for (size_t i=1; i < t.columnwise_data.shape.size(); i++) {
for (size_t i = 1; i < t.columnwise_data.shape.size(); i++) {
shape_cache.push_back(t.columnwise_data.shape[i]);
}
shape_cache.push_back(t.columnwise_data.shape.front());
Expand All @@ -243,8 +242,7 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
}
break;
}
case NVTE_MXFP8_1D_SCALING:
{
case NVTE_MXFP8_1D_SCALING: {
if (!t.has_data() && t.has_columnwise_data()) {
ret.data = t.columnwise_data.shape.data();
ret.ndim = t.columnwise_data.shape.size();
Expand All @@ -254,9 +252,9 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
}
break;
}
default:
NVTE_ERROR("Cannot parse tensor shape with scaling mode \"",
transformer_engine::to_string(t.scaling_mode), "\"");
default:
NVTE_ERROR("Cannot parse tensor shape with scaling mode \"",
transformer_engine::to_string(t.scaling_mode), "\"");
}

return ret;
Expand All @@ -273,21 +271,19 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
return ret;
}

size_t nvte_tensor_ndims(const NVTETensor tensor) {
return nvte_tensor_shape(tensor).ndim;
}
size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; }

size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) {
const auto &shape = nvte_tensor_shape(tensor);
NVTE_CHECK(0 <= dim && dim < shape.ndim,
"Attempted to access index ", dim, " in a shape array with ", shape.ndim, " entries");
NVTE_CHECK(0 <= dim && dim < shape.ndim, "Attempted to access index ", dim,
" in a shape array with ", shape.ndim, " entries");
return shape.data[dim];
}

size_t nvte_tensor_numel(const NVTETensor tensor) {
const auto &shape = nvte_tensor_shape(tensor);
size_t numel = 1;
for (size_t i=0; i < shape.ndim; i++) {
for (size_t i = 0; i < shape.ndim; i++) {
numel *= shape.data[i];
}
return numel;
Expand Down
10 changes: 5 additions & 5 deletions transformer_engine/common/util/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

#include "../util/string.h"

#define NVTE_WARN(...) \
do { \
std::cerr << ::transformer_engine::concat_strings( \
__FILE__ ":", __LINE__, " in function ", __func__, ": ", \
::transformer_engine::concat_strings(__VA_ARGS__), "\n"); \
#define NVTE_WARN(...) \
do { \
std::cerr << ::transformer_engine::concat_strings( \
__FILE__ ":", __LINE__, " in function ", __func__, ": ", \
::transformer_engine::concat_strings(__VA_ARGS__), "\n"); \
} while (false)

#define NVTE_ERROR(...) \
Expand Down
17 changes: 5 additions & 12 deletions transformer_engine/pytorch/csrc/extensions/type_converters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
* See LICENSE for license information.
************************************************************************/

#include "pybind.h"

#include <ATen/ATen.h>
#include <pybind11/pybind11.h>
#include <transformer_engine/transformer_engine.h>

#include "common.h"
#include "pybind.h"

namespace transformer_engine::pytorch {
namespace detail {
Expand All @@ -26,12 +25,9 @@ TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer
}

// FP8 data transpose
if (!tensor.attr("_transpose_invalid").cast<bool>()
&& !tensor.attr("_transpose").is_none()) {
if (!tensor.attr("_transpose_invalid").cast<bool>() && !tensor.attr("_transpose").is_none()) {
const auto &data_transpose = tensor.attr("_transpose").cast<at::Tensor>();
ret.set_columnwise_data(data_transpose.data_ptr(),
fp8_dtype,
getTensorShape(data_transpose));
ret.set_columnwise_data(data_transpose.data_ptr(), fp8_dtype, getTensorShape(data_transpose));
}

// Scale-inverse
Expand Down Expand Up @@ -59,18 +55,15 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer)
const auto &data = tensor.attr("_rowwise_data").cast<at::Tensor>();
const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data));
ret.set_rowwise_scale_inv(scale_inv.data_ptr(),
DType::kFloat8E8M0,
getTensorShape(scale_inv));
ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0, getTensorShape(scale_inv));
}

// Column-scaled data
if (!tensor.attr("_columnwise_data").is_none()) {
const auto &data = tensor.attr("_columnwise_data").cast<at::Tensor>();
const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
ret.set_columnwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data));
ret.set_columnwise_scale_inv(scale_inv.data_ptr(),
DType::kFloat8E8M0,
ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0,
getTensorShape(scale_inv));
}

Expand Down
6 changes: 4 additions & 2 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,8 +927,10 @@ def _all_gather_mxfp8(
if not isinstance(input_, MXFP8TensorBase):
input_ = quantizer(input_)
elif (
input_.rowwise_data is None and quantizer.rowwise_usage
or input_.columnwise_data is None and quantizer.columnwise_usage
input_.rowwise_data is None
and quantizer.rowwise_usage
or input_.columnwise_data is None
and quantizer.columnwise_usage
):
warnings.warn(
"Input and quantizer do not have matching usages. "
Expand Down
9 changes: 3 additions & 6 deletions transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,11 @@ def update_usage(
if rowwise_usage:
if self._rowwise_data is None:
raise RuntimeError(
"Requested row-wise usage, "
"but MXFP8Tensor is missing row-scaled FP8 data"
"Requested row-wise usage, but MXFP8Tensor is missing row-scaled FP8 data"
)
if self._rowwise_scale_inv is None:
raise RuntimeError(
"Requested row-wise usage, "
"but MXFP8Tensor is missing row-scaled scale-inverses"
"Requested row-wise usage, but MXFP8Tensor is missing row-scaled scale-inverses"
)
else:
self._rowwise_data = None
Expand All @@ -243,8 +241,7 @@ def update_usage(
if columnwise_usage:
if self._columnwise_data is None:
raise RuntimeError(
"Requested column-wise usage, "
"but MXFP8Tensor is missing column-scaled FP8 data"
"Requested column-wise usage, but MXFP8Tensor is missing column-scaled FP8 data"
)
if self._columnwise_scale_inv is None:
raise RuntimeError(
Expand Down

0 comments on commit 351aece

Please sign in to comment.