|
| 1 | +#include "set-rows.cuh" |
| 2 | + |
| 3 | +typedef void (*set_rows_kernel_t)(const char * src, char * dst); |
| 4 | + |
| 5 | +template<typename src_t, typename dst_t> |
| 6 | +__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {} |
| 7 | + |
| 8 | +template<> |
| 9 | +__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) { |
| 10 | + *dst_h = __float2half(*src_f); |
| 11 | +} |
| 12 | + |
| 13 | +template<> |
| 14 | +__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) { |
| 15 | + *dst_f = *src_f; |
| 16 | +} |
| 17 | + |
| 18 | +template<typename src_t, typename dst_t> |
| 19 | +static __global__ void k_set_rows( |
| 20 | + const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst, |
| 21 | + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, |
| 22 | + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, |
| 23 | + const int64_t s01, const int64_t s02, const int64_t s03, |
| 24 | + const int64_t s10, const int64_t s11, const int64_t s12, |
| 25 | + const int64_t s1, const int64_t s2, const int64_t s3) { |
| 26 | + |
| 27 | + const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; |
| 28 | + const int64_t ne_total = ne00 * ne01 * ne02 * ne03; |
| 29 | + |
| 30 | + if (i >= ne_total) { |
| 31 | + return; |
| 32 | + } |
| 33 | + |
| 34 | + const int64_t i03 = i / (ne00 * ne01 * ne02); |
| 35 | + const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); |
| 36 | + const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00; |
| 37 | + const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00; |
| 38 | + |
| 39 | + const int64_t i12 = i03 % ne12; |
| 40 | + const int64_t i11 = i02 % ne11; |
| 41 | + const int64_t i10 = i01; |
| 42 | + |
| 43 | + const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); |
| 44 | + |
| 45 | + const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; |
| 46 | + dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3; |
| 47 | + |
| 48 | + const src_t* src_elem = src0_row + i00; |
| 49 | + dst_t* dst_elem = dst_row_ptr + i00; |
| 50 | + set_rows_1(src_elem, dst_elem); |
| 51 | +} |
| 52 | + |
| 53 | +template<typename src_t, typename dst_t> |
| 54 | +static void set_rows_cuda( |
| 55 | + const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d, |
| 56 | + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, |
| 57 | + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, |
| 58 | + const size_t nb01, const size_t nb02, const size_t nb03, |
| 59 | + const size_t nb10, const size_t nb11, const size_t nb12, |
| 60 | + const size_t nb1, const size_t nb2, const size_t nb3, |
| 61 | + cudaStream_t stream) { |
| 62 | + |
| 63 | + const int64_t ne_total = ne00 * ne01 * ne02 * ne03; |
| 64 | + const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE; |
| 65 | + const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE); |
| 66 | + const dim3 grid_size(num_blocks); |
| 67 | + |
| 68 | + |
| 69 | + const int64_t s01 = nb01/sizeof(src_t); |
| 70 | + const int64_t s02 = nb02/sizeof(src_t); |
| 71 | + const int64_t s03 = nb03/sizeof(src_t); |
| 72 | + const int64_t s10 = nb10/sizeof(int64_t); |
| 73 | + const int64_t s11 = nb11/sizeof(int64_t); |
| 74 | + const int64_t s12 = nb12/sizeof(int64_t); |
| 75 | + const int64_t s1 = nb1/sizeof(dst_t); |
| 76 | + const int64_t s2 = nb2/sizeof(dst_t); |
| 77 | + const int64_t s3 = nb3/sizeof(dst_t); |
| 78 | + |
| 79 | + if (ne_total > 0) { |
| 80 | + k_set_rows<<<grid_size, block_size, 0, stream>>>( |
| 81 | + src0_d, src1_d, dst_d, |
| 82 | + ne00, ne01, ne02, ne03, |
| 83 | + ne10, ne11, ne12, ne13, |
| 84 | + s01, s02, s03, |
| 85 | + s10, s11, s12, |
| 86 | + s1, s2, s3); |
| 87 | + } |
| 88 | +} |
| 89 | + |
| 90 | + |
| 91 | +void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 92 | + const ggml_tensor * src0 = dst->src[0]; |
| 93 | + const ggml_tensor * src1 = dst->src[1]; |
| 94 | + |
| 95 | + GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 96 | + GGML_ASSERT(src1->type == GGML_TYPE_I64); |
| 97 | + |
| 98 | + GGML_TENSOR_BINARY_OP_LOCALS |
| 99 | + |
| 100 | + const float * src0_d = (const float *)src0->data; |
| 101 | + const int64_t * src1_d = (const int64_t *)src1->data; |
| 102 | + |
| 103 | + cudaStream_t stream = ctx.stream(); |
| 104 | + |
| 105 | + |
| 106 | + |
| 107 | + if (dst->type == GGML_TYPE_F32) { |
| 108 | + set_rows_cuda( |
| 109 | + src0_d, src1_d, (float*)dst->data, |
| 110 | + ne00, ne01, ne02, ne03, |
| 111 | + ne10, ne11, ne12, ne13, |
| 112 | + nb01, nb02, nb03, |
| 113 | + nb10, nb11, nb12, |
| 114 | + nb1, nb2, nb3, |
| 115 | + stream |
| 116 | + ); |
| 117 | + } else if (dst->type == GGML_TYPE_F16) { |
| 118 | + set_rows_cuda( |
| 119 | + src0_d, src1_d, (half*)dst->data, |
| 120 | + ne00, ne01, ne02, ne03, |
| 121 | + ne10, ne11, ne12, ne13, |
| 122 | + nb01, nb02, nb03, |
| 123 | + nb10, nb11, nb12, |
| 124 | + nb1, nb2, nb3, |
| 125 | + stream |
| 126 | + ); |
| 127 | + } else { |
| 128 | + GGML_ABORT("unsupported type"); |
| 129 | + } |
| 130 | +} |
0 commit comments