Skip to content

Commit f9945b2

Browse files
committed
Add op_check tests
1 parent 265d4e2 commit f9945b2

File tree

2 files changed

+146
-66
lines changed

2 files changed

+146
-66
lines changed

mamba_ssm/ops/selective_scan_interface_ compilable.py

Lines changed: 121 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import torch
2-
import torch.nn.functional as F
32
from einops import rearrange
43
from typing import Optional, Tuple
54

6-
import selective_scan_cuda
5+
from mamba_ssm.ops.selective_scan_interface import selective_scan_cuda
76

87

98
@torch.library.custom_op(
109
"custom_ops::selective_scan_fwd",
1110
device_types=["cuda"],
1211
mutates_args=(),
12+
schema="(Tensor u, Tensor delta, Tensor A, Tensor B, Tensor C, Tensor? D, Tensor? z, Tensor? delta_bias, bool delta_softplus, bool return_last_state) -> (Tensor, Tensor, Tensor, Tensor, bool, bool, bool)",
1313
)
1414
def custom_selective_scan_fwd(
1515
u: torch.Tensor,
@@ -22,28 +22,33 @@ def custom_selective_scan_fwd(
2222
delta_bias: Optional[torch.Tensor],
2323
delta_softplus: bool,
2424
return_last_state: bool,
25-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, bool, bool, bool]:
25+
):
2626
pass
2727

28+
2829
@torch.library.register_fake("custom_ops::selective_scan_fwd")
2930
def custom_selective_scan_fwd_fake(
30-
u,
31-
delta,
32-
A,
33-
B,
34-
C,
35-
D,
36-
z,
37-
delta_bias,
38-
delta_softplus,
39-
return_last_state,
31+
u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state
4032
):
41-
final_out = torch.empty_like(u)
4233
dstate = A.size(1) * (2 if A.is_complex() else 1)
43-
last_state_fake = u.new_empty((u.size(0), u.size(1), dstate)) if return_last_state else u.new_empty(0)
44-
out_fake = torch.empty_like(u)
45-
x_fake = u.new_empty((u.size(0), u.size(1), u.size(2), 2 * dstate))
46-
return final_out, last_state_fake, out_fake, x_fake, False, False, z is not None
34+
seqlen = u.size(2)
35+
n_chunks = (seqlen + 2048 - 1) // 2048
36+
37+
squeeze_B = B.dim() == 3
38+
squeeze_C = C.dim() == 3
39+
has_z = z is not None
40+
41+
final_out = torch.empty_like(delta)
42+
out_fake = torch.empty_like(delta)
43+
last_state_fake = (
44+
u.new_empty((u.size(0), u.size(1), dstate))
45+
if return_last_state
46+
else u.new_empty(0)
47+
)
48+
x_fake = u.new_empty((u.size(0), u.size(1), n_chunks, 2 * A.size(1)), dtype=A.dtype)
49+
50+
return final_out, last_state_fake, out_fake, x_fake, squeeze_B, squeeze_C, has_z
51+
4752

4853
@torch.library.register_kernel("custom_ops::selective_scan_fwd", "cuda")
4954
def custom_selective_scan_fwd_cuda(
@@ -81,16 +86,23 @@ def custom_selective_scan_fwd_cuda(
8186
C = rearrange(C, "b dstate l -> b 1 dstate l").contiguous()
8287
squeeze_C = True
8388

84-
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
89+
out, x, *rest = selective_scan_cuda.fwd(
90+
u, delta, A, B, C, D, z, delta_bias, delta_softplus
91+
)
8592
has_z = z is not None
86-
final_out = rest[0].clone() if has_z else out.clone()
93+
if has_z:
94+
final_out = rest[0].clone()
95+
else:
96+
final_out = out.clone()
8797
last_state = x[:, :, -1, 1::2].clone() if return_last_state else u.new_empty(0)
8898
return final_out, last_state, out, x, squeeze_B, squeeze_C, has_z
8999

100+
90101
@torch.library.custom_op(
91102
"custom_ops::selective_scan_bwd",
92103
device_types=["cuda"],
93104
mutates_args=(),
105+
schema="(Tensor dout, Tensor u, Tensor delta, Tensor A, Tensor B, Tensor C, Tensor? D, Tensor? z, Tensor? delta_bias, bool delta_softplus, Tensor out, Tensor x, bool squeeze_B, bool squeeze_C, bool recompute_out_z) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor?, Tensor?, Tensor?)",
94106
)
95107
def custom_selective_scan_bwd(
96108
dout: torch.Tensor,
@@ -107,9 +119,11 @@ def custom_selective_scan_bwd(
107119
x: torch.Tensor,
108120
squeeze_B: bool,
109121
squeeze_C: bool,
110-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
122+
recompute_out_z: bool,
123+
):
111124
pass
112125

126+
113127
@torch.library.register_fake("custom_ops::selective_scan_bwd")
114128
def custom_selective_scan_bwd_fake(
115129
dout,
@@ -126,16 +140,33 @@ def custom_selective_scan_bwd_fake(
126140
x,
127141
squeeze_B,
128142
squeeze_C,
143+
recompute_out_z,
129144
):
145+
# Here we just return shape-compatible fake tensors
130146
du = torch.empty_like(u)
131147
ddelta = torch.empty_like(delta)
132148
dA = torch.empty_like(A)
133-
dB = torch.empty_like(B)
134-
dC = torch.empty_like(C)
135-
dD = torch.empty_like(D) if (D is not None and D.numel() > 0) else u.new_empty(0)
136-
dz = torch.empty_like(z) if (z is not None and z.numel() > 0) else u.new_empty(0)
137-
ddelta_bias = torch.empty_like(delta_bias) if (delta_bias is not None and delta_bias.numel() > 0) else u.new_empty(0)
138-
return du, ddelta, dA, dB, dC, dD, dz, ddelta_bias
149+
150+
# Decide if variable B/C
151+
is_variable_B = B.dim() > 3
152+
is_variable_C = C.dim() > 3
153+
154+
dB = torch.empty_like(
155+
B, dtype=B.dtype
156+
) # If variable_B, still float32 is okay for fake
157+
dC = torch.empty_like(C, dtype=C.dtype)
158+
159+
dD = torch.empty_like(D) if (D is not None) else None
160+
ddelta_bias_out = torch.empty_like(delta_bias) if (delta_bias is not None) else None
161+
dz = torch.empty_like(z) if (z is not None) else None
162+
163+
if squeeze_B and dB.numel() > 0:
164+
dB = dB.squeeze(1)
165+
if squeeze_C and dC.numel() > 0:
166+
dC = dC.squeeze(1)
167+
168+
return du, ddelta, dA, dB, dC, dD, ddelta_bias_out, dz
169+
139170

140171
@torch.library.register_kernel("custom_ops::selective_scan_bwd", "cuda")
141172
def custom_selective_scan_bwd_cuda(
@@ -153,68 +184,101 @@ def custom_selective_scan_bwd_cuda(
153184
x: torch.Tensor,
154185
squeeze_B: bool,
155186
squeeze_C: bool,
187+
recompute_out_z: bool,
156188
):
157189
if dout.stride(-1) != 1:
158190
dout = dout.contiguous()
159-
B = B.contiguous()
160-
C = C.contiguous()
161191

162192
results = selective_scan_cuda.bwd(
163-
u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, delta_softplus, False
193+
u,
194+
delta,
195+
A,
196+
B,
197+
C,
198+
D,
199+
z,
200+
delta_bias,
201+
dout,
202+
x,
203+
out,
204+
None,
205+
delta_softplus,
206+
recompute_out_z,
164207
)
208+
165209
has_z = z is not None
166210
if has_z:
167-
du, ddelta, dA, dB, dC, dD, ddelta_bias, dz = results
211+
du, ddelta, dA, dB, dC, dD, ddelta_bias_out, dz = results
168212
else:
169-
du, ddelta, dA, dB, dC, dD, ddelta_bias = results
170-
dz = u.new_empty(0)
213+
du, ddelta, dA, dB, dC, dD, ddelta_bias_out = results
214+
dz = None
171215

172216
if squeeze_B and dB.numel() > 0:
173217
dB = dB.squeeze(1)
174218
if squeeze_C and dC.numel() > 0:
175219
dC = dC.squeeze(1)
176220

177-
return du, ddelta, dA, dB, dC, dD, dz, ddelta_bias
221+
return du, ddelta, dA, dB, dC, dD, ddelta_bias_out, dz
222+
178223

179224
def custom_bridge(ctx, *grads):
180225
dout = grads[0] if grads else ctx.saved_tensors[0].new_empty(0)
181226
saved = ctx.saved_tensors
227+
182228
if not ctx.has_z:
183229
u, delta, A, B, C, D, delta_bias, x, out = saved
184230
z = None
185231
else:
186232
u, delta, A, B, C, D, z, delta_bias, x, out = saved
187233

188-
du, ddelta, dA, dB, dC, dD, dz, ddelta_bias = torch.ops.custom_ops.selective_scan_bwd(
189-
dout,
190-
u,
191-
delta,
192-
A,
193-
B,
194-
C,
195-
D,
196-
z,
197-
delta_bias,
198-
ctx.delta_softplus,
199-
out,
200-
x,
201-
ctx.squeeze_B,
202-
ctx.squeeze_C
234+
du, ddelta, dA, dB, dC, dD, ddelta_bias_out, dz = (
235+
torch.ops.custom_ops.selective_scan_bwd(
236+
dout,
237+
u,
238+
delta,
239+
A,
240+
B,
241+
C,
242+
D,
243+
z,
244+
delta_bias,
245+
ctx.delta_softplus,
246+
out,
247+
x,
248+
ctx.squeeze_B,
249+
ctx.squeeze_C,
250+
False,
251+
)
203252
)
204253

254+
# For optional inputs, return None if not provided in forward
255+
if D is None:
256+
dD = None
257+
if z is None:
258+
dz = None
259+
if delta_bias is None:
260+
ddelta_bias_out = None
261+
262+
# Return gradients in the order of forward inputs:
263+
# (u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
264+
# `delta_softplus` and `return_last_state` are bools -> gradient = None
265+
d_delta_softplus = None
266+
d_return_last_state = None
267+
205268
return (
206269
du,
207270
ddelta,
208271
dA,
209272
dB,
210273
dC,
211-
dD if D is not None else None,
212-
dz if z is not None else None,
213-
ddelta_bias if delta_bias is not None else None,
214-
None,
215-
None,
274+
dD,
275+
dz,
276+
ddelta_bias_out,
277+
d_delta_softplus,
278+
d_return_last_state,
216279
)
217280

281+
218282
def custom_setup_context(ctx, inputs, output):
219283
(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) = inputs
220284
(final_out, last_state, out, x, squeeze_B, squeeze_C, has_z) = output
@@ -236,10 +300,12 @@ def custom_setup_context(ctx, inputs, output):
236300
else:
237301
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
238302

303+
239304
torch.library.register_autograd(
240305
"custom_ops::selective_scan_fwd", custom_bridge, setup_context=custom_setup_context
241306
)
242307

308+
243309
def selective_scan_fn_custom_op(
244310
u: torch.Tensor,
245311
delta: torch.Tensor,
@@ -252,20 +318,9 @@ def selective_scan_fn_custom_op(
252318
delta_softplus: bool,
253319
return_last_state: bool,
254320
) -> torch.Tensor:
255-
# Pass all arguments positionally, exactly in schema order:
256321
final_out, last_state, _, _, _, _, _ = torch.ops.custom_ops.selective_scan_fwd(
257-
u,
258-
delta,
259-
A,
260-
B,
261-
C,
262-
D,
263-
z,
264-
delta_bias,
265-
delta_softplus,
266-
return_last_state
322+
u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state
267323
)
268-
269324
if return_last_state:
270325
return final_out, last_state
271326
else:

tests/ops/test_selective_scan.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,28 @@ def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype):
254254
# atol=atolw if not is_variable_C else atol)
255255
# assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)
256256
# assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)
257+
258+
def test_selective_scan_opcheck():
259+
from torch.library import opcheck
260+
261+
device = "cuda"
262+
# small inputs for opcheck
263+
u = torch.randn(1, 2, 8, device=device, requires_grad=True)
264+
delta = torch.randn(1, 2, 8, device=device, requires_grad=True)
265+
A = torch.randn(2, 8, device=device, requires_grad=True)
266+
B = torch.randn(1, 1, 8, 8, device=device, requires_grad=True)
267+
C = torch.randn(1, 1, 8, 8, device=device, requires_grad=True)
268+
D = torch.randn(2, device=device, requires_grad=True)
269+
z = torch.randn(1, 2, 8, device=device, requires_grad=True)
270+
delta_bias = torch.randn(2, device=device, requires_grad=True)
271+
delta_softplus = False
272+
return_last_state = False
273+
274+
# Run opcheck
275+
result = opcheck(
276+
torch.ops.custom_ops.selective_scan_fwd,
277+
(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state),
278+
test_utils=("test_schema", "test_faketensor", "test_aot_dispatch_dynamic", "test_autograd_registration"),
279+
raise_exception=True
280+
)
281+
print("Opcheck result:", result)

0 commit comments

Comments
 (0)