Skip to content

Commit 5264945

Browse files
author
eleanorTurintech
committed
Peformance improvements
1 parent 517a43e commit 5264945

File tree

2 files changed

+134
-49
lines changed

2 files changed

+134
-49
lines changed

whisper/__init__.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,20 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
6060
if os.path.exists(download_target) and not os.path.isfile(download_target):
6161
raise RuntimeError(f"{download_target} exists and is not a regular file")
6262

63+
def compute_sha256(file_path: str) -> str:
64+
sha256 = hashlib.sha256()
65+
with open(file_path, "rb") as f:
66+
for chunk in iter(lambda: f.read(8192), b""):
67+
sha256.update(chunk)
68+
return sha256.hexdigest()
69+
6370
if os.path.isfile(download_target):
64-
with open(download_target, "rb") as f:
65-
model_bytes = f.read()
66-
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
67-
return model_bytes if in_memory else download_target
71+
if compute_sha256(download_target) == expected_sha256:
72+
if in_memory:
73+
with open(download_target, "rb") as f:
74+
return f.read()
75+
else:
76+
return download_target
6877
else:
6978
warnings.warn(
7079
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
@@ -86,13 +95,16 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
8695
output.write(buffer)
8796
loop.update(len(buffer))
8897

89-
model_bytes = open(download_target, "rb").read()
90-
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
98+
if compute_sha256(download_target) != expected_sha256:
9199
raise RuntimeError(
92100
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
93101
)
94102

95-
return model_bytes if in_memory else download_target
103+
if in_memory:
104+
with open(download_target, "rb") as f:
105+
return f.read()
106+
else:
107+
return download_target
96108

97109

98110
def available_models() -> List[str]:
@@ -157,4 +169,4 @@ def load_model(
157169
if alignment_heads is not None:
158170
model.set_alignment_heads(alignment_heads)
159171

160-
return model.to(device)
172+
return model.to(device)

whisper/model.py

+114-41
Original file line numberDiff line numberDiff line change
@@ -224,56 +224,129 @@ def __init__(
224224
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
225225
self.register_buffer("mask", mask, persistent=False)
226226

227-
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
227+
# Optimisation: pre-compute and register the mask in CUDA if available
228+
if torch.cuda.is_available():
229+
self.register_buffer("mask_cuda", mask.cuda(), persistent=False)
230+
231+
<<<<<<< Updated upstream
232+
233+
def forward(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
228234
"""
229-
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
230-
the text tokens
231-
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
232-
the encoded audio features to be attended on
235+
Args:
236+
tokens: (n_batch, n_token)
237+
audio_features: (n_batch, n_audio_ctx, n_audio_state)
238+
239+
Returns:
240+
logits: (n_batch, n_token, n_vocab)
233241
"""
234-
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
235-
x = (
236-
self.token_embedding(x)
237-
+ self.positional_embedding[offset : offset + x.shape[-1]]
238-
)
239-
x = x.to(xa.dtype)
242+
n_batch, n_token = tokens.shape
243+
n_audio_ctx, n_audio_state = audio_features.shape[1:]
244+
245+
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
246+
247+
# Optimisation: Move audio_features to GPU once here.
248+
if torch.cuda.is_available():
249+
audio_features = audio_features.cuda()
250+
240251

241252
for block in self.blocks:
242-
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
253+
x = block(x, audio_features)
243254

244255
x = self.ln(x)
245-
logits = (
246-
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
247-
).float()
256+
logits = x @ self.token_embedding.weight.T
257+
258+
# Optimisation: Apply the precomputed CUDA mask if available.
259+
if torch.cuda.is_available():
260+
mask = self.mask_cuda[:n_token, :n_token]
261+
else:
262+
mask = self.mask[:n_token, :n_token]
263+
264+
logits = logits + mask
248265

249266
return logits
250267

251268

252-
class Whisper(nn.Module):
253-
def __init__(self, dims: ModelDimensions):
254-
super().__init__()
255-
self.dims = dims
256-
self.encoder = AudioEncoder(
257-
self.dims.n_mels,
258-
self.dims.n_audio_ctx,
259-
self.dims.n_audio_state,
260-
self.dims.n_audio_head,
261-
self.dims.n_audio_layer,
262-
)
263-
self.decoder = TextDecoder(
264-
self.dims.n_vocab,
265-
self.dims.n_text_ctx,
266-
self.dims.n_text_state,
267-
self.dims.n_text_head,
268-
self.dims.n_text_layer,
269-
)
270-
# use the last half among the decoder layers for time alignment by default;
271-
# to use a specific set of heads, see `set_alignment_heads()` below.
272-
all_heads = torch.zeros(
273-
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
274-
)
275-
all_heads[self.dims.n_text_layer // 2 :] = True
276-
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
269+
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
270+
=======
271+
def forward(
272+
self,
273+
tokens: Tensor,
274+
audio_features: Tensor,
275+
kv_cache: Optional[dict] = None
276+
) -> Tensor:
277+
>>>>>>> Stashed changes
278+
"""
279+
Args:
280+
tokens: (n_batch, n_token) or x tensor
281+
audio_features: (n_batch, n_audio_ctx, n_audio_state) or xa tensor
282+
kv_cache: Optional cache for key/value tensors
283+
"""
284+
if kv_cache is not None:
285+
# Handle the kv_cache case
286+
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
287+
x = (
288+
self.token_embedding(tokens)
289+
+ self.positional_embedding[offset : offset + tokens.shape[-1]]
290+
)
291+
x = x.to(audio_features.dtype)
292+
293+
for block in self.blocks:
294+
x = block(x, audio_features, mask=self.mask, kv_cache=kv_cache)
295+
296+
x = self.ln(x)
297+
logits = (
298+
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
299+
).float()
300+
301+
return logits
302+
else:
303+
# Handle the non-kv_cache case
304+
n_batch, n_token = tokens.shape
305+
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
306+
307+
if torch.cuda.is_available():
308+
audio_features = audio_features.cuda()
309+
310+
for block in self.blocks:
311+
x = block(x, audio_features)
312+
313+
x = self.ln(x)
314+
logits = x @ self.token_embedding.weight.T
315+
316+
if torch.cuda.is_available():
317+
mask = self.mask_cuda[:n_token, :n_token]
318+
else:
319+
mask = self.mask[:n_token, :n_token]
320+
321+
logits = logits + mask
322+
323+
return logits
324+
325+
class Whisper(nn.Module):
326+
def __init__(self, dims: ModelDimensions):
327+
super().__init__()
328+
self.dims = dims
329+
self.encoder = AudioEncoder(
330+
self.dims.n_mels,
331+
self.dims.n_audio_ctx,
332+
self.dims.n_audio_state,
333+
self.dims.n_audio_head,
334+
self.dims.n_audio_layer,
335+
)
336+
self.decoder = TextDecoder(
337+
self.dims.n_vocab,
338+
self.dims.n_text_ctx,
339+
self.dims.n_text_state,
340+
self.dims.n_text_head,
341+
self.dims.n_text_layer,
342+
)
343+
# use the last half among the decoder layers for time alignment by default;
344+
# to use a specific set of heads, see `set_alignment_heads()` below.
345+
all_heads = torch.zeros(
346+
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
347+
)
348+
all_heads[self.dims.n_text_layer // 2 :] = True
349+
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
277350

278351
def set_alignment_heads(self, dump: bytes):
279352
array = np.frombuffer(
@@ -342,4 +415,4 @@ def install_hooks(layer: nn.Module):
342415

343416
detect_language = detect_language_function
344417
transcribe = transcribe_function
345-
decode = decode_function
418+
decode = decode_function

0 commit comments

Comments
 (0)