@@ -229,91 +229,58 @@ def __init__(
229
229
self .register_buffer ("mask_cuda" , mask .cuda (), persistent = False )
230
230
231
231
232
- def forward (self , tokens : Tensor , audio_features : Tensor ) -> Tensor :
232
+ def forward (self , tokens : Tensor , audio_features : Tensor , kv_cache : Optional [ dict ] = None ) -> Tensor :
233
233
"""
234
234
Args:
235
235
tokens: (n_batch, n_token)
236
236
audio_features: (n_batch, n_audio_ctx, n_audio_state)
237
+ kv_cache: Optional cache for key/value tensors
237
238
238
239
Returns:
239
240
logits: (n_batch, n_token, n_vocab)
240
241
"""
241
242
n_batch , n_token = tokens .shape
242
- n_audio_ctx , n_audio_state = audio_features .shape [1 :]
243
-
244
- x = self .token_embedding (tokens ) + self .positional_embedding [:n_token ]
243
+
244
+ # Get the dtype of audio_features to ensure consistency
245
+ dtype = audio_features .dtype
246
+
247
+ # Handle kv_cache for token embedding offset
248
+ if kv_cache is not None :
249
+ offset = next (iter (kv_cache .values ())).shape [1 ] if kv_cache else 0
250
+ x = self .token_embedding (tokens ) + self .positional_embedding [offset :offset + tokens .shape [1 ]]
251
+ else :
252
+ x = self .token_embedding (tokens ) + self .positional_embedding [:n_token ]
253
+
254
+ # Convert to the same dtype as audio_features
255
+ x = x .to (dtype )
245
256
246
257
# Optimisation: Move audio_features to GPU once here.
247
258
if torch .cuda .is_available ():
248
259
audio_features = audio_features .cuda ()
249
260
250
-
261
+ # Process through attention blocks
251
262
for block in self .blocks :
252
- x = block (x , audio_features )
263
+ x = block (x , audio_features , kv_cache = kv_cache )
253
264
254
265
x = self .ln (x )
255
- logits = x @ self .token_embedding .weight .T
256
-
257
- # Optimisation: Apply the precomputed CUDA mask if available.
258
- if torch .cuda .is_available ():
259
- mask = self .mask_cuda [:n_token , :n_token ]
260
- else :
261
- mask = self .mask [:n_token , :n_token ]
262
266
263
- logits = logits + mask
264
-
265
- return logits
266
-
267
-
268
- def forward (self , x : Tensor , xa : Tensor , kv_cache : Optional [dict ] = None ):
269
- """
270
- Args:
271
- tokens: (n_batch, n_token) or x tensor
272
- audio_features: (n_batch, n_audio_ctx, n_audio_state) or xa tensor
273
- kv_cache: Optional cache for key/value tensors
274
- """
275
- if kv_cache is not None :
276
- # Handle the kv_cache case
277
- offset = next (iter (kv_cache .values ())).shape [1 ] if kv_cache else 0
278
- x = (
279
- self .token_embedding (tokens )
280
- + self .positional_embedding [offset : offset + tokens .shape [- 1 ]]
281
- )
282
- x = x .to (audio_features .dtype )
283
-
284
- for block in self .blocks :
285
- x = block (x , audio_features , mask = self .mask , kv_cache = kv_cache )
286
-
287
- x = self .ln (x )
288
- logits = (
289
- x @ torch .transpose (self .token_embedding .weight .to (x .dtype ), 0 , 1 )
290
- ).float ()
291
-
292
- return logits
293
- else :
294
- # Handle the non-kv_cache case
295
- n_batch , n_token = tokens .shape
296
- x = self .token_embedding (tokens ) + self .positional_embedding [:n_token ]
297
-
298
- if torch .cuda .is_available ():
299
- audio_features = audio_features .cuda ()
300
-
301
- for block in self .blocks :
302
- x = block (x , audio_features )
303
-
304
- x = self .ln (x )
305
- logits = x @ self .token_embedding .weight .T
306
-
267
+ # Ensure consistent dtype for matrix multiplication
268
+ # Convert token_embedding weight to the same dtype as x
269
+ embedding_weights = self .token_embedding .weight .to (x .dtype )
270
+ logits = x @ embedding_weights .T
271
+
272
+ # Apply mask if not using kv_cache (inference)
273
+ if kv_cache is None :
274
+ # Optimisation: Apply the precomputed CUDA mask if available.
307
275
if torch .cuda .is_available ():
308
276
mask = self .mask_cuda [:n_token , :n_token ]
309
277
else :
310
278
mask = self .mask [:n_token , :n_token ]
311
-
279
+
312
280
logits = logits + mask
313
281
314
- return logits
315
-
316
-
282
+ return logits
283
+
317
284
# The Whisper class has been moved outside of TextDecoder and is now a top-level class
318
285
class Whisper (nn .Module ):
319
286
def __init__ (self , dims : ModelDimensions ):
0 commit comments