Skip to content

Commit 78e097f

Browse files
alexeldeibyaox12
andauthored
[Jax] Fix narrowing conversions (#2094)
Signed-off-by: Ace Eldeib <[email protected]> Co-authored-by: Xin Yao <[email protected]>
1 parent d88137c commit 78e097f

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

transformer_engine/jax/csrc/extensions/activation.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
3737
auto is_2x = static_cast<bool>(is_2x_int);
3838
auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis
3939

40-
auto input_shape = std::vector<size_t>{m, act_len * n};
41-
auto output_shape = std::vector<size_t>{m, n};
42-
auto output_trans_shape = std::vector<size_t>{n, m};
40+
auto input_shape = std::vector<size_t>{m, static_cast<size_t>(act_len * n)};
41+
auto output_shape = std::vector<size_t>{m, static_cast<size_t>(n)};
42+
auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n), m};
4343
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
4444
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
4545
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
@@ -253,11 +253,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
253253
auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
254254
auto n = input_dims.back();
255255

256-
auto input_shape = std::vector<size_t>{m, n};
257-
auto act_input_shape = std::vector<size_t>{m, n * act_len};
258-
auto output_shape = std::vector<size_t>{m, n * act_len};
259-
auto output_trans_shape = std::vector<size_t>{n * act_len, m};
260-
auto dbias_shape = std::vector<size_t>{n * act_len};
256+
auto input_shape = std::vector<size_t>{m, static_cast<size_t>(n)};
257+
auto act_input_shape = std::vector<size_t>{m, static_cast<size_t>(n * act_len)};
258+
auto output_shape = std::vector<size_t>{m, static_cast<size_t>(n * act_len)};
259+
auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n * act_len), m};
260+
auto dbias_shape = std::vector<size_t>{static_cast<size_t>(n * act_len)};
261261
std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());
262262

263263
auto input_tensor =

transformer_engine/jax/csrc/extensions/normalization.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
118118
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
119119
std::vector<size_t>{
120120
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
121-
scale_inv_buf->dimensions().back()});
121+
static_cast<size_t>(scale_inv_buf->dimensions().back())});
122122
}
123123

124124
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
@@ -135,7 +135,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
135135
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
136136
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
137137
colwise_scale_inv_buf->dimensions().size() - 1),
138-
colwise_scale_inv_buf->dimensions().back()});
138+
static_cast<size_t>(colwise_scale_inv_buf->dimensions().back())});
139139
}
140140

141141
if (_norm_type == NVTE_Norm_Type::LayerNorm) {

0 commit comments

Comments
 (0)