Skip to content

Commit 4285874

Browse files
yaox12timmoon10
andauthored
[Common] Add checks to CUDA kernel launch and CUDA API calls (#2074)
* add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao <[email protected]> * Remove exceptions from destructors Signed-off-by: Tim Moon <[email protected]> * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Xin Yao <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]>
1 parent 715c3bb commit 4285874

37 files changed

+256
-130
lines changed

transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,10 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
101101
DType::kInt32);
102102
}
103103
// CUDA event creation
104-
cudaEventCreateWithFlags(&_start_compute, 0);
105-
cudaEventCreateWithFlags(&_stop_compute, 0);
106-
cudaEventCreateWithFlags(&_start_comm, 0);
107-
cudaEventCreateWithFlags(&_stop_comm, 0);
104+
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_compute, 0));
105+
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_compute, 0));
106+
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_comm, 0));
107+
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_comm, 0));
108108

109109
/*
110110
Defining the launcher order between the communication and GEMM kernels
@@ -114,11 +114,11 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
114114
*/
115115
int max_connection = transformer_engine::getenv<int>("CUDA_DEVICE_MAX_CONNECTIONS", 8);
116116
int runtime_version = 0;
117-
cudaRuntimeGetVersion(&runtime_version);
117+
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&runtime_version));
118118
cudaDeviceProp deviceProp;
119-
cudaGetDeviceProperties(&deviceProp, 0);
119+
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, 0));
120120
if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) {
121-
cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming);
121+
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming));
122122
} else {
123123
_comm_launch_event = 0;
124124
}
@@ -129,9 +129,13 @@ CommOverlapCore::~CommOverlapCore() {
129129
cudaEventDestroy(_start_comm);
130130
cudaEventDestroy(_stop_compute);
131131
cudaEventDestroy(_start_compute);
132-
if (_comm_launch_event) cudaEventDestroy(_comm_launch_event);
132+
if (_comm_launch_event) {
133+
cudaEventDestroy(_comm_launch_event);
134+
}
133135

134-
if (_atomic_gemm) cudaFree(_counter.dptr());
136+
if (_atomic_gemm) {
137+
cudaFree(_counter.dptr());
138+
}
135139

136140
for (size_t i = 0; i < _stream_compute.size(); i++) {
137141
cudaStreamSynchronize(_stream_compute[i]);
@@ -698,7 +702,9 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
698702
cudaEventDestroy(_stop_recv);
699703
cudaEventDestroy(_stop_send);
700704
cudaStreamDestroy(_stream_recv);
701-
for (size_t i = 0; i < _stream_send.size(); i++) cudaStreamDestroy(_stream_send[i]);
705+
for (size_t i = 0; i < _stream_send.size(); i++) {
706+
cudaStreamDestroy(_stream_send[i]);
707+
}
702708
}
703709

704710
TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source,

transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2319,6 +2319,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
23192319
if (comm->push == 0) {
23202320
kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]),
23212321
reinterpret_cast<int *>(flagptr));
2322+
NVTE_CHECK_CUDA(cudaGetLastError());
23222323
} else {
23232324
void *srcptr = reinterpret_cast<char *>(comm->mem_ptr[srchandler]) + srcoffset;
23242325
void *dstptr = reinterpret_cast<char *>(comm->peer_ptr[dsthandler][peerlocal]) + dstoffset;
@@ -2516,8 +2517,11 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
25162517
&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]), reinterpret_cast<int *>(flagptr),
25172518
reinterpret_cast<int4 *>(srcptr), reinterpret_cast<int4 *>(dstptr),
25182519
signalonly ? 0 : bytes / 16, comm->ub_timeout);
2519-
if (!signalonly)
2520+
NVTE_CHECK_CUDA(cudaGetLastError());
2521+
if (!signalonly) {
25202522
kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]));
2523+
NVTE_CHECK_CUDA(cudaGetLastError());
2524+
}
25212525
if (comm->use_ce) {
25222526
NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
25232527
}
@@ -2532,6 +2536,7 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
25322536
reinterpret_cast<int *>(0 ? // temporary disable
25332537
GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2)
25342538
: nullptr));
2539+
NVTE_CHECK_CUDA(cudaGetLastError());
25352540
}
25362541
}
25372542

@@ -2612,24 +2617,28 @@ void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
26122617
dim3 block(1);
26132618
dim3 grid(1);
26142619
producer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i);
2620+
NVTE_CHECK_CUDA(cudaGetLastError());
26152621
}
26162622

26172623
void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
26182624
dim3 block(1);
26192625
dim3 grid(1);
26202626
consumer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i);
2627+
NVTE_CHECK_CUDA(cudaGetLastError());
26212628
}
26222629

26232630
void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream) {
26242631
dim3 block(1);
26252632
dim3 grid(1);
26262633
consumer_batch_kernel<<<grid, block, 0, stream>>>(atomic_ptr, first_chunk_i, num_chunks);
2634+
NVTE_CHECK_CUDA(cudaGetLastError());
26272635
}
26282636

26292637
void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream) {
26302638
dim3 block(1);
26312639
dim3 grid(1);
26322640
reset_counters_kernel<<<grid, block, 0, stream>>>(atomic_ptr, num_chunks, allgather);
2641+
NVTE_CHECK_CUDA(cudaGetLastError());
26332642
}
26342643

26352644
template <typename fp8type, int nvec>
@@ -2683,6 +2692,7 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in
26832692
reduce_fp8_in_bf16_out_cuda<fp8type, nvec>
26842693
<<<grid, block, 0, stream>>>(inputs, output, scale, num_inputs, input_size,
26852694
num_aligned_elements_per_input, tot_input_size);
2695+
NVTE_CHECK_CUDA(cudaGetLastError());
26862696
}
26872697

26882698
template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale,
@@ -2738,4 +2748,5 @@ void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cud
27382748
dim3 grid(num_blocks);
27392749
reduce_bf16_cuda<nvec><<<grid, block, 0, stream>>>(
27402750
inputs, output, num_inputs, input_size, num_aligned_elements_per_input, tot_input_size);
2751+
NVTE_CHECK_CUDA(cudaGetLastError());
27412752
}

transformer_engine/common/common.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
5050
update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>(
5151
reinterpret_cast<const float *>(t->scale.dptr),
5252
reinterpret_cast<float *>(t->scale_inv.dptr));
53+
NVTE_CHECK_CUDA(cudaGetLastError());
5354
}
5455
}
5556

@@ -91,6 +92,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
9192
dim3 grid(numBlocks, 1, 1); \
9293
memset_kernel<vectorizedType> \
9394
<<<grid, kThreadsPerBlock, 0, stream>>>(ptr, value, size_in_bytes); \
95+
NVTE_CHECK_CUDA(cudaGetLastError()); \
9496
return; \
9597
}
9698

@@ -101,7 +103,7 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream
101103

102104
if (size_in_bytes > 4096) {
103105
// Use cudaMemsetAsync for larger sizes.
104-
cudaMemsetAsync(ptr, value, size_in_bytes, stream);
106+
NVTE_CHECK_CUDA(cudaMemsetAsync(ptr, value, size_in_bytes, stream));
105107
return;
106108
}
107109

transformer_engine/common/fused_attn/context_parallel.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor
341341
thd_read_half_tensor_kernel<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
342342
half.data.dptr, tensor.data.dptr, reinterpret_cast<int *>(cu_seqlens.data.dptr), batch,
343343
hidden_size_in_bytes, half_idx, tensor_shape[seq_dim]);
344+
NVTE_CHECK_CUDA(cudaGetLastError());
344345
}
345346

346347
/***************************************************************************************************
@@ -397,11 +398,13 @@ void thd_second_half_lse_correction(Tensor lse, const Tensor &lse_per_step,
397398
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr),
398399
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
399400
second_half_lse_seqlen);
401+
NVTE_CHECK_CUDA(cudaGetLastError());
400402
} else {
401403
thd_lse_kernel<false, LseCorrectionFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
402404
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr),
403405
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
404406
second_half_lse_seqlen);
407+
NVTE_CHECK_CUDA(cudaGetLastError());
405408
}
406409
}
407410

@@ -446,11 +449,13 @@ void thd_read_second_half_lse(const Tensor &lse, const Tensor &cu_seqlens, Tenso
446449
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(half_lse.data.dptr),
447450
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
448451
second_half_lse_seqlen);
452+
NVTE_CHECK_CUDA(cudaGetLastError());
449453
} else {
450454
thd_lse_kernel<false, ReadLseFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
451455
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(half_lse.data.dptr),
452456
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
453457
second_half_lse_seqlen);
458+
NVTE_CHECK_CUDA(cudaGetLastError());
454459
}
455460
}
456461

@@ -519,6 +524,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co
519524
reinterpret_cast<float *>(lse_per_step.data.dptr),
520525
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, dim_per_head,
521526
lse_seqlen, lse_per_step_seqlen);
527+
NVTE_CHECK_CUDA(cudaGetLastError());
522528
} else {
523529
thd_out_correction_kernel<dtype, only_second_half, tile, false>
524530
<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
@@ -528,6 +534,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co
528534
reinterpret_cast<float *>(lse_per_step.data.dptr),
529535
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, dim_per_head,
530536
lse_seqlen, lse_per_step_seqlen);
537+
NVTE_CHECK_CUDA(cudaGetLastError());
531538
}
532539
}
533540

@@ -602,6 +609,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step,
602609
reinterpret_cast<dtype *>(grad.data.dptr),
603610
reinterpret_cast<dtype *>(grad_per_step.data.dptr),
604611
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, hidden_size, total_tokens);
612+
NVTE_CHECK_CUDA(cudaGetLastError());
605613
}
606614

607615
template <typename dtype>
@@ -667,6 +675,7 @@ void thd_get_partitioned_indices(const Tensor &cu_seqlens, Tensor output, int to
667675
thd_partition_indices_kernel<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
668676
reinterpret_cast<int *>(output.data.dptr), reinterpret_cast<int *>(cu_seqlens.data.dptr),
669677
batch, total_tokens, world_size, rank);
678+
NVTE_CHECK_CUDA(cudaGetLastError());
670679
}
671680

672681
} // namespace context_parallel

transformer_engine/common/fused_attn/flash_attn.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ void prepare_flash_attn_fwd(Tensor qkvi, Tensor qkv, cudaStream_t stream) {
9191
prepare_kernel_fwd<dtype><<<grid, threads, 0, stream>>>(
9292
reinterpret_cast<dtype *>(qkvi.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
9393
shape[1], shape[2], shape[3], shape[4]););
94+
NVTE_CHECK_CUDA(cudaGetLastError());
9495
}
9596

9697
void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream_t stream) {
@@ -129,6 +130,7 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream
129130
reinterpret_cast<dtype *>(q.data.dptr), reinterpret_cast<dtype *>(k.data.dptr),
130131
reinterpret_cast<dtype *>(v.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
131132
q_shape[0], q_shape[1], q_shape[2], q_shape[3]););
133+
NVTE_CHECK_CUDA(cudaGetLastError());
132134
}
133135

134136
} // namespace flash_attention

transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
416416
actual_b, b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
417417
static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t *>(devActualSeqlenQ),
418418
static_cast<int32_t *>(devActualSeqlenKV));
419+
NVTE_CHECK_CUDA(cudaGetLastError());
419420
variant_pack[seq_q] = devActualSeqlenQ;
420421
variant_pack[seq_kv] = devActualSeqlenKV;
421422
}
@@ -454,6 +455,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
454455
layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
455456
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
456457
devOffsetsV, devOffsetsO, devOffsetsS);
458+
NVTE_CHECK_CUDA(cudaGetLastError());
457459
if (is_ragged_q) {
458460
variant_pack[offset_q] = devOffsetsQ;
459461
variant_pack[offset_o] = devOffsetsO;
@@ -883,6 +885,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
883885
actual_b, b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
884886
static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t *>(devActualSeqlenQ),
885887
static_cast<int32_t *>(devActualSeqlenKV));
888+
NVTE_CHECK_CUDA(cudaGetLastError());
886889
variant_pack[seq_q] = devActualSeqlenQ;
887890
variant_pack[seq_kv] = devActualSeqlenKV;
888891
}
@@ -916,6 +919,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
916919
layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
917920
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
918921
devOffsetsV, devOffsetsO, devOffsetsS);
922+
NVTE_CHECK_CUDA(cudaGetLastError());
919923
if (is_ragged_q) {
920924
variant_pack[offset_q] = devOffsetsQ;
921925
variant_pack[offset_o] = devOffsetsO;

transformer_engine/common/fused_attn/fused_attn_fp8.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,6 +1111,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
11111111
cu_seqlens_to_offsets<<<gridDims, blockDims, 0, stream>>>(
11121112
b, h, d, reinterpret_cast<int32_t*>(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset,
11131113
o_ragged_offset);
1114+
NVTE_CHECK_CUDA(cudaGetLastError());
11141115
void* devPtrQKVRaggedOffset = reinterpret_cast<void*>(qkv_ragged_offset);
11151116
void* devPtrORaggedOffset = reinterpret_cast<void*>(o_ragged_offset);
11161117
void* devPtrMNKOverride = reinterpret_cast<void*>(actual_seqlens_q);
@@ -1577,6 +1578,7 @@ void fused_attn_fp8_bwd_impl(
15771578
cu_seqlens_to_offsets<<<gridDims, blockDims, 0, stream>>>(
15781579
b, h, d, reinterpret_cast<int32_t*>(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset,
15791580
o_ragged_offset);
1581+
NVTE_CHECK_CUDA(cudaGetLastError());
15801582
void* devPtrQKVRaggedOffset = reinterpret_cast<void*>(qkv_ragged_offset);
15811583
void* devPtrORaggedOffset = reinterpret_cast<void*>(o_ragged_offset);
15821584
void* devPtrMNKOverride = reinterpret_cast<void*>(actual_seqlens_q);
@@ -1933,6 +1935,7 @@ void fused_attn_fp8_fwd_impl_v1(
19331935
b, b, static_cast<const int32_t*>(devPtrcuSeqlensQ), // TODO(pass max_b)
19341936
static_cast<const int32_t*>(devPtrcuSeqlensKV), static_cast<int32_t*>(devActualSeqlenQ),
19351937
static_cast<int32_t*>(devActualSeqlenKV));
1938+
NVTE_CHECK_CUDA(cudaGetLastError());
19361939
variant_pack[seq_q] = devActualSeqlenQ;
19371940
variant_pack[seq_kv] = devActualSeqlenKV;
19381941
}
@@ -2329,6 +2332,7 @@ void fused_attn_fp8_bwd_impl_v1(
23292332
b, b, static_cast<const int32_t*>(devPtrcuSeqlensQ), // TODO(pass max_b)
23302333
static_cast<const int32_t*>(devPtrcuSeqlensKV), static_cast<int32_t*>(devActualSeqlenQ),
23312334
static_cast<int32_t*>(devActualSeqlenKV));
2335+
NVTE_CHECK_CUDA(cudaGetLastError());
23322336
variant_pack[seq_q] = devActualSeqlenQ;
23332337
variant_pack[seq_kv] = devActualSeqlenKV;
23342338
}

transformer_engine/common/fused_attn/kv_cache.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso
157157
reinterpret_cast<int *>(page_table.data.dptr),
158158
reinterpret_cast<int *>(cu_new_lens.data.dptr),
159159
reinterpret_cast<int *>(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len);
160+
NVTE_CHECK_CUDA(cudaGetLastError());
160161
}
161162
dim3 grid_size(b, max_ctx_len);
162163
copy_to_kv_cache_kernel<<<grid_size, block_size, 0, stream>>>(
@@ -166,6 +167,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso
166167
reinterpret_cast<int *>(cu_new_lens.data.dptr),
167168
reinterpret_cast<int *>(cu_cached_lens.data.dptr), qkv_format, h_kv, d_k, d_v, b,
168169
max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged);
170+
NVTE_CHECK_CUDA(cudaGetLastError());
169171
}
170172
}
171173

@@ -215,6 +217,7 @@ void convert_thd_to_bshd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se
215217
reinterpret_cast<scalar_t *>(tensor.data.dptr),
216218
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
217219
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
220+
NVTE_CHECK_CUDA(cudaGetLastError());
218221
}
219222

220223
void convert_thd_to_bshd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int b,
@@ -254,6 +257,7 @@ void convert_bshd_to_thd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se
254257
reinterpret_cast<scalar_t *>(tensor.data.dptr),
255258
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
256259
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
260+
NVTE_CHECK_CUDA(cudaGetLastError());
257261
}
258262

259263
void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int t,

transformer_engine/common/fused_attn/utils.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -600,13 +600,14 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud
600600
// workspace size requires 4 bytes
601601
uint32_t *dout = static_cast<uint32_t *>(workspace);
602602
uint32_t hout{};
603-
cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream);
603+
NVTE_CHECK_CUDA(cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream));
604604
constexpr int threads = 128;
605605
const int blocks = (len - 1) / threads + 1;
606606
get_runtime_num_segments_kernel<<<blocks, threads, 0, stream>>>(static_cast<int32_t *>(cu_seqlen),
607607
len, dout);
608-
cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
609-
cudaStreamSynchronize(stream);
608+
NVTE_CHECK_CUDA(cudaGetLastError());
609+
NVTE_CHECK_CUDA(cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream));
610+
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
610611
return hout;
611612
}
612613

@@ -633,4 +634,5 @@ void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t
633634

634635
fused_attn::extract_seed_and_offset<<<1, 1, 0, stream>>>(
635636
rng_state_ptr, captured, seed_ptr, seed_val, offset_ptr, offset_val, offset_intragraph);
637+
NVTE_CHECK_CUDA(cudaGetLastError());
636638
}

transformer_engine/common/fused_router/fused_moe_aux_loss.cu

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,9 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
177177
config.stream = stream;
178178

179179
// Update the max cluster size based on the device
180-
cudaOccupancyMaxPotentialClusterSize(
180+
NVTE_CHECK_CUDA(cudaOccupancyMaxPotentialClusterSize(
181181
&cluster_size,
182-
reinterpret_cast<void*>(fused_moe_aux_loss_forward_kernel<DataType, IndexType>), &config);
182+
reinterpret_cast<void*>(fused_moe_aux_loss_forward_kernel<DataType, IndexType>), &config));
183183

184184
cudaLaunchAttribute attribute[1];
185185
attribute[0].id = cudaLaunchAttributeClusterDimension;
@@ -189,14 +189,15 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
189189
config.numAttrs = 1;
190190
config.attrs = attribute;
191191

192-
cudaLaunchKernelEx(&config, fused_moe_aux_loss_forward_kernel<DataType, IndexType>, probs,
193-
tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk,
194-
coeff, aux_loss, Const_buf);
192+
NVTE_CHECK_CUDA(cudaLaunchKernelEx(
193+
&config, fused_moe_aux_loss_forward_kernel<DataType, IndexType>, probs, tokens_per_expert,
194+
total_num_tokens, num_experts, num_rows, num_cols, topk, coeff, aux_loss, Const_buf));
195195
} else {
196196
size_t smem_size = sizeof(CompType) * num_cols;
197197
fused_moe_aux_loss_forward_kernel<DataType, IndexType>
198198
<<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts,
199199
num_rows, num_cols, topk, coeff, aux_loss, Const_buf);
200+
NVTE_CHECK_CUDA(cudaGetLastError());
200201
}
201202
}
202203

@@ -247,6 +248,7 @@ void fused_moe_aux_loss_backward_kernel_launcher(const float* Const_buf,
247248
int grid_size = (num_rows + block_size - 1) / block_size;
248249
fused_moe_aux_loss_backward_kernel<DataType, IndexType><<<grid_size, block_size, 0, stream>>>(
249250
Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss, grad_probs);
251+
NVTE_CHECK_CUDA(cudaGetLastError());
250252
}
251253

252254
void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_per_expert,

0 commit comments

Comments
 (0)