Skip to content

Commit 8c7e863

Browse files
committed
CUDA: add set rows for f32 and f16
1 parent 67d1ef2 commit 8c7e863

File tree

3 files changed

+143
-0
lines changed

3 files changed

+143
-0
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include "ggml-cuda/upscale.cuh"
4444
#include "ggml-cuda/wkv.cuh"
4545
#include "ggml-cuda/gla.cuh"
46+
#include "ggml-cuda/set-rows.cuh"
4647
#include "ggml.h"
4748

4849
#include <algorithm>
@@ -2230,6 +2231,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22302231
case GGML_OP_GET_ROWS_BACK:
22312232
ggml_cuda_op_get_rows_back(ctx, dst);
22322233
break;
2234+
case GGML_OP_SET_ROWS:
2235+
ggml_cuda_op_set_rows(ctx, dst);
2236+
break;
22332237
case GGML_OP_DUP:
22342238
ggml_cuda_dup(ctx, dst);
22352239
break;
@@ -3214,6 +3218,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32143218
{
32153219
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
32163220
} break;
3221+
case GGML_OP_SET_ROWS:
3222+
{
3223+
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
3224+
op->src[0]->type == GGML_TYPE_F32 &&
3225+
op->src[1]->type == GGML_TYPE_I64;
3226+
} break;
32173227
case GGML_OP_CPY:
32183228
{
32193229
ggml_type src0_type = op->src[0]->type;

ggml/src/ggml-cuda/set-rows.cu

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#include "set-rows.cuh"
2+
3+
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
4+
5+
static __device__ void set_rows_1_f32_f32(const char * src, char * dst) {
6+
const float * src_f = (const float *) src;
7+
float * dst_f = (float *) dst;
8+
*dst_f = *src_f;
9+
}
10+
11+
static __device__ void set_rows_1_f32_f16(const char * src, char * dst) {
12+
const float * src_f = (const float *) src;
13+
half * dst_h = (half *) dst;
14+
*dst_h = __float2half(*src_f);
15+
}
16+
17+
//TODO: consolidate kernels from cpy.cu, get_rows etc to make this function generic
18+
template<set_rows_kernel_t set_rows_1>
19+
static __global__ void k_set_rows(
20+
const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
21+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
22+
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
23+
const size_t nb01, const size_t nb02, const size_t nb03,
24+
const size_t nb10, const size_t nb11, const size_t nb12,
25+
const size_t nb1, const size_t nb2, const size_t nb3,
26+
const size_t src_type_size, const size_t dst_type_size) {
27+
28+
const int i03 = blockIdx.z;
29+
const int i02 = blockIdx.y;
30+
const int i01 = blockIdx.x * blockDim.y + threadIdx.y; // Row index
31+
32+
if (i01 >= ne01) {
33+
return;
34+
}
35+
36+
const int i12 = i03 % ne12;
37+
const int i11 = i02 % ne11;
38+
const int i10 = i01;
39+
40+
const int64_t dst_row = *(int64_t *)((char *)src1 + i10*nb10 + i11*nb11 + i12*nb12);
41+
42+
const char * src0_row = src0 + i01*nb01 + i02*nb02 + i03*nb03;
43+
char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
44+
45+
for (int col = threadIdx.x; col < ne00; col += blockDim.x) {
46+
const char * src_elem = src0_row + col * src_type_size;
47+
char * dst_elem = dst_row_ptr + col * dst_type_size;
48+
set_rows_1(src_elem, dst_elem);
49+
}
50+
}
51+
52+
template<set_rows_kernel_t set_rows_1>
53+
static void set_rows_cuda(
54+
const char * src0_d, const int64_t * src1_d, char * dst_d,
55+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
56+
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
57+
const size_t nb01, const size_t nb02, const size_t nb03,
58+
const size_t nb10, const size_t nb11, const size_t nb12,
59+
const size_t nb1, const size_t nb2, const size_t nb3,
60+
const size_t src_type_size, const size_t dst_type_size,
61+
cudaStream_t stream) {
62+
63+
const int max_threads_per_row = 256;
64+
const int threads_per_row = std::min((int)ne00, max_threads_per_row);
65+
66+
const int max_threads_per_block = 256;
67+
const int rows_per_block = std::max(1, max_threads_per_block / threads_per_row);
68+
69+
const dim3 block_size(threads_per_row, rows_per_block, 1);
70+
const dim3 grid_size(
71+
(ne01 + rows_per_block - 1) / rows_per_block, // thread-groups
72+
ne02,
73+
ne03
74+
);
75+
76+
if (ne01 > 0 && ne00 > 0) {
77+
k_set_rows<set_rows_1><<<grid_size, block_size, 0, stream>>>(
78+
src0_d, src1_d, dst_d,
79+
ne00, ne01, ne02, ne03,
80+
ne10, ne11, ne12, ne13,
81+
nb01, nb02, nb03,
82+
nb10, nb11, nb12,
83+
nb1, nb2, nb3,
84+
src_type_size, dst_type_size
85+
);
86+
}
87+
}
88+
89+
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
90+
const ggml_tensor * src0 = dst->src[0];
91+
const ggml_tensor * src1 = dst->src[1];
92+
93+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
94+
GGML_ASSERT(src1->type == GGML_TYPE_I64);
95+
96+
GGML_TENSOR_BINARY_OP_LOCALS
97+
98+
const float * src0_d = (const float *)src0->data;
99+
const int64_t * src1_d = (const int64_t *)src1->data;
100+
101+
cudaStream_t stream = ctx.stream();
102+
103+
if (dst->type == GGML_TYPE_F32) {
104+
set_rows_cuda<set_rows_1_f32_f32>(
105+
(const char *)src0_d, src1_d, (char *)dst->data,
106+
ne00, ne01, ne02, ne03,
107+
ne10, ne11, ne12, ne13,
108+
nb01, nb02, nb03,
109+
nb10, nb11, nb12,
110+
nb1, nb2, nb3,
111+
sizeof(float), sizeof(float),
112+
stream
113+
);
114+
} else if (dst->type == GGML_TYPE_F16) {
115+
set_rows_cuda<set_rows_1_f32_f16>(
116+
(const char *)src0_d, src1_d, (char *)dst->data,
117+
ne00, ne01, ne02, ne03,
118+
ne10, ne11, ne12, ne13,
119+
nb01, nb02, nb03,
120+
nb10, nb11, nb12,
121+
nb1, nb2, nb3,
122+
sizeof(float), sizeof(half),
123+
stream
124+
);
125+
} else {
126+
GGML_ABORT("unsupported type");
127+
}
128+
}

ggml/src/ggml-cuda/set-rows.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
3+
#include "common.cuh"
4+
5+
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)