\n",
"\n",
diff --git a/docs/index.rst b/docs/index.rst
index d64cebbfa2..316c2ded59 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -44,7 +44,9 @@ Transformer Engine documentation
examples/fp8_primer.ipynb
examples/advanced_optimizations.ipynb
- examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
+ examples/te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb
+ examples/te_gemma/tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb
+ examples/te_gemma/tutorial_generation_gemma_with_te.ipynb
.. toctree::
:hidden:
diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh
index 90c5e499f3..985f92cedd 100644
--- a/qa/L0_pytorch_unittest/test.sh
+++ b/qa/L0_pytorch_unittest/test.sh
@@ -22,5 +22,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
+pytest -v -s $TE_PATH/tests/pytorch/test_generation.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py
diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py
index d6ba66cbbc..a2ce84293c 100644
--- a/tests/pytorch/test_fused_rope.py
+++ b/tests/pytorch/test_fused_rope.py
@@ -11,7 +11,7 @@
def apply_rotary_pos_emb_thd(
- t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor
+ t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor, start_positions: torch.Tensor
) -> torch.Tensor:
"""A baseline implementation of applying RoPE for `thd` format.
@@ -20,14 +20,106 @@ def apply_rotary_pos_emb_thd(
cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`,
with shape [b + 1] and dtype torch.int32.
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]
+ start_positions (Tensor): Tensor of shape [b] determining the beginning offsets
+ of frequeuncies applied to sequences.
Returns:
Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
"""
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- return torch.cat(
- [apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)]) for x in torch.split(t, seqlens)]
- ).squeeze(1)
+ if start_positions is None:
+ return torch.cat(
+ [
+ apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)])
+ for x in torch.split(t, seqlens)
+ ]
+ ).squeeze(1)
+ else:
+ return torch.cat(
+ [
+ apply_rotary_pos_emb(
+ x.unsqueeze(1), freqs[start_positions[i] : (x.size(0) + start_positions[i])]
+ )
+ for i, x in enumerate(torch.split(t, seqlens))
+ ]
+ ).squeeze(1)
+
+
+def apply_rotary_pos_emb_with_start_positions(
+ t: torch.Tensor,
+ freqs: torch.Tensor,
+ tensor_format: str = "sbhd",
+ start_positions: Union[torch.Tensor, None] = None,
+) -> torch.Tensor:
+ """
+ Apply rotary positional embedding tensor to the input tensor.
+ This is non-fused version which supports start_positions parameters.
+ Non-fused implementation with start_positions is slow, thus it is not included in the
+ Transformer Engine directly.
+
+ Parameters
+ ----------
+ t: torch.Tensor
+ Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
+ rotary positional embedding will be applied.
+ freqs: torch.Tensor
+ Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
+ with `s2 >= s` and `d2 <= d`.
+ tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
+ start_positions: torch.Tensor, default = None.
+ We may not want begin all the sequences from the 0 embedding.
+ This tensor argument allows that.
+ """
+
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
+ """
+ change sign so the last dimension becomes [-odd, +even]
+ """
+ x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
+ x1, x2 = x.unbind(dim=-2)
+ return torch.cat((-x2, x1), dim=-1)
+
+ if start_positions is None:
+ return apply_rotary_pos_emb(t, freqs, tensor_format=tensor_format)
+
+ max_seq_len = freqs.shape[0]
+ cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]
+
+ # Only apply the rotary embeddings up to the sequence length of the running
+ # input.
+ assert (
+ cur_seq_len <= max_seq_len
+ ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
+
+ if tensor_format == "bshd":
+ t = t.transpose(0, 1)
+ # cos/sin first then dtype conversion for better precision
+ cos_ = torch.cos(freqs).to(t.dtype)
+ sin_ = torch.sin(freqs).to(t.dtype)
+
+ rot_dim = freqs.shape[-1]
+ # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
+ t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
+
+ # shifted_sin, shifted_cos will have the same shape as t. They will contain
+ # scaling factors shifted for each sequence by the corresponding start_positions offset.
+
+ shifted_sin = sin_[:cur_seq_len].expand(t.shape).clone()
+ shifted_cos = cos_[:cur_seq_len].expand(t.shape).clone()
+
+ for b in range(start_positions.shape[0]):
+ assert max_seq_len >= start_positions[b]
+ shifted_freq = slice(start_positions[b], (start_positions[b] + cur_seq_len))
+ shifted_sin[:, b, :] = sin_[shifted_freq, 0, ...]
+ shifted_cos[:, b, :] = cos_[shifted_freq, 0, ...]
+
+ t = (t * shifted_cos) + (_rotate_half(t) * shifted_sin)
+ out = torch.cat((t, t_pass), dim=-1)
+
+ if tensor_format == "bshd":
+ out = out.transpose(0, 1).contiguous()
+
+ return out
def get_tol(dtype: torch.dtype) -> Dict:
@@ -54,8 +146,9 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("margin", [0, 10])
+@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)])
-@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
+@pytest.mark.parametrize("tensor_format", ["bshd", "sbhd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
def test_fused_rope(
dtype: torch.dtype,
@@ -63,6 +156,7 @@ def test_fused_rope(
hidden_size: int,
rotary_percent: float,
margin: int,
+ start_positions: bool,
transpose: Union[Tuple, None],
tensor_format: str,
loss_func: Callable,
@@ -80,11 +174,24 @@ def test_fused_rope(
t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True
+ if margin == 0 and start_positions == True:
+ # If sequence to encode has the same length as length of encoding
+ # there is no space left for starting with positions >0.
+ pytest.skip("Skipping test with margin=0 and start_positions=True")
+
+ start_positions = (
+ torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device)
+ if start_positions
+ else None
+ )
+
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
emb = rotary_pos_emb(seq_length)
# unfused
- output_unfused = apply_rotary_pos_emb(t, emb, tensor_format=tensor_format, fused=False)
+ output_unfused = apply_rotary_pos_emb_with_start_positions(
+ t, emb, tensor_format=tensor_format, start_positions=start_positions
+ )
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
@@ -92,10 +199,7 @@ def test_fused_rope(
# fused
output_fused = apply_rotary_pos_emb(
- t,
- emb,
- tensor_format=tensor_format,
- fused=True,
+ t, emb, tensor_format=tensor_format, fused=True, start_positions=start_positions
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
@@ -112,12 +216,14 @@ def test_fused_rope(
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("transpose", [None, (1, 2)])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
+@pytest.mark.parametrize("start_positions", [True, False])
def test_fused_rope_thd(
dtype: torch.dtype,
hidden_size: int,
rotary_percent: float,
transpose: Union[Tuple, None],
loss_func: Callable,
+ start_positions: bool,
) -> None:
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
@@ -135,11 +241,17 @@ def test_fused_rope_thd(
t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True
+ start_positions = (
+ torch.randint(0, 20, (cu_seqlens.shape[-1],), dtype=torch.int32, device=device)
+ if start_positions
+ else None
+ )
+
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
emb = rotary_pos_emb(cu_seqlens[-1])
# unfused
- output_unfused = apply_rotary_pos_emb_thd(t, cu_seqlens, emb)
+ output_unfused = apply_rotary_pos_emb_thd(t, cu_seqlens, emb, start_positions=start_positions)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
@@ -147,7 +259,12 @@ def test_fused_rope_thd(
# fused
output_fused = apply_rotary_pos_emb(
- t, emb, fused=True, tensor_format="thd", cu_seqlens=cu_seqlens
+ t,
+ emb,
+ fused=True,
+ tensor_format="thd",
+ cu_seqlens=cu_seqlens,
+ start_positions=start_positions,
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
diff --git a/tests/pytorch/test_generation.py b/tests/pytorch/test_generation.py
new file mode 100644
index 0000000000..343dd4db1d
--- /dev/null
+++ b/tests/pytorch/test_generation.py
@@ -0,0 +1,210 @@
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import pytest
+import torch
+
+import transformer_engine.pytorch as te
+
+
+class TestInferenceParams:
+ def test_setup_before_new_input_bshd(self):
+ inference_params = te.attention.InferenceParams(64, 128, qkv_format="bshd")
+
+ inference_params.setup_before_new_input(length=16)
+ # Offset before first sequence is equal to 0.
+ assert inference_params.sequence_len_offset == 0
+
+ # Offset before second sequence is equal to 16.
+ inference_params.setup_before_new_input(length=4)
+ assert inference_params.sequence_len_offset == 16
+
+ def test_setup_before_new_input_thd(self):
+ inference_params = te.attention.InferenceParams(4, 128, qkv_format="thd")
+
+ inference_params.setup_before_new_input(
+ lengths_tensor=torch.Tensor([1, 0, 2, 4]).cuda(), max_input_length=20
+ )
+
+ assert torch.equal(
+ inference_params.cached_sequence_lengths, torch.Tensor([0, 0, 0, 0]).cuda()
+ )
+ assert torch.equal(
+ inference_params.input_sequence_lengths, torch.Tensor([1, 0, 2, 4]).cuda()
+ )
+ assert inference_params.max_incoming_seq_len == 20
+
+ inference_params.setup_before_new_input(
+ lengths_tensor=torch.Tensor([2, 3, 5, 1]).cuda(), max_input_length=10
+ )
+ assert torch.equal(
+ inference_params.cached_sequence_lengths, torch.Tensor([1, 0, 2, 4]).cuda()
+ )
+ assert torch.equal(
+ inference_params.input_sequence_lengths, torch.Tensor([2, 3, 5, 1]).cuda()
+ )
+ assert inference_params.max_incoming_seq_len == 10
+
+ @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
+ @pytest.mark.parametrize("batch_size", [64, 128, 256])
+ @pytest.mark.parametrize("max_seq_len", [128, 256, 512])
+ @pytest.mark.parametrize("max_input_len", [32, 128])
+ def test_save_to_kv_cache_thd(self, batch_size, max_seq_len, max_input_len, dtype):
+ h, d = 16, 256
+
+ inference_params = te.attention.InferenceParams(batch_size, max_seq_len, qkv_format="thd")
+ inference_params.allocate_memory_for_kv_cache_if_empty(1, h, d, dtype)
+
+ t = batch_size * max_input_len
+ key_layer = torch.randn((t, h, d)).cuda().to(dtype)
+ value_layer = torch.randn((t, h, d)).cuda().to(dtype)
+
+ sequence_lengths = [1, 2] * (batch_size // 2)
+
+ # We save the same sequences two time, which should result in sequences of lentgh 2 and 4
+ # in the cache
+ inference_params.reset()
+ inference_params.setup_before_new_input(
+ lengths_tensor=torch.tensor(sequence_lengths).cuda(), max_input_length=max_input_len
+ )
+ inference_params.save_to_kv_cache(1, key_layer, value_layer)
+
+ inference_params.setup_before_new_input(
+ lengths_tensor=torch.tensor(sequence_lengths).cuda(), max_input_length=max_input_len
+ )
+ inference_params.save_to_kv_cache(1, key_layer, value_layer)
+
+ key_memory, value_memory = inference_params.key_value_memory_dict[1]
+
+ # Chcek whether the sequences were copied properly.
+
+ def check(memory, layer, b, idx1, idx2):
+ # Check if sequence idx in batch b in memory corresponds
+ # to the sequence idx2 in batch b in layer.
+ assert torch.equal(memory[b * max_seq_len + idx1], layer[b * max_input_len + idx2, :])
+
+ # even indices
+ for b in range(0, batch_size, 2):
+ check(key_memory, key_layer, b, 0, 0)
+ check(key_memory, key_layer, b, 1, 0)
+ assert (key_memory[b * max_seq_len + 2 : ((b + 1) * max_seq_len)] == 0).all()
+
+ check(value_memory, value_layer, b, 0, 0)
+ check(value_memory, value_layer, b, 1, 0)
+ assert (value_memory[b * max_seq_len + 2 : ((b + 1) * max_seq_len)] == 0).all()
+
+ # odd indices
+ for b in range(1, batch_size, 2):
+ check(key_memory, key_layer, b, 0, 0)
+ check(key_memory, key_layer, b, 1, 1)
+ check(key_memory, key_layer, b, 2, 0)
+ check(key_memory, key_layer, b, 3, 1)
+ assert (key_memory[b * max_seq_len + 4 : ((b + 1) * max_seq_len)] == 0).all()
+
+ check(value_memory, value_layer, b, 0, 0)
+ check(value_memory, value_layer, b, 1, 1)
+ check(value_memory, value_layer, b, 2, 0)
+ check(value_memory, value_layer, b, 3, 1)
+ assert (value_memory[b * max_seq_len + 4 : ((b + 1) * max_seq_len)] == 0).all()
+
+ @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
+ @pytest.mark.parametrize("batch_size", [64, 128, 256])
+ @pytest.mark.parametrize("max_seq_len", [128, 256, 512])
+ def test_save_to_kv_cache_bshd(self, batch_size, max_seq_len, dtype):
+ # This test checks if key_layer and value_layer are copied to cache.
+ # Cache size is equal to the size of one key/value layer.
+ h, d = 16, 256
+
+ inference_params = te.attention.InferenceParams(batch_size, max_seq_len, qkv_format="bshd")
+
+ inference_params.allocate_memory_for_kv_cache_if_empty(1, h, d, dtype)
+ key_layer = torch.randn((max_seq_len, batch_size, h, d)).cuda().to(dtype)
+ value_layer = torch.randn((max_seq_len, batch_size, h, d)).cuda().to(dtype)
+
+ inference_params.setup_before_new_input(length=0)
+ inference_params.save_to_kv_cache(1, key_layer, value_layer)
+
+ key_memory, value_memory = inference_params.key_value_memory_dict[1]
+
+ assert torch.equal(key_memory, key_layer)
+ assert torch.equal(value_memory, value_layer)
+
+ @pytest.mark.parametrize("layer_number", [1, 100])
+ @pytest.mark.parametrize("batch_size", [1, 128])
+ @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
+ def test_allocate_memory_for_kv_cache_if_empty(self, layer_number, batch_size, dtype):
+ nr_heads = 16
+ head_dim = 256
+ max_sequence_len = 128
+ inference_params = te.attention.InferenceParams(
+ batch_size, max_sequence_len, qkv_format="bshd"
+ )
+
+ assert layer_number not in inference_params.key_value_memory_dict
+
+ inference_params.allocate_memory_for_kv_cache_if_empty(
+ layer_number, nr_heads, head_dim, dtype
+ )
+
+ key_memory, value_memory = inference_params.key_value_memory_dict[layer_number]
+
+ assert key_memory.shape == (max_sequence_len, batch_size, nr_heads, head_dim)
+ assert value_memory.shape == (max_sequence_len, batch_size, nr_heads, head_dim)
+
+ # Should not allocate new buffers.
+ inference_params.allocate_memory_for_kv_cache_if_empty(layer_number, 100, 100, dtype)
+
+ assert key_memory.shape == (max_sequence_len, batch_size, nr_heads, head_dim)
+ assert value_memory.shape == (max_sequence_len, batch_size, nr_heads, head_dim)
+
+ def test_set_params_to_thd_attention(self):
+ # This test check whether parameteres needed to run thd attention
+ # are computed correcly. This parameters are passed to the fused_attn_fwd(..)
+ # to indicate which parts of the key/query/value layers are sequences and
+ # which of them are offsets.
+ batch_size = 4
+ channels = 1024
+ max_sequence_len = 128
+ max_input_len = 20
+ inference_params = te.attention.InferenceParams(
+ batch_size, max_sequence_len, qkv_format="thd"
+ )
+
+ inference_params.setup_before_new_input(
+ lengths_tensor=torch.Tensor([1, 1, 1, 1]).cuda(), max_input_length=max_input_len
+ )
+ inference_params.setup_before_new_input(
+ lengths_tensor=torch.Tensor([1, 0, 2, 4]).cuda(), max_input_length=max_input_len
+ )
+
+ buffers = [torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") for _ in range(6)]
+ max_q_len, max_kv_len, buffers = inference_params.set_params_to_thd_attention(
+ buffers, channels
+ )
+
+ cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = (
+ buffers
+ )
+
+ assert max_q_len == max_input_len
+ assert max_kv_len == max_sequence_len
+ assert torch.equal(cu_seqlens_q, torch.tensor([0, 1, 1, 3, 7]).cuda())
+ assert torch.equal(cu_seqlens_kv, torch.tensor([0, 2, 3, 6, 11]).cuda())
+
+ assert torch.equal(
+ seq_offsets_q,
+ torch.tensor([k * max_input_len * channels for k in range(batch_size + 1)]).cuda(),
+ )
+ assert torch.equal(
+ seq_offsets_k,
+ torch.tensor([k * max_sequence_len * channels for k in range(batch_size + 1)]).cuda(),
+ )
+ assert torch.equal(
+ seq_offsets_v,
+ torch.tensor([k * max_sequence_len * channels for k in range(batch_size + 1)]).cuda(),
+ )
+ assert torch.equal(
+ seq_offsets_o,
+ torch.tensor([k * max_input_len * channels for k in range(batch_size + 1)]).cuda(),
+ )
diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py
index 7eed97a0ca..8e20957384 100644
--- a/tests/pytorch/test_numerics.py
+++ b/tests/pytorch/test_numerics.py
@@ -3,8 +3,9 @@
# See LICENSE for license information.
import math
+import functools
import os
-from typing import Dict, List, Optional
+from typing import Dict, List, Tuple, Optional
import pytest
import copy
@@ -12,6 +13,8 @@
import torch.nn as nn
from torch.nn import Parameter
+import transformer_engine.pytorch.cpp_extensions as ext
+
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_model_init
from transformer_engine.pytorch.utils import (
init_method_normal,
@@ -40,6 +43,22 @@
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
+@functools.cache
+def _cudnn_version() -> Tuple[int, int, int]:
+ """Runtime cuDNN version (major, minor, patch)"""
+ encoded_version = ext.get_cudnn_version()
+ major_version_magnitude = 1000 if encoded_version < 90000 else 10000
+ major, encoded_version = divmod(encoded_version, major_version_magnitude)
+ minor, patch = divmod(encoded_version, 100)
+ return (major, minor, patch)
+
+
+def get_device_compute_capability() -> Tuple[int, int]:
+ """CUDA compute capability of current GPU"""
+ props = torch.cuda.get_device_properties(torch.cuda.current_device())
+ return (props.major, props.minor)
+
+
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@@ -1682,6 +1701,139 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
assert_allclose(full_output, incremental_output, atol[dtype])
+@pytest.mark.parametrize("dtype", param_types)
+@pytest.mark.parametrize("bs", batch_sizes)
+@pytest.mark.parametrize("model_key", model_configs_inference.keys())
+@pytest.mark.parametrize("use_RoPE", all_boolean)
+@pytest.mark.parametrize("module", module_inference)
+@pytest.mark.skipif(
+ get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+."
+)
+@pytest.mark.skipif(_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.")
+def test_kv_cache_accuracy_thd(dtype, bs, model_key, use_RoPE, module):
+ """
+ In thd attention sequences can have various lengths,
+ different that 's' dimension of input to the Transformer Layer.
+
+ The test contains of:
+ - one context phase when sequences with various lengths(!) are passed through the model,
+ - 2 phases when sequences with length 1 are passed through the model.
+
+ The output is compared with the case when all this sequences are passed at one.
+ """
+ if dtype == torch.float32:
+ pytest.skip("torch.float32 does not support thd")
+
+ fused_attn_env = os.environ["NVTE_FUSED_ATTN"]
+ os.environ["NVTE_FUSED_ATTN"] = "1" # Only fused attention supports thd.
+
+ if not fp8_available:
+ pytest.skip(reason_for_no_fp8)
+
+ config = model_configs_inference[model_key]
+
+ S = config.seq_len
+ B = bs
+ H = config.num_attention_heads
+ D = config.hidden_size
+ G = 2 # generation phase length
+ S_max = S + G
+ head_size = config.embed
+
+ layer_number = 1
+ rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")
+
+ # Tensors have shapes [b, s, h, d] and the seqlens are the tensor of shapes [b]
+ # which indicate the length of sequences - sequences starts from the begining.
+ # This function copies sequences from tensor into dst_tensor.
+ # dst_tensor should be big enough to fit this sequences.
+ def _concat_thd(dst_tensor, dst_seqlens, tensor, seqlens):
+ for b in range(B):
+ dst_tensor[b, dst_seqlens[b] : (dst_seqlens[b] + seqlens[b]), :] = tensor[
+ b, : seqlens[b], :
+ ]
+ dst_seqlens.copy_(dst_seqlens + seqlens)
+
+ if module == "TransformerLayer":
+ model = TransformerLayer(
+ hidden_size=D,
+ ffn_hidden_size=4 * D,
+ num_attention_heads=H,
+ attn_input_format="thd",
+ self_attn_mask_type="padding_causal",
+ layer_number=layer_number,
+ params_dtype=dtype,
+ device="cuda",
+ ).eval()
+ attn_name = "self_attn_mask_type"
+ else:
+ model = (
+ MultiheadAttention(
+ hidden_size=D,
+ num_attention_heads=H,
+ qkv_format="thd",
+ layer_number=layer_number,
+ params_dtype=dtype,
+ attn_mask_type="padding_causal",
+ )
+ .cuda()
+ .eval()
+ )
+ attn_name = "attn_mask_type"
+
+ inference_params = InferenceParams(B, S_max, qkv_format="thd")
+
+ kwargs = {
+ "inference_params": inference_params,
+ "rotary_pos_emb": rotary_freqs if use_RoPE else None,
+ }
+
+ total_sequence_lengths = torch.zeros((B,)).cuda().to(torch.int32)
+ total_tensor = torch.zeros((B, S_max, D)).cuda().to(dtype)
+
+ # Sequences split into chunks.
+
+ # context phase
+ sequence_lengths = torch.randint(1, S, (B,)).cuda().to(torch.int32)
+ chunk = torch.randn((B, S, D)).cuda().to(dtype)
+ inference_params.setup_before_new_input(max_input_length=S, lengths_tensor=sequence_lengths)
+ model(
+ chunk, inference_params=inference_params, rotary_pos_emb=rotary_freqs if use_RoPE else None
+ )
+ _concat_thd(total_tensor, total_sequence_lengths, chunk, sequence_lengths)
+
+ # generation phase
+ for _ in range(G):
+ sequence_lengths = torch.ones((B,)).cuda().to(torch.int32)
+ chunk = torch.randn((B, 1, D)).cuda().to(dtype)
+ inference_params.setup_before_new_input(max_input_length=1, lengths_tensor=sequence_lengths)
+ # we need to remove 'causal' from mask
+ # otherwise tokens we add will be considered as a first in the sequence,
+ # but they need to interact with all tokens from key-value cache.
+ # after removing this line, tests should fail
+ kwargs[attn_name] = "padding"
+ output = model(chunk, **kwargs)
+ _concat_thd(total_tensor, total_sequence_lengths, chunk, sequence_lengths)
+ incremental_logits = output[:, -1, :] # last element of each seq.
+
+ # Sequences passed in one, concatenated chunk.
+
+ kwargs[attn_name] = "padding_causal" # add 'causal' back to the mask
+ inference_params.reset()
+ inference_params.setup_before_new_input(
+ max_input_length=S_max, lengths_tensor=total_sequence_lengths
+ )
+ full_output = model(total_tensor, **kwargs)
+ full_logits = full_output[
+ torch.arange(0, B), total_sequence_lengths - 1, :
+ ] # last element of each seq.
+
+ # Final result should be close.
+ torch.testing.assert_close(full_logits, incremental_logits, atol=1e-2, rtol=1e-2)
+
+ os.environ["NVTE_FUSED_ATTN"] = fused_attn_env
+
+
@pytest.mark.parametrize(
"shape",
[
diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu
index e7cf940a57..560b7b55d3 100644
--- a/transformer_engine/common/fused_rope/fused_rope.cu
+++ b/transformer_engine/common/fused_rope/fused_rope.cu
@@ -15,11 +15,11 @@ namespace transformer_engine {
template
__device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst,
- const int offset_block, const int offset_block_dst,
- const int h, const int d, const int d2, const int stride_h,
- const int stride_d, const int o_stride_h,
- const int o_stride_d) {
- int s_id = blockIdx.x;
+ const int begin_offset, const int offset_block,
+ const int offset_block_dst, const int h, const int d,
+ const int d2, const int stride_h, const int stride_d,
+ const int o_stride_h, const int o_stride_d) {
+ int s_id = blockIdx.x + begin_offset;
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos, v_sin;
@@ -52,11 +52,11 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
template
__device__ void fused_rope_block_backward(const scalar_t *src, const float *freqs, scalar_t *dst,
- const int offset_block, const int offset_block_dst,
- const int h, const int d, const int d2,
- const int stride_h, const int stride_d,
+ const int begin_offset, const int offset_block,
+ const int offset_block_dst, const int h, const int d,
+ const int d2, const int stride_h, const int stride_d,
const int o_stride_h, const int o_stride_d) {
- int s_id = blockIdx.x;
+ int s_id = blockIdx.x + begin_offset;
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos = cosf(freqs[s_id * d2 + d_id]);
@@ -88,68 +88,75 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
}
template
-__global__ void fused_rope_forward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst,
- const int h, const int d, const int d2,
- const int stride_s, const int stride_b,
- const int stride_h, const int stride_d,
- const int o_stride_s, const int o_stride_b,
- const int o_stride_h, const int o_stride_d) {
+__global__ void fused_rope_forward_kernel(const scalar_t *src, const float *freqs,
+ const int *start_positions, scalar_t *dst, const int h,
+ const int d, const int d2, const int stride_s,
+ const int stride_b, const int stride_h,
+ const int stride_d, const int o_stride_s,
+ const int o_stride_b, const int o_stride_h,
+ const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
+ int begin_offset = (start_positions == 0) ? 0 : start_positions[b_id];
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
- fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h,
- stride_d, o_stride_h, o_stride_d);
+ fused_rope_block_forward(src, freqs, dst, begin_offset, offset_block, offset_block_dst, h, d, d2,
+ stride_h, stride_d, o_stride_h, o_stride_d);
}
template
-__global__ void fused_rope_backward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst,
- const int h, const int d, const int d2,
- const int stride_s, const int stride_b,
- const int stride_h, const int stride_d,
- const int o_stride_s, const int o_stride_b,
- const int o_stride_h, const int o_stride_d) {
+__global__ void fused_rope_backward_kernel(const scalar_t *src, const float *freqs,
+ const int *start_positions, scalar_t *dst, const int h,
+ const int d, const int d2, const int stride_s,
+ const int stride_b, const int stride_h,
+ const int stride_d, const int o_stride_s,
+ const int o_stride_b, const int o_stride_h,
+ const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
+ int begin_offset = (start_positions == 0) ? 0 : start_positions[b_id];
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
- fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h,
- stride_d, o_stride_h, o_stride_d);
+ fused_rope_block_backward(src, freqs, dst, begin_offset, offset_block, offset_block_dst, h, d, d2,
+ stride_h, stride_d, o_stride_h, o_stride_d);
}
template
__global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu_seqlens,
- const float *freqs, scalar_t *dst, const int h,
- const int d, const int d2, const int stride_t,
- const int stride_h, const int stride_d,
- const int o_stride_t, const int o_stride_h,
- const int o_stride_d) {
+ const float *freqs, const int *start_positions,
+ scalar_t *dst, const int h, const int d, const int d2,
+ const int stride_t, const int stride_h,
+ const int stride_d, const int o_stride_t,
+ const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int t_id = s_id + cu_seqlens[b_id];
if (t_id >= cu_seqlens[b_id + 1]) return;
int offset_block = t_id * stride_t;
int offset_block_dst = t_id * o_stride_t;
- fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h,
- stride_d, o_stride_h, o_stride_d);
+ int begin_offset = (start_positions == 0) ? 0 : start_positions[b_id];
+ fused_rope_block_forward(src, freqs, dst, begin_offset, offset_block, offset_block_dst, h, d, d2,
+ stride_h, stride_d, o_stride_h, o_stride_d);
}
template
__global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *cu_seqlens,
- const float *freqs, scalar_t *dst, const int h,
- const int d, const int d2, const int stride_t,
- const int stride_h, const int stride_d,
- const int o_stride_t, const int o_stride_h,
- const int o_stride_d) {
+ const float *freqs, const int *start_positions,
+ scalar_t *dst, const int h, const int d,
+ const int d2, const int stride_t, const int stride_h,
+ const int stride_d, const int o_stride_t,
+ const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int t_id = s_id + cu_seqlens[b_id];
if (t_id >= cu_seqlens[b_id + 1]) return;
int offset_block = t_id * stride_t;
int offset_block_dst = t_id * o_stride_t;
- fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h,
- stride_d, o_stride_h, o_stride_d);
+ int begin_offset = (start_positions == 0) ? 0 : start_positions[b_id];
+ fused_rope_block_backward(src, freqs, dst, begin_offset, offset_block, offset_block_dst, h, d, d2,
+ stride_h, stride_d, o_stride_h, o_stride_d);
}
template
-void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, scalar_t *output,
- const int s, const int b, const int h, const int d, const int d2,
+void fused_rope_forward_launcher(const scalar_t *input, const float *freqs,
+ const int *start_positions, scalar_t *output, const int s,
+ const int b, const int h, const int d, const int d2,
const int stride_s, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
@@ -158,115 +165,123 @@ void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, scal
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_forward_kernel<<>>(
- input, freqs, output, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
- o_stride_b, o_stride_h, o_stride_d);
+ input, freqs, start_positions, output, h, d, d2, stride_s, stride_b, stride_h, stride_d,
+ o_stride_s, o_stride_b, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template
void fused_rope_backward_launcher(const scalar_t *output_grads, const float *freqs,
- scalar_t *input_grads, const int s, const int b, const int h,
- const int d, const int d2, const int stride_s, const int stride_b,
- const int stride_h, const int stride_d, const int o_stride_s,
- const int o_stride_b, const int o_stride_h, const int o_stride_d,
- cudaStream_t stream) {
+ const int *start_positions, scalar_t *input_grads, const int s,
+ const int b, const int h, const int d, const int d2,
+ const int stride_s, const int stride_b, const int stride_h,
+ const int stride_d, const int o_stride_s, const int o_stride_b,
+ const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_backward_kernel<<>>(
- output_grads, freqs, input_grads, h, d, d2, stride_s, stride_b, stride_h, stride_d,
- o_stride_s, o_stride_b, o_stride_h, o_stride_d);
+ output_grads, freqs, start_positions, input_grads, h, d, d2, stride_s, stride_b, stride_h,
+ stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template
void fused_rope_thd_forward_launcher(const scalar_t *input, const int *cu_seqlens,
- const float *freqs, scalar_t *output, const int max_s,
- const int b, const int h, const int d, const int d2,
- const int stride_t, const int stride_h, const int stride_d,
- const int o_stride_t, const int o_stride_h,
- const int o_stride_d, cudaStream_t stream) {
+ const float *freqs, const int *start_positions,
+ scalar_t *output, const int max_s, const int b, const int h,
+ const int d, const int d2, const int stride_t,
+ const int stride_h, const int stride_d, const int o_stride_t,
+ const int o_stride_h, const int o_stride_d,
+ cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(max_s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
- fused_rope_thd_forward_kernel<<>>(input, cu_seqlens, freqs, output, h,
- d, d2, stride_t, stride_h, stride_d,
- o_stride_t, o_stride_h, o_stride_d);
+ fused_rope_thd_forward_kernel<<>>(
+ input, cu_seqlens, freqs, start_positions, output, h, d, d2, stride_t, stride_h, stride_d,
+ o_stride_t, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template
void fused_rope_thd_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens,
- const float *freqs, scalar_t *input_grads, const int max_s,
- const int b, const int h, const int d, const int d2,
- const int stride_t, const int stride_h, const int stride_d,
- const int o_stride_t, const int o_stride_h,
- const int o_stride_d, cudaStream_t stream) {
+ const float *freqs, const int *start_positions,
+ scalar_t *input_grads, const int max_s, const int b,
+ const int h, const int d, const int d2, const int stride_t,
+ const int stride_h, const int stride_d, const int o_stride_t,
+ const int o_stride_h, const int o_stride_d,
+ cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(max_s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_thd_backward_kernel<<>>(
- output_grads, cu_seqlens, freqs, input_grads, h, d, d2, stride_t, stride_h, stride_d,
- o_stride_t, o_stride_h, o_stride_d);
+ output_grads, cu_seqlens, freqs, start_positions, input_grads, h, d, d2, stride_t, stride_h,
+ stride_d, o_stride_t, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
-void fused_rope_forward(const Tensor &input, const Tensor &freqs, Tensor *output, const int s,
- const int b, const int h, const int d, const int d2, const int stride_s,
- const int stride_b, const int stride_h, const int stride_d,
- const int o_stride_s, const int o_stride_b, const int o_stride_h,
- const int o_stride_d, cudaStream_t stream) {
+void fused_rope_forward(const Tensor &input, const Tensor &freqs, const Tensor &start_positions,
+ Tensor *output, const int s, const int b, const int h, const int d,
+ const int d2, const int stride_s, const int stride_b, const int stride_h,
+ const int stride_d, const int o_stride_s, const int o_stride_b,
+ const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t,
fused_rope_forward_launcher(reinterpret_cast(input.data.dptr),
reinterpret_cast(freqs.data.dptr),
+ reinterpret_cast(start_positions.data.dptr),
reinterpret_cast(output->data.dptr), s, b, h, d, d2,
stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, stream););
}
-void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, Tensor *input_grads,
- const int s, const int b, const int h, const int d, const int d2,
- const int stride_s, const int stride_b, const int stride_h,
- const int stride_d, const int o_stride_s, const int o_stride_b,
- const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
+void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs,
+ const Tensor &start_positions, Tensor *input_grads, const int s,
+ const int b, const int h, const int d, const int d2, const int stride_s,
+ const int stride_b, const int stride_h, const int stride_d,
+ const int o_stride_s, const int o_stride_b, const int o_stride_h,
+ const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t,
fused_rope_backward_launcher(reinterpret_cast(output_grads.data.dptr),
reinterpret_cast(freqs.data.dptr),
+ reinterpret_cast(start_positions.data.dptr),
reinterpret_cast(input_grads->data.dptr), s, b, h, d,
d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
o_stride_b, o_stride_h, o_stride_d, stream););
}
void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs,
- Tensor *output, const int max_s, const int b, const int h, const int d,
- const int d2, const int stride_t, const int stride_h,
- const int stride_d, const int o_stride_t, const int o_stride_h,
- const int o_stride_d, cudaStream_t stream) {
+ const Tensor &start_positions, Tensor *output, const int max_s,
+ const int b, const int h, const int d, const int d2, const int stride_t,
+ const int stride_h, const int stride_d, const int o_stride_t,
+ const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t,
fused_rope_thd_forward_launcher(reinterpret_cast(input.data.dptr),
reinterpret_cast(cu_seqlens.data.dptr),
reinterpret_cast(freqs.data.dptr),
+ reinterpret_cast(start_positions.data.dptr),
reinterpret_cast(output->data.dptr), max_s, b, h,
d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h,
o_stride_d, stream););
}
void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlens,
- const Tensor &freqs, Tensor *input_grads, const int max_s, const int b,
- const int h, const int d, const int d2, const int stride_t,
- const int stride_h, const int stride_d, const int o_stride_t,
- const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
+ const Tensor &freqs, const Tensor &start_positions,
+ Tensor *input_grads, const int max_s, const int b, const int h,
+ const int d, const int d2, const int stride_t, const int stride_h,
+ const int stride_d, const int o_stride_t, const int o_stride_h,
+ const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t,
fused_rope_thd_backward_launcher(reinterpret_cast(output_grads.data.dptr),
reinterpret_cast(cu_seqlens.data.dptr),
reinterpret_cast(freqs.data.dptr),
+ reinterpret_cast(start_positions.data.dptr),
reinterpret_cast(input_grads->data.dptr), max_s,
b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t,
o_stride_h, o_stride_d, stream););
@@ -274,58 +289,62 @@ void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlen
} // end namespace transformer_engine
-void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output,
- const int s, const int b, const int h, const int d, const int d2,
+void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs,
+ const NVTETensor start_positions, NVTETensor output, const int s,
+ const int b, const int h, const int d, const int d2,
const int stride_s, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_forward);
using namespace transformer_engine;
fused_rope_forward(*reinterpret_cast(input),
- *reinterpret_cast(freqs), reinterpret_cast(output),
- s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
- o_stride_h, o_stride_d, stream);
+ *reinterpret_cast(freqs),
+ *reinterpret_cast(start_positions),
+ reinterpret_cast(output), s, b, h, d, d2, stride_s, stride_b,
+ stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream);
}
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs,
- NVTETensor input_grads, const int s, const int b, const int h,
- const int d, const int d2, const int stride_s, const int stride_b,
- const int stride_h, const int stride_d, const int o_stride_s,
- const int o_stride_b, const int o_stride_h, const int o_stride_d,
- cudaStream_t stream) {
+ const NVTETensor start_positions, NVTETensor input_grads, const int s,
+ const int b, const int h, const int d, const int d2,
+ const int stride_s, const int stride_b, const int stride_h,
+ const int stride_d, const int o_stride_s, const int o_stride_b,
+ const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_backward);
using namespace transformer_engine;
fused_rope_backward(*reinterpret_cast(output_grads),
*reinterpret_cast(freqs),
+ *reinterpret_cast(start_positions),
reinterpret_cast(input_grads), s, b, h, d, d2, stride_s, stride_b,
stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream);
}
void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens,
- const NVTETensor freqs, NVTETensor output, const int max_s,
- const int b, const int h, const int d, const int d2,
- const int stride_t, const int stride_h, const int stride_d,
- const int o_stride_t, const int o_stride_h, const int o_stride_d,
- cudaStream_t stream) {
+ const NVTETensor freqs, const NVTETensor start_positions,
+ NVTETensor output, const int max_s, const int b, const int h,
+ const int d, const int d2, const int stride_t, const int stride_h,
+ const int stride_d, const int o_stride_t, const int o_stride_h,
+ const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_thd_forward);
using namespace transformer_engine;
fused_rope_thd_forward(
*reinterpret_cast(input), *reinterpret_cast(cu_seqlens),
- *reinterpret_cast(freqs), reinterpret_cast(output), max_s, b, h, d,
- d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);
+ *reinterpret_cast(freqs), *reinterpret_cast(start_positions),
+ reinterpret_cast(output), max_s, b, h, d, d2, stride_t, stride_h, stride_d,
+ o_stride_t, o_stride_h, o_stride_d, stream);
}
void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
- const NVTETensor freqs, NVTETensor input_grads, const int max_s,
- const int b, const int h, const int d, const int d2,
- const int stride_t, const int stride_h, const int stride_d,
- const int o_stride_t, const int o_stride_h, const int o_stride_d,
- cudaStream_t stream) {
+ const NVTETensor freqs, const NVTETensor start_positions,
+ NVTETensor input_grads, const int max_s, const int b, const int h,
+ const int d, const int d2, const int stride_t, const int stride_h,
+ const int stride_d, const int o_stride_t, const int o_stride_h,
+ const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_thd_backward);
using namespace transformer_engine;
- fused_rope_thd_backward(*reinterpret_cast(output_grads),
- *reinterpret_cast(cu_seqlens),
- *reinterpret_cast(freqs),
- reinterpret_cast(input_grads), max_s, b, h, d, d2, stride_t,
- stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);
+ fused_rope_thd_backward(
+ *reinterpret_cast(output_grads),
+ *reinterpret_cast(cu_seqlens), *reinterpret_cast(freqs),
+ *reinterpret_cast(start_positions), reinterpret_cast(input_grads),
+ max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);
}
diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h
index b92de88eca..01305c1e6d 100644
--- a/transformer_engine/common/include/transformer_engine/fused_rope.h
+++ b/transformer_engine/common/include/transformer_engine/fused_rope.h
@@ -17,6 +17,7 @@ extern "C" {
*
* \param[in] input Input tensor for fused rope.
* \param[in] freqs The freqs tensor.
+ * \param[in] start_positions The beginning offsets.
* \param[out] output Output tensor.
* \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input.
@@ -33,8 +34,9 @@ extern "C" {
* \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation.
*/
-void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output,
- const int s, const int b, const int h, const int d, const int d2,
+void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs,
+ const NVTETensor start_positions, NVTETensor output, const int s,
+ const int b, const int h, const int d, const int d2,
const int stride_s, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream);
@@ -43,6 +45,7 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVT
*
* \param[in] output_grads Incoming gradient tensor for backward.
* \param[in] freqs The freqs tensor.
+ * \param[in] start_positions The tensor with positions of first tokens in sequences.
* \param[out] input_grads Input gradient tensor to calculate.
* \param[in] s Length of the s dimension of output_grads.
* \param[in] b Length of the b dimension of output_grads.
@@ -60,43 +63,45 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVT
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs,
- NVTETensor input_grads, const int s, const int b, const int h,
- const int d, const int d2, const int stride_s, const int stride_b,
- const int stride_h, const int stride_d, const int o_stride_s,
- const int o_stride_b, const int o_stride_h, const int o_stride_d,
- cudaStream_t stream);
+ const NVTETensor start_positions, NVTETensor input_grads, const int s,
+ const int b, const int h, const int d, const int d2,
+ const int stride_s, const int stride_b, const int stride_h,
+ const int stride_d, const int o_stride_s, const int o_stride_b,
+ const int o_stride_h, const int o_stride_d, cudaStream_t stream);
/*! \brief Apply rotary positional embedding to the input tensor in thd format.
*
- * \param[in] input Input tensor for fused rope.
- * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
- * \param[in] freqs The freqs tensor.
- * \param[out] output Output tensor.
- * \param[in] max_s Max sequence length.
- * \param[in] b Batch size.
- * \param[in] h Length of the h dimension of input.
- * \param[in] d Length of the d dimension of input.
- * \param[in] d2 Length of the d dimension of freqs.
- * \param[in] stride_t Stride of the t dimension of input.
- * \param[in] stride_h Stride of the h dimension of input.
- * \param[in] stride_d Stride of the d dimension of input.
- * \param[in] o_stride_t Stride of the t dimension of output.
- * \param[in] o_stride_h Stride of the h dimension of output.
- * \param[in] o_stride_d Stride of the d dimension of output.
- * \param[in] stream CUDA stream used for the operation.
+ * \param[in] input Input tensor for fused rope.
+ * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
+ * \param[in] freqs The freqs tensor.
+ * \param[in] start_positions The tensor with positions of first tokens in sequences.
+ * \param[out] output Output tensor.
+ * \param[in] max_s Max sequence length.
+ * \param[in] b Batch size.
+ * \param[in] h Length of the h dimension of input.
+ * \param[in] d Length of the d dimension of input.
+ * \param[in] d2 Length of the d dimension of freqs.
+ * \param[in] stride_t Stride of the t dimension of input.
+ * \param[in] stride_h Stride of the h dimension of input.
+ * \param[in] stride_d Stride of the d dimension of input.
+ * \param[in] o_stride_t Stride of the t dimension of output.
+ * \param[in] o_stride_h Stride of the h dimension of output.
+ * \param[in] o_stride_d Stride of the d dimension of output.
+ * \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens,
- const NVTETensor freqs, NVTETensor output, const int max_s,
- const int b, const int h, const int d, const int d2,
- const int stride_t, const int stride_h, const int stride_d,
- const int o_stride_t, const int o_stride_h, const int o_stride_d,
- cudaStream_t stream);
+ const NVTETensor freqs, NVTETensor start_positions,
+ NVTETensor output, const int max_s, const int b, const int h,
+ const int d, const int d2, const int stride_t, const int stride_h,
+ const int stride_d, const int o_stride_t, const int o_stride_h,
+ const int o_stride_d, cudaStream_t stream);
/*! \brief Compute the backward of the fused rope in thd format.
*
* \param[in] output_grads Incoming gradient tensor for backward.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* \param[in] freqs The freqs tensor.
+ * \param[in] start_positions The beginning offsets.
* \param[out] input_grads Input gradient to calculate.
* \param[in] max_s Max sequence length.
* \param[in] b Batch size.
@@ -112,11 +117,11 @@ void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seq
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
- const NVTETensor freqs, NVTETensor input_grads, const int max_s,
- const int b, const int h, const int d, const int d2,
- const int stride_t, const int stride_h, const int stride_d,
- const int o_stride_t, const int o_stride_h, const int o_stride_d,
- cudaStream_t stream);
+ const NVTETensor freqs, NVTETensor start_positions,
+ NVTETensor input_grads, const int max_s, const int b, const int h,
+ const int d, const int d2, const int stride_t, const int stride_h,
+ const int stride_d, const int o_stride_t, const int o_stride_h,
+ const int o_stride_d, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py
index fa72ecfa33..7430027335 100644
--- a/transformer_engine/pytorch/attention.py
+++ b/transformer_engine/pytorch/attention.py
@@ -703,18 +703,43 @@ class InferenceParams: # pylint: disable=too-few-public-methods
Parameters
----------
- max_batch_size : int
+ max_batch_size: int
maximum batch size during inference.
- max_sequence_length : int
- maximum sequence length during inference.
+ max_sequence_length: int
+ maximum sequence length during inference.
+ qkv_format: str
+ Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}.
+ `s` stands for the sequence length dimension,
+ `b` batch size, `h` the number of attention heads,
+ `d` head size, and `t` the total number of sequences in a batch, i.e.
+ `t = sum(s_i) for i = 0...b-1`.
"""
- def __init__(self, max_batch_size, max_sequence_length):
+ def __init__(self, max_batch_size, max_sequence_length, qkv_format="bshd"):
+ assert qkv_format in ["bshd", "sbhd", "thd"]
+
self.max_sequence_length = max_sequence_length
self.max_batch_size = max_batch_size
- self.sequence_len_offset = 0
- self.batch_size_offset = 0
+
+ # self.key_value_memory_dict[layer number] = (key_cache, value_cache)
+ # if qkv_format in ["bshd", "sbhd"]: (key/value)_cache.shape = [b/s, s/b, h, d]
+ # # if qkv_format = "thd": (key/value)_cache.shape = [t, h, d]
self.key_value_memory_dict = {}
+ self.qkv_format = qkv_format
+
+ if qkv_format == "thd":
+ # In thd attention layout input sequences can have different lenghts.
+ # self.input_sequence_lengths stores tensor of shape [b] with lengths of input sequences
+ # and self.cached_sequence_lengths is the sum of all previous input lengths tensors -
+ # equivalently it contains total lengths of cached sequences.
+ self.cached_sequence_lengths = torch.zeros(
+ (max_batch_size,), device="cuda", dtype=torch.int32)
+ self.input_sequence_lengths = torch.zeros(
+ (max_batch_size,), device="cuda", dtype=torch.int32)
+ else:
+ self.sequence_len_offset = 0
+ self.batch_size_offset = 0
+ self.input_sequence_length = None
def swap_key_value_dict(self, batch_indices):
"""
@@ -742,6 +767,214 @@ def swap_key_value_dict(self, batch_indices):
)
+ def setup_before_new_input(self, lengths_tensor=None, max_input_length=None, length=None):
+ """
+ Updates parameters representing incoming sequence lengths and lengths
+ of sequences in the cache. Should be called before every forward pass in the inference.
+
+ Parameters
+ ----------
+ lengths_tensor: torch.Tensor
+ 1d tensor with sequence lengths in new input.
+ Should be used only when self.qkv_format = "thd".
+ max_input_length: int
+ Should be used only when self.qkv_format = "thd".
+ If the incoming sequences tensor has shape [b * s, h, d],
+ this should be equal to s.
+ length: int
+ Length of the incoming sequences.
+ Should be used only when self.qkv_format in ["bshd", "sbhd"].
+ """
+ if self.qkv_format == "thd":
+ assert lengths_tensor is not None and max_input_length is not None, \
+ "lengths_tensor and max_input_length should not be none for qkv_format = \"thd\""
+ torch.add(
+ self.cached_sequence_lengths,
+ self.input_sequence_lengths,
+ out=self.cached_sequence_lengths)
+ self.input_sequence_lengths.copy_(lengths_tensor)
+ self.max_incoming_seq_len = max_input_length
+
+ else:
+ assert length is not None, \
+ "length should not be none for qkv_format in [\"bshd\", \"sbhd\"]"
+ if self.input_sequence_length is not None:
+ self.sequence_len_offset += self.input_sequence_length
+ self.input_sequence_length = length
+
+ def reset(self):
+ """
+ Resets the parameters to allow the use of this object in a new generation iteration.
+ This method does not reallocate buffers,
+ making it more efficient than creating a new InferenceParams object.
+ Moreover, reusing the same object with the same buffers is compatible
+ with the CUDA Graphs.
+ """
+ if self.qkv_format == "thd":
+ self.cached_sequence_lengths.zero_()
+ self.input_sequence_lengths.zero_()
+ else:
+ self.input_sequence_length = None
+ self.sequence_len_offset = 0
+
+ def save_to_kv_cache(self, layer_number, key_layer, value_layer):
+ """
+ Saves key_layer and value_layer in the cache.
+
+ Parameters
+ ----------
+ layer_number: input
+ layer number of the current `TransformerLayer` when multiple such modules are
+ concatenated to form a transformer block.
+ key_layer: torch.Tensor
+ Tensor - of the format corresponding to the self.qkv_format -
+ representing key_layer.
+ Notice: if self.qkv_format in ["bshd", "sbhd"] then both layers are in format sbhd
+ Notice: if self.qkv_format = "thd", we assume that offsets of the sequences
+ are of the form k * self.max_incoming_seq_len for k = 0, ..., batch_size-1.
+ value_layer: int
+ Tensor - of the format corresponding to the self.qkv_format -
+ representing value_layer.
+ Notice: if self.qkv_format in ["bshd", "sbhd"] both layers are in format sbhd
+ Notice: if self.qkv_format = "thd", we assume that offsets of the sequences
+ are of the form k * self.max_incoming_seq_len for k = 0, ..., batch_size-1.
+ """
+ # Current kernels work only with contiguous tensors, it can be made faster in the future.
+ key_layer, value_layer = key_layer.contiguous(), value_layer.contiguous()
+ inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
+ if self.qkv_format == "thd":
+ channels = inference_key_memory.shape[1] * inference_key_memory.shape[2] # h * d
+ # This kernels copies kernels from input layers into cache,
+ # taking into account the thd format and sequence lengths.
+ tex.attention_copy(
+ inference_key_memory,
+ self.cached_sequence_lengths,
+ self.input_sequence_lengths,
+ key_layer,
+ self.max_incoming_seq_len,
+ self.max_sequence_length,
+ self.max_batch_size,
+ channels)
+
+ tex.attention_copy(
+ inference_value_memory,
+ self.cached_sequence_lengths,
+ self.input_sequence_lengths,
+ value_layer,
+ self.max_incoming_seq_len,
+ self.max_sequence_length,
+ self.max_batch_size,
+ channels)
+ key_layer, value_layer = inference_key_memory, inference_value_memory
+ else:
+ assert self.qkv_format in ["bshd", "sbhd"], \
+ "Attention format not supported by the inference."
+ batch_start = self.batch_size_offset
+ batch_end = batch_start + key_layer.size(1)
+ assert batch_end <= inference_key_memory.size(1)
+
+ sequence_start = self.sequence_len_offset
+ sequence_end = sequence_start + key_layer.size(0)
+ assert sequence_end <= inference_key_memory.size(0)
+
+ # Copy keys and values into KV-cache
+ seq_offsets = slice(sequence_start, sequence_end)
+ batch_offsets = slice(batch_start, batch_end)
+ inference_key_memory[seq_offsets, batch_offsets, ...] = key_layer
+ inference_value_memory[seq_offsets, batch_offsets, ...] = value_layer
+ key_layer = inference_key_memory[:sequence_end, batch_offsets, ...]
+ value_layer = inference_value_memory[:sequence_end, batch_offsets, ...]
+ return key_layer, value_layer
+
+ def allocate_memory_for_kv_cache_if_empty(
+ self,
+ layer_number,
+ num_gqa_groups_per_partition,
+ hidden_size_per_attention_head,
+ dtype):
+ """
+ Allocates memory for kv_cache for given layer, if it hasn't been alocated before.
+
+ Parameters
+ ----------
+ layer_number: input
+ layer number of the current `TransformerLayer` when multiple such modules are
+ concatenated to form a transformer block.
+ num_gqa_groups_per_partition: torch.Tensor
+ This will be third dimension of cache tensor.
+ hidden_size_per_attention_head: int
+ This will be fourth dimension of cache tensor.
+ """
+
+ if layer_number in self.key_value_memory_dict:
+ return # Already allocated
+
+ b, s = self.max_batch_size, self.max_sequence_length
+
+ def _allocate_memory(dims):
+ return torch.zeros(
+ *dims,
+ num_gqa_groups_per_partition,
+ hidden_size_per_attention_head,
+ dtype=dtype,
+ device=torch.cuda.current_device(),
+ )
+
+ if self.qkv_format == "thd":
+ inference_key_memory = _allocate_memory((b * s,))
+ inference_value_memory = _allocate_memory((b * s,))
+ else:
+ inference_key_memory = _allocate_memory((s, b))
+ inference_value_memory = _allocate_memory((s, b))
+ self.key_value_memory_dict[layer_number] = (
+ inference_key_memory,
+ inference_value_memory,
+ )
+
+ def set_params_to_thd_attention(self, buffers, channels):
+ """
+ Fused attention with q/k/v of thd layout with offsets needs some parameters informing
+ about sequence lengths. This function computes them and
+ saves them into the provided buffers.
+
+ Parameters
+ ----------
+ buffers: List[torch.Tensor]
+ buffers of size [batch_size + 1] for the parameters:
+ cu_seqlens_q, cu_seqlens_kv, seq_offsets_q,
+ seq_offsets_k, seq_offsets_v, seq_offsets_o
+ respectively.
+ channels: int
+ value of num_heads * hidden_dim_for_each_head.
+
+ Returns
+ ----------
+ max_seqlen_q: int
+ Maximal value of query sequence length.
+ max_seqlen_kv: int
+ Maximal value of key/value sequence length.
+ buffers: torch.Tensor
+ Tensor with filled buffers.
+ """
+ max_seqlen_q, max_seqlen_kv = self.max_incoming_seq_len, self.max_sequence_length
+
+ cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = \
+ buffers
+
+ torch.cumsum(self.input_sequence_lengths, dim=0, out=cu_seqlens_q[1:])
+ torch.cumsum(
+ self.cached_sequence_lengths + self.input_sequence_lengths,
+ dim=0, out=cu_seqlens_kv[1:])
+ # If layer has shape [b * s_layer, h, d]
+ # offsets are of the form [k * s_layer * h * d for k = 0, ..., batch_size]
+ seq_offsets_q.copy_(
+ torch.arange(0, self.max_batch_size + 1, device="cuda") * channels * max_seqlen_q)
+ seq_offsets_k.copy_(
+ torch.arange(0, self.max_batch_size + 1, device="cuda") * channels * max_seqlen_kv)
+ seq_offsets_v.copy_(seq_offsets_k)
+ seq_offsets_o.copy_(seq_offsets_q)
+
+ return max_seqlen_q, max_seqlen_kv, buffers
@torch.no_grad()
def get_swa_mask(
window_size: Tuple[int, int],
@@ -2460,33 +2693,44 @@ def forward(
freqs: torch.Tensor,
tensor_format: str = "sbhd",
cu_seqlens: Union[torch.Tensor, None] = None,
+ beginning_offsets: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
+ if beginning_offsets is None:
+ # Each sequence will start from positional encoding corresponding to 0.
+ # Otherwise sequence i will start from positional encoding
+ # corresponding to beginning_offsets[i].
+ beginning_offsets = torch.Tensor()
if freqs.dtype != torch.float32:
freqs = freqs.float()
if tensor_format == "sbhd":
- output = tex.fused_rope_forward(t, freqs, False)
+ output = tex.fused_rope_forward(t, freqs, beginning_offsets, False)
elif tensor_format == "bshd":
- output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1)
+ output = tex.fused_rope_forward(
+ t.transpose(0, 1), freqs, beginning_offsets, True
+ ).transpose(0, 1)
elif tensor_format == "thd":
- output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs)
+ output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, beginning_offsets)
else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
- ctx.save_for_backward(freqs, cu_seqlens)
+ ctx.save_for_backward(freqs, cu_seqlens, beginning_offsets)
ctx.tensor_format = tensor_format
return output
@staticmethod
- def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
- freqs, cu_seqlens = ctx.saved_tensors
+ def backward(
+ ctx, grad_output: torch.Tensor
+ ) -> Tuple[Union[torch.Tensor, None], ...]:
+ freqs, cu_seqlens, start_positions = ctx.saved_tensors
if ctx.tensor_format == "sbhd":
- grad_input = tex.fused_rope_backward(grad_output, freqs, False)
+ grad_input = tex.fused_rope_backward(grad_output, freqs, start_positions, False)
elif ctx.tensor_format == "bshd":
grad_input = tex.fused_rope_backward(
- grad_output.transpose(0, 1), freqs, True
+ grad_output.transpose(0, 1), freqs, start_positions, True
).transpose(0, 1)
elif ctx.tensor_format == "thd":
- grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs)
+ grad_input = tex.fused_rope_thd_backward(
+ grad_output, cu_seqlens, freqs, start_positions)
else:
raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")
@@ -2508,6 +2752,7 @@ def apply_rotary_pos_emb(
tensor_format: str = "sbhd",
fused: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None,
+ start_positions: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
"""
Apply rotary positional embedding tensor to the input tensor.
@@ -2528,12 +2773,18 @@ def apply_rotary_pos_emb(
cu_seqlens: torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'.
+ start_positions: torch.Tensor, default = None.
+ Token i from sequence s have position encoding corresponding to
+ position start_positions[i]. If start_positions=None, then this token has position i.
"""
+ assert not (start_positions is not None and not fused), \
+ """start_positions != None and fused=False is not supported"""
+
if fused:
assert (
tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'."
- return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens)
+ return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, start_positions)
assert tensor_format in ("sbhd", "bshd"), (
"Only formats `sbhd` or `bshd` are supported for input tensor `t` "
@@ -5121,6 +5372,7 @@ def __init__(
self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream
+ self.channels = kv_channels * num_attention_heads
self.hidden_size_per_attention_head = kv_channels
@@ -5210,6 +5462,16 @@ def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unuse
self.register_load_state_dict_post_hook(remove_extra_states_check)
+ self._allocator = StaticBufferAllocator()
+
+
+ def alloc(self, size, dtype, device):
+ """
+ Allocated the buffer and works correctly with CUDA Graphs.
+ """
+ return self._allocator(size, dtype, device)
+
+
def _checkpointed_attention_forward(
self,
attention_func: Callable,
@@ -5413,21 +5675,7 @@ def forward(
first microbatch (since it is the first gradient being
produced)
"""
- with self.prepare_forward(
- query_layer,
- is_first_microbatch,
- num_gemms=3,
- allow_non_contiguous=True,
- ) as query_layer:
-
- if self.fp8:
- if self.fp8_meta["recipe"].fp8_mha:
- if not self.fp8_meta["recipe"].fp8_dpa:
- self.fp8_meta["recipe"].fp8_dpa = True
- self.logger.WARNING(
- """Forcing fp8_meta["recipe"].fp8_dpa=True due to """
- """fp8_meta["recipe"].fp8_mha=True"""
- )
+ batch_size = key_layer.shape[0]
if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
@@ -5484,28 +5732,26 @@ def forward(
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)
- (
- inference_key_memory,
- inference_value_memory,
- ) = inference_params.key_value_memory_dict[self.layer_number]
+ key_layer, value_layer = inference_params.save_to_kv_cache(
+ self.layer_number, key_layer, value_layer
+ )
- batch_start = inference_params.batch_size_offset
- batch_end = batch_start + key_layer.size(1)
- assert batch_end <= inference_key_memory.size(1)
+ if qkv_format == "thd":
+ # Allocation of buffers, it works correctly with CUDA Graphs.
+ NR_BUFFERS = 6
+ buffers = [
+ self.alloc(batch_size + 1, dtype=torch.int32, device="cuda")
+ for _ in range(NR_BUFFERS)
+ ]
- sequence_start = inference_params.sequence_len_offset
- sequence_end = sequence_start + key_layer.size(0)
- assert sequence_end <= inference_key_memory.size(0)
+ max_seqlen_q, max_seqlen_kv, buffers = \
+ inference_params.set_params_to_thd_attention(buffers, self.channels)
+ cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, \
+ seq_offsets_k, seq_offsets_v, seq_offsets_o = buffers
- # Copy keys and values into KV-cache
- inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
- key_layer
- )
- inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
- value_layer
- )
- key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
- value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
+ # query_layer is reshaped to the format [t, h, d]
+ # and make contiguous - needed by the THD attention
+ query_layer = query_layer.view(-1, *query_layer.shape[2:]).contiguous()
if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
@@ -5618,18 +5864,55 @@ def forward(
assert (
core_attention_bias is None
), "core_attention_bias must be None when core_attention_bias_type is alibi!"
- if (
- _alibi_cache["_num_heads"] != query_layer.shape[-2]
- or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
- or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
- or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment
- or _alibi_cache["_alibi_slopes"] is None
- ):
- _alibi_cache["_alibi_slopes_require_update"] = True
- _alibi_cache["_alibi_bias_require_update"] = True
+ if (_alibi_cache["_num_heads"] != query_layer.shape[-2]
+ or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
+ or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
+ or _alibi_cache["_alibi_slopes"] is None):
+ _alibi_cache["_alibi_slopes_require_update"] = True
+ _alibi_cache["_alibi_bias_require_update"] = True
+
+ if core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias is not None:
+ use_flash_attention = False
- context_parallel = (
- self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1
+ fu_core_attention_bias_type = core_attention_bias_type
+ fu_core_attention_bias = core_attention_bias
+ if core_attention_bias_type == "alibi" and use_fused_attention and alibi_slopes is not None:
+ fu_core_attention_bias_type = "post_scale_bias"
+ _, fu_core_attention_bias = get_alibi(
+ query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes,
+ bias_dtype=query_layer.dtype)
+ if (use_fused_attention
+ and fu_core_attention_bias_type == "post_scale_bias"
+ and (fu_core_attention_bias.shape[0] != 1
+ or fu_core_attention_bias.shape[1] != query_layer.shape[-2])):
+ if fu_core_attention_bias.requires_grad:
+ # remove this line when cuDNN adds bwd support for
+ # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
+ use_fused_attention = False
+ else:
+ # max512 backend will only support [1, h, s, s]
+ os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
+
+ if query_layer.shape[-1] == 256 and query_layer.requires_grad:
+ # Fused attention is not supported for backward with head_dim = 256.
+ # to do (cyang): move it to the tex.get_fused_attn_backend
+ use_fused_attention = False
+
+ if use_fused_attention:
+ fused_attention_backend = tex.get_fused_attn_backend(
+ TE_DType[query_layer.dtype]
+ if not isinstance(query_layer, Float8Tensor) else query_layer._fp8_dtype,
+ TE_DType[key_layer.dtype]
+ if not isinstance(key_layer, Float8Tensor) else key_layer._fp8_dtype,
+ QKVLayout[qkv_layout],
+ AttnBiasType[fu_core_attention_bias_type],
+ AttnMaskType[attn_mask_type],
+ self.attention_dropout,
+ query_layer.shape[-2], # num_attn_heads
+ key_layer.shape[-2], # num_gqa_groups
+ max_seqlen_q,
+ max_seqlen_kv,
+ query_layer.shape[-1], # head_dim
)
core_attention_bias_shape = None
@@ -5663,87 +5946,33 @@ def forward(
and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
)
- attention_params = AttentionParams(
- qkv_type=type(query_layer),
- qkv_dtype=query_layer.dtype,
- qkv_layout=qkv_layout,
- batch_size=batch_size,
- num_heads=query_layer.shape[-2],
- num_gqa_groups=key_layer.shape[-2],
- max_seqlen_q=max_seqlen_q,
- max_seqlen_kv=max_seqlen_kv,
- head_dim=query_layer.shape[-1],
- attn_mask_type=attn_mask_type,
- window_size=window_size,
- alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
- core_attention_bias_type=core_attention_bias_type,
- core_attention_bias_shape=core_attention_bias_shape,
- core_attention_bias_requires_grad=(
- core_attention_bias.requires_grad if core_attention_bias is not None else False
- ),
- pad_between_seqs=pad_between_seqs,
- attention_dropout=self.attention_dropout,
- context_parallel=context_parallel,
- deterministic=self.deterministic,
- is_training=self.training,
- fp8=self.fp8,
- fp8_meta=self.fp8_meta,
- )
- global _attention_backends
- if (
- _attention_backends["attention_params"] is None
- or attention_params != _attention_backends["attention_params"]
- ):
- _attention_backends["attention_params"] = attention_params
- _attention_backends["backend_selection_requires_update"] = True
- if _attention_backends["backend_selection_requires_update"]:
- (
- use_flash_attention,
- use_fused_attention,
- fused_attention_backend,
- use_unfused_attention,
- _,
- ) = get_attention_backend(attention_params)
- if use_flash_attention:
- self.logger.info("Running with FlashAttention backend")
- elif use_fused_attention:
- self.logger.info(
- "Running with FusedAttention backend (sub-backend %s)",
- int(fused_attention_backend),
- )
- elif use_unfused_attention:
- self.logger.info("Running with UnfusedDotProductAttention backend")
- else:
- use_flash_attention = _attention_backends["use_flash_attention"]
- use_fused_attention = _attention_backends["use_fused_attention"]
- fused_attention_backend = _attention_backends["fused_attention_backend"]
- use_unfused_attention = _attention_backends["use_unfused_attention"]
-
- if use_flash_attention:
- if core_attention_bias_type == "alibi":
- alibi_slopes, _ = get_alibi(
- query_layer.shape[-2],
- max_seqlen_q,
- max_seqlen_kv,
- alibi_slopes=alibi_slopes,
- )
- return self.flash_attention(
- query_layer,
- key_layer,
- value_layer,
- attention_mask=attention_mask,
- qkv_layout=qkv_layout,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_kv=cu_seqlens_kv,
- attn_mask_type=attn_mask_type,
- window_size=window_size,
- alibi_slopes=alibi_slopes,
- cp_group=self.cp_group,
- cp_global_ranks=self.cp_global_ranks,
- cp_stream=self.cp_stream,
- max_seqlen_q=max_seqlen_q,
- max_seqlen_kv=max_seqlen_kv,
- )
+ if self.attention_type == "self":
+ if self.qkv_format == "bshd" and query_layer.shape[1] != value_layer.shape[1] or \
+ self.qkv_format == "sbhd" and query_layer.shape[0] != value_layer.shape[0]:
+ # Flash attention does not self-support max_seqlen_q != max_seqlen_kv
+ use_flash_attention = False
+
+ if use_flash_attention:
+ if _NVTE_DEBUG:
+ print("[DotProductAttention]: using flash-attn",_flash_attn_version)
+ if core_attention_bias_type == "alibi":
+ alibi_slopes, _ = get_alibi(
+ query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes)
+ return self.flash_attention(query_layer,
+ key_layer,
+ value_layer,
+ attention_mask=attention_mask,
+ qkv_layout=qkv_layout,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_kv=cu_seqlens_kv,
+ attn_mask_type=attn_mask_type,
+ window_size=window_size,
+ alibi_slopes=alibi_slopes,
+ cp_group=self.cp_group,
+ cp_global_ranks=self.cp_global_ranks,
+ cp_stream=self.cp_stream,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_kv=max_seqlen_kv)
if use_fused_attention:
fu_core_attention_bias_type = core_attention_bias_type
@@ -5845,15 +6074,26 @@ def forward(
query_layer,
key_layer,
value_layer,
- qkv_layout=qkv_layout,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_kv=cu_seqlens_kv,
- attn_mask_type=attn_mask_type,
- attention_mask=attention_mask,
- core_attention_bias_type=core_attention_bias_type,
- core_attention_bias=core_attention_bias,
- alibi_slopes=alibi_slopes,
- )
+ qkv_layout = qkv_layout,
+ cu_seqlens_q = cu_seqlens_q,
+ cu_seqlens_kv = cu_seqlens_kv,
+ attn_mask_type = attn_mask_type,
+ attention_mask = attention_mask,
+ core_attention_bias_type = core_attention_bias_type,
+ core_attention_bias = core_attention_bias,
+ alibi_slopes = alibi_slopes)
+
+ return self.unfused_attention(query_layer,
+ key_layer,
+ value_layer,
+ qkv_layout = qkv_layout,
+ cu_seqlens_q = cu_seqlens_q,
+ cu_seqlens_kv = cu_seqlens_kv,
+ attn_mask_type = attn_mask_type,
+ attention_mask = attention_mask,
+ core_attention_bias_type = core_attention_bias_type,
+ core_attention_bias = core_attention_bias,
+ alibi_slopes = alibi_slopes)
raise Exception("No dot product attention support for the provided inputs!")
@@ -6206,17 +6446,13 @@ def __init__(
**common_gemm_kwargs,
)
- def _allocate_memory(
- self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
- ) -> torch.Tensor:
- return torch.empty(
- inference_max_sequence_len,
- batch_size,
- self.num_gqa_groups_per_partition,
- self.hidden_size_per_attention_head,
- dtype=dtype,
- device=torch.cuda.current_device(),
- )
+ self._allocator = StaticBufferAllocator()
+
+ def alloc(self, size, dtype, device):
+ """
+ Allocated the buffer and works correctly with CUDA Graphs.
+ """
+ return self._allocator(size, dtype, device)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""
@@ -6360,25 +6596,13 @@ def forward(
# Pre-allocate memory for key-values for inference
# =================================================
- if inference_params and self.layer_number is not None:
- if self.layer_number not in inference_params.key_value_memory_dict:
- inf_max_seq_len = inference_params.max_sequence_length
- inf_max_batch_size = inference_params.max_batch_size
- inference_key_memory = self._allocate_memory(
- inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
- )
- inference_value_memory = self._allocate_memory(
- inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
- )
- inference_params.key_value_memory_dict[self.layer_number] = (
- inference_key_memory,
- inference_value_memory,
- )
- else:
- (
- inference_key_memory,
- inference_value_memory,
- ) = inference_params.key_value_memory_dict[self.layer_number]
+ if inference_params is not None:
+ inference_params.allocate_memory_for_kv_cache_if_empty(
+ self.layer_number,
+ self.num_gqa_groups_per_partition,
+ self.hidden_size_per_attention_head,
+ hidden_states.dtype
+ )
# ======================
# Query, Key, and Value
@@ -6538,21 +6762,42 @@ def forward(
q_pos_emb, k_pos_emb = rotary_pos_emb
- # adjust key and value for inference
- if inference_params is not None:
- if self.qkv_format == "sbhd":
- sequence_length = key_layer.size(0)
- elif self.qkv_format == "bshd":
- sequence_length = key_layer.size(1)
+ if self.qkv_format == "thd" and inference_params is not None:
+ # For thd attention incoming tokens can be on different positions,
+ # so we need to copy different positional encoding freqency
+ # for every sequence in a batch.
+ #
+ # For example if sequence lengths in context phase are: 2 and 5 (batch size=2),
+ # in first generation phase key_layer have shape [2, 1, d].
+ # key_layer[0, :] corresponds to the token with position 3 = 2 + 1,
+ # and key_layer [1, :] corresponds to the token with position 6 = 5 + 1.
+
+ query_layer = apply_rotary_pos_emb(
+ query_layer, q_pos_emb, "bshd", fused=True,
+ start_positions=inference_params.cached_sequence_lengths)
+ key_layer = apply_rotary_pos_emb(
+ key_layer, k_pos_emb, "bshd", fused=True,
+ start_positions=inference_params.cached_sequence_lengths)
+
+ else:
+ # adjust key and value for inference
+ if inference_params is not None:
+ if self.qkv_format == "sbhd":
+ sequence_length = key_layer.size(0)
+ elif self.qkv_format == "bshd":
+ sequence_length = key_layer.size(1)
+
+ sequence_start = inference_params.sequence_len_offset
+ sequence_end = sequence_start + sequence_length
- sequence_start = inference_params.sequence_len_offset
- sequence_end = sequence_start + sequence_length
+ q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
+ k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]
- q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
- k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]
+ query_layer = apply_rotary_pos_emb(
+ query_layer, q_pos_emb, self.qkv_format, fused=True)
+ key_layer = apply_rotary_pos_emb(
+ key_layer, k_pos_emb, self.qkv_format, fused=True)
- query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
- key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)
# ===========================
# Core attention computation
@@ -6576,6 +6821,12 @@ def forward(
inference_params=inference_params,
)
+ if self.qkv_format == "thd":
+ # [b * sq, h] -> [qs, b, h]
+ context_layer = context_layer.view(
+ (inference_params.max_batch_size, -1, context_layer.shape[1])
+ ).contiguous()
+
# ===================
# Output. [sq, b, h]
# ===================
@@ -6596,3 +6847,20 @@ def forward(
if self.input_layernorm and self.return_layernorm_output:
outputs += (layernorm_output,)
return outputs if len(outputs) > 1 else outputs[0]
+
+
+class StaticBufferAllocator(torch.nn.Module):
+ """
+ This class is used when we use te.make_graphed_callable().
+ CUDA Graphs require all tensors to be static. Neverthless,
+ torch API make_graphed_callable() takes care of output of torch modules,
+ and makes them static. Thus by wrapping allocation of memory into
+ torch.nn.Module, we can greatly simplify our code.
+ """
+
+ # pylint: disable=no-self-use
+ def forward(self, size, dtype, device):
+ """
+ Return buffer of given size, dtype and device.
+ """
+ return torch.zeros(size, dtype=dtype, device=device)
diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h
index f06b0cb197..40ec6959d2 100644
--- a/transformer_engine/pytorch/csrc/extensions.h
+++ b/transformer_engine/pytorch/csrc/extensions.h
@@ -357,16 +357,18 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio
**************************************************************************************************/
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
+ const at::Tensor &start_positions,
const bool transpose_output_memory);
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
+ const at::Tensor &start_positions,
const bool transpose_output_memory);
at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens,
- const at::Tensor &freqs);
+ const at::Tensor &freqs, const at::Tensor &start_positions);
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens,
- const at::Tensor &freqs);
+ const at::Tensor &freqs, const at::Tensor &start_positions);
/***************************************************************************************************
* Miscellaneous
@@ -376,6 +378,17 @@ size_t get_cublasLt_version();
size_t get_cudnn_version();
+bool userbuf_comm_available();
+
+void placeholder();
+
+/***************************************************************************************************
+ * Generation
+ **************************************************************************************************/
+
+void attention_copy(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len,
+ torch::Tensor B, int max_incoming_seq_len, int max_seq_len, int b, int s);
+
/***************************************************************************************************
* Support THD format for Context Parallel
**************************************************************************************************/
diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cu b/transformer_engine/pytorch/csrc/extensions/apply_rope.cu
index c58ba91d5e..8dc0545e26 100644
--- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cu
+++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cu
@@ -7,6 +7,7 @@
#include "extensions.h"
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
+ const at::Tensor &start_positions,
const bool transpose_output_memory) {
using namespace transformer_engine;
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
@@ -55,16 +56,19 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
auto input_cu = makeTransformerEngineTensor(input);
auto freqs_cu = makeTransformerEngineTensor(freqs);
+ auto start_positions_cu = makeTransformerEngineTensor(start_positions);
auto output_cu = makeTransformerEngineTensor(output);
- nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), output_cu.data(), s, b, h, d, d2,
- stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
- o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
+ nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), start_positions_cu.data(),
+ output_cu.data(), s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d,
+ o_stride_s, o_stride_b, o_stride_h, o_stride_d,
+ at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
+ const at::Tensor &start_positions,
const bool transpose_output_memory) {
using namespace transformer_engine;
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
@@ -111,17 +115,19 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto freqs_cu = makeTransformerEngineTensor(freqs);
+ auto start_positions_cu = makeTransformerEngineTensor(start_positions);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
- nvte_fused_rope_backward(output_grads_cu.data(), freqs_cu.data(), input_grads_cu.data(), s, b, h,
- d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
- o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
+ nvte_fused_rope_backward(output_grads_cu.data(), freqs_cu.data(), start_positions_cu.data(),
+ input_grads_cu.data(), s, b, h, d, d2, stride_s, stride_b, stride_h,
+ stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d,
+ at::cuda::getCurrentCUDAStream());
return input_grads;
}
at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens,
- const at::Tensor &freqs) {
+ const at::Tensor &freqs, const at::Tensor &start_positions) {
using namespace transformer_engine;
TORCH_CHECK(input.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor");
@@ -163,16 +169,18 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output);
+ auto start_positions_cu = makeTransformerEngineTensor(start_positions);
nvte_fused_rope_thd_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
- output_cu.data(), max_s, b, h, d, d2, stride_t, stride_h, stride_d,
- o_stride_t, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
+ start_positions_cu.data(), output_cu.data(), max_s, b, h, d, d2,
+ stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d,
+ at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens,
- const at::Tensor &freqs) {
+ const at::Tensor &freqs, const at::Tensor &start_positions) {
using namespace transformer_engine;
TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor");
@@ -212,10 +220,11 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
+ auto start_positions_cu = makeTransformerEngineTensor(start_positions);
nvte_fused_rope_thd_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
- input_grads_cu.data(), max_s, b, h, d, d2, stride_t, stride_h,
- stride_d, o_stride_t, o_stride_h, o_stride_d,
+ start_positions_cu.data(), input_grads_cu.data(), max_s, b, h, d, d2,
+ stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d,
at::cuda::getCurrentCUDAStream());
return input_grads;
diff --git a/transformer_engine/pytorch/csrc/extensions/generation.cu b/transformer_engine/pytorch/csrc/extensions/generation.cu
new file mode 100644
index 0000000000..5a162f1af6
--- /dev/null
+++ b/transformer_engine/pytorch/csrc/extensions/generation.cu
@@ -0,0 +1,55 @@
+/*************************************************************************
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include "extensions.h"
+
+// Kernel used to update KV chache when attention layout is "thd".
+template
+__global__ void attention_copy_kernel(scalar_t* cache_tensor, int* seq_len, int* incoming_seq_len,
+ scalar_t* hidden_tensor, int max_incoming_seq_len,
+ int max_seq_len, int b, int s) {
+ for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
+ int to_copy = s * incoming_seq_len[batch_idx];
+ int offset = seq_len[batch_idx];
+
+ scalar_t* begin_cache_copy = cache_tensor + max_seq_len * s * batch_idx + s * offset;
+ scalar_t* begin_hidden_copy = hidden_tensor + s * batch_idx * max_incoming_seq_len;
+
+ for (int i = threadIdx.x; i < to_copy; i += blockDim.x) {
+ *(begin_cache_copy + i) = *(begin_hidden_copy + i);
+ }
+ }
+}
+
+template
+void attention_copy_launcher(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len,
+ torch::Tensor B, int max_incoming_seq_len, int max_seq_len, int b,
+ int s) {
+ attention_copy_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
+ reinterpret_cast(A.data_ptr()), seq_len.data_ptr(),
+ incoming_seq_len.data_ptr(), reinterpret_cast(B.data_ptr()),
+ max_incoming_seq_len, max_seq_len, b, s);
+}
+
+void attention_copy(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len,
+ torch::Tensor B, int max_incoming_seq_len, int max_seq_len, int b, int s) {
+ if (A.scalar_type() == at::ScalarType::Half) {
+ using dtype = at::Half;
+ attention_copy_launcher(A, seq_len, incoming_seq_len, B, max_incoming_seq_len,
+ max_seq_len, b, s);
+
+ } else if (A.scalar_type() == at::ScalarType::BFloat16) {
+ using dtype = at::BFloat16;
+ attention_copy_launcher(A, seq_len, incoming_seq_len, B, max_incoming_seq_len,
+ max_seq_len, b, s);
+ } else if (A.scalar_type() == at::ScalarType::Float) {
+ using dtype = float;
+ attention_copy_launcher(A, seq_len, incoming_seq_len, B, max_incoming_seq_len,
+ max_seq_len, b, s);
+ } else {
+ NVTE_ERROR("Unsupported dtype of out\n");
+ }
+}
diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp
index 89bce77ded..d250ce4484 100644
--- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp
+++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp
@@ -155,6 +155,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard());
m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams);
+ // Generation
+ m.def("attention_copy", &attention_copy, "attention_copy");
+
// Support THD format for Context Parallel
m.def("thd_read_half_tensor", &thd_read_half_tensor,
"Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD "
diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py
index 130cf91f0e..3e077a4c07 100644
--- a/transformer_engine/pytorch/transformer.py
+++ b/transformer_engine/pytorch/transformer.py
@@ -184,6 +184,10 @@ class TransformerLayer(torch.nn.Module):
head size. Note that these formats are very closely
related to the `qkv_format` in the `MultiHeadAttention`
and `DotProductAttention` modules.
+ Notion: The experimental version of the 'thd' attention is supported
+ when :attr:`inference_params` is passed to the forward function.
+
+
Parallelism parameters
----------------------
@@ -280,6 +284,9 @@ def __init__(
) -> None:
super().__init__()
+ if ub_tp_comm_overlap:
+ assert tex.userbuf_comm_available(), "Userbuffer communication backend not available."
+
self.self_attn_mask_type = self_attn_mask_type
self.window_size = check_set_window_size(self_attn_mask_type, window_size)
self.enc_dec_attn_mask_type = enc_dec_attn_mask_type