Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

curious about atomic_gemm_overlap_ag #1583

Open
ehion opened this issue Mar 17, 2025 · 0 comments
Open

curious about atomic_gemm_overlap_ag #1583

ehion opened this issue Mar 17, 2025 · 0 comments

Comments

@ehion
Copy link

ehion commented Mar 17, 2025

atomic_gemm_overlap_ag seems different with naive ring-based allgather in which case the last chunk of last rank sending to the last chunk of first rank
in the implementation of atomic_gemm_overlap_ag with loop [0, tp-1), how to make each rank has the same output AG result ? in my view, only the last rank can do full GEMM and in each send/rev step ubufs[rank] in the device(rank>0) will be overwritten.

`
void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator,
TensorWrapper &B_copy, cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;

// Get GEMM dimensions between TN and NN input layouts
const size_t m = (transa) ? A.size(0) : A.size(1);
const size_t n = _ubuf.size(0);
const size_t n_chunk = n / _tp_size;
assert(pre_gelu_out.numel() == 0);

// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();

// Create an GEMM output buffer with N+1 chunks in a contiguous memory
void *D_buffer_ptr;
int D_chunk_bytes = n_chunk * m * D.element_size();
NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main));
auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), nullptr);

// Reset atomic counters
int *counter_ptr = reinterpret_cast<int *>(_counter.dptr());
reset_counters(counter_ptr, _tp_size, true, stream_main);

// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));

auto input_b = TensorWrapper(_ubuf.dptr(), B.shape(), B.dtype(), nullptr, nullptr, B.scale_inv());
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
auto workspace_chunk =
TensorWrapper(workspace.dptr(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());

for (int i = 0; i < _tp_size - 1; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// have the AG output in all ranks to be contiguous after the ring
// exchanges
int send_chunk_id = i;
int recv_chunk_id = i + 1;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;

if (_use_multiatomic_ag) {
  if (i == 0) {
    _ub_comm->use_ce = 0;
    userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
                                     _ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr,
                                     true, _stream_recv);
  }
} else {
  userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _next_rank,
                   _stream_recv);
  userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _prev_rank,
                   _stream_recv);
  producer(counter_ptr, recv_chunk_id, _stream_recv);
}
if (i == 0) {
  nvte_cublas_atomic_gemm(A.data(), input_b.data(), D_buffer.data(), bias.data(),
                          pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
                          accumulate, use_split_accumulator, _math_sms, 0, _tp_size, false,
                          _counter.data(), stream_main);
}

}

// Store the input activation for backprop
if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_self_chunk_id].numel());
assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size());
NVTE_CHECK_CUDA(
cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(),
_ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
}

// Copy the first GEMM output chunk to the end chunk position of D_buffer
char *src_ptr = reinterpret_cast<char *>(D_buffer.dptr());
NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, D_chunk_bytes,
cudaMemcpyDeviceToDevice, stream_main));

// Return the last N rows of D_buffer
NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.numel() * D.element_size(),
cudaMemcpyDeviceToDevice, stream_main));

// Clean up buffer allocation
NVTE_CHECK_CUDA(cudaFreeAsync(D_buffer_ptr, stream_main));

_ub_comm->sms = ori_sms;
}
`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant