diff --git a/picolm/model.c b/picolm/model.c index 4b4040d..be7fc9d 100644 --- a/picolm/model.c +++ b/picolm/model.c @@ -705,6 +705,157 @@ float *model_forward(model_t *m, int token, int pos) { return s->logits; } +float *model_forward_batch(model_t *m, const int *tokens, int num_tokens, int start_pos) { + if (num_tokens == 0) return NULL; + + model_config_t *c = &m->config; + model_weights_t *w = &m->weights; + run_state_t *s = &m->state; + + int dim = c->n_embd; + int n_ffn = c->n_ffn; + int n_heads = c->n_heads; + int n_kv_heads = c->n_kv_heads; + int head_dim = c->head_dim; + int kv_dim = n_kv_heads * head_dim; + int kv_mul = n_heads / n_kv_heads; + int seq_len = c->max_seq_len; + int half_dim = head_dim / 2; + + /* 为当前 batch 的中间状态分配堆内存,尺寸为 [num_tokens, dim] */ + float *batch_x = (float *)malloc((size_t)num_tokens * dim * sizeof(float)); + if (!batch_x) return NULL; + + /* 1. Embedding lookup for all tokens in the batch */ + for (int i = 0; i < num_tokens; i++) { + size_t row_bytes = gguf_type_row_size(w->type_token_embd, dim); + const void *embd_row = (const uint8_t *)w->token_embd + (size_t)tokens[i] * row_bytes; + dequantize_row(embd_row, batch_x + (size_t)i * dim, dim, w->type_token_embd); + } + + /* 2. Transformer layers */ + for (int l = 0; l < c->n_layers; l++) { + layer_weights_t *lw = &w->layers[l]; + + for (int i = 0; i < num_tokens; i++) { + int pos = start_pos + i; + + /* RoPE table pointers for this position */ + const float *cos_pos = s->rope_cos + (size_t)pos * half_dim; + const float *sin_pos = s->rope_sin + (size_t)pos * half_dim; + + memcpy(s->x, batch_x + (size_t)i * dim, dim * sizeof(float)); + + /* ---- Attention ---- */ + rmsnorm(s->xb, s->x, s->attn_norm_w[l], dim); + + /* QKV projections */ + matmul(s->q, s->xb, lw->attn_q, dim, dim, lw->type_attn_q); + + /* K and V: project into float temp, then store as FP16 in cache */ + float *k_tmp = s->xb2; /* reuse xb2 as temp for K (kv_dim <= dim) */ + matmul(k_tmp, s->xb, lw->attn_k, dim, kv_dim, lw->type_attn_k); + + /* Store K as FP16 */ + uint16_t *kcache_layer = s->key_cache + (size_t)l * seq_len * kv_dim; + uint16_t *vcache_layer = s->val_cache + (size_t)l * seq_len * kv_dim; + uint16_t *key_pos_fp16 = kcache_layer + (size_t)pos * kv_dim; + + /* Apply RoPE to Q and K (using pre-computed tables) */ + rope(s->q, k_tmp, head_dim, n_heads, n_kv_heads, cos_pos, sin_pos); + + /* Convert K to FP16 and store */ + for (int d = 0; d < kv_dim; d++) { + key_pos_fp16[d] = fp32_to_fp16(k_tmp[d]); + } + + /* V projection -> store directly as FP16 */ + float *v_tmp = s->xb2; + matmul(v_tmp, s->xb, lw->attn_v, dim, kv_dim, lw->type_attn_v); + uint16_t *val_pos_fp16 = vcache_layer + (size_t)pos * kv_dim; + for (int d = 0; d < kv_dim; d++) { + val_pos_fp16[d] = fp32_to_fp16(v_tmp[d]); + } + + /* ---- Flash Attention (online softmax) ---- */ + for (int h = 0; h < n_heads; h++) { + float *qh = s->q + h * head_dim; + int kv_h = h / kv_mul; + float *xbh = s->xb + h * head_dim; + + float max_score = -1e30f; + float sum_exp = 0.0f; + /* Accumulator for weighted V values */ + float acc[256]; /* head_dim is typically 64-128 */ + memset(acc, 0, (size_t)head_dim * sizeof(float)); + + for (int t = 0; t <= pos; t++) { + /* Compute score: dot(Q_h, K_t) / sqrt(head_dim) */ + const uint16_t *kt = kcache_layer + (size_t)t * kv_dim + kv_h * head_dim; + float score = 0.0f; + for (int d = 0; d < head_dim; d++) { + score += qh[d] * fp16_to_fp32(kt[d]); + } + score /= sqrtf((float)head_dim); + + /* Online softmax update */ + const uint16_t *vt = vcache_layer + (size_t)t * kv_dim + kv_h * head_dim; + + if (score > max_score) { + float correction = expf(max_score - score); + sum_exp = sum_exp * correction + 1.0f; + for (int d = 0; d < head_dim; d++) { + acc[d] = acc[d] * correction + fp16_to_fp32(vt[d]); + } + max_score = score; + } else { + float w = expf(score - max_score); + sum_exp += w; + for (int d = 0; d < head_dim; d++) { + acc[d] += w * fp16_to_fp32(vt[d]); + } + } + } + + /* Normalize */ + float inv_sum = 1.0f / sum_exp; + for (int d = 0; d < head_dim; d++) { + xbh[d] = acc[d] * inv_sum; + } + } + + /* Output projection */ + matmul(s->xb2, s->xb, lw->attn_output, dim, dim, lw->type_attn_output); + vec_add(s->x, s->xb2, dim); + + /* ---- FFN (SwiGLU) ---- */ + rmsnorm(s->xb, s->x, s->ffn_norm_w[l], dim); + + matmul(s->hb, s->xb, lw->ffn_gate, dim, n_ffn, lw->type_ffn_gate); + matmul(s->hb2, s->xb, lw->ffn_up, dim, n_ffn, lw->type_ffn_up); + + silu(s->hb, n_ffn); + elemwise_mul(s->hb, s->hb, s->hb2, n_ffn); + + matmul(s->xb, s->hb, lw->ffn_down, n_ffn, dim, lw->type_ffn_down); + vec_add(s->x, s->xb, dim); + + memcpy(batch_x + (size_t)i * dim, s->x, dim * sizeof(float)); + } + } + + /* 3. Final RMSNorm */ + memcpy(s->x, batch_x + (size_t)(num_tokens - 1) * dim, dim * sizeof(float)); + rmsnorm(s->x, s->x, s->output_norm_w, dim); + + /* 4. Output projection -> logits */ + matmul(s->logits, s->x, w->output, dim, c->vocab_size, w->type_output); + + free(batch_x); + + return s->logits; +} + void model_free(model_t *m) { if (m->state.mem_block) { free(m->state.mem_block); diff --git a/picolm/model.h b/picolm/model.h index 2a751e9..9c52556 100644 --- a/picolm/model.h +++ b/picolm/model.h @@ -130,6 +130,10 @@ int model_load(model_t *m, const char *path, int max_seq_len); /* Run one forward pass. Returns pointer to logits[vocab_size]. */ float *model_forward(model_t *m, int token, int pos); +/* Run a forward pass for a batch of tokens. The batch should contain a contiguous sequence of tokens starting at start_pos. + * Returns pointer to logits for the last token in the batch (i.e. batch_x[(num_tokens-1) * dim] after the final RMSNorm). */ +float *model_forward_batch(model_t *m, const int *tokens, int num_tokens, int start_pos); + /* Free all resources. */ void model_free(model_t *m); diff --git a/picolm/picolm.c b/picolm/picolm.c index c2f1624..d7d2822 100644 --- a/picolm/picolm.c +++ b/picolm/picolm.c @@ -192,6 +192,17 @@ int main(int argc, char **argv) { total_steps = model.config.max_seq_len; } + float* logits = model_forward_batch(&model, prompt_tokens, n_prompt, 0); /* prefill all prompt tokens in batch */ + grammar_apply(&grammar, logits, model.config.vocab_size); + int next = sampler_sample(&sampler, logits, model.config.vocab_size); + grammar_advance(&grammar, &tokenizer, next); + const char* piece = tokenizer_decode(&tokenizer, token, next); + printf("%s", piece); + fflush(stdout); + pos = n_prompt; + token = next; + total_gen++; + for (; pos < total_steps; pos++) { /* Determine which token to feed */ if (pos < start_pos) {