Skip to content

Commit 9c3a675

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 91c6dba commit 9c3a675

File tree

6 files changed

+26
-24
lines changed

6 files changed

+26
-24
lines changed

transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
861861
// Chunk dims
862862
std::vector<size_t> input_b_chunk_shape =
863863
(transb ? std::vector<size_t>{2 * k_chunk, n} : std::vector<size_t>{2 * n_chunk, k});
864-
// (transb ? std::vector<size_t>{k, 2 * n_chunk} : std::vector<size_t>{2 * n_chunk, k});
864+
// (transb ? std::vector<size_t>{k, 2 * n_chunk} : std::vector<size_t>{2 * n_chunk, k});
865865
std::vector<size_t> output_chunk_shape = {(transb ? 1 : 2) * n_chunk, m};
866866
input_a_chunk_size *= transb ? 2 : 1;
867867
input_b_chunk_size *= 2;
@@ -894,22 +894,22 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
894894

895895
// GEMM
896896
TensorWrapper input_a_chunk, input_b_chunk;
897-
if (ag_on_B) { // AllGather is performed on input B tensor (default case).
898-
// Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP.
899-
input_a_chunk = get_tensor_chunk(A, transb ? input_a_chunk_size * send_chunk_id / 2 : 0,
897+
if (ag_on_B) { // AllGather is performed on input B tensor (default case).
898+
// Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP.
899+
input_a_chunk = get_tensor_chunk(
900+
A, transb ? input_a_chunk_size * send_chunk_id / 2 : 0,
900901
transb ? std::vector<size_t>{k_chunk * 2, m} : shape_to_vector(A.shape()));
901902
input_b_chunk =
902903
get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id / 2, input_b_chunk_shape);
903-
} else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad.
904+
} else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad.
904905
assert(transa == false && transb == true);
905-
input_a_chunk = get_buffer_chunk_like(
906-
A, input_a_chunk_size * send_chunk_id / 2, std::vector<size_t>{k_chunk * 2, m}
907-
);
906+
input_a_chunk = get_buffer_chunk_like(A, input_a_chunk_size * send_chunk_id / 2,
907+
std::vector<size_t>{k_chunk * 2, m});
908908
input_b_chunk =
909909
get_tensor_chunk(B, input_b_chunk_size * send_chunk_id / 2, input_b_chunk_shape);
910910
}
911-
auto output_chunk =
912-
get_tensor_chunk(D, transb ? 0 : output_chunk_size * send_chunk_id / 2, output_chunk_shape);
911+
auto output_chunk = get_tensor_chunk(D, transb ? 0 : output_chunk_size * send_chunk_id / 2,
912+
output_chunk_shape);
913913
auto aux_chunk = (do_gelu)
914914
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id / 2,
915915
{2 * n_chunk, k})
@@ -964,15 +964,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
964964

965965
// GEMM
966966
TensorWrapper input_a_chunk, input_b_chunk;
967-
if (ag_on_B) { // AllGather is performed on input B tensor (default case).
968-
// Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP.
969-
input_a_chunk = get_tensor_chunk(A, transb ? input_a_chunk_size * send_chunk_id : 0,
970-
transb ? std::vector<size_t>{k_chunk, m} : shape_to_vector(A.shape()));
967+
if (ag_on_B) { // AllGather is performed on input B tensor (default case).
968+
// Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP.
969+
input_a_chunk =
970+
get_tensor_chunk(A, transb ? input_a_chunk_size * send_chunk_id : 0,
971+
transb ? std::vector<size_t>{k_chunk, m} : shape_to_vector(A.shape()));
971972
input_b_chunk =
972973
get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
973-
} else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad.
974+
} else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad.
974975
assert(transa == false && transb == true);
975-
input_a_chunk = get_buffer_chunk_like(A, input_a_chunk_size * send_chunk_id,
976+
input_a_chunk = get_buffer_chunk_like(
977+
A, input_a_chunk_size * send_chunk_id,
976978
transb ? std::vector<size_t>{k_chunk, m} : std::vector<size_t>{m, k});
977979
input_b_chunk =
978980
get_tensor_chunk(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);

transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,8 @@ class CommOverlapCore {
130130
virtual void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B,
131131
bool transb, TensorWrapper &D, TensorWrapper &bias,
132132
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
133-
bool accumulate, bool use_split_accumulator,
134-
bool ag_on_B, TensorWrapper &B_copy,
135-
cudaStream_t stream_main) {
133+
bool accumulate, bool use_split_accumulator, bool ag_on_B,
134+
TensorWrapper &B_copy, cudaStream_t stream_main) {
136135
NVTE_ERROR("Operation is not implemented.");
137136
}
138137

transformer_engine/pytorch/cpp_extensions/gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def general_gemm(
108108
workspace.shape[0],
109109
accumulate,
110110
use_split_accumulator,
111-
ag_on_B, # ag_on_B
111+
ag_on_B, # ag_on_B
112112
)
113113
kwargs = {
114114
"comm_overlap": ub,

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
120120
py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
121121
DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
122122
at::Tensor workspace, size_t workspaceSize, bool accumulate,
123-
bool use_split_accumulator, bool ag_on_B, CommOverlapCore *comm_overlap = nullptr,
123+
bool use_split_accumulator, bool ag_on_B,
124+
CommOverlapCore *comm_overlap = nullptr,
124125
std::optional<CommOverlapType> comm_type = std::nullopt,
125126
MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);
126127

transformer_engine/pytorch/csrc/extensions/gemm.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
9090
py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
9191
DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
9292
at::Tensor workspace, size_t workspaceSize, bool accumulate,
93-
bool use_split_accumulator, bool ag_on_B, CommOverlapCore* comm_overlap,
93+
bool use_split_accumulator, bool ag_on_B,
94+
CommOverlapCore* comm_overlap,
9495
std::optional<CommOverlapType> comm_type, MaybeTensor extra_output,
9596
bool bulk_overlap) {
9697
// Input tensors

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
110110
py::arg("quantizer"), py::arg("output_dtype"), py::arg("bias"), py::arg("bias_type"),
111111
py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"),
112112
py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"),
113-
py::arg("ag_on_B"),
114-
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
113+
py::arg("ag_on_B"), py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
115114
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false);
116115
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
117116
py::arg("quantizer"));

0 commit comments

Comments
 (0)