From 853bc5ecaf6112589b27e360bee67a020416a8a3 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sun, 6 Jul 2025 22:26:34 +0800 Subject: [PATCH 1/4] CUDA: add set rows for f32 and f16 --- ggml/src/ggml-cuda/ggml-cuda.cu | 10 +++ ggml/src/ggml-cuda/set-rows.cu | 128 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/set-rows.cuh | 5 ++ 3 files changed, 143 insertions(+) create mode 100644 ggml/src/ggml-cuda/set-rows.cu create mode 100644 ggml/src/ggml-cuda/set-rows.cuh diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 72406f0af3622..88b17dd682c95 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -43,6 +43,7 @@ #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" +#include "ggml-cuda/set-rows.cuh" #include "ggml.h" #include @@ -2230,6 +2231,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GET_ROWS_BACK: ggml_cuda_op_get_rows_back(ctx, dst); break; + case GGML_OP_SET_ROWS: + ggml_cuda_op_set_rows(ctx, dst); + break; case GGML_OP_DUP: ggml_cuda_dup(ctx, dst); break; @@ -3216,6 +3220,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g { return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; } break; + case GGML_OP_SET_ROWS: + { + return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_I64; + } break; case GGML_OP_CPY: { ggml_type src0_type = op->src[0]->type; diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu new file mode 100644 index 0000000000000..61030bef4e1a3 --- /dev/null +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -0,0 +1,128 @@ +#include "set-rows.cuh" + +typedef void (*set_rows_kernel_t)(const char * src, char * dst); + +static __device__ void set_rows_1_f32_f32(const char * src, char * dst) { + const float * src_f = (const float *) src; + float * dst_f = (float *) dst; + *dst_f = *src_f; +} + +static __device__ void set_rows_1_f32_f16(const char * src, char * dst) { + const float * src_f = (const float *) src; + half * dst_h = (half *) dst; + *dst_h = __float2half(*src_f); +} + +//TODO: consolidate kernels from cpy.cu, get_rows etc to make this function generic +template +static __global__ void k_set_rows( + const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const size_t nb01, const size_t nb02, const size_t nb03, + const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + const size_t src_type_size, const size_t dst_type_size) { + + const int i03 = blockIdx.z; + const int i02 = blockIdx.y; + const int i01 = blockIdx.x * blockDim.y + threadIdx.y; // Row index + + if (i01 >= ne01) { + return; + } + + const int i12 = i03 % ne12; + const int i11 = i02 % ne11; + const int i10 = i01; + + const int64_t dst_row = *(int64_t *)((char *)src1 + i10*nb10 + i11*nb11 + i12*nb12); + + const char * src0_row = src0 + i01*nb01 + i02*nb02 + i03*nb03; + char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3; + + for (int col = threadIdx.x; col < ne00; col += blockDim.x) { + const char * src_elem = src0_row + col * src_type_size; + char * dst_elem = dst_row_ptr + col * dst_type_size; + set_rows_1(src_elem, dst_elem); + } +} + +template +static void set_rows_cuda( + const char * src0_d, const int64_t * src1_d, char * dst_d, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const size_t nb01, const size_t nb02, const size_t nb03, + const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + const size_t src_type_size, const size_t dst_type_size, + cudaStream_t stream) { + + const int max_threads_per_row = 256; + const int threads_per_row = std::min((int)ne00, max_threads_per_row); + + const int max_threads_per_block = 256; + const int rows_per_block = std::max(1, max_threads_per_block / threads_per_row); + + const dim3 block_size(threads_per_row, rows_per_block, 1); + const dim3 grid_size( + (ne01 + rows_per_block - 1) / rows_per_block, // thread-groups + ne02, + ne03 + ); + + if (ne01 > 0 && ne00 > 0) { + k_set_rows<<>>( + src0_d, src1_d, dst_d, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + src_type_size, dst_type_size + ); + } +} + +void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_I64); + + GGML_TENSOR_BINARY_OP_LOCALS + + const float * src0_d = (const float *)src0->data; + const int64_t * src1_d = (const int64_t *)src1->data; + + cudaStream_t stream = ctx.stream(); + + if (dst->type == GGML_TYPE_F32) { + set_rows_cuda( + (const char *)src0_d, src1_d, (char *)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + sizeof(float), sizeof(float), + stream + ); + } else if (dst->type == GGML_TYPE_F16) { + set_rows_cuda( + (const char *)src0_d, src1_d, (char *)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + sizeof(float), sizeof(half), + stream + ); + } else { + GGML_ABORT("unsupported type"); + } +} diff --git a/ggml/src/ggml-cuda/set-rows.cuh b/ggml/src/ggml-cuda/set-rows.cuh new file mode 100644 index 0000000000000..6d5022a64ac69 --- /dev/null +++ b/ggml/src/ggml-cuda/set-rows.cuh @@ -0,0 +1,5 @@ +#pragma once + +#include "common.cuh" + +void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 15e1b89711da177a0aa8a188220e1610c436e80d Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 9 Jul 2025 00:49:29 +0800 Subject: [PATCH 2/4] Review: change kernel params, use strides from host --- ggml/src/ggml-cuda/set-rows.cu | 95 ++++++++++++++++++--------------- ggml/src/ggml-cuda/set-rows.cuh | 2 + 2 files changed, 54 insertions(+), 43 deletions(-) diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 61030bef4e1a3..56e20b875ec68 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -2,22 +2,25 @@ typedef void (*set_rows_kernel_t)(const char * src, char * dst); -static __device__ void set_rows_1_f32_f32(const char * src, char * dst) { - const float * src_f = (const float *) src; - float * dst_f = (float *) dst; - *dst_f = *src_f; +template +__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) { + GGML_ABORT("unsupport type for set_rows"); } -static __device__ void set_rows_1_f32_f16(const char * src, char * dst) { - const float * src_f = (const float *) src; - half * dst_h = (half *) dst; +template<> +__device__ __forceinline__ void set_rows_1(const float * src_f, half * dst_h) { *dst_h = __float2half(*src_f); } +template<> +__device__ __forceinline__ void set_rows_1(const float * src_f, float * dst_f) { + *dst_f = *src_f; +} + //TODO: consolidate kernels from cpy.cu, get_rows etc to make this function generic -template +template static __global__ void k_set_rows( - const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst, + const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, const size_t nb01, const size_t nb02, const size_t nb03, @@ -25,9 +28,10 @@ static __global__ void k_set_rows( const size_t nb1, const size_t nb2, const size_t nb3, const size_t src_type_size, const size_t dst_type_size) { - const int i03 = blockIdx.z; - const int i02 = blockIdx.y; - const int i01 = blockIdx.x * blockDim.y + threadIdx.y; // Row index + const int i03 = blockIdx.z / ne02; + const int i02 = blockIdx.z % ne02; + const int i01 = blockDim.x * blockIdx.x + threadIdx.x; + const int i00 = blockIdx.y; if (i01 >= ne01) { return; @@ -37,21 +41,19 @@ static __global__ void k_set_rows( const int i11 = i02 % ne11; const int i10 = i01; - const int64_t dst_row = *(int64_t *)((char *)src1 + i10*nb10 + i11*nb11 + i12*nb12); + const int64_t dst_row = *(src1 + i10*nb10 + i11*nb11 + i12*nb12); - const char * src0_row = src0 + i01*nb01 + i02*nb02 + i03*nb03; - char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3; + const src_t * src0_row = (const src_t *)src0 + i01*nb01 + i02*nb02 + i03*nb03; + dst_t * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3; - for (int col = threadIdx.x; col < ne00; col += blockDim.x) { - const char * src_elem = src0_row + col * src_type_size; - char * dst_elem = dst_row_ptr + col * dst_type_size; - set_rows_1(src_elem, dst_elem); - } + const src_t* src_elem = src0_row + i00; + dst_t* dst_elem = dst_row_ptr + i00; + set_rows_1(src_elem, dst_elem); } -template +template static void set_rows_cuda( - const char * src0_d, const int64_t * src1_d, char * dst_d, + const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, const size_t nb01, const size_t nb02, const size_t nb03, @@ -60,32 +62,39 @@ static void set_rows_cuda( const size_t src_type_size, const size_t dst_type_size, cudaStream_t stream) { - const int max_threads_per_row = 256; - const int threads_per_row = std::min((int)ne00, max_threads_per_row); - - const int max_threads_per_block = 256; - const int rows_per_block = std::max(1, max_threads_per_block / threads_per_row); - - const dim3 block_size(threads_per_row, rows_per_block, 1); + const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE); const dim3 grid_size( - (ne01 + rows_per_block - 1) / rows_per_block, // thread-groups - ne02, - ne03 + (ne01 + CUDA_SET_ROWS_BLOCK_SIZE - 1)/CUDA_SET_ROWS_BLOCK_SIZE, + ne00, + ne03*ne02 ); - if (ne01 > 0 && ne00 > 0) { - k_set_rows<<>>( + const int s1 = nb01 / sizeof(src_t); + const int s2 = nb02 / sizeof(src_t); + const int s3 = nb03 / sizeof(src_t); + + const int s10 = nb10 / sizeof(int64_t); + const int s11 = nb11 / sizeof(int64_t); + const int s12 = nb12 / sizeof(int64_t); + + const int s_dst = nb1 / sizeof(dst_t); + const int s_dst2 = nb2 / sizeof(dst_t); + const int s_dst3 = nb3 / sizeof(dst_t); + + + if(ne01 > 0 && ne00 > 0) { + k_set_rows<<>>( src0_d, src1_d, dst_d, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, - nb01, nb02, nb03, - nb10, nb11, nb12, - nb1, nb2, nb3, - src_type_size, dst_type_size - ); + s1, s2, s3, + s10, s11, s12, + s_dst, s_dst2, s_dst3, + src_type_size, dst_type_size); } } + void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -101,8 +110,8 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { cudaStream_t stream = ctx.stream(); if (dst->type == GGML_TYPE_F32) { - set_rows_cuda( - (const char *)src0_d, src1_d, (char *)dst->data, + set_rows_cuda( + src0_d, src1_d, (float*)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb01, nb02, nb03, @@ -112,8 +121,8 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { stream ); } else if (dst->type == GGML_TYPE_F16) { - set_rows_cuda( - (const char *)src0_d, src1_d, (char *)dst->data, + set_rows_cuda( + src0_d, src1_d, (half*)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/set-rows.cuh b/ggml/src/ggml-cuda/set-rows.cuh index 6d5022a64ac69..b5fe3d799539c 100644 --- a/ggml/src/ggml-cuda/set-rows.cuh +++ b/ggml/src/ggml-cuda/set-rows.cuh @@ -2,4 +2,6 @@ #include "common.cuh" +#define CUDA_SET_ROWS_BLOCK_SIZE 64 + void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 85e2a202eaeb5096b6096d8e4ef387a6433f2e69 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 9 Jul 2025 14:03:51 +0800 Subject: [PATCH 3/4] Use 1-d kernel --- ggml/src/ggml-cuda/set-rows.cu | 69 +++++++++++++++------------------ ggml/src/ggml-cuda/set-rows.cuh | 2 +- 2 files changed, 32 insertions(+), 39 deletions(-) diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 56e20b875ec68..f06d06f8b6bf0 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -3,9 +3,7 @@ typedef void (*set_rows_kernel_t)(const char * src, char * dst); template -__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) { - GGML_ABORT("unsupport type for set_rows"); -} +__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {} template<> __device__ __forceinline__ void set_rows_1(const float * src_f, half * dst_h) { @@ -17,7 +15,6 @@ __device__ __forceinline__ void set_rows_1(const float * src_f, fl *dst_f = *src_f; } -//TODO: consolidate kernels from cpy.cu, get_rows etc to make this function generic template static __global__ void k_set_rows( const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst, @@ -25,25 +22,27 @@ static __global__ void k_set_rows( const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, const size_t nb01, const size_t nb02, const size_t nb03, const size_t nb10, const size_t nb11, const size_t nb12, - const size_t nb1, const size_t nb2, const size_t nb3, - const size_t src_type_size, const size_t dst_type_size) { + const size_t nb1, const size_t nb2, const size_t nb3) { - const int i03 = blockIdx.z / ne02; - const int i02 = blockIdx.z % ne02; - const int i01 = blockDim.x * blockIdx.x + threadIdx.x; - const int i00 = blockIdx.y; + const int64_t i = blockDim.x * blockIdx.x + threadIdx.x; + const int64_t ne_total = ne00 * ne01 * ne02 * ne03; - if (i01 >= ne01) { + if (i >= ne_total) { return; } - const int i12 = i03 % ne12; - const int i11 = i02 % ne11; - const int i10 = i01; + const int64_t i03 = i / (ne00 * ne01 * ne02); + const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00; + const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00; + + const int64_t i12 = i03 % ne12; + const int64_t i11 = i02 % ne11; + const int64_t i10 = i01; const int64_t dst_row = *(src1 + i10*nb10 + i11*nb11 + i12*nb12); - const src_t * src0_row = (const src_t *)src0 + i01*nb01 + i02*nb02 + i03*nb03; + const src_t * src0_row = src0 + i01*nb01 + i02*nb02 + i03*nb03; dst_t * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3; const src_t* src_elem = src0_row + i00; @@ -59,38 +58,32 @@ static void set_rows_cuda( const size_t nb01, const size_t nb02, const size_t nb03, const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb1, const size_t nb2, const size_t nb3, - const size_t src_type_size, const size_t dst_type_size, cudaStream_t stream) { + const int64_t ne_total = ne00 * ne01 * ne02 * ne03; + const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE; const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE); - const dim3 grid_size( - (ne01 + CUDA_SET_ROWS_BLOCK_SIZE - 1)/CUDA_SET_ROWS_BLOCK_SIZE, - ne00, - ne03*ne02 - ); - - const int s1 = nb01 / sizeof(src_t); - const int s2 = nb02 / sizeof(src_t); - const int s3 = nb03 / sizeof(src_t); + const dim3 grid_size(num_blocks); - const int s10 = nb10 / sizeof(int64_t); - const int s11 = nb11 / sizeof(int64_t); - const int s12 = nb12 / sizeof(int64_t); - const int s_dst = nb1 / sizeof(dst_t); - const int s_dst2 = nb2 / sizeof(dst_t); - const int s_dst3 = nb3 / sizeof(dst_t); + const int64_t s01 = nb01/sizeof(src_t); + const int64_t s02 = nb02/sizeof(src_t); + const int64_t s03 = nb03/sizeof(src_t); + const int64_t s10 = nb10/sizeof(int64_t); + const int64_t s11 = nb11/sizeof(int64_t); + const int64_t s12 = nb12/sizeof(int64_t); + const int64_t s1 = nb1/sizeof(dst_t); + const int64_t s2 = nb2/sizeof(dst_t); + const int64_t s3 = nb3/sizeof(dst_t); - - if(ne01 > 0 && ne00 > 0) { + if (ne_total > 0) { k_set_rows<<>>( src0_d, src1_d, dst_d, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, - s1, s2, s3, + s01, s02, s03, s10, s11, s12, - s_dst, s_dst2, s_dst3, - src_type_size, dst_type_size); + s1, s2, s3); } } @@ -109,6 +102,8 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { cudaStream_t stream = ctx.stream(); + + if (dst->type == GGML_TYPE_F32) { set_rows_cuda( src0_d, src1_d, (float*)dst->data, @@ -117,7 +112,6 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { nb01, nb02, nb03, nb10, nb11, nb12, nb1, nb2, nb3, - sizeof(float), sizeof(float), stream ); } else if (dst->type == GGML_TYPE_F16) { @@ -128,7 +122,6 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { nb01, nb02, nb03, nb10, nb11, nb12, nb1, nb2, nb3, - sizeof(float), sizeof(half), stream ); } else { diff --git a/ggml/src/ggml-cuda/set-rows.cuh b/ggml/src/ggml-cuda/set-rows.cuh index b5fe3d799539c..c140c0873c8a8 100644 --- a/ggml/src/ggml-cuda/set-rows.cuh +++ b/ggml/src/ggml-cuda/set-rows.cuh @@ -2,6 +2,6 @@ #include "common.cuh" -#define CUDA_SET_ROWS_BLOCK_SIZE 64 +#define CUDA_SET_ROWS_BLOCK_SIZE 256 void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 9deb7644483af6f32866ad45516a89b293a843fa Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 12 Jul 2025 20:32:30 +0800 Subject: [PATCH 4/4] Review: use int64_t for blockDim.x, rename nb->s for clarity --- ggml/src/ggml-cuda/set-rows.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index f06d06f8b6bf0..d8b3e63e1aa57 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -20,11 +20,11 @@ static __global__ void k_set_rows( const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, - const size_t nb01, const size_t nb02, const size_t nb03, - const size_t nb10, const size_t nb11, const size_t nb12, - const size_t nb1, const size_t nb2, const size_t nb3) { + const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t s10, const int64_t s11, const int64_t s12, + const int64_t s1, const int64_t s2, const int64_t s3) { - const int64_t i = blockDim.x * blockIdx.x + threadIdx.x; + const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; const int64_t ne_total = ne00 * ne01 * ne02 * ne03; if (i >= ne_total) { @@ -40,10 +40,10 @@ static __global__ void k_set_rows( const int64_t i11 = i02 % ne11; const int64_t i10 = i01; - const int64_t dst_row = *(src1 + i10*nb10 + i11*nb11 + i12*nb12); + const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); - const src_t * src0_row = src0 + i01*nb01 + i02*nb02 + i03*nb03; - dst_t * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3; + const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; + dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3; const src_t* src_elem = src0_row + i00; dst_t* dst_elem = dst_row_ptr + i00;