Skip to content

Fixes and Enhancements for Mamba Inference and Reference Implementations #743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 44 additions & 28 deletions mamba_ssm/models/mixer_seq_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,25 @@
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None


from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import nn, Tensor


def create_block(
d_model,
d_intermediate,
ssm_cfg=None,
attn_layer_idx=None,
attn_cfg=None,
norm_epsilon=1e-5,
rms_norm=False,
residual_in_fp32=False,
fused_add_norm=False,
layer_idx=None,
device=None,
dtype=None,
):
d_model: int,
d_intermediate: int,
ssm_cfg: Optional[Dict] = None,
attn_layer_idx: Optional[List[int]] = None,
attn_cfg: Optional[Dict] = None,
norm_epsilon: float = 1e-5,
rms_norm: bool = False,
residual_in_fp32: bool = False,
fused_add_norm: bool = False,
layer_idx: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> nn.Module:
if ssm_cfg is None:
ssm_cfg = {}
if attn_layer_idx is None:
Expand Down Expand Up @@ -88,7 +93,7 @@ def _init_weights(
n_layer,
initializer_range=0.02, # Now only used for embedding layer.
rescale_prenorm_residual=True,
n_residuals_per_layer=1, # Change to 2 if we have MLP
n_residuals_per_layer=1,
):
if isinstance(module, nn.Linear):
if module.bias is not None:
Expand Down Expand Up @@ -122,16 +127,16 @@ def __init__(
n_layer: int,
d_intermediate: int,
vocab_size: int,
ssm_cfg=None,
attn_layer_idx=None,
attn_cfg=None,
ssm_cfg: Optional[Dict] = None,
attn_layer_idx: Optional[List[int]] = None,
attn_cfg: Optional[Dict] = None,
norm_epsilon: float = 1e-5,
rms_norm: bool = False,
initializer_cfg=None,
fused_add_norm=False,
residual_in_fp32=False,
device=None,
dtype=None,
initializer_cfg: Optional[Dict] = None,
fused_add_norm: bool = False,
residual_in_fp32: bool = False,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
Expand Down Expand Up @@ -187,7 +192,12 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs)
for i, layer in enumerate(self.layers)
}

def forward(self, input_ids, inference_params=None, **mixer_kwargs):
def forward(
self,
input_ids: Tensor,
inference_params = None,
**mixer_kwargs
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
Expand All @@ -213,13 +223,12 @@ def forward(self, input_ids, inference_params=None, **mixer_kwargs):


class MambaLMHeadModel(nn.Module, GenerationMixin):

def __init__(
self,
config: MambaConfig,
initializer_cfg=None,
device=None,
dtype=None,
initializer_cfg: Optional[Dict] = None,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
self.config = config
d_model = config.d_model
Expand Down Expand Up @@ -271,7 +280,14 @@ def tie_weights(self):
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
def forward(
self,
input_ids: Tensor,
position_ids: Optional[Tensor] = None,
inference_params = None,
num_last_tokens: int = 0,
**mixer_kwargs
) -> Union[Tensor, Tuple[Tensor, Dict[str, Tensor]]]:
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
Expand Down
105 changes: 89 additions & 16 deletions mamba_ssm/modules/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,31 @@

try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn, causal_conv1d_update = None, None
except ImportError as e:
raise ImportError(
"causal_conv1d package not found. Please install it with: "
"pip install causal-conv1d>=1.4.0"
) from e

try:
from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
except ImportError:
except ImportError as e:
causal_conv1d_varlen_states = None
import warnings
warnings.warn(
"causal_conv1d_varlen module not found. Variable length sequences will not be supported. "
"Install the latest causal_conv1d package for full functionality."
)

try:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
except ImportError as e:
selective_state_update = None
import warnings
warnings.warn(
"selective_state_update module not found. Performance may be degraded. "
"Make sure to install with the 'triton' extra: pip install mamba-ssm[triton]"
)

from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated

Expand Down Expand Up @@ -221,9 +234,12 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
else:
assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package"
assert batch == 1, "varlen inference only supports batch dimension 1"
# The 'batch' variable here might be misleading when cu_seqlens is used.
# The actual number of sequences is cu_seqlens.shape[0] - 1.
# conv_state is already shaped (inference_batch, ...).
# xBC should be (total_tokens, features) when cu_seqlens is present.
conv_varlen_states = causal_conv1d_varlen_states(
xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
xBC, cu_seqlens, state_len=conv_state.shape[-1]
)
conv_state.copy_(conv_varlen_states)
assert self.activation in ["silu", "swish"]
Expand Down Expand Up @@ -308,16 +324,55 @@ def step(self, hidden_states, conv_state, ssm_state):

# SSM step
if selective_state_update is None:
assert self.ngroups == 1, "Only support ngroups=1 for this inference code path"
assert self.nheads % self.ngroups == 0, "nheads must be divisible by ngroups for PyTorch step fallback"
k = self.nheads // self.ngroups

# Discretize A and B
# dt is already (batch, nheads) from xBC split and projection
dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
dA = torch.exp(dt * A) # (batch, nheads)
x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
y = rearrange(y, "b h p -> b (h p)")
# A is (nheads,)

# Reshape for grouped operations
# x: (B, d_ssm) -> (B, ngroups, k, headdim)
x_r = rearrange(x, "b (g k p) -> b g k p", g=self.ngroups, k=k, p=self.headdim)
# dt: (B, nheads) -> (B, ngroups, k)
dt_r = rearrange(dt, "b (g k) -> b g k", g=self.ngroups, k=k)
# A: (nheads,) -> (ngroups, k)
A_r = rearrange(A, "(g k) -> g k", g=self.ngroups, k=k)
# dA: (B, ngroups, k)
dA_r = torch.exp(dt_r * A_r.unsqueeze(0))

# B: (B, ngroups * d_state) -> (B, ngroups, d_state)
B_r = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
# C: (B, ngroups * d_state) -> (B, ngroups, d_state)
C_r = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
# ssm_state: (B, nheads, headdim, d_state) -> (B, ngroups, k, headdim, d_state)
ssm_state_r = rearrange(ssm_state, "b (g k) p n -> b g k p n", g=self.ngroups, k=k)

# SSM recurrence: h_new = dA * h_old + dB * x
# dB = dt * B
# dB_scaled_by_dt: (B, ngroups, k, d_state)
dB_scaled_by_dt = torch.einsum("bgk,bgn->bgkn", dt_r, B_r)
# dBx: (B, ngroups, k, headdim, d_state)
dBx = torch.einsum("bgkp,bgkn->bgkpn", x_r, dB_scaled_by_dt)

ssm_state_new_r = dA_r.unsqueeze(-1).unsqueeze(-1) * ssm_state_r + dBx
ssm_state.copy_(rearrange(ssm_state_new_r, "b g k p n -> b (g k) p n"))

# Output: y = C * h_new + D * x
# y_interim: (B, ngroups, k, headdim)
y_interim = torch.einsum("bgkpn,bgn->bgkp", ssm_state_new_r.to(dtype), C_r)

D_param = self.D.to(dtype)
if self.D_has_hdim: # D is (d_ssm) = (nheads * headdim)
D_r = rearrange(D_param, "(g k p) -> g k p", g=self.ngroups, k=k, p=self.headdim)
y_r = y_interim + D_r.unsqueeze(0) * x_r
else: # D is (nheads)
D_r = rearrange(D_param, "(g k) -> g k", g=self.ngroups, k=k)
y_r = y_interim + D_r.unsqueeze(0).unsqueeze(-1) * x_r

y = rearrange(y_r, "b g k p -> b (g k p)") # (B, d_ssm)

if not self.rmsnorm:
y = y * self.act(z) # (B D)
else:
Expand Down Expand Up @@ -376,8 +431,26 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
else:
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
# TODO: What if batch size changes between generation, and we reuse the same states?
if initialize_states:
# Handle batch size changes or explicit initialization
if initialize_states or conv_state.shape[0] != batch_size or ssm_state.shape[0] != batch_size:
# Re-initialize states if batch size changed or if explicitly requested
conv_state = torch.zeros(
batch_size,
self.conv1d.weight.shape[0], # out_channels
self.d_conv, # kernel_size
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
)
ssm_state = torch.zeros(
batch_size,
self.nheads,
self.headdim,
self.d_state,
device=self.in_proj.weight.device,
dtype=self.in_proj.weight.dtype,
)
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
elif initialize_states: # Original condition if batch sizes matched but re-init was true
conv_state.zero_()
ssm_state.zero_()
return conv_state, ssm_state
31 changes: 25 additions & 6 deletions mamba_ssm/ops/selective_scan_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta
if y.is_complex():
y = y.real * 2
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
y = torch.stack(ys, dim=2) # (batch, dim, L)
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
Expand Down Expand Up @@ -385,7 +385,8 @@ def mamba_inner_ref(
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
C_proj_bias=None, delta_softplus=True
C_proj_bias=None, delta_softplus=True,
b_rms_weight=None, c_rms_weight=None, dt_rms_weight=None, b_c_dt_rms_eps=1e-6
):
assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
L = xz.shape[-1]
Expand All @@ -399,21 +400,39 @@ def mamba_inner_ref(
x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
delta = rearrange(delta, "d (b l) -> b d l", l=L)

if dt_rms_weight is not None:
delta_reshaped = rearrange(delta, "b d l -> (b l) d").contiguous()
delta_reshaped = rms_norm_forward(delta_reshaped, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps)
delta = rearrange(delta_reshaped, "(b l) d -> b d l", l=L).contiguous()

if B is None: # variable B
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
if B_proj_bias is not None:
B = B + B_proj_bias.to(dtype=B.dtype)
if not A.is_complex():
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
B = rearrange(B, "(b l) dstate -> b dstate l", l=L)
else:
B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2)
if b_rms_weight is not None:
B_reshaped = rearrange(B, "b dstate l -> (b l) dstate").contiguous()
B_reshaped = rms_norm_forward(B_reshaped, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
B = rearrange(B_reshaped, "(b l) dstate -> b dstate l", l=L).contiguous()
else:
B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
B = B.contiguous() # Ensure contiguity if not already handled by RMSNorm path
if C is None: # variable B
C = x_dbl[:, -d_state:] # (bl d)
if C_proj_bias is not None:
C = C + C_proj_bias.to(dtype=C.dtype)
if not A.is_complex():
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=L)
else:
C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2)
if c_rms_weight is not None:
C_reshaped = rearrange(C, "b dstate l -> (b l) dstate").contiguous()
C_reshaped = rms_norm_forward(C_reshaped, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
C = rearrange(C_reshaped, "(b l) dstate -> b dstate l", l=L).contiguous()
else:
C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
C = C.contiguous() # Ensure contiguity if not already handled by RMSNorm path
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
2 changes: 1 addition & 1 deletion mamba_ssm/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def load_state_dict_hf(model_name, device=None, dtype=None):
# If not fp32, then we don't want to load directly to the GPU
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
return torch.load(resolved_archive_file, map_location=mapped_device)
state_dict = torch.load(resolved_archive_file, map_location=mapped_device)
# Convert dtype before moving to GPU to save memory
if dtype is not None:
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
Expand Down
33 changes: 16 additions & 17 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,27 +99,28 @@ def get_torch_hip_version():
return None


def check_if_hip_home_none(global_option: str) -> None:

if HIP_HOME is not None:
return
# warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
# in that case.
def check_if_hip_home_none(global_option: str):
if HIP_HOME is None:
raise RuntimeError(
f"{global_option} was requested, but the ROCm/HIP installation is incomplete. "
'Please make sure ROCm is properly installed and HIP_HOME environment variable is set.\n'
'On Ubuntu, you may need to install: rocm-libs hipcc hiprt hipcub rocprim rocrand rocthrust rocblas hipblas rocsolver hipsparse rocsparse hipfft rocfft rocthrust rocrand'
)
warnings.warn(
f"{global_option} was requested, but hipcc was not found. Are you sure your environment has hipcc available?"
)


def check_if_cuda_home_none(global_option: str) -> None:
if CUDA_HOME is not None:
return
# warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
# in that case.
warnings.warn(
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
)
if CUDA_HOME is None:
raise RuntimeError(
f"{global_option} was requested, but CUDA installation was not found. "
'Please ensure CUDA is properly installed and the CUDA_HOME environment variable is set.\n'
'Common solutions include:\n'
'1. Install CUDA from NVIDIA: https://developer.nvidia.com/cuda-downloads\n'
'2. Set CUDA_HOME to your CUDA installation directory (e.g., /usr/local/cuda-11.8)\n'
'3. Add CUDA to your PATH: export PATH=$PATH:$CUDA_HOME/bin'
)


def append_nvcc_threads(nvcc_extra_args):
Expand Down Expand Up @@ -158,8 +159,6 @@ def append_nvcc_threads(nvcc_extra_args):
UserWarning
)

cc_flag.append("-DBUILD_PYTHON_PACKAGE")

else:
check_if_cuda_home_none(PACKAGE_NAME)
# Check, if CUDA11 is installed for compute capability 8.0
Expand Down