@@ -224,56 +224,129 @@ def __init__(
224
224
mask = torch .empty (n_ctx , n_ctx ).fill_ (- np .inf ).triu_ (1 )
225
225
self .register_buffer ("mask" , mask , persistent = False )
226
226
227
- def forward (self , x : Tensor , xa : Tensor , kv_cache : Optional [dict ] = None ):
227
+ # Optimisation: pre-compute and register the mask in CUDA if available
228
+ if torch .cuda .is_available ():
229
+ self .register_buffer ("mask_cuda" , mask .cuda (), persistent = False )
230
+
231
+ < << << << Updated upstream
232
+
233
+ def forward (self , tokens : Tensor , audio_features : Tensor ) -> Tensor :
228
234
"""
229
- x : torch.LongTensor, shape = (batch_size, <= n_ctx)
230
- the text tokens
231
- xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
232
- the encoded audio features to be attended on
235
+ Args:
236
+ tokens: (n_batch, n_token)
237
+ audio_features: (n_batch, n_audio_ctx, n_audio_state)
238
+
239
+ Returns:
240
+ logits: (n_batch, n_token, n_vocab)
233
241
"""
234
- offset = next (iter (kv_cache .values ())).shape [1 ] if kv_cache else 0
235
- x = (
236
- self .token_embedding (x )
237
- + self .positional_embedding [offset : offset + x .shape [- 1 ]]
238
- )
239
- x = x .to (xa .dtype )
242
+ n_batch , n_token = tokens .shape
243
+ n_audio_ctx , n_audio_state = audio_features .shape [1 :]
244
+
245
+ x = self .token_embedding (tokens ) + self .positional_embedding [:n_token ]
246
+
247
+ # Optimisation: Move audio_features to GPU once here.
248
+ if torch .cuda .is_available ():
249
+ audio_features = audio_features .cuda ()
250
+
240
251
241
252
for block in self .blocks :
242
- x = block (x , xa , mask = self . mask , kv_cache = kv_cache )
253
+ x = block (x , audio_features )
243
254
244
255
x = self .ln (x )
245
- logits = (
246
- x @ torch .transpose (self .token_embedding .weight .to (x .dtype ), 0 , 1 )
247
- ).float ()
256
+ logits = x @ self .token_embedding .weight .T
257
+
258
+ # Optimisation: Apply the precomputed CUDA mask if available.
259
+ if torch .cuda .is_available ():
260
+ mask = self .mask_cuda [:n_token , :n_token ]
261
+ else :
262
+ mask = self .mask [:n_token , :n_token ]
263
+
264
+ logits = logits + mask
248
265
249
266
return logits
250
267
251
268
252
- class Whisper (nn .Module ):
253
- def __init__ (self , dims : ModelDimensions ):
254
- super ().__init__ ()
255
- self .dims = dims
256
- self .encoder = AudioEncoder (
257
- self .dims .n_mels ,
258
- self .dims .n_audio_ctx ,
259
- self .dims .n_audio_state ,
260
- self .dims .n_audio_head ,
261
- self .dims .n_audio_layer ,
262
- )
263
- self .decoder = TextDecoder (
264
- self .dims .n_vocab ,
265
- self .dims .n_text_ctx ,
266
- self .dims .n_text_state ,
267
- self .dims .n_text_head ,
268
- self .dims .n_text_layer ,
269
- )
270
- # use the last half among the decoder layers for time alignment by default;
271
- # to use a specific set of heads, see `set_alignment_heads()` below.
272
- all_heads = torch .zeros (
273
- self .dims .n_text_layer , self .dims .n_text_head , dtype = torch .bool
274
- )
275
- all_heads [self .dims .n_text_layer // 2 :] = True
276
- self .register_buffer ("alignment_heads" , all_heads .to_sparse (), persistent = False )
269
+ def forward (self , x : Tensor , xa : Tensor , kv_cache : Optional [dict ] = None ):
270
+ == == == =
271
+ def forward (
272
+ self ,
273
+ tokens : Tensor ,
274
+ audio_features : Tensor ,
275
+ kv_cache : Optional [dict ] = None
276
+ ) -> Tensor :
277
+ >> >> >> > Stashed changes
278
+ """
279
+ Args:
280
+ tokens: (n_batch, n_token) or x tensor
281
+ audio_features: (n_batch, n_audio_ctx, n_audio_state) or xa tensor
282
+ kv_cache: Optional cache for key/value tensors
283
+ """
284
+ if kv_cache is not None :
285
+ # Handle the kv_cache case
286
+ offset = next (iter (kv_cache .values ())).shape [1 ] if kv_cache else 0
287
+ x = (
288
+ self .token_embedding (tokens )
289
+ + self .positional_embedding [offset : offset + tokens .shape [- 1 ]]
290
+ )
291
+ x = x .to (audio_features .dtype )
292
+
293
+ for block in self .blocks :
294
+ x = block (x , audio_features , mask = self .mask , kv_cache = kv_cache )
295
+
296
+ x = self .ln (x )
297
+ logits = (
298
+ x @ torch .transpose (self .token_embedding .weight .to (x .dtype ), 0 , 1 )
299
+ ).float ()
300
+
301
+ return logits
302
+ else :
303
+ # Handle the non-kv_cache case
304
+ n_batch , n_token = tokens .shape
305
+ x = self .token_embedding (tokens ) + self .positional_embedding [:n_token ]
306
+
307
+ if torch .cuda .is_available ():
308
+ audio_features = audio_features .cuda ()
309
+
310
+ for block in self .blocks :
311
+ x = block (x , audio_features )
312
+
313
+ x = self .ln (x )
314
+ logits = x @ self .token_embedding .weight .T
315
+
316
+ if torch .cuda .is_available ():
317
+ mask = self .mask_cuda [:n_token , :n_token ]
318
+ else :
319
+ mask = self .mask [:n_token , :n_token ]
320
+
321
+ logits = logits + mask
322
+
323
+ return logits
324
+
325
+ class Whisper (nn .Module ):
326
+ def __init__ (self , dims : ModelDimensions ):
327
+ super ().__init__ ()
328
+ self .dims = dims
329
+ self .encoder = AudioEncoder (
330
+ self .dims .n_mels ,
331
+ self .dims .n_audio_ctx ,
332
+ self .dims .n_audio_state ,
333
+ self .dims .n_audio_head ,
334
+ self .dims .n_audio_layer ,
335
+ )
336
+ self .decoder = TextDecoder (
337
+ self .dims .n_vocab ,
338
+ self .dims .n_text_ctx ,
339
+ self .dims .n_text_state ,
340
+ self .dims .n_text_head ,
341
+ self .dims .n_text_layer ,
342
+ )
343
+ # use the last half among the decoder layers for time alignment by default;
344
+ # to use a specific set of heads, see `set_alignment_heads()` below.
345
+ all_heads = torch .zeros (
346
+ self .dims .n_text_layer , self .dims .n_text_head , dtype = torch .bool
347
+ )
348
+ all_heads [self .dims .n_text_layer // 2 :] = True
349
+ self .register_buffer ("alignment_heads" , all_heads .to_sparse (), persistent = False )
277
350
278
351
def set_alignment_heads (self , dump : bytes ):
279
352
array = np .frombuffer (
@@ -342,4 +415,4 @@ def install_hooks(layer: nn.Module):
342
415
343
416
detect_language = detect_language_function
344
417
transcribe = transcribe_function
345
- decode = decode_function
418
+ decode = decode_function
0 commit comments