2
2
3
3
typedef void (*set_rows_kernel_t )(const char * src, char * dst);
4
4
5
- static __device__ void set_rows_1_f32_f32 (const char * src, char * dst) {
6
- const float * src_f = (const float *) src;
7
- float * dst_f = (float *) dst;
8
- *dst_f = *src_f;
5
+ template <typename src_t , typename dst_t >
6
+ __device__ void set_rows_1 (const src_t * src_f, dst_t * dst_f) {
7
+ GGML_ABORT (" unsupport type for set_rows" );
9
8
}
10
9
11
- static __device__ void set_rows_1_f32_f16 (const char * src, char * dst) {
12
- const float * src_f = (const float *) src;
13
- half * dst_h = (half *) dst;
10
+ template <>
11
+ __device__ __forceinline__ void set_rows_1<float , half>(const float * src_f, half * dst_h) {
14
12
*dst_h = __float2half (*src_f);
15
13
}
16
14
15
+ template <>
16
+ __device__ __forceinline__ void set_rows_1<float , float >(const float * src_f, float * dst_f) {
17
+ *dst_f = *src_f;
18
+ }
19
+
17
20
// TODO: consolidate kernels from cpy.cu, get_rows etc to make this function generic
18
- template <set_rows_kernel_t set_rows_1 >
21
+ template <typename src_t , typename dst_t >
19
22
static __global__ void k_set_rows (
20
- const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
23
+ const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
21
24
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
22
25
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
23
26
const size_t nb01, const size_t nb02, const size_t nb03,
24
27
const size_t nb10, const size_t nb11, const size_t nb12,
25
28
const size_t nb1, const size_t nb2, const size_t nb3,
26
29
const size_t src_type_size, const size_t dst_type_size) {
27
30
28
- const int i03 = blockIdx .z ;
29
- const int i02 = blockIdx .y ;
30
- const int i01 = blockIdx .x * blockDim .y + threadIdx .y ; // Row index
31
+ const int i03 = blockIdx .z / ne02;
32
+ const int i02 = blockIdx .z % ne02;
33
+ const int i01 = blockDim .x * blockIdx .x + threadIdx .x ;
34
+ const int i00 = blockIdx .y ;
31
35
32
36
if (i01 >= ne01) {
33
37
return ;
@@ -37,21 +41,19 @@ static __global__ void k_set_rows(
37
41
const int i11 = i02 % ne11;
38
42
const int i10 = i01;
39
43
40
- const int64_t dst_row = *(int64_t *)(( char *) src1 + i10*nb10 + i11*nb11 + i12*nb12);
44
+ const int64_t dst_row = *(src1 + i10*nb10 + i11*nb11 + i12*nb12);
41
45
42
- const char * src0_row = src0 + i01*nb01 + i02*nb02 + i03*nb03;
43
- char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
46
+ const src_t * src0_row = ( const src_t *) src0 + i01*nb01 + i02*nb02 + i03*nb03;
47
+ dst_t * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
44
48
45
- for (int col = threadIdx .x ; col < ne00; col += blockDim .x ) {
46
- const char * src_elem = src0_row + col * src_type_size;
47
- char * dst_elem = dst_row_ptr + col * dst_type_size;
48
- set_rows_1 (src_elem, dst_elem);
49
- }
49
+ const src_t * src_elem = src0_row + i00;
50
+ dst_t * dst_elem = dst_row_ptr + i00;
51
+ set_rows_1 (src_elem, dst_elem);
50
52
}
51
53
52
- template <set_rows_kernel_t set_rows_1 >
54
+ template <typename src_t , typename dst_t >
53
55
static void set_rows_cuda (
54
- const char * src0_d, const int64_t * src1_d, char * dst_d,
56
+ const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
55
57
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
56
58
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
57
59
const size_t nb01, const size_t nb02, const size_t nb03,
@@ -60,32 +62,39 @@ static void set_rows_cuda(
60
62
const size_t src_type_size, const size_t dst_type_size,
61
63
cudaStream_t stream) {
62
64
63
- const int max_threads_per_row = 256 ;
64
- const int threads_per_row = std::min ((int )ne00, max_threads_per_row);
65
-
66
- const int max_threads_per_block = 256 ;
67
- const int rows_per_block = std::max (1 , max_threads_per_block / threads_per_row);
68
-
69
- const dim3 block_size (threads_per_row, rows_per_block, 1 );
65
+ const dim3 block_size (CUDA_SET_ROWS_BLOCK_SIZE);
70
66
const dim3 grid_size (
71
- (ne01 + rows_per_block - 1 ) / rows_per_block, // thread-groups
72
- ne02 ,
73
- ne03
67
+ (ne01 + CUDA_SET_ROWS_BLOCK_SIZE - 1 )/CUDA_SET_ROWS_BLOCK_SIZE,
68
+ ne00 ,
69
+ ne03*ne02
74
70
);
75
71
76
- if (ne01 > 0 && ne00 > 0 ) {
77
- k_set_rows<set_rows_1><<<grid_size, block_size, 0 , stream>>> (
72
+ const int s1 = nb01 / sizeof (src_t );
73
+ const int s2 = nb02 / sizeof (src_t );
74
+ const int s3 = nb03 / sizeof (src_t );
75
+
76
+ const int s10 = nb10 / sizeof (int64_t );
77
+ const int s11 = nb11 / sizeof (int64_t );
78
+ const int s12 = nb12 / sizeof (int64_t );
79
+
80
+ const int s_dst = nb1 / sizeof (dst_t );
81
+ const int s_dst2 = nb2 / sizeof (dst_t );
82
+ const int s_dst3 = nb3 / sizeof (dst_t );
83
+
84
+
85
+ if (ne01 > 0 && ne00 > 0 ) {
86
+ k_set_rows<<<grid_size, block_size, 0 , stream>>> (
78
87
src0_d, src1_d, dst_d,
79
88
ne00, ne01, ne02, ne03,
80
89
ne10, ne11, ne12, ne13,
81
- nb01, nb02, nb03,
82
- nb10, nb11, nb12,
83
- nb1, nb2, nb3,
84
- src_type_size, dst_type_size
85
- );
90
+ s1, s2, s3,
91
+ s10, s11, s12,
92
+ s_dst, s_dst2, s_dst3,
93
+ src_type_size, dst_type_size);
86
94
}
87
95
}
88
96
97
+
89
98
void ggml_cuda_op_set_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
90
99
const ggml_tensor * src0 = dst->src [0 ];
91
100
const ggml_tensor * src1 = dst->src [1 ];
@@ -101,8 +110,8 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
101
110
cudaStream_t stream = ctx.stream ();
102
111
103
112
if (dst->type == GGML_TYPE_F32) {
104
- set_rows_cuda<set_rows_1_f32_f32> (
105
- ( const char *) src0_d, src1_d, (char *)dst->data ,
113
+ set_rows_cuda (
114
+ src0_d, src1_d, (float *)dst->data ,
106
115
ne00, ne01, ne02, ne03,
107
116
ne10, ne11, ne12, ne13,
108
117
nb01, nb02, nb03,
@@ -112,8 +121,8 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
112
121
stream
113
122
);
114
123
} else if (dst->type == GGML_TYPE_F16) {
115
- set_rows_cuda<set_rows_1_f32_f16> (
116
- ( const char *) src0_d, src1_d, (char *)dst->data ,
124
+ set_rows_cuda (
125
+ src0_d, src1_d, (half *)dst->data ,
117
126
ne00, ne01, ne02, ne03,
118
127
ne10, ne11, ne12, ne13,
119
128
nb01, nb02, nb03,
0 commit comments