Skip to content

Commit e99335b

Browse files
committed
Add mamba part of plamo2
1 parent 98c0e98 commit e99335b

File tree

2 files changed

+99
-10
lines changed

2 files changed

+99
-10
lines changed

convert_hf_to_gguf.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,6 +2312,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
23122312
# Handle Plamo2 specific tensor naming
23132313
# The model has both attention and Mamba layers
23142314

2315+
# Debug: log all Mamba-related tensors
2316+
if "mixer" in name or "ssm" in name or "A_log" in name or ".D" in name:
2317+
logger.info(f"Processing Mamba tensor: {name}, shape: {data_torch.shape}")
2318+
23152319
# Handle the nested layer structure: layers.layers.X
23162320
if name.startswith("model.layers.layers."):
23172321
# Extract the layer number and rest of the name
@@ -2341,30 +2345,49 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
23412345
]
23422346

23432347
# Handle Mamba-specific A_log tensor transformation
2344-
if name.endswith(".A_log"):
2348+
if name.endswith(".A_log") or name.endswith(".mixer.A_log"):
23452349
# Map the A_log tensor directly to ssm_a
23462350
new_name = self.map_tensor_name(name)
23472351
# Add .weight suffix if not present
23482352
if not new_name.endswith(".weight"):
23492353
new_name += ".weight"
2350-
logger.debug(f"A_log --> A ==> {new_name}, original shape: {data_torch.shape}")
2354+
logger.info(f"A_log --> A ==> {new_name}, original shape: {data_torch.shape}")
23512355

23522356
# Transform A_log to A: A = -exp(A_log)
23532357
data_torch = -torch.exp(data_torch)
23542358

2355-
# PLaMo2 A_log is shape {d_state} but llama.cpp expects {d_state, d_inner}
2356-
# Expand the tensor to the correct shape
2359+
# Handle different tensor shapes for A_log
2360+
logger.info(f"A_log tensor after exp transform, shape: {data_torch.shape}")
2361+
2362+
# Ensure we have a 2D tensor
2363+
while len(data_torch.shape) > 2:
2364+
data_torch = data_torch.squeeze(0)
2365+
23572366
if len(data_torch.shape) == 1:
2358-
d_state = data_torch.shape[0] # 64
2359-
d_inner = 8192 # SSM inner size for PLaMo2
2367+
# PLaMo2 A_log is shape {d_state} but llama.cpp expects {d_state, d_inner}
2368+
d_state = data_torch.shape[0] # 64 for PLaMo2
2369+
# Get d_inner from model hyperparameters
2370+
mamba_num_heads = self.hparams.get("mamba_num_heads", 64)
2371+
hidden_size_per_head = self.hparams.get("hidden_size_per_head", 128)
2372+
d_inner = mamba_num_heads * hidden_size_per_head # 64 * 128 = 8192
23602373

23612374
# Create tensor with correct shape {d_state, d_inner} = {64, 8192}
23622375
# Each row of the matrix should contain the same value from the original 1D tensor
23632376
new_tensor = data_torch.new_zeros((d_state, d_inner))
23642377
for i in range(d_state):
23652378
new_tensor[i, :] = data_torch[i] # Broadcast the single value across the inner dimension
23662379
data_torch = new_tensor
2367-
logger.debug(f"Expanded A tensor from {d_state} to shape: {data_torch.shape}")
2380+
logger.info(f"Expanded A tensor from {d_state} to shape: {data_torch.shape}")
2381+
elif len(data_torch.shape) == 2:
2382+
# Check if dimensions need to be transposed
2383+
# Expected shape is [d_state, d_inner] where d_state = 64, d_inner = 8192
2384+
if data_torch.shape[0] == 8192 and data_torch.shape[1] == 64:
2385+
data_torch = data_torch.transpose(0, 1)
2386+
logger.info(f"Transposed A tensor to shape: {data_torch.shape}")
2387+
2388+
# Final shape check and reshape if needed
2389+
if data_torch.shape != torch.Size([64, 8192]):
2390+
logger.warning(f"Unexpected A tensor shape after processing: {data_torch.shape}, expected [64, 8192]")
23682391

23692392
return [(new_name, data_torch)]
23702393

src/llama-model.cpp

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7792,9 +7792,75 @@ struct llm_build_plamo2 : public llm_graph_context {
77927792
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
77937793
}
77947794
} else if (layer_type == "mamba") {
7795-
// Mamba layer processing - simplified implementation for now
7796-
// TODO: Implement full mamba layer logic
7797-
GGML_ASSERT(false && "Mamba layers not yet fully implemented for PLaMo2");
7795+
// Mamba layer processing
7796+
const int64_t d_conv = hparams.ssm_d_conv;
7797+
const int64_t d_inner = hparams.ssm_d_inner;
7798+
const int64_t d_state = hparams.ssm_d_state;
7799+
const int64_t dt_rank = hparams.ssm_dt_rank;
7800+
7801+
// Apply linear transformation: n_embd -> 2*d_inner
7802+
ggml_tensor * xz = build_lora_mm(model.layers[il].ssm_in, mixer_norm);
7803+
cb(xz, "ssm_in", il);
7804+
7805+
// Split into x and z
7806+
ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, n_tokens, xz->nb[1], 0);
7807+
ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, n_tokens, xz->nb[1], d_inner*ggml_element_size(xz));
7808+
7809+
// For simplified PLaMo2 implementation without state caching,
7810+
// we use a basic convolution approach
7811+
// Reshape x for convolution: {d_inner, n_tokens} -> {n_tokens, d_inner}
7812+
x = ggml_cont(ctx0, ggml_transpose(ctx0, x));
7813+
7814+
// Apply 1D convolution with proper padding
7815+
// Note: PLaMo2 conv1d weight shape is {d_inner, d_conv}
7816+
ggml_tensor * conv_w = ggml_reshape_2d(ctx0, model.layers[il].ssm_conv1d, d_conv, d_inner);
7817+
x = ggml_conv_1d(ctx0, conv_w, x, 1, d_conv - 1, 1);
7818+
7819+
// Transpose back and apply SiLU
7820+
x = ggml_cont(ctx0, ggml_transpose(ctx0, x));
7821+
x = ggml_silu(ctx0, x);
7822+
cb(x, "ssm_conv", il);
7823+
7824+
// SSM sequence transformation
7825+
{
7826+
// Project x to dt, B, C
7827+
ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_bcdt, x);
7828+
cb(x_db, "ssm_bcdt", il);
7829+
7830+
// Split into dt, B, C
7831+
ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
7832+
ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
7833+
ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
7834+
7835+
// Project dt_rank to d_inner
7836+
dt = build_lora_mm(model.layers[il].ssm_dt, dt);
7837+
cb(dt, "ssm_dt", il);
7838+
7839+
// For simplified implementation without full SSM scan,
7840+
// we'll create a basic selective scan approximation
7841+
// Note: This is a simplified version and may not capture all SSM dynamics
7842+
7843+
// Create dummy state tensors for ggml_ssm_scan
7844+
ggml_tensor * dummy_s = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, d_state, d_inner, 1);
7845+
7846+
// Use ggml_ssm_scan for proper SSM computation
7847+
ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, dummy_s, x, dt, model.layers[il].ssm_a, B, C);
7848+
7849+
// Extract the output (first part of y_ssm)
7850+
ggml_tensor * y = ggml_view_2d(ctx0, y_ssm, d_inner, n_tokens, y_ssm->nb[1], 0);
7851+
7852+
// Add D parameter contribution
7853+
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
7854+
x = y;
7855+
}
7856+
7857+
// Gated output
7858+
x = ggml_mul(ctx0, x, ggml_silu(ctx0, ggml_cont(ctx0, z)));
7859+
cb(x, "ssm_gate", il);
7860+
7861+
// Output projection
7862+
cur = build_lora_mm(model.layers[il].ssm_out, x);
7863+
cb(cur, "ssm_out", il);
77987864
} else {
77997865
// Default to attention-like processing for unknown layer types
78007866
cur = build_lora_mm(model.layers[il].wqkv, mixer_norm);

0 commit comments

Comments
 (0)