diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index 4be57e08..cf02c6d5 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -133,6 +133,11 @@ def __init__( device=None, dtype=None, ) -> None: + # Ensure head dimension does not exceed hardware limits + max_head_dim = 256 # Example limit, adjust based on hardware + if d_model > max_head_dim: + print(f"Warning: d_model ({d_model}) exceeds the hardware limit. Adjusting to {max_head_dim}.") + d_model = max_head_dim factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.residual_in_fp32 = residual_in_fp32 @@ -221,6 +226,11 @@ def __init__( device=None, dtype=None, ) -> None: + # Ensure head dimension does not exceed hardware limits + max_head_dim = 256 # Example limit, adjust based on hardware + if d_model > max_head_dim: + print(f"Warning: d_model ({d_model}) exceeds the hardware limit. Adjusting to {max_head_dim}.") + d_model = max_head_dim self.config = config d_model = config.d_model n_layer = config.n_layer diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index e60f987d..f0e96403 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -85,6 +85,10 @@ def __init__( self.use_mem_eff_path = use_mem_eff_path self.layer_idx = layer_idx + # Validate head dimension + if self.headdim > 256: + raise ValueError("headdim should not exceed 256 due to hardware limits.") + # Order: [z, x, B, C, dt] d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads if self.process_group is None: @@ -300,6 +304,7 @@ def step(self, hidden_states, conv_state, ssm_state): dt = repeat(dt, "b h -> b h p", p=self.headdim) dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim) D = repeat(self.D, "h -> h p", p=self.headdim) + B = rearrange(B, "b (g n) -> b g n", g=self.ngroups) C = rearrange(C, "b (g n) -> b g n", g=self.ngroups) x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim) @@ -327,6 +332,8 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs) ssm_state = torch.zeros( batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype ) + if self.headdim > 256: + raise ValueError("headdim should not exceed 256 due to hardware limits.") return conv_state, ssm_state def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 4c8a3882..533dd201 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -58,6 +58,12 @@ def __init__( self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank self.use_fast_path = use_fast_path self.layer_idx = layer_idx + self.d_inner = int(self.expand * self.d_model) + # Ensure d_inner does not exceed the shared memory limit + MAX_SAFE_D_INNER = 256 # Safe maximum value for d_inner + if self.d_inner > MAX_SAFE_D_INNER: + print(f"Warning: d_inner ({self.d_inner}) exceeds the safe maximum value. Setting d_inner to {MAX_SAFE_D_INNER}.") + self.d_inner = MAX_SAFE_D_INNER self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)