@@ -7792,9 +7792,75 @@ struct llm_build_plamo2 : public llm_graph_context {
7792
7792
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7793
7793
}
7794
7794
} else if (layer_type == "mamba") {
7795
- // Mamba layer processing - simplified implementation for now
7796
- // TODO: Implement full mamba layer logic
7797
- GGML_ASSERT(false && "Mamba layers not yet fully implemented for PLaMo2");
7795
+ // Mamba layer processing
7796
+ const int64_t d_conv = hparams.ssm_d_conv;
7797
+ const int64_t d_inner = hparams.ssm_d_inner;
7798
+ const int64_t d_state = hparams.ssm_d_state;
7799
+ const int64_t dt_rank = hparams.ssm_dt_rank;
7800
+
7801
+ // Apply linear transformation: n_embd -> 2*d_inner
7802
+ ggml_tensor * xz = build_lora_mm(model.layers[il].ssm_in, mixer_norm);
7803
+ cb(xz, "ssm_in", il);
7804
+
7805
+ // Split into x and z
7806
+ ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, n_tokens, xz->nb[1], 0);
7807
+ ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, n_tokens, xz->nb[1], d_inner*ggml_element_size(xz));
7808
+
7809
+ // For simplified PLaMo2 implementation without state caching,
7810
+ // we use a basic convolution approach
7811
+ // Reshape x for convolution: {d_inner, n_tokens} -> {n_tokens, d_inner}
7812
+ x = ggml_cont(ctx0, ggml_transpose(ctx0, x));
7813
+
7814
+ // Apply 1D convolution with proper padding
7815
+ // Note: PLaMo2 conv1d weight shape is {d_inner, d_conv}
7816
+ ggml_tensor * conv_w = ggml_reshape_2d(ctx0, model.layers[il].ssm_conv1d, d_conv, d_inner);
7817
+ x = ggml_conv_1d(ctx0, conv_w, x, 1, d_conv - 1, 1);
7818
+
7819
+ // Transpose back and apply SiLU
7820
+ x = ggml_cont(ctx0, ggml_transpose(ctx0, x));
7821
+ x = ggml_silu(ctx0, x);
7822
+ cb(x, "ssm_conv", il);
7823
+
7824
+ // SSM sequence transformation
7825
+ {
7826
+ // Project x to dt, B, C
7827
+ ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_bcdt, x);
7828
+ cb(x_db, "ssm_bcdt", il);
7829
+
7830
+ // Split into dt, B, C
7831
+ ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
7832
+ ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
7833
+ ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
7834
+
7835
+ // Project dt_rank to d_inner
7836
+ dt = build_lora_mm(model.layers[il].ssm_dt, dt);
7837
+ cb(dt, "ssm_dt", il);
7838
+
7839
+ // For simplified implementation without full SSM scan,
7840
+ // we'll create a basic selective scan approximation
7841
+ // Note: This is a simplified version and may not capture all SSM dynamics
7842
+
7843
+ // Create dummy state tensors for ggml_ssm_scan
7844
+ ggml_tensor * dummy_s = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, d_state, d_inner, 1);
7845
+
7846
+ // Use ggml_ssm_scan for proper SSM computation
7847
+ ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, dummy_s, x, dt, model.layers[il].ssm_a, B, C);
7848
+
7849
+ // Extract the output (first part of y_ssm)
7850
+ ggml_tensor * y = ggml_view_2d(ctx0, y_ssm, d_inner, n_tokens, y_ssm->nb[1], 0);
7851
+
7852
+ // Add D parameter contribution
7853
+ y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
7854
+ x = y;
7855
+ }
7856
+
7857
+ // Gated output
7858
+ x = ggml_mul(ctx0, x, ggml_silu(ctx0, ggml_cont(ctx0, z)));
7859
+ cb(x, "ssm_gate", il);
7860
+
7861
+ // Output projection
7862
+ cur = build_lora_mm(model.layers[il].ssm_out, x);
7863
+ cb(cur, "ssm_out", il);
7798
7864
} else {
7799
7865
// Default to attention-like processing for unknown layer types
7800
7866
cur = build_lora_mm(model.layers[il].wqkv, mixer_norm);
0 commit comments