From 50f88fc4caf5790e5902e5a63107b364f69f83a4 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 27 Jun 2025 11:21:26 +0200 Subject: [PATCH 01/18] ggml : add ggml_scale_bias --- ggml/include/ggml.h | 13 +++++++++++++ ggml/src/ggml-cpu/ops.cpp | 13 +++++++++---- ggml/src/ggml-metal/ggml-metal.m | 5 +++-- ggml/src/ggml-metal/ggml-metal.metal | 6 ++++-- ggml/src/ggml.c | 28 +++++++++++++++++++++++----- tests/test-backend-ops.cpp | 11 +++++++---- 6 files changed, 59 insertions(+), 17 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 9c4e24023b5ad..236ac52eb352d 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1185,6 +1185,19 @@ extern "C" { struct ggml_tensor * a, float s); + // x = s * a + b + GGML_API struct ggml_tensor * ggml_scale_bias( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b); + + GGML_API struct ggml_tensor * ggml_scale_bias_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b); + // b -> view(a,offset,nb1,nb2,3), return modified a GGML_API struct ggml_tensor * ggml_set( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 8531baf6c57fb..bc61080797b2d 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3937,9 +3937,11 @@ static void ggml_compute_forward_scale_f32( GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); - // scale factor - float v; - memcpy(&v, dst->op_params, sizeof(float)); + float s; // scale factor + float b; // bias + + memcpy(&s, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&b, (float *) dst->op_params + 1, sizeof(float)); const int ith = params->ith; const int nth = params->nth; @@ -3963,7 +3965,10 @@ static void ggml_compute_forward_scale_f32( // src0 is same shape as dst => same indices memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); } - ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v); + ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s); + if (b != 0.0f) { + ggml_vec_acc1_f32(nc, (float *) ((char *) dst->data + i1*nb1), b); + } } } diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index d8d30cc0b41ca..69b8a268bfcb7 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2189,8 +2189,8 @@ static bool ggml_metal_encode_node( { GGML_ASSERT(ggml_is_contiguous(src0)); - float scale; - memcpy(&scale, dst->op_params, sizeof(scale)); + float scale = ((const float *)(dst->op_params))[0]; + float bias = ((const float *)(dst->op_params))[1]; int64_t n = ggml_nelements(dst); @@ -2207,6 +2207,7 @@ static bool ggml_metal_encode_node( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + [encoder setBytes:&bias length:sizeof(bias) atIndex:3]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 5f004a856bde6..ae012b1c79826 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -810,16 +810,18 @@ kernel void kernel_scale( device const float * src0, device float * dst, constant float & scale, + constant float & bias, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; + dst[tpig] = src0[tpig] * scale + bias; } kernel void kernel_scale_4( device const float4 * src0, device float4 * dst, constant float & scale, + constant float & bias, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; + dst[tpig] = src0[tpig] * scale + bias; } kernel void kernel_clamp( diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ee605977f3a2c..e77d33fc7a1aa 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2858,12 +2858,14 @@ static struct ggml_tensor * ggml_scale_impl( struct ggml_context * ctx, struct ggml_tensor * a, float s, + float b, bool inplace) { GGML_ASSERT(ggml_is_padded_1d(a)); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_set_op_params(result, &s, sizeof(s)); + float params[2] = { s, b }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_SCALE; result->src[0] = a; @@ -2875,14 +2877,30 @@ struct ggml_tensor * ggml_scale( struct ggml_context * ctx, struct ggml_tensor * a, float s) { - return ggml_scale_impl(ctx, a, s, false); + return ggml_scale_impl(ctx, a, s, 0.0, false); } struct ggml_tensor * ggml_scale_inplace( struct ggml_context * ctx, struct ggml_tensor * a, float s) { - return ggml_scale_impl(ctx, a, s, true); + return ggml_scale_impl(ctx, a, s, 0.0, true); +} + +struct ggml_tensor * ggml_scale_bias( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b) { + return ggml_scale_impl(ctx, a, s, b, false); +} + +struct ggml_tensor * ggml_scale_bias_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b) { + return ggml_scale_impl(ctx, a, s, b, true); } // ggml_set @@ -5472,7 +5490,7 @@ static void ggml_compute_backward( } break; case GGML_OP_MEAN: { if (src0_needs_grads) { - ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false)); + ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false)); } } break; case GGML_OP_REPEAT: { @@ -5549,7 +5567,7 @@ static void ggml_compute_backward( if (src0_needs_grads) { float s; memcpy(&s, tensor->op_params, sizeof(float)); - ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, 0.0, false)); } } break; case GGML_OP_SET: { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 615c2dc008a8d..d1b2ff10d14ad 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1655,22 +1655,24 @@ struct test_scale : public test_case { const ggml_type type; const std::array ne; float scale; + float bias; std::string vars() override { - return VARS_TO_STR3(type, ne, scale); + return VARS_TO_STR4(type, ne, scale, bias); } test_scale(ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 10, 10}, - float scale = 2.0f) - : type(type), ne(ne), scale(scale) {} + float scale = 2.0f, + float bias = 0.0f) + : type(type), ne(ne), scale(scale), bias(bias) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_param(a); ggml_set_name(a, "a"); - ggml_tensor * out = ggml_scale(ctx, a, scale); + ggml_tensor * out = ggml_scale_bias(ctx, a, scale, bias); ggml_set_name(out, "out"); return out; @@ -4209,6 +4211,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_add1()); test_cases.emplace_back(new test_scale()); + test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f)); test_cases.emplace_back(new test_silu_back()); for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) { From a5ccf168f163947f2548752153cad4218ae95933 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 8 Jul 2025 23:13:42 +0200 Subject: [PATCH 02/18] ggml_vec_mad1_f32 --- ggml/src/ggml-cpu/ops.cpp | 5 +---- ggml/src/ggml-cpu/vec.h | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index a427de404e360..7c50271732ba0 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4671,10 +4671,7 @@ static void ggml_compute_forward_scale_f32( // src0 is same shape as dst => same indices memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); } - ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s); - if (b != 0.0f) { - ggml_vec_acc1_f32(nc, (float *) ((char *) dst->data + i1*nb1), b); - } + ggml_vec_mad1_f32(nc, (float *) ((char *) dst->data + i1*nb1), s, b); } } diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 1f5857a23e35c..e0109be51d183 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -351,6 +351,36 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int #endif } +inline static void ggml_vec_mad1_f32(const int n, float * y, const float s, const float b) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vs = GGML_F32_VEC_SET1(s); + GGML_F32_VEC vb = GGML_F32_VEC_SET1(b); + + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = y[i]*s + b; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] *= y[i]*s + b; + } +#endif +} + //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #if defined(GGML_USE_ACCELERATE) From e427af75fb27682167218d8a1d2fea13e8fe0e22 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 8 Jul 2025 23:19:16 +0200 Subject: [PATCH 03/18] add more simd --- ggml/src/ggml-cpu/vec.h | 64 +++++++++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index e0109be51d183..78c7ed2d157e0 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -352,27 +352,61 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int } inline static void ggml_vec_mad1_f32(const int n, float * y, const float s, const float b) { -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); +#if defined(GGML_USE_ACCELERATE) + vDSP_vsmul(y, 1, &s, y, 1, n); + vDSP_vsadd(y, 1, &b, y, 1, n); +#elif defined(GGML_SIMD) + #if defined(__ARM_FEATURE_SVE) + const int sve_register_length = ggml_cpu_get_sve_cnt() * 8; + const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16 + const int ggml_f32_step = 2 * ggml_f32_epr; + + GGML_F32_VEC vs = GGML_F32_VEC_SET1(s); + GGML_F32_VEC vb = GGML_F32_VEC_SET1(b); + + const int np = (n & ~(ggml_f32_step - 1)); + svfloat32_t ay1; + svfloat32_t ay2; + for (int i = 0; i < np; i += ggml_f32_step) { + ay1 = GGML_F32_VEC_LOAD(y + i); + ay1 = GGML_F32_VEC_FMA(ay1, vs, vb); + GGML_F32_VEC_STORE(y + i, ay1); + + ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); + ay2 = GGML_F32_VEC_FMA(ay2, vs, vb); + GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2); + } + // leftovers + // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only + if (np < n) { + svbool_t pg = svwhilelt_b32(np, n); + ay1 = svld1_f32(pg, y + np); + ay1 = svmul_f32_m(pg, ay1, vs); + ay1 = svadd_f32_m(pg, ay1, vb); + svst1_f32(pg, y + np, ay1); + } + #else + const int np = (n & ~(GGML_F32_STEP - 1)); - GGML_F32_VEC vs = GGML_F32_VEC_SET1(s); - GGML_F32_VEC vb = GGML_F32_VEC_SET1(b); + GGML_F32_VEC vs = GGML_F32_VEC_SET1(s); + GGML_F32_VEC vb = GGML_F32_VEC_SET1(b); - GGML_F32_VEC ay[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb); + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb); - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } } - } - // leftovers - for (int i = np; i < n; ++i) { - y[i] = y[i]*s + b; - } + // leftovers + for (int i = np; i < n; ++i) { + y[i] = y[i]*s + b; + } + #endif #else // scalar for (int i = 0; i < n; ++i) { From 92a87384520853405412d2355ab3b82bd2887a6f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 8 Jul 2025 23:26:21 +0200 Subject: [PATCH 04/18] add CUDA --- ggml/src/ggml-cuda/scale.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/scale.cu b/ggml/src/ggml-cuda/scale.cu index 1405e066e86a2..eb4a5f0fcdf06 100644 --- a/ggml/src/ggml-cuda/scale.cu +++ b/ggml/src/ggml-cuda/scale.cu @@ -1,18 +1,18 @@ #include "scale.cuh" -static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) { +static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } - dst[i] = scale * x[i]; + dst[i] = scale * x[i] + bias; } -static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { +static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; - scale_f32<<>>(x, dst, scale, k); + scale_f32<<>>(x, dst, scale, bias, k); } void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -24,8 +24,8 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - float scale; - memcpy(&scale, dst->op_params, sizeof(float)); + float scale = ((const float *)(dst->op_params))[0]; + float bias = ((const float *)(dst->op_params))[1]; - scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream); + scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream); } From a28df6f00ccbeb9050c3a5c906986413ac1c2a1d Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 8 Jul 2025 23:27:32 +0200 Subject: [PATCH 05/18] sycl --- ggml/src/ggml-sycl/ggml-sycl.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 21c81e99a19aa..00bd5ecb552fa 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1695,7 +1695,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; } -static void scale_f32(const float * x, float * dst, const float scale, const int k, +static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -1704,7 +1704,7 @@ static void scale_f32(const float * x, float * dst, const float scale, const int return; } - dst[i] = scale * x[i]; + dst[i] = scale * x[i] + bias; } @@ -1842,7 +1842,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl( -static void scale_f32_sycl(const float *x, float *dst, const float scale, +static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE; stream->parallel_for( @@ -1850,7 +1850,7 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale, sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - scale_f32(x, dst, scale, k, item_ct1); + scale_f32(x, dst, scale, bias, k, item_ct1); }); } @@ -2318,10 +2318,10 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds const float * src0_dd = static_cast(dst->src[0]->data); float * dst_dd = static_cast(dst->data); - float scale; - memcpy(&scale, dst->op_params, sizeof(float)); + float scale = ((const float *)(dst->op_params))[0]; + float bias = ((const float *)(dst->op_params))[1]; - scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream); + scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream); /* DPCT1010:87: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. From 782b58fa065145a9e8d7edf99538e57a2362046e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 8 Jul 2025 23:31:04 +0200 Subject: [PATCH 06/18] vulkan --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/scale.comp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 2245a655498c5..c36e1a6d3bfc2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -7508,7 +7508,7 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - op_params[0], 0.0f, + op_params[0], op_params[1], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }, dryrun); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp index 4663428dee0a2..f10b0a02b5076 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp @@ -18,7 +18,7 @@ void main() { continue; } - data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1)); + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2)); idx += num_threads; } } From 477a97ad876dcb1ccad35f0b6a104cb7ebd7fedb Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 8 Jul 2025 23:34:15 +0200 Subject: [PATCH 07/18] cann (placeholder) --- ggml/src/ggml-cann/ggml-cann.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index eae575cc040cd..b60a83c5c6fc8 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2188,7 +2188,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_RMS_NORM: - case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_CLAMP: @@ -2210,6 +2209,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: return true; + case GGML_OP_SCALE: + float bias = ((const float *)(dst->op_params))[1]; + return bias == 0.0f; // TODO: support bias != 0.0f case GGML_OP_SOFT_MAX: // TODO: support broadcast // ref: https://github.com/ggml-org/llama.cpp/pull/14435 From 0e51a0a8b061e108705193798a7188a75633fb9b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 8 Jul 2025 23:36:47 +0200 Subject: [PATCH 08/18] opencl --- ggml/src/ggml-opencl/ggml-opencl.cpp | 5 +++-- ggml/src/ggml-opencl/kernels/scale.cl | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index a9fc039038705..0485f8d38ed88 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -5586,8 +5586,8 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - float scale; - memcpy(&scale, dst->op_params, sizeof(scale)); + float scale = ((const float *)(dst->op_params))[0]; + float bias = ((const float *)(dst->op_params))[1]; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; @@ -5602,6 +5602,7 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &bias)); int n = ggml_nelements(dst)/4; diff --git a/ggml/src/ggml-opencl/kernels/scale.cl b/ggml/src/ggml-opencl/kernels/scale.cl index 8cfd518fa5a3e..aeca8a456e4fe 100644 --- a/ggml/src/ggml-opencl/kernels/scale.cl +++ b/ggml/src/ggml-opencl/kernels/scale.cl @@ -8,9 +8,10 @@ kernel void kernel_scale( ulong offset0, global float4 * dst, ulong offsetd, - float scale + float scale, + float bias ) { src0 = (global float4*)((global char*)src0 + offset0); dst = (global float4*)((global char*)dst + offsetd); - dst[get_global_id(0)] = src0[get_global_id(0)] * scale; + dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias; } From 4d0195324e33d2bbd17d88baefaa22fa12796f76 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 9 Jul 2025 00:00:31 +0200 Subject: [PATCH 09/18] will this fix cpu? --- ggml/src/ggml-cpu/ops.cpp | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 7c50271732ba0..5a07819038d30 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4666,12 +4666,22 @@ static void ggml_compute_forward_scale_f32( const size_t nb1 = dst->nb[1]; - for (int i1 = ir0; i1 < ir1; i1++) { - if (dst->data != src0->data) { - // src0 is same shape as dst => same indices - memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + if (b == 0.0f) { + for (int i1 = ir0; i1 < ir1; i1++) { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + } + ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s); + } + } else { + for (int i1 = ir0; i1 < ir1; i1++) { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + } + ggml_vec_mad1_f32(nc, (float *) ((char *) dst->data + i1*nb1), s, b); } - ggml_vec_mad1_f32(nc, (float *) ((char *) dst->data + i1*nb1), s, b); } } From b22708fd90e482636e1d9e1d13b78ffac3fada20 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 9 Jul 2025 00:00:44 +0200 Subject: [PATCH 10/18] fix cuda --- ggml/src/ggml-cuda/scale.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/scale.cu b/ggml/src/ggml-cuda/scale.cu index eb4a5f0fcdf06..f915bf5faa23a 100644 --- a/ggml/src/ggml-cuda/scale.cu +++ b/ggml/src/ggml-cuda/scale.cu @@ -24,8 +24,10 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - float scale = ((const float *)(dst->op_params))[0]; - float bias = ((const float *)(dst->op_params))[1]; + float scale; + float bias; + memcpy(&scale, dst->op_params, sizeof(float)); + memcpy(&bias, (float *) dst->op_params + 1, sizeof(float)); scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream); } From c8d89317c96c3ee08aa72992b308f52ea74ea5aa Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 9 Jul 2025 00:06:53 +0200 Subject: [PATCH 11/18] suggestions from coderabbit --- ggml/src/ggml-cann/ggml-cann.cpp | 2 +- ggml/src/ggml-cpu/vec.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index b60a83c5c6fc8..55542ec45b005 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2210,7 +2210,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_COUNT_EQUAL: return true; case GGML_OP_SCALE: - float bias = ((const float *)(dst->op_params))[1]; + float bias = ((const float *)(op->op_params))[1]; return bias == 0.0f; // TODO: support bias != 0.0f case GGML_OP_SOFT_MAX: // TODO: support broadcast diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 78c7ed2d157e0..80f2eb550e1ec 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -404,13 +404,13 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float s, cons // leftovers for (int i = np; i < n; ++i) { - y[i] = y[i]*s + b; + y[i] = y[i]*s + b; } #endif #else // scalar for (int i = 0; i < n; ++i) { - y[i] *= y[i]*s + b; + y[i] = y[i]*s + b; } #endif } From 265cb4353832b2a2d7926e593e7f204172839f4e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 9 Jul 2025 11:52:58 +0200 Subject: [PATCH 12/18] fix cann compile error --- ggml/src/ggml-cann/ggml-cann.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 55542ec45b005..ccb17eb072eb2 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2210,7 +2210,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_COUNT_EQUAL: return true; case GGML_OP_SCALE: - float bias = ((const float *)(op->op_params))[1]; + float bias; + memcpy(&bias, (float*)op->op_params + 1, sizeof(float)); return bias == 0.0f; // TODO: support bias != 0.0f case GGML_OP_SOFT_MAX: // TODO: support broadcast From 563aca0b561aa354cd88e51ca4f83d2638bcac3a Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 9 Jul 2025 11:55:56 +0200 Subject: [PATCH 13/18] vDSP_vsmsa --- ggml/src/ggml-cpu/vec.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 80f2eb550e1ec..91b441d33a094 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -353,8 +353,7 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int inline static void ggml_vec_mad1_f32(const int n, float * y, const float s, const float b) { #if defined(GGML_USE_ACCELERATE) - vDSP_vsmul(y, 1, &s, y, 1, n); - vDSP_vsadd(y, 1, &b, y, 1, n); + vDSP_vsmsa(y, 1, &s, &b, y, 1, n); #elif defined(GGML_SIMD) #if defined(__ARM_FEATURE_SVE) const int sve_register_length = ggml_cpu_get_sve_cnt() * 8; From 50c678f6da54a4b227028808c896718134a69b0b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 9 Jul 2025 11:56:48 +0200 Subject: [PATCH 14/18] rm __ARM_FEATURE_SVE --- ggml/src/ggml-cpu/vec.h | 60 +++++++++++------------------------------ 1 file changed, 15 insertions(+), 45 deletions(-) diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 91b441d33a094..66cdb619e777f 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -355,57 +355,27 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float s, cons #if defined(GGML_USE_ACCELERATE) vDSP_vsmsa(y, 1, &s, &b, y, 1, n); #elif defined(GGML_SIMD) - #if defined(__ARM_FEATURE_SVE) - const int sve_register_length = ggml_cpu_get_sve_cnt() * 8; - const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16 - const int ggml_f32_step = 2 * ggml_f32_epr; - - GGML_F32_VEC vs = GGML_F32_VEC_SET1(s); - GGML_F32_VEC vb = GGML_F32_VEC_SET1(b); - - const int np = (n & ~(ggml_f32_step - 1)); - svfloat32_t ay1; - svfloat32_t ay2; - for (int i = 0; i < np; i += ggml_f32_step) { - ay1 = GGML_F32_VEC_LOAD(y + i); - ay1 = GGML_F32_VEC_FMA(ay1, vs, vb); - GGML_F32_VEC_STORE(y + i, ay1); - - ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); - ay2 = GGML_F32_VEC_FMA(ay2, vs, vb); - GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2); - } - // leftovers - // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only - if (np < n) { - svbool_t pg = svwhilelt_b32(np, n); - ay1 = svld1_f32(pg, y + np); - ay1 = svmul_f32_m(pg, ay1, vs); - ay1 = svadd_f32_m(pg, ay1, vb); - svst1_f32(pg, y + np, ay1); - } - #else - const int np = (n & ~(GGML_F32_STEP - 1)); + // TODO: #if defined(__ARM_FEATURE_SVE) + const int np = (n & ~(GGML_F32_STEP - 1)); - GGML_F32_VEC vs = GGML_F32_VEC_SET1(s); - GGML_F32_VEC vb = GGML_F32_VEC_SET1(b); + GGML_F32_VEC vs = GGML_F32_VEC_SET1(s); + GGML_F32_VEC vb = GGML_F32_VEC_SET1(b); - GGML_F32_VEC ay[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb); + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb); - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); - } + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); } + } - // leftovers - for (int i = np; i < n; ++i) { - y[i] = y[i]*s + b; - } - #endif + // leftovers + for (int i = np; i < n; ++i) { + y[i] = y[i]*s + b; + } #else // scalar for (int i = 0; i < n; ++i) { From 0d70ca81e81c9ad52689786614339b9171dfc4d6 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 9 Jul 2025 12:05:34 +0200 Subject: [PATCH 15/18] use memcpy for op params --- ggml/src/ggml-metal/ggml-metal.m | 6 ++++-- ggml/src/ggml-opencl/ggml-opencl.cpp | 6 ++++-- ggml/src/ggml-sycl/ggml-sycl.cpp | 6 ++++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 65093628fc021..83a0739809a6e 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2255,8 +2255,10 @@ static bool ggml_metal_encode_node( { GGML_ASSERT(ggml_is_contiguous(src0)); - float scale = ((const float *)(dst->op_params))[0]; - float bias = ((const float *)(dst->op_params))[1]; + float scale; + float bias; + memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float)); int64_t n = ggml_nelements(dst); diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0485f8d38ed88..43d8e5c72c937 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -5586,8 +5586,10 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - float scale = ((const float *)(dst->op_params))[0]; - float bias = ((const float *)(dst->op_params))[1]; + float scale; + float bias; + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&bias, ((int32_t *) dst->op_params) + 1, sizeof(float)); ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 00bd5ecb552fa..199182b98e64c 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2318,8 +2318,10 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds const float * src0_dd = static_cast(dst->src[0]->data); float * dst_dd = static_cast(dst->data); - float scale = ((const float *)(dst->op_params))[0]; - float bias = ((const float *)(dst->op_params))[1]; + float scale; + float bias; + memcpy(&scale, dst->op_params, sizeof(float)); + memcpy(&bias, (float *) dst->op_params + 1, sizeof(float)); scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream); /* From 4ea74b04e59025630468acb0e8b5984166615f0b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 9 Jul 2025 12:07:05 +0200 Subject: [PATCH 16/18] make code looks more consistent --- ggml/src/ggml-cuda/scale.cu | 4 ++-- ggml/src/ggml-sycl/ggml-sycl.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/scale.cu b/ggml/src/ggml-cuda/scale.cu index f915bf5faa23a..2ee9e588992f4 100644 --- a/ggml/src/ggml-cuda/scale.cu +++ b/ggml/src/ggml-cuda/scale.cu @@ -26,8 +26,8 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float scale; float bias; - memcpy(&scale, dst->op_params, sizeof(float)); - memcpy(&bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&bias, (float *) dst->op_params + 1, sizeof(float)); scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream); } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 199182b98e64c..cd15bbdb29fa2 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2320,7 +2320,7 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds float scale; float bias; - memcpy(&scale, dst->op_params, sizeof(float)); + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); memcpy(&bias, (float *) dst->op_params + 1, sizeof(float)); scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream); From cd1703a3bcebb5319cc8964d8bd6c52470025336 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 9 Jul 2025 12:16:40 +0200 Subject: [PATCH 17/18] use scalar for __ARM_FEATURE_SVE --- ggml/src/ggml-cpu/vec.h | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 66cdb619e777f..4652598ead13c 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -355,27 +355,33 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float s, cons #if defined(GGML_USE_ACCELERATE) vDSP_vsmsa(y, 1, &s, &b, y, 1, n); #elif defined(GGML_SIMD) - // TODO: #if defined(__ARM_FEATURE_SVE) - const int np = (n & ~(GGML_F32_STEP - 1)); + #if defined(__ARM_FEATURE_SVE) + // scalar ; TODO: Write SVE code + for (int i = 0; i < n; ++i) { + y[i] = y[i]*s + b; + } + #else + const int np = (n & ~(GGML_F32_STEP - 1)); - GGML_F32_VEC vs = GGML_F32_VEC_SET1(s); - GGML_F32_VEC vb = GGML_F32_VEC_SET1(b); + GGML_F32_VEC vs = GGML_F32_VEC_SET1(s); + GGML_F32_VEC vb = GGML_F32_VEC_SET1(b); - GGML_F32_VEC ay[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb); + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb); - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } } - } - // leftovers - for (int i = np; i < n; ++i) { - y[i] = y[i]*s + b; - } + // leftovers + for (int i = np; i < n; ++i) { + y[i] = y[i]*s + b; + } + #endif #else // scalar for (int i = 0; i < n; ++i) { From ebbad7796df3caf794587a168372cf589d665930 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 9 Jul 2025 14:11:53 +0200 Subject: [PATCH 18/18] add x param to ggml_vec_mad1_f32 --- ggml/src/ggml-cpu/ops.cpp | 10 +++++----- ggml/src/ggml-cpu/vec.h | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 5a07819038d30..fd77e9a6abad5 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4670,17 +4670,17 @@ static void ggml_compute_forward_scale_f32( for (int i1 = ir0; i1 < ir1; i1++) { if (dst->data != src0->data) { // src0 is same shape as dst => same indices + // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); } ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s); } } else { for (int i1 = ir0; i1 < ir1; i1++) { - if (dst->data != src0->data) { - // src0 is same shape as dst => same indices - memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); - } - ggml_vec_mad1_f32(nc, (float *) ((char *) dst->data + i1*nb1), s, b); + ggml_vec_mad1_f32(nc, + (float *) ((char *) dst->data + i1*nb1), + (float *) ((char *) src0->data + i1*nb1), + s, b); } } } diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 4652598ead13c..d18783a00a1a5 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -351,14 +351,14 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int #endif } -inline static void ggml_vec_mad1_f32(const int n, float * y, const float s, const float b) { +inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) { #if defined(GGML_USE_ACCELERATE) - vDSP_vsmsa(y, 1, &s, &b, y, 1, n); + vDSP_vsmsa(x, 1, &s, &b, y, 1, n); #elif defined(GGML_SIMD) #if defined(__ARM_FEATURE_SVE) // scalar ; TODO: Write SVE code for (int i = 0; i < n; ++i) { - y[i] = y[i]*s + b; + y[i] = x[i]*s + b; } #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -370,7 +370,7 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float s, cons for (int i = 0; i < np; i += GGML_F32_STEP) { for (int j = 0; j < GGML_F32_ARR; j++) { - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb); GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); @@ -379,13 +379,13 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float s, cons // leftovers for (int i = np; i < n; ++i) { - y[i] = y[i]*s + b; + y[i] = x[i]*s + b; } #endif #else // scalar for (int i = 0; i < n; ++i) { - y[i] = y[i]*s + b; + y[i] = x[i]*s + b; } #endif }