diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 963e4d03dd77b..f77b2629a19b0 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type( get_rows_cuda_float((const float *) src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; + case GGML_TYPE_I32: + get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; case GGML_TYPE_BF16: get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); @@ -210,6 +214,10 @@ void get_rows_cuda( ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; + case GGML_TYPE_I32: + ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; case GGML_TYPE_F16: ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 121e83641b23d..7e58c3ad5d2c9 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3192,6 +3192,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g switch (op->src[0]->type) { case GGML_TYPE_F16: case GGML_TYPE_F32: + case GGML_TYPE_BF16: + case GGML_TYPE_I32: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: