Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvements for transcription (up to 20% faster transcription on CPU) #2516

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Fix forward method overload in TextDecoder
eleanorTurintech committed Mar 11, 2025
commit d00d4868116b5cacc84efd1427a3a3fb2d534678
89 changes: 28 additions & 61 deletions whisper/model.py
Original file line number Diff line number Diff line change
@@ -229,91 +229,58 @@ def __init__(
self.register_buffer("mask_cuda", mask.cuda(), persistent=False)


def forward(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
def forward(self, tokens: Tensor, audio_features: Tensor, kv_cache: Optional[dict] = None) -> Tensor:
"""
Args:
tokens: (n_batch, n_token)
audio_features: (n_batch, n_audio_ctx, n_audio_state)
kv_cache: Optional cache for key/value tensors

Returns:
logits: (n_batch, n_token, n_vocab)
"""
n_batch, n_token = tokens.shape
n_audio_ctx, n_audio_state = audio_features.shape[1:]

x = self.token_embedding(tokens) + self.positional_embedding[:n_token]

# Get the dtype of audio_features to ensure consistency
dtype = audio_features.dtype

# Handle kv_cache for token embedding offset
if kv_cache is not None:
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = self.token_embedding(tokens) + self.positional_embedding[offset:offset + tokens.shape[1]]
else:
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]

# Convert to the same dtype as audio_features
x = x.to(dtype)

# Optimisation: Move audio_features to GPU once here.
if torch.cuda.is_available():
audio_features = audio_features.cuda()


# Process through attention blocks
for block in self.blocks:
x = block(x, audio_features)
x = block(x, audio_features, kv_cache=kv_cache)

x = self.ln(x)
logits = x @ self.token_embedding.weight.T

# Optimisation: Apply the precomputed CUDA mask if available.
if torch.cuda.is_available():
mask = self.mask_cuda[:n_token, :n_token]
else:
mask = self.mask[:n_token, :n_token]

logits = logits + mask

return logits


def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
Args:
tokens: (n_batch, n_token) or x tensor
audio_features: (n_batch, n_audio_ctx, n_audio_state) or xa tensor
kv_cache: Optional cache for key/value tensors
"""
if kv_cache is not None:
# Handle the kv_cache case
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (
self.token_embedding(tokens)
+ self.positional_embedding[offset : offset + tokens.shape[-1]]
)
x = x.to(audio_features.dtype)

for block in self.blocks:
x = block(x, audio_features, mask=self.mask, kv_cache=kv_cache)

x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()

return logits
else:
# Handle the non-kv_cache case
n_batch, n_token = tokens.shape
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]

if torch.cuda.is_available():
audio_features = audio_features.cuda()

for block in self.blocks:
x = block(x, audio_features)

x = self.ln(x)
logits = x @ self.token_embedding.weight.T

# Ensure consistent dtype for matrix multiplication
# Convert token_embedding weight to the same dtype as x
embedding_weights = self.token_embedding.weight.to(x.dtype)
logits = x @ embedding_weights.T

# Apply mask if not using kv_cache (inference)
if kv_cache is None:
# Optimisation: Apply the precomputed CUDA mask if available.
if torch.cuda.is_available():
mask = self.mask_cuda[:n_token, :n_token]
else:
mask = self.mask[:n_token, :n_token]

logits = logits + mask

return logits


return logits

# The Whisper class has been moved outside of TextDecoder and is now a top-level class
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):