Skip to content

Commit

Permalink
WIP: add docstring for IP
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa committed Feb 24, 2025
1 parent 3cb001d commit 583b76f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 5 deletions.
10 changes: 5 additions & 5 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,11 +514,11 @@ def get_attention_backend(
use_unfused_attention = False

# Filter: KV cache
# backend | non-paged/paged | precision
# ---------------------------------------------------------------------------------
# FlashAttention | non-paged/paged | FP16/BF16
# FusedAttention | non-paged/paged | FP16/BF16 (non-paged/paged), FP8 (non-paged)
# UnfusedDotProductAttention | non-paged/paged | FP32/FP16/BF16
# backend | precision
# -------------------------------------------------------------------------
# FlashAttention | FP16/BF16 (non-paged/paged)
# FusedAttention | FP16/BF16 (non-paged/paged), FP8 (non-paged)
# UnfusedDotProductAttention | FP32/FP16/BF16 (non-paged/paged)
if inference_params is not None:
if context_parallel:
logger.debug(
Expand Down
48 changes: 48 additions & 0 deletions transformer_engine/pytorch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,54 @@ class InferenceParams:
to efficiently cache previous tokens and reuse them for the current
inference iteration.
A typical KV caching workflow is as follows.::
modules = [TransformerLayer() for _ in range(num_layers)]
model = torch.nn.Sequential(*modules)
inference_params = InferenceParams(max_batch_size, max_seqlen_kv, ...)
for i in range(inference_iterations):
# seq_ids = [0, 2, 3]
# step_lens = [10, 1, 1]
# step_dict = OrderedDict(zip(seq_ids, step_lens))
inference_params.pre_step(step_dict)
output = model(
...,
inference_params=inference_params,
attn_mask_type="padding_causal",
)
# assume qkv_format = "bshd"
if inference_params.is_output_right_aligned:
output = output[:,-1]
else:
output = output[:,step_dict.values()]
The memory allocation and copies of the new KV tokens to KV cache take place
in the following locations.::
class TransformerLayer:
class MultiHeadAttention:
if self.layer_number not in inference_params:
inference_params.allocate_memory(self.layer_number)
class DotProductAttention:
if inference_params is not None:
q, k_cache, v_cache, qkv_format = inference_params.step(
new_q, new_k, new_v, qkv_format)
output = attention(q, k_cache, v_cache, new_qkv_format)
if inference_params is not None:
output = inference_params.post_step(output)
return output
InferenceParams supports cache_qkv_format = "bshd" only, and the step() function may
change qkv_format depending on the attention backend.
Backend | Before step() | After step()
------------------------------------------------------------------------------------
FusedAttention | {bshd, sbhd, thd} | {bshd_2bshd, sbhd_2bshd, thd_2bshd}
FlashAttention | {bshd, sbhd, thd} | {bshd, sbhd, thd}
UnfusedDotProductAttention | {bshd, sbhd, thd} | {bshd, sbhd, bshd}
Parameters
----------
max_batch_size: int
Expand Down

0 comments on commit 583b76f

Please sign in to comment.