1
1
import torch
2
- import torch .nn .functional as F
3
2
from einops import rearrange
4
3
from typing import Optional , Tuple
5
4
6
- import selective_scan_cuda
5
+ from mamba_ssm . ops . selective_scan_interface import selective_scan_cuda
7
6
8
7
9
8
@torch .library .custom_op (
10
9
"custom_ops::selective_scan_fwd" ,
11
10
device_types = ["cuda" ],
12
11
mutates_args = (),
12
+ schema = "(Tensor u, Tensor delta, Tensor A, Tensor B, Tensor C, Tensor? D, Tensor? z, Tensor? delta_bias, bool delta_softplus, bool return_last_state) -> (Tensor, Tensor, Tensor, Tensor, bool, bool, bool)" ,
13
13
)
14
14
def custom_selective_scan_fwd (
15
15
u : torch .Tensor ,
@@ -22,28 +22,33 @@ def custom_selective_scan_fwd(
22
22
delta_bias : Optional [torch .Tensor ],
23
23
delta_softplus : bool ,
24
24
return_last_state : bool ,
25
- ) -> Tuple [ torch . Tensor , torch . Tensor , torch . Tensor , torch . Tensor , bool , bool , bool ] :
25
+ ):
26
26
pass
27
27
28
+
28
29
@torch .library .register_fake ("custom_ops::selective_scan_fwd" )
29
30
def custom_selective_scan_fwd_fake (
30
- u ,
31
- delta ,
32
- A ,
33
- B ,
34
- C ,
35
- D ,
36
- z ,
37
- delta_bias ,
38
- delta_softplus ,
39
- return_last_state ,
31
+ u , delta , A , B , C , D , z , delta_bias , delta_softplus , return_last_state
40
32
):
41
- final_out = torch .empty_like (u )
42
33
dstate = A .size (1 ) * (2 if A .is_complex () else 1 )
43
- last_state_fake = u .new_empty ((u .size (0 ), u .size (1 ), dstate )) if return_last_state else u .new_empty (0 )
44
- out_fake = torch .empty_like (u )
45
- x_fake = u .new_empty ((u .size (0 ), u .size (1 ), u .size (2 ), 2 * dstate ))
46
- return final_out , last_state_fake , out_fake , x_fake , False , False , z is not None
34
+ seqlen = u .size (2 )
35
+ n_chunks = (seqlen + 2048 - 1 ) // 2048
36
+
37
+ squeeze_B = B .dim () == 3
38
+ squeeze_C = C .dim () == 3
39
+ has_z = z is not None
40
+
41
+ final_out = torch .empty_like (delta )
42
+ out_fake = torch .empty_like (delta )
43
+ last_state_fake = (
44
+ u .new_empty ((u .size (0 ), u .size (1 ), dstate ))
45
+ if return_last_state
46
+ else u .new_empty (0 )
47
+ )
48
+ x_fake = u .new_empty ((u .size (0 ), u .size (1 ), n_chunks , 2 * A .size (1 )), dtype = A .dtype )
49
+
50
+ return final_out , last_state_fake , out_fake , x_fake , squeeze_B , squeeze_C , has_z
51
+
47
52
48
53
@torch .library .register_kernel ("custom_ops::selective_scan_fwd" , "cuda" )
49
54
def custom_selective_scan_fwd_cuda (
@@ -81,16 +86,23 @@ def custom_selective_scan_fwd_cuda(
81
86
C = rearrange (C , "b dstate l -> b 1 dstate l" ).contiguous ()
82
87
squeeze_C = True
83
88
84
- out , x , * rest = selective_scan_cuda .fwd (u , delta , A , B , C , D , z , delta_bias , delta_softplus )
89
+ out , x , * rest = selective_scan_cuda .fwd (
90
+ u , delta , A , B , C , D , z , delta_bias , delta_softplus
91
+ )
85
92
has_z = z is not None
86
- final_out = rest [0 ].clone () if has_z else out .clone ()
93
+ if has_z :
94
+ final_out = rest [0 ].clone ()
95
+ else :
96
+ final_out = out .clone ()
87
97
last_state = x [:, :, - 1 , 1 ::2 ].clone () if return_last_state else u .new_empty (0 )
88
98
return final_out , last_state , out , x , squeeze_B , squeeze_C , has_z
89
99
100
+
90
101
@torch .library .custom_op (
91
102
"custom_ops::selective_scan_bwd" ,
92
103
device_types = ["cuda" ],
93
104
mutates_args = (),
105
+ schema = "(Tensor dout, Tensor u, Tensor delta, Tensor A, Tensor B, Tensor C, Tensor? D, Tensor? z, Tensor? delta_bias, bool delta_softplus, Tensor out, Tensor x, bool squeeze_B, bool squeeze_C, bool recompute_out_z) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor?, Tensor?, Tensor?)" ,
94
106
)
95
107
def custom_selective_scan_bwd (
96
108
dout : torch .Tensor ,
@@ -107,9 +119,11 @@ def custom_selective_scan_bwd(
107
119
x : torch .Tensor ,
108
120
squeeze_B : bool ,
109
121
squeeze_C : bool ,
110
- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
122
+ recompute_out_z : bool ,
123
+ ):
111
124
pass
112
125
126
+
113
127
@torch .library .register_fake ("custom_ops::selective_scan_bwd" )
114
128
def custom_selective_scan_bwd_fake (
115
129
dout ,
@@ -126,16 +140,33 @@ def custom_selective_scan_bwd_fake(
126
140
x ,
127
141
squeeze_B ,
128
142
squeeze_C ,
143
+ recompute_out_z ,
129
144
):
145
+ # Here we just return shape-compatible fake tensors
130
146
du = torch .empty_like (u )
131
147
ddelta = torch .empty_like (delta )
132
148
dA = torch .empty_like (A )
133
- dB = torch .empty_like (B )
134
- dC = torch .empty_like (C )
135
- dD = torch .empty_like (D ) if (D is not None and D .numel () > 0 ) else u .new_empty (0 )
136
- dz = torch .empty_like (z ) if (z is not None and z .numel () > 0 ) else u .new_empty (0 )
137
- ddelta_bias = torch .empty_like (delta_bias ) if (delta_bias is not None and delta_bias .numel () > 0 ) else u .new_empty (0 )
138
- return du , ddelta , dA , dB , dC , dD , dz , ddelta_bias
149
+
150
+ # Decide if variable B/C
151
+ is_variable_B = B .dim () > 3
152
+ is_variable_C = C .dim () > 3
153
+
154
+ dB = torch .empty_like (
155
+ B , dtype = B .dtype
156
+ ) # If variable_B, still float32 is okay for fake
157
+ dC = torch .empty_like (C , dtype = C .dtype )
158
+
159
+ dD = torch .empty_like (D ) if (D is not None ) else None
160
+ ddelta_bias_out = torch .empty_like (delta_bias ) if (delta_bias is not None ) else None
161
+ dz = torch .empty_like (z ) if (z is not None ) else None
162
+
163
+ if squeeze_B and dB .numel () > 0 :
164
+ dB = dB .squeeze (1 )
165
+ if squeeze_C and dC .numel () > 0 :
166
+ dC = dC .squeeze (1 )
167
+
168
+ return du , ddelta , dA , dB , dC , dD , ddelta_bias_out , dz
169
+
139
170
140
171
@torch .library .register_kernel ("custom_ops::selective_scan_bwd" , "cuda" )
141
172
def custom_selective_scan_bwd_cuda (
@@ -153,68 +184,101 @@ def custom_selective_scan_bwd_cuda(
153
184
x : torch .Tensor ,
154
185
squeeze_B : bool ,
155
186
squeeze_C : bool ,
187
+ recompute_out_z : bool ,
156
188
):
157
189
if dout .stride (- 1 ) != 1 :
158
190
dout = dout .contiguous ()
159
- B = B .contiguous ()
160
- C = C .contiguous ()
161
191
162
192
results = selective_scan_cuda .bwd (
163
- u , delta , A , B , C , D , z , delta_bias , dout , x , out , None , delta_softplus , False
193
+ u ,
194
+ delta ,
195
+ A ,
196
+ B ,
197
+ C ,
198
+ D ,
199
+ z ,
200
+ delta_bias ,
201
+ dout ,
202
+ x ,
203
+ out ,
204
+ None ,
205
+ delta_softplus ,
206
+ recompute_out_z ,
164
207
)
208
+
165
209
has_z = z is not None
166
210
if has_z :
167
- du , ddelta , dA , dB , dC , dD , ddelta_bias , dz = results
211
+ du , ddelta , dA , dB , dC , dD , ddelta_bias_out , dz = results
168
212
else :
169
- du , ddelta , dA , dB , dC , dD , ddelta_bias = results
170
- dz = u . new_empty ( 0 )
213
+ du , ddelta , dA , dB , dC , dD , ddelta_bias_out = results
214
+ dz = None
171
215
172
216
if squeeze_B and dB .numel () > 0 :
173
217
dB = dB .squeeze (1 )
174
218
if squeeze_C and dC .numel () > 0 :
175
219
dC = dC .squeeze (1 )
176
220
177
- return du , ddelta , dA , dB , dC , dD , dz , ddelta_bias
221
+ return du , ddelta , dA , dB , dC , dD , ddelta_bias_out , dz
222
+
178
223
179
224
def custom_bridge (ctx , * grads ):
180
225
dout = grads [0 ] if grads else ctx .saved_tensors [0 ].new_empty (0 )
181
226
saved = ctx .saved_tensors
227
+
182
228
if not ctx .has_z :
183
229
u , delta , A , B , C , D , delta_bias , x , out = saved
184
230
z = None
185
231
else :
186
232
u , delta , A , B , C , D , z , delta_bias , x , out = saved
187
233
188
- du , ddelta , dA , dB , dC , dD , dz , ddelta_bias = torch .ops .custom_ops .selective_scan_bwd (
189
- dout ,
190
- u ,
191
- delta ,
192
- A ,
193
- B ,
194
- C ,
195
- D ,
196
- z ,
197
- delta_bias ,
198
- ctx .delta_softplus ,
199
- out ,
200
- x ,
201
- ctx .squeeze_B ,
202
- ctx .squeeze_C
234
+ du , ddelta , dA , dB , dC , dD , ddelta_bias_out , dz = (
235
+ torch .ops .custom_ops .selective_scan_bwd (
236
+ dout ,
237
+ u ,
238
+ delta ,
239
+ A ,
240
+ B ,
241
+ C ,
242
+ D ,
243
+ z ,
244
+ delta_bias ,
245
+ ctx .delta_softplus ,
246
+ out ,
247
+ x ,
248
+ ctx .squeeze_B ,
249
+ ctx .squeeze_C ,
250
+ False ,
251
+ )
203
252
)
204
253
254
+ # For optional inputs, return None if not provided in forward
255
+ if D is None :
256
+ dD = None
257
+ if z is None :
258
+ dz = None
259
+ if delta_bias is None :
260
+ ddelta_bias_out = None
261
+
262
+ # Return gradients in the order of forward inputs:
263
+ # (u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
264
+ # `delta_softplus` and `return_last_state` are bools -> gradient = None
265
+ d_delta_softplus = None
266
+ d_return_last_state = None
267
+
205
268
return (
206
269
du ,
207
270
ddelta ,
208
271
dA ,
209
272
dB ,
210
273
dC ,
211
- dD if D is not None else None ,
212
- dz if z is not None else None ,
213
- ddelta_bias if delta_bias is not None else None ,
214
- None ,
215
- None ,
274
+ dD ,
275
+ dz ,
276
+ ddelta_bias_out ,
277
+ d_delta_softplus ,
278
+ d_return_last_state ,
216
279
)
217
280
281
+
218
282
def custom_setup_context (ctx , inputs , output ):
219
283
(u , delta , A , B , C , D , z , delta_bias , delta_softplus , return_last_state ) = inputs
220
284
(final_out , last_state , out , x , squeeze_B , squeeze_C , has_z ) = output
@@ -236,10 +300,12 @@ def custom_setup_context(ctx, inputs, output):
236
300
else :
237
301
ctx .save_for_backward (u , delta , A , B , C , D , z , delta_bias , x , out )
238
302
303
+
239
304
torch .library .register_autograd (
240
305
"custom_ops::selective_scan_fwd" , custom_bridge , setup_context = custom_setup_context
241
306
)
242
307
308
+
243
309
def selective_scan_fn_custom_op (
244
310
u : torch .Tensor ,
245
311
delta : torch .Tensor ,
@@ -252,20 +318,9 @@ def selective_scan_fn_custom_op(
252
318
delta_softplus : bool ,
253
319
return_last_state : bool ,
254
320
) -> torch .Tensor :
255
- # Pass all arguments positionally, exactly in schema order:
256
321
final_out , last_state , _ , _ , _ , _ , _ = torch .ops .custom_ops .selective_scan_fwd (
257
- u ,
258
- delta ,
259
- A ,
260
- B ,
261
- C ,
262
- D ,
263
- z ,
264
- delta_bias ,
265
- delta_softplus ,
266
- return_last_state
322
+ u , delta , A , B , C , D , z , delta_bias , delta_softplus , return_last_state
267
323
)
268
-
269
324
if return_last_state :
270
325
return final_out , last_state
271
326
else :
0 commit comments