diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index c3596bfe..1a77596d 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -21,6 +21,9 @@ class SelectiveScanFn(torch.autograd.Function): @staticmethod def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False): + + is_variable_B = B.dim() >= 3 + if u.stride(-1) != 1: u = u.contiguous() if delta.stride(-1) != 1: @@ -43,6 +46,14 @@ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softp ctx.delta_softplus = delta_softplus ctx.has_z = z is not None last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + + # The cuda kernel does a peculiar optimization of not multiplying the state + # by B if B is not variable! This does not impact MambaInnerFn, because it + # never returns the state. But SelectiveScanFn may needd to return the + # last state! Hence the following is needed. + if not is_variable_B: + last_state = torch.einsum('bdn,dn->bdn', last_state, B) + if not ctx.has_z: ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out if not return_last_state else (out, last_state)