Skip to content

Commit d00d486

Browse files
Fix forward method overload in TextDecoder
1 parent 7a552cb commit d00d486

File tree

1 file changed

+28
-61
lines changed

1 file changed

+28
-61
lines changed

whisper/model.py

+28-61
Original file line numberDiff line numberDiff line change
@@ -229,91 +229,58 @@ def __init__(
229229
self.register_buffer("mask_cuda", mask.cuda(), persistent=False)
230230

231231

232-
def forward(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
232+
def forward(self, tokens: Tensor, audio_features: Tensor, kv_cache: Optional[dict] = None) -> Tensor:
233233
"""
234234
Args:
235235
tokens: (n_batch, n_token)
236236
audio_features: (n_batch, n_audio_ctx, n_audio_state)
237+
kv_cache: Optional cache for key/value tensors
237238
238239
Returns:
239240
logits: (n_batch, n_token, n_vocab)
240241
"""
241242
n_batch, n_token = tokens.shape
242-
n_audio_ctx, n_audio_state = audio_features.shape[1:]
243-
244-
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
243+
244+
# Get the dtype of audio_features to ensure consistency
245+
dtype = audio_features.dtype
246+
247+
# Handle kv_cache for token embedding offset
248+
if kv_cache is not None:
249+
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
250+
x = self.token_embedding(tokens) + self.positional_embedding[offset:offset + tokens.shape[1]]
251+
else:
252+
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
253+
254+
# Convert to the same dtype as audio_features
255+
x = x.to(dtype)
245256

246257
# Optimisation: Move audio_features to GPU once here.
247258
if torch.cuda.is_available():
248259
audio_features = audio_features.cuda()
249260

250-
261+
# Process through attention blocks
251262
for block in self.blocks:
252-
x = block(x, audio_features)
263+
x = block(x, audio_features, kv_cache=kv_cache)
253264

254265
x = self.ln(x)
255-
logits = x @ self.token_embedding.weight.T
256-
257-
# Optimisation: Apply the precomputed CUDA mask if available.
258-
if torch.cuda.is_available():
259-
mask = self.mask_cuda[:n_token, :n_token]
260-
else:
261-
mask = self.mask[:n_token, :n_token]
262266

263-
logits = logits + mask
264-
265-
return logits
266-
267-
268-
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
269-
"""
270-
Args:
271-
tokens: (n_batch, n_token) or x tensor
272-
audio_features: (n_batch, n_audio_ctx, n_audio_state) or xa tensor
273-
kv_cache: Optional cache for key/value tensors
274-
"""
275-
if kv_cache is not None:
276-
# Handle the kv_cache case
277-
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
278-
x = (
279-
self.token_embedding(tokens)
280-
+ self.positional_embedding[offset : offset + tokens.shape[-1]]
281-
)
282-
x = x.to(audio_features.dtype)
283-
284-
for block in self.blocks:
285-
x = block(x, audio_features, mask=self.mask, kv_cache=kv_cache)
286-
287-
x = self.ln(x)
288-
logits = (
289-
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
290-
).float()
291-
292-
return logits
293-
else:
294-
# Handle the non-kv_cache case
295-
n_batch, n_token = tokens.shape
296-
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
297-
298-
if torch.cuda.is_available():
299-
audio_features = audio_features.cuda()
300-
301-
for block in self.blocks:
302-
x = block(x, audio_features)
303-
304-
x = self.ln(x)
305-
logits = x @ self.token_embedding.weight.T
306-
267+
# Ensure consistent dtype for matrix multiplication
268+
# Convert token_embedding weight to the same dtype as x
269+
embedding_weights = self.token_embedding.weight.to(x.dtype)
270+
logits = x @ embedding_weights.T
271+
272+
# Apply mask if not using kv_cache (inference)
273+
if kv_cache is None:
274+
# Optimisation: Apply the precomputed CUDA mask if available.
307275
if torch.cuda.is_available():
308276
mask = self.mask_cuda[:n_token, :n_token]
309277
else:
310278
mask = self.mask[:n_token, :n_token]
311-
279+
312280
logits = logits + mask
313281

314-
return logits
315-
316-
282+
return logits
283+
317284
# The Whisper class has been moved outside of TextDecoder and is now a top-level class
318285
class Whisper(nn.Module):
319286
def __init__(self, dims: ModelDimensions):

0 commit comments

Comments
 (0)