From 4914f5b8a4adf7c48b1f93348df109e724244db3 Mon Sep 17 00:00:00 2001 From: Faras Siddiqui Date: Fri, 13 Mar 2026 04:59:06 +0500 Subject: [PATCH] feat: NEON-vectorized flash attention with hardware FP16 on AArch64 The flash attention inner loops were scalar with software FP16 conversion (~12 instructions per fp16_to_fp32 call). At 100 tokens of context, this costs 117M scalar instructions per token across 32 heads x 22 layers, nearly half the cost of the NEON-optimized matmuls. Hardware vcvt_f32_f16 does 4 conversions in 1 instruction (vs 48 scalar), collapsing the attention overhead to <2% of matmul cost. Changes: - quant.h: add fp16x4_to_f32 / f32x4_to_fp16 helpers, guarded by PICOLM_NEON && __aarch64__. Software FP16 untouched as fallback. - model.c: NEON paths for Q.K dot product, V accumulation (both branches), normalization, and KV cache FP16 writes. All under #ifdef PICOLM_FP16_HW with original scalar code in #else. - tensor.c: NEON path for RoPE K heads (mirrors existing Q-head pattern, for code symmetry). Precondition: head_dim and kv_dim must be multiples of 4. True for all LLaMA-architecture models (head_dim is always 64 or 128). 3 files changed, 62 insertions(+). Binary size unchanged (87736 bytes with -O3 -ffast-math). Tested on Apple M4 Pro, TinyLlama 1.1B Q4_K_M, -t 0 greedy: -n 20: 23.9 -> 29.6 tok/s (+24% vs baseline, +11% vs flags-only) -n 100: 20.9 -> 27.2 tok/s (+30% vs baseline, +23% vs flags-only) Output character-identical to baseline at all context lengths. --json mode and --cache round-trip verified. --- picolm/model.c | 43 +++++++++++++++++++++++++++++++++++++++++++ picolm/quant.h | 9 +++++++++ picolm/tensor.c | 19 +++++++++++++++++++ 3 files changed, 71 insertions(+) diff --git a/picolm/model.c b/picolm/model.c index 4b4040d..c27f3d2 100644 --- a/picolm/model.c +++ b/picolm/model.c @@ -601,17 +601,29 @@ float *model_forward(model_t *m, int token, int pos) { rope(s->q, k_tmp, head_dim, n_heads, n_kv_heads, cos_pos, sin_pos); /* Convert K to FP16 and store */ +#ifdef PICOLM_FP16_HW + for (int d = 0; d < kv_dim; d += 4) { + f32x4_to_fp16(key_pos_fp16 + d, vld1q_f32(k_tmp + d)); + } +#else for (int d = 0; d < kv_dim; d++) { key_pos_fp16[d] = fp32_to_fp16(k_tmp[d]); } +#endif /* V projection -> store directly as FP16 */ float *v_tmp = s->xb2; matmul(v_tmp, s->xb, lw->attn_v, dim, kv_dim, lw->type_attn_v); uint16_t *val_pos_fp16 = vcache_layer + (size_t)pos * kv_dim; +#ifdef PICOLM_FP16_HW + for (int d = 0; d < kv_dim; d += 4) { + f32x4_to_fp16(val_pos_fp16 + d, vld1q_f32(v_tmp + d)); + } +#else for (int d = 0; d < kv_dim; d++) { val_pos_fp16[d] = fp32_to_fp16(v_tmp[d]); } +#endif /* ---- Flash Attention (online softmax) ---- * @@ -647,10 +659,18 @@ float *model_forward(model_t *m, int token, int pos) { for (int t = 0; t <= pos; t++) { /* Compute score: dot(Q_h, K_t) / sqrt(head_dim) */ const uint16_t *kt = kcache_layer + (size_t)t * kv_dim + kv_h * head_dim; +#ifdef PICOLM_FP16_HW + float32x4_t dot_acc = vdupq_n_f32(0); + for (int d = 0; d < head_dim; d += 4) { + dot_acc = vmlaq_f32(dot_acc, vld1q_f32(qh + d), fp16x4_to_f32(kt + d)); + } + float score = vaddvq_f32(dot_acc); +#else float score = 0.0f; for (int d = 0; d < head_dim; d++) { score += qh[d] * fp16_to_fp32(kt[d]); } +#endif score /= sqrtf((float)head_dim); /* Online softmax update */ @@ -659,24 +679,47 @@ float *model_forward(model_t *m, int token, int pos) { if (score > max_score) { float correction = expf(max_score - score); sum_exp = sum_exp * correction + 1.0f; +#ifdef PICOLM_FP16_HW + float32x4_t corr_v = vdupq_n_f32(correction); + for (int d = 0; d < head_dim; d += 4) { + float32x4_t a = vld1q_f32(acc + d); + vst1q_f32(acc + d, vmlaq_f32(fp16x4_to_f32(vt + d), a, corr_v)); + } +#else for (int d = 0; d < head_dim; d++) { acc[d] = acc[d] * correction + fp16_to_fp32(vt[d]); } +#endif max_score = score; } else { float w = expf(score - max_score); sum_exp += w; +#ifdef PICOLM_FP16_HW + float32x4_t w_v = vdupq_n_f32(w); + for (int d = 0; d < head_dim; d += 4) { + float32x4_t a = vld1q_f32(acc + d); + vst1q_f32(acc + d, vmlaq_f32(a, fp16x4_to_f32(vt + d), w_v)); + } +#else for (int d = 0; d < head_dim; d++) { acc[d] += w * fp16_to_fp32(vt[d]); } +#endif } } /* Normalize */ float inv_sum = 1.0f / sum_exp; +#ifdef PICOLM_FP16_HW + float32x4_t inv_v = vdupq_n_f32(inv_sum); + for (int d = 0; d < head_dim; d += 4) { + vst1q_f32(xbh + d, vmulq_f32(vld1q_f32(acc + d), inv_v)); + } +#else for (int d = 0; d < head_dim; d++) { xbh[d] = acc[d] * inv_sum; } +#endif } /* Output projection */ diff --git a/picolm/quant.h b/picolm/quant.h index e35095c..a433033 100644 --- a/picolm/quant.h +++ b/picolm/quant.h @@ -16,6 +16,15 @@ static inline float vaddvq_f32_compat(float32x4_t v) { return vget_lane_f32(vpadd_f32(r, r), 0); #endif } +#if defined(__aarch64__) +#define PICOLM_FP16_HW 1 +static inline float32x4_t fp16x4_to_f32(const uint16_t *p) { + return vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(p))); +} +static inline void f32x4_to_fp16(uint16_t *p, float32x4_t v) { + vst1_u16(p, vreinterpret_u16_f16(vcvt_f16_f32(v))); +} +#endif #endif #if defined(__SSE2__) || (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_AMD64))) diff --git a/picolm/tensor.c b/picolm/tensor.c index 59a68e6..85d5119 100644 --- a/picolm/tensor.c +++ b/picolm/tensor.c @@ -246,12 +246,31 @@ void rope(float *q, float *k, int head_dim, int n_heads, int n_kv_heads, /* Apply RoPE to all KV heads */ for (int h = 0; h < n_kv_heads; h++) { float *kh = k + h * head_dim; +#ifdef PICOLM_NEON + int i = 0; + for (; i + 3 < half; i += 4) { + float32x4x2_t kv = vld2q_f32(kh + i * 2); + float32x4_t cv = vld1q_f32(cos_pos + i); + float32x4_t sv = vld1q_f32(sin_pos + i); + float32x4_t new_even = vmlsq_f32(vmulq_f32(kv.val[0], cv), kv.val[1], sv); + float32x4_t new_odd = vmlaq_f32(vmulq_f32(kv.val[0], sv), kv.val[1], cv); + float32x4x2_t result = {{ new_even, new_odd }}; + vst2q_f32(kh + i * 2, result); + } + for (; i < half; i++) { + float k0 = kh[i * 2]; + float k1 = kh[i * 2 + 1]; + kh[i * 2] = k0 * cos_pos[i] - k1 * sin_pos[i]; + kh[i * 2 + 1] = k0 * sin_pos[i] + k1 * cos_pos[i]; + } +#else for (int i = 0; i < half; i++) { float k0 = kh[i * 2]; float k1 = kh[i * 2 + 1]; kh[i * 2] = k0 * cos_pos[i] - k1 * sin_pos[i]; kh[i * 2 + 1] = k0 * sin_pos[i] + k1 * cos_pos[i]; } +#endif } }