From 943f1e0ac9366491f8f589aefef62f65c3f3c07c Mon Sep 17 00:00:00 2001 From: Jennifer Zhou Date: Tue, 12 Nov 2024 20:30:40 -0800 Subject: [PATCH] Fix an int conversion error (#1325) fix an int conversion error Signed-off-by: Jennifer Zhou --- transformer_engine/jax/csrc/extensions/activation.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 9d5fb4f7b4..a2090bceba 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -264,8 +264,8 @@ Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act auto *output = output_buf->untyped_data(); auto act_input_dims = act_input_buf.dimensions(); - auto m = product(act_input_dims, 0, act_input_dims.size() - 2); - auto n = act_input_dims.back(); + auto m = static_cast(product(act_input_dims, 0, act_input_dims.size() - 2)); + auto n = static_cast(act_input_dims.back()); auto act_len = act_input_dims.end()[-2]; auto input_shape = std::vector{m, n};