@@ -861,7 +861,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
861861 // Chunk dims
862862 std::vector<size_t > input_b_chunk_shape =
863863 (transb ? std::vector<size_t >{2 * k_chunk, n} : std::vector<size_t >{2 * n_chunk, k});
864- // (transb ? std::vector<size_t>{k, 2 * n_chunk} : std::vector<size_t>{2 * n_chunk, k});
864+ // (transb ? std::vector<size_t>{k, 2 * n_chunk} : std::vector<size_t>{2 * n_chunk, k});
865865 std::vector<size_t > output_chunk_shape = {(transb ? 1 : 2 ) * n_chunk, m};
866866 input_a_chunk_size *= transb ? 2 : 1 ;
867867 input_b_chunk_size *= 2 ;
@@ -894,22 +894,22 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
894894
895895 // GEMM
896896 TensorWrapper input_a_chunk, input_b_chunk;
897- if (ag_on_B) { // AllGather is performed on input B tensor (default case).
898- // Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP.
899- input_a_chunk = get_tensor_chunk (A, transb ? input_a_chunk_size * send_chunk_id / 2 : 0 ,
897+ if (ag_on_B) { // AllGather is performed on input B tensor (default case).
898+ // Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP.
899+ input_a_chunk = get_tensor_chunk (
900+ A, transb ? input_a_chunk_size * send_chunk_id / 2 : 0 ,
900901 transb ? std::vector<size_t >{k_chunk * 2 , m} : shape_to_vector (A.shape ()));
901902 input_b_chunk =
902903 get_buffer_chunk_like (B, input_b_chunk_size * send_chunk_id / 2 , input_b_chunk_shape);
903- } else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad.
904+ } else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad.
904905 assert (transa == false && transb == true );
905- input_a_chunk = get_buffer_chunk_like (
906- A, input_a_chunk_size * send_chunk_id / 2 , std::vector<size_t >{k_chunk * 2 , m}
907- );
906+ input_a_chunk = get_buffer_chunk_like (A, input_a_chunk_size * send_chunk_id / 2 ,
907+ std::vector<size_t >{k_chunk * 2 , m});
908908 input_b_chunk =
909909 get_tensor_chunk (B, input_b_chunk_size * send_chunk_id / 2 , input_b_chunk_shape);
910910 }
911- auto output_chunk =
912- get_tensor_chunk (D, transb ? 0 : output_chunk_size * send_chunk_id / 2 , output_chunk_shape);
911+ auto output_chunk = get_tensor_chunk (D, transb ? 0 : output_chunk_size * send_chunk_id / 2 ,
912+ output_chunk_shape);
913913 auto aux_chunk = (do_gelu)
914914 ? get_tensor_chunk (pre_gelu_out, output_chunk_size * send_chunk_id / 2 ,
915915 {2 * n_chunk, k})
@@ -964,15 +964,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
964964
965965 // GEMM
966966 TensorWrapper input_a_chunk, input_b_chunk;
967- if (ag_on_B) { // AllGather is performed on input B tensor (default case).
968- // Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP.
969- input_a_chunk = get_tensor_chunk (A, transb ? input_a_chunk_size * send_chunk_id : 0 ,
970- transb ? std::vector<size_t >{k_chunk, m} : shape_to_vector (A.shape ()));
967+ if (ag_on_B) { // AllGather is performed on input B tensor (default case).
968+ // Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP.
969+ input_a_chunk =
970+ get_tensor_chunk (A, transb ? input_a_chunk_size * send_chunk_id : 0 ,
971+ transb ? std::vector<size_t >{k_chunk, m} : shape_to_vector (A.shape ()));
971972 input_b_chunk =
972973 get_buffer_chunk_like (B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
973- } else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad.
974+ } else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad.
974975 assert (transa == false && transb == true );
975- input_a_chunk = get_buffer_chunk_like (A, input_a_chunk_size * send_chunk_id,
976+ input_a_chunk = get_buffer_chunk_like (
977+ A, input_a_chunk_size * send_chunk_id,
976978 transb ? std::vector<size_t >{k_chunk, m} : std::vector<size_t >{m, k});
977979 input_b_chunk =
978980 get_tensor_chunk (B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
0 commit comments