Skip to content

Commit b9c3eef

Browse files
authored
CUDA: add bf16 and i32 to getrows (#14529)
1 parent 6491d6e commit b9c3eef

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

ggml/src/ggml-cuda/getrows.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
168168
get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
169169
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
170170
break;
171+
case GGML_TYPE_I32:
172+
get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
173+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
174+
break;
171175
case GGML_TYPE_BF16:
172176
get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
173177
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
@@ -210,6 +214,10 @@ void get_rows_cuda(
210214
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
211215
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
212216
break;
217+
case GGML_TYPE_I32:
218+
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
219+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
220+
break;
213221
case GGML_TYPE_F16:
214222
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
215223
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3200,6 +3200,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32003200
switch (op->src[0]->type) {
32013201
case GGML_TYPE_F16:
32023202
case GGML_TYPE_F32:
3203+
case GGML_TYPE_BF16:
3204+
case GGML_TYPE_I32:
32033205
case GGML_TYPE_Q4_0:
32043206
case GGML_TYPE_Q4_1:
32053207
case GGML_TYPE_Q5_0:

0 commit comments

Comments
 (0)