You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
atomic_gemm_overlap_ag seems different with naive ring-based allgather in which case the last chunk of last rank sending to the last chunk of first rank
in the implementation of atomic_gemm_overlap_ag with loop [0, tp-1), how to make each rank has the same output AG result ? in my view, only the last rank can do full GEMM and in each send/rev step ubufs[rank] in the device(rank>0) will be overwritten.
// Get GEMM dimensions between TN and NN input layouts
const size_t m = (transa) ? A.size(0) : A.size(1);
const size_t n = _ubuf.size(0);
const size_t n_chunk = n / _tp_size;
assert(pre_gelu_out.numel() == 0);
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
// Create an GEMM output buffer with N+1 chunks in a contiguous memory
void *D_buffer_ptr;
int D_chunk_bytes = n_chunk * m * D.element_size();
NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main));
auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), nullptr);
for (int i = 0; i < _tp_size - 1; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// have the AG output in all ranks to be contiguous after the ring
// exchanges
int send_chunk_id = i;
int recv_chunk_id = i + 1;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
// Store the input activation for backprop
if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_self_chunk_id].numel());
assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size());
NVTE_CHECK_CUDA(
cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(),
_ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
}
// Copy the first GEMM output chunk to the end chunk position of D_buffer
char *src_ptr = reinterpret_cast<char *>(D_buffer.dptr());
NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, D_chunk_bytes,
cudaMemcpyDeviceToDevice, stream_main));
// Return the last N rows of D_buffer
NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.numel() * D.element_size(),
cudaMemcpyDeviceToDevice, stream_main));
// Clean up buffer allocation
NVTE_CHECK_CUDA(cudaFreeAsync(D_buffer_ptr, stream_main));
_ub_comm->sms = ori_sms;
}
`
The text was updated successfully, but these errors were encountered:
atomic_gemm_overlap_ag seems different with naive ring-based allgather in which case the last chunk of last rank sending to the last chunk of first rank
in the implementation of atomic_gemm_overlap_ag with loop [0, tp-1), how to make each rank has the same output AG result ? in my view, only the last rank can do full GEMM and in each send/rev step ubufs[rank] in the device(rank>0) will be overwritten.
`
void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator,
TensorWrapper &B_copy, cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get GEMM dimensions between TN and NN input layouts
const size_t m = (transa) ? A.size(0) : A.size(1);
const size_t n = _ubuf.size(0);
const size_t n_chunk = n / _tp_size;
assert(pre_gelu_out.numel() == 0);
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
// Create an GEMM output buffer with N+1 chunks in a contiguous memory
void *D_buffer_ptr;
int D_chunk_bytes = n_chunk * m * D.element_size();
NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main));
auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), nullptr);
// Reset atomic counters
int *counter_ptr = reinterpret_cast<int *>(_counter.dptr());
reset_counters(counter_ptr, _tp_size, true, stream_main);
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
auto input_b = TensorWrapper(_ubuf.dptr(), B.shape(), B.dtype(), nullptr, nullptr, B.scale_inv());
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
auto workspace_chunk =
TensorWrapper(workspace.dptr(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
for (int i = 0; i < _tp_size - 1; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// have the AG output in all ranks to be contiguous after the ring
// exchanges
int send_chunk_id = i;
int recv_chunk_id = i + 1;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
}
// Store the input activation for backprop
if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_self_chunk_id].numel());
assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size());
NVTE_CHECK_CUDA(
cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(),
_ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
}
// Copy the first GEMM output chunk to the end chunk position of D_buffer
char *src_ptr = reinterpret_cast<char *>(D_buffer.dptr());
NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, D_chunk_bytes,
cudaMemcpyDeviceToDevice, stream_main));
// Return the last N rows of D_buffer
NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.numel() * D.element_size(),
cudaMemcpyDeviceToDevice, stream_main));
// Clean up buffer allocation
NVTE_CHECK_CUDA(cudaFreeAsync(D_buffer_ptr, stream_main));
_ub_comm->sms = ori_sms;
}
`
The text was updated successfully, but these errors were encountered: