diff --git a/docs/examples/te_gemma/media/calibration.svg b/docs/examples/te_gemma/media/calibration.svg new file mode 100644 index 0000000000..b1e1b5ae4b --- /dev/null +++ b/docs/examples/te_gemma/media/calibration.svg @@ -0,0 +1 @@ +FP8 with initial scaling factorsHighprecisionweightInitialFP8 scalingfactorsFP8WeightFP8InputHighprecisioninputFP8GEMMWeight calibrationHighprecisionweightFP8 scalingfactorsHighprecisioninputHighprecisionGEMMFP8 with calibrated scaling factorsHighprecisionweightCalibratedFP8 scalingfactorsFP8WeightFP8InputHighprecisioninputFP8GEMM \ No newline at end of file diff --git a/docs/examples/te_gemma/media/calibration_1_half.svg b/docs/examples/te_gemma/media/calibration_1_half.svg new file mode 100644 index 0000000000..af2641387f --- /dev/null +++ b/docs/examples/te_gemma/media/calibration_1_half.svg @@ -0,0 +1 @@ +HighprecisionweightInitialFP8 scalingfactorsFP8WeightFP8InputHighprecisioninputFP8GEMMHighprecisionweightFP8 scalingfactorsHighprecisioninputHighprecisionGEMMFP8 with initial scaling factorsWeight calibration \ No newline at end of file diff --git a/docs/examples/te_gemma/media/calibration_2_half.svg b/docs/examples/te_gemma/media/calibration_2_half.svg new file mode 100644 index 0000000000..2d56f7d434 --- /dev/null +++ b/docs/examples/te_gemma/media/calibration_2_half.svg @@ -0,0 +1 @@ +Weight calibrationHighprecisionweightFP8 scalingfactorsHighprecisioninputHighprecisionGEMMFP8 with calibrated scaling factorsHighprecisionweightCalibratedFP8 scalingfactorsFP8WeightFP8InputHighprecisioninputFP8GEMM \ No newline at end of file diff --git a/docs/examples/te_gemma/media/fp8_model_init.svg b/docs/examples/te_gemma/media/fp8_model_init.svg new file mode 100644 index 0000000000..c7fce2120d --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init.svg @@ -0,0 +1 @@ +FP32/BF16FP8FP8 with fp8_model_init()FP8weightFP8GEMMHighprecisionweightHighprecisioninputHighprecisionGEMMHighprecisionweightFP8WeightFP8inputFP8GEMMFP8input \ No newline at end of file diff --git a/docs/examples/te_gemma/media/fp8_model_init_1_half.svg b/docs/examples/te_gemma/media/fp8_model_init_1_half.svg new file mode 100644 index 0000000000..3b217a3eb2 --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init_1_half.svg @@ -0,0 +1 @@ +FP32/BF16HighprecisionweightHighprecisioninputHighprecisionGEMMHighprecisionweightFP8WeightFP8inputFP8GEMMFP8 \ No newline at end of file diff --git a/docs/examples/te_gemma/media/fp8_model_init_2_half.svg b/docs/examples/te_gemma/media/fp8_model_init_2_half.svg new file mode 100644 index 0000000000..46587664fe --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init_2_half.svg @@ -0,0 +1 @@ +FP8FP8 with fp8_model_init()FP8weightFP8GEMMHighprecisionweightFP8WeightFP8inputFP8GEMMFP8input \ No newline at end of file diff --git a/docs/examples/te_gemma/media/generation_animation.gif b/docs/examples/te_gemma/media/generation_animation.gif new file mode 100644 index 0000000000..25150cb9b6 Binary files /dev/null and b/docs/examples/te_gemma/media/generation_animation.gif differ diff --git a/docs/examples/te_gemma/media/graphs.svg b/docs/examples/te_gemma/media/graphs.svg new file mode 100644 index 0000000000..f734637e6d --- /dev/null +++ b/docs/examples/te_gemma/media/graphs.svg @@ -0,0 +1 @@ +Without CUDA GraphsWith CUDA GraphsLaunch 1Kernel 1Launch 2Kernel 2Launch 3Kernel 3Launch Graph 1Kernel 1Kernel 2Kernel 3 \ No newline at end of file diff --git a/docs/examples/te_gemma/media/graphs_1.png b/docs/examples/te_gemma/media/graphs_1.png new file mode 100644 index 0000000000..f42b50fe0d Binary files /dev/null and b/docs/examples/te_gemma/media/graphs_1.png differ diff --git a/docs/examples/te_gemma/media/graphs_2.png b/docs/examples/te_gemma/media/graphs_2.png new file mode 100644 index 0000000000..35c34ede55 Binary files /dev/null and b/docs/examples/te_gemma/media/graphs_2.png differ diff --git a/docs/examples/te_gemma/media/plot.svg b/docs/examples/te_gemma/media/plot.svg new file mode 100644 index 0000000000..481f156df6 --- /dev/null +++ b/docs/examples/te_gemma/media/plot.svg @@ -0,0 +1 @@ +87.68 s54.11 s28.22 s16.75 s12.13 s0 s10 s20 s30 s40 s50 s60 s70 s80 s90 s100 sHF (baseline)TE (subsitution ofGemmaDecoderLayer withte.TransformerLayer)TE + THD attentionTE + THD attention + CUDA GraphsTE + THD attention + FP8 \ No newline at end of file diff --git a/docs/examples/te_gemma/media/thd_bshd.svg b/docs/examples/te_gemma/media/thd_bshd.svg new file mode 100644 index 0000000000..47eed69565 --- /dev/null +++ b/docs/examples/te_gemma/media/thd_bshd.svg @@ -0,0 +1 @@ +BSHD LayoutQKVQKVCumulative sequence lengths:3, 3 + 1, 3 + 1 + 3, 3 + 1 + 3 + 1Sequence offsets:0, 4, 8, 12[batch_size,seq_len,head_nr,dim][total_nr_tokens,head_nr,dim]Seq. 1Seq. 2Seq. 4Seq. 3sbtTHD LayoutPad. 1Pad. 2Pad. 4Pad. 3Attention masktokenpadding \ No newline at end of file diff --git a/docs/examples/te_gemma/requirements.txt b/docs/examples/te_gemma/requirements.txt new file mode 100644 index 0000000000..c90fb6dad0 --- /dev/null +++ b/docs/examples/te_gemma/requirements.txt @@ -0,0 +1,4 @@ +transformers==4.41.1 +accelerate==0.30.1 +datasets==2.19.1 +sentencepiece==0.2.0 \ No newline at end of file diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py new file mode 100644 index 0000000000..758f77219f --- /dev/null +++ b/docs/examples/te_gemma/te_gemma.py @@ -0,0 +1,476 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from contextlib import contextmanager + +from typing import Optional + +import torch +import transformer_engine as te +from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding +from transformer_engine.common.recipe import Format, DelayedScaling +from torch.cuda.amp import autocast + +import transformers +from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig, GemmaModel + +import torch.nn.functional as F + + +class TEGemmaDecoderLayer(te.pytorch.TransformerLayer): + """ + Wrapper class over TE's `TransformerLayer`. This makes the wrapper very + similar to HF's `GemmaDecoderLayer` and easier to replace it in the code. + + Args: + config: GemmaConfig + args: positional args (for compatibility with `GemmaDecoderLayer`) + kwargs: keyword args (for compatibility with `GemmaDecoderLayer`) + """ + + def __init__(self, config: GemmaConfig, layer_idx: int, *args, **kwargs): + super().__init__( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + bias=False, + layernorm_epsilon=config.rms_norm_eps, + hidden_dropout=0, + attention_dropout=0, + fuse_qkv_params=config.fuse_qkv_params, + normalization="RMSNorm", + activation="geglu", + attn_input_format=config.qkv_format, + num_gqa_groups=config.num_key_value_heads, + kv_channels=256, + layer_number=( + layer_idx + 1 + ), # Layer numbers in TE starts from 1, not 0 like in the HF. + zero_centered_gamma=True, + ) + self.te_rope_emb = RotaryPositionEmbedding(256)( + max_seq_len=config.max_position_embeddings + ).cuda() + + def forward(self, *args, **kwargs): # We need to additionally pass positional encoding. + # this args cannot be passed to TransformerLayer + keys_to_remove = [ + "position_ids", + "past_key_value", + "output_attentions", + "use_cache", + "cache_position", + ] + for key in keys_to_remove: + kwargs.pop(key, None) + # We need to return tuple to be compatible with HF. + return (super().forward(*args, rotary_pos_emb=self.te_rope_emb, **kwargs),) + + +class StaticGemmaModel(torch.nn.Module): + """ + StaticGemma is based of HF GemmaModel class. + It is adjusted to work properly with CUDA Graphs. + """ + + def __init__( + self, + model: GemmaModel, + dtype: torch.dtype, + mask: torch.Tensor, + lm_head: torch.nn.Module, + ): + super().__init__() + self.model = model + self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype) + self.mask = mask + self.lm_head = lm_head + + def set_inference_params(self, inference_params): + self.inference_params = inference_params + + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor = None): + with torch.no_grad(): + # static operation - for CUDA graphs + hidden_states.data[:] = hidden_states.data[:] * self.normalizer + for decoder_layer in self.model.layers: + hidden_states.data[:] = decoder_layer( + hidden_states, + attention_mask=attention_mask, + self_attn_mask_type=self.mask, + inference_params=self.inference_params, + )[ + 0 + ] # static copy - for CUDA graphs + + hidden_states.copy_(self.model.norm(hidden_states)) # static copy - for CUDA graphs + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits + + +class GemmaGenerator(torch.nn.Module): + """ + GemmaGenerator gets one layer of embeddins, + makes forward pass and returns next tokens. + """ + + def __init__( + self, model: GemmaModel, lm_head: torch.nn.Module, dtype: torch.dtype, qkv_format: str + ): + super().__init__() + self.model = model + self.gemma_layers = StaticGemmaModel(model, dtype, "padding", lm_head) + self.qkv_format = qkv_format + + def set_inference_params(self, inference_params): + self.inference_params = inference_params + self.gemma_layers.set_inference_params(inference_params) + + def forward(self, hidden_states: torch.Tensor, mask: torch.Tensor = None): + logits = self.gemma_layers(hidden_states, attention_mask=mask) + + assert logits.shape[0] == hidden_states.shape[0] # b + assert logits.shape[1] == hidden_states.shape[1] # seq_len + # logits.shape[2] = number of tokens + logits = logits[:, -1, :] + next_tokens = torch.argmax(logits, dim=1) + + # static copy for CUDA graphs + hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1)) + + # self.inference_params contains for example kv_cache. + # This needs to be called before every pass, + # to update the information of sequence lengths. + # Here we increase sequence offsets by one, + # because we generated one token for every sequence. + if self.qkv_format == "thd": + self.inference_params.setup_before_new_input( + lengths_tensor=torch.ones((next_tokens.shape[0],), device="cuda"), + max_input_length=1, + ) + else: + self.inference_params.setup_before_new_input(length=1) + + return next_tokens + + +@contextmanager +def replace_decoder(te_decoder_cls): + """ + Replace `GemmaDecoderLayer` with custom `TEGemmaDecoderLayer`. + """ + original_gemma_decoder_cls = transformers.models.gemma.modeling_gemma.GemmaDecoderLayer + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = te_decoder_cls + try: + yield + finally: + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = original_gemma_decoder_cls + + +class TEGemmaForCausalLM(GemmaForCausalLM): + """ + Causal LM created with `GemmaModel`. The underlying `GemmaDecoderLayer` + class is monkey-patched with `TEGemmaDecoderLayer` class before + initializing the causal LM with `GemmaForCausalLM`. + + Args: + config: GemmaConfig + """ + + def __init__(self, config: GemmaConfig): + with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer): + super().__init__(config) + self.to(torch.bfloat16).cuda() + self.hidden_size = config.hidden_size + self._model_generation_phase = GemmaGenerator( + lm_head=self.lm_head, + model=self.model, + dtype=torch.bfloat16, + qkv_format=config.qkv_format, + ) + self._model_context_phase = StaticGemmaModel( + self.model, torch.bfloat16, "padding_causal", self.lm_head + ) + + if self.config.fp8: + self.fp8_recipe = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max" + ) + + @staticmethod + def _padding_to_end(inputs, lengths): + """ + Gets the tensor with sequence padded from the beginning and + return tensor padded from its end. + + Parameters + ---------- + inputs : Tensor, tensor with shape [b, s] containing token numbers. + It's padded from the beggining. + lengths: Tensor, tensor with shape [s] with lengths of the sequences. + + """ + max_seq_len = torch.max(lengths) + batch_size, max_seq_len = inputs.shape + new_input_ids = inputs.clone() + for i in range(batch_size): + new_input_ids[i, : lengths[i]] = inputs[i, (max_seq_len - lengths[i]) : max_seq_len] + new_input_ids[i, lengths[i] :] = inputs[i, 0 : (max_seq_len - lengths[i])] + inputs.copy_(new_input_ids) + + def _next_64_multiply(self, x): + return ((x + 63) // 64) * 64 + + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. + def _create_hidden_states_buffer(self, input_ids: torch.Tensor): + return torch.empty( + (input_ids.shape[0], input_ids.shape[1], self.hidden_size), + device="cuda", + dtype=torch.float32, + ) + + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. + def _create_inference_params(self, max_batch_size: int, max_sequence_length: int): + return InferenceParams( + max_batch_size, max_sequence_length, qkv_format=self.config.qkv_format + ) + + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. + def _get_max_input_seq_len(self, input_ids): + return input_ids.shape[1] + + # The buffer for generation is some part (beginning) of hidden states buffer. + # This function returns pointer to it and also copies there data if provided. + def _get_generation_buffer(self, hidden_states_buffer, data_to_copy=None): + # hidden_states_buffer has shape [b, s, hd] + # generation_buffer will have shape [b, 1, hd] + # Notice that "generation_buffer = hidden_states_buffer[:, 0, :].unsqueeze(1)" + # will return uncontiguous buffer, which we want to avoid. + output = hidden_states_buffer.view(-1)[ + : hidden_states_buffer.shape[0] * hidden_states_buffer.shape[2] + ] + if data_to_copy is not None: + output.copy_(data_to_copy.reshape(-1)) + generation_buffer = output.view( + (hidden_states_buffer.shape[0], 1, hidden_states_buffer.shape[2]) + ) + return generation_buffer + + def _generate_context_phase(self, input_ids: torch.Tensor, inference_params: InferenceParams): + hidden_states = self._create_hidden_states_buffer(input_ids) + hidden_states.data[:] = self.model.embed_tokens(input_ids) + + # We need to update offsets before every forward pass to make cache work properly. + lengths = input_ids.ne(0).sum(dim=1) + if self.config.qkv_format == "thd": + inference_params.setup_before_new_input( + lengths_tensor=lengths, max_input_length=input_ids.shape[1] + ) + else: + inference_params.setup_before_new_input(length=input_ids.shape[1]) + + hidden_states.data[:] = self.model.embed_tokens(input_ids) + logits = self._model_context_phase( + hidden_states, + attention_mask=((input_ids == 0) if self.config.qkv_format != "thd" else None), + ) + + # We choose logits coresponding with last token in each sequence, + # which have various lengths - they are stored in (inference_params.incoming_seq_len - 1) + # Tensor when qkv_format == "thd" and + # they are the last token in the sequence when qkv_format != "thd". + if self.config.qkv_format == "thd": + logits = logits[ + torch.arange(logits.size(0)), inference_params.input_sequence_lengths - 1, : + ] + else: + logits = logits[:, -1, :] + next_tokens = torch.argmax(logits, dim=1) + + # self.hidden_states have shape [b, s, hd]. + # We return hidden state for the last token - output has shape [b, 1, hd] + hidden_states = self._get_generation_buffer( + hidden_states, self.model.embed_tokens(next_tokens) + ) + return hidden_states, next_tokens + + def _make_mask_one_token_longer(self, mask): + return torch.cat( + [mask, torch.zeros(mask.size(0), 1, 1, 1, dtype=torch.bool, device=mask.device)], dim=-1 + ) + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + pad_token_id: int = 0, + max_new_tokens: int = 0, + *args, + **kwargs + ): + self.eval() + + # We need both autocasts: FP8 for operations that can run in lower precision + # and BF16 for those that cannot. + with autocast(dtype=torch.bfloat16, cache_enabled=False), te.pytorch.fp8_autocast( + enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None + ): + + batch_size, max_input_sequence_len = input_ids.shape[0], self._get_max_input_seq_len( + input_ids + ) + lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() # [s] + input_ids = F.pad( + input_ids, (max_input_sequence_len - input_ids.shape[1], 0), "constant", 0 + ) + + # InferenceParams is a cache, where keys and values of previous tokens are stored. + # Moreover it stores length of both already generated and input sequences. + inference_params = self._create_inference_params( + max_batch_size=batch_size, + max_sequence_length=self._next_64_multiply(max_input_sequence_len + max_new_tokens), + ) + + self._model_context_phase.set_inference_params(inference_params) + self._model_generation_phase.set_inference_params(inference_params) + + if self.config.qkv_format == "thd": + # For thd layout padding is at the end, otherwise at the beginning. + TEGemmaForCausalLM._padding_to_end(input_ids, lengths) + + hidden_states, next_tokens = self._generate_context_phase(input_ids, inference_params) + + # Generation phase. + if self.config.qkv_format == "thd": + inference_params.setup_before_new_input( + lengths_tensor=torch.ones((next_tokens.shape[0],), device="cuda"), + max_input_length=1, + ) + else: + inference_params.setup_before_new_input(length=1) + + output_tokens = [next_tokens] + + mask = None + if self.config.qkv_format != "thd": + mask = (input_ids == 0).unsqueeze(1).unsqueeze(1) + + for _ in range(max_new_tokens): + if self.config.qkv_format != "thd": + # It will not work with cuda graphs, but it is not used for thd qkv_format. + mask = self._make_mask_one_token_longer(mask) + + next_tokens = self._model_generation_phase(hidden_states, mask) + # next_tokens is static output tensor, so we need to clone it + # - it gets changed every iteration. + output_tokens.append(next_tokens.clone()) + + result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) + return result + + +class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM): + """ + TEGemmaForCausalLMCudaGraphs is the version of the class TEGemmaForCausalLM + using CUDA Graphs to speed it up. We need to make one trade-off. + Namely, batch_size, max_seq_len and max_context_seq_len need to be static. + It is necessary to run generation with the same value of + these variables that we recorded graph on. + """ + + def __init__(self, config: GemmaConfig): + super().__init__(config) + assert ( + config.qkv_format == "thd" + ), "Generation with CUDA Graphs are implemented only for thd format." + + # Preparation of the static buffers. + self.config = config + self.hidden_states_buffer = torch.empty( + ( + config.cuda_graphs_static_batch_size, + config.cuda_graphs_static_max_context_len, + config.hidden_size, + ) + ).cuda() + # This is in fact part of the buffer for hidden_states. + self.generation_buffer = self._get_generation_buffer(self.hidden_states_buffer) + self.inference_params = InferenceParams( + max_batch_size=config.cuda_graphs_static_batch_size, + max_sequence_length=config.cuda_graphs_static_max_seq_len, + qkv_format="thd", + ) + + self._model_generation_phase.set_inference_params(self.inference_params) + self._model_context_phase.set_inference_params(self.inference_params) + + def record(self): + # We want to record model in training=False, because it will be used in generation. + self.eval() + + # Here "the trick" happens. We override methods from TEGemmaForCausalLM + # with their recorded version. After invocation of each of them, + # captured graph will be replayed with minimal usage of CPU, + # what will lead to huge speedup. + input_shape = ( + self.config.cuda_graphs_static_batch_size, + self.config.cuda_graphs_static_max_context_len, + ) + self.inference_params.reset() + self.inference_params.setup_before_new_input( + lengths_tensor=torch.tensor(input_shape[0] * [input_shape[1]], device="cuda"), + max_input_length=input_shape[1], + ) + self._model_context_phase = self.record_graph( + self._model_context_phase, self.hidden_states_buffer + ) # CUDA Graphs recording + + input_shape = (self.config.cuda_graphs_static_batch_size, 1) + self.inference_params.reset() + self.inference_params.setup_before_new_input( + lengths_tensor=torch.tensor(input_shape[0] * [input_shape[1]], device="cuda"), + max_input_length=input_shape[1], + ) + self._model_generation_phase = self.record_graph( + self._model_generation_phase, self.generation_buffer + ) # CUDA Graphs recording + + """ + Functions _create_hidden_states_buffer and _create_inference_params + from base class are overriden to make hidden_states and inference_params static + - not changing their position in memory between every invocation. + """ + + def _create_hidden_states_buffer(self, *args, **kwargs): + return self.hidden_states_buffer + + def _create_inference_params(self, *args, **kwargs): + self.inference_params.reset() + return self.inference_params + + def _get_max_input_seq_len(self, _): + return self.config.cuda_graphs_static_max_context_len + + @torch.no_grad() + def record_graph(self, function, input_tensor): + # function is invoked on argument (self.hidden_states,) and all kernels are recorded. + # record_graph() returns captured function, which can be run later with lower of th CPU. + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling( + fp8_format=fp8_format, amax_history_len=1024, amax_compute_algo="max" + ) + + # We need both autocasts: FP8 for operations that can run in lower precision + # and BF16 for those that cannot. + with autocast(dtype=torch.bfloat16, cache_enabled=False): + graphed_function = te.pytorch.make_graphed_callables( + function, + (input_tensor,), + fp8_enabled=self.config.fp8, + fp8_recipe=fp8_recipe, + allow_unused_input=True, + num_warmup_iters=3, + ) + return graphed_function diff --git a/docs/examples/te_gemma/te_gemma_loading_weights.py b/docs/examples/te_gemma/te_gemma_loading_weights.py new file mode 100644 index 0000000000..87e6667a9b --- /dev/null +++ b/docs/examples/te_gemma/te_gemma_loading_weights.py @@ -0,0 +1,159 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import re +import gc +import torch + +from typing import List + +from transformer_engine.pytorch.fp8 import fp8_model_init + +from transformers.modeling_utils import load_state_dict, _load_state_dict_into_model +from transformers.utils.hub import get_checkpoint_shard_files + +""" + This file contains logic of mapping the HuggingFace GemmaModel parameters + with TransformerEngine TransformerLayer. When we have initialized Transformer models + both with HF and with TE, we can copy parameters from the first to the second. +""" + + +def _load_weights_for_fp8_model(vanilla_model, hyperparams): + # The weights are loaded from the file with state_dict + # of model with weights which contains also fp8 parameters. + # The weights are in BF16 precision, but they contain fp8 metadata + # computed by the calibration procedure. + vanilla_model.load_state_dict( + torch.load(hyperparams.fp8_model_weights_filename), + strict=False, + # strict = false, because some parameters have + # multiple pointers to the same weight + # vanilla_model._model_context_phase.model + # and vanilla_model._model_generation_phase.model + ) + + +def _load_weights_for_standard_model(vanilla_model, config): + # The weights are loaded from the file with original weights. + archive_file = os.path.join(config.model_name, "model.safetensors.index.json") + resolved_archive_file, _ = get_checkpoint_shard_files(config.model_name, archive_file) + total_dict = {} + for shard_file in resolved_archive_file: + state_dict = load_state_dict(shard_file) + total_dict.update(state_dict) + + replace_params( + total_dict, + vanilla_model.state_dict(), + config, + qkv_fused_and_interleaved=config.fuse_qkv_params, + ) + # Copy parameters like embedding: + _load_state_dict_into_model(vanilla_model, total_dict, start_prefix="") + + # Force mem release. Taken from huggingface code. + del total_dict + gc.collect() + + +def load_te_model(cls, config): + """ + Custom method adapted from `from_pretrained` method in HuggingFace + Transformers repo: + https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 + """ + config.use_cache = False # To make TransformerLayer compatible with GemmaModel + with fp8_model_init(config.fp8_model_init): + # there we need only to create model + vanilla_model = cls(config).to(torch.bfloat16).cuda() + + # and now we copy the weights into it + if config.fp8_model_weights_filename is not None: + _load_weights_for_fp8_model(vanilla_model, config) + else: + _load_weights_for_standard_model(vanilla_model, config) + + return vanilla_model + + +def _get_all_layer_prefixes_to_update(hf_state_dict): + """ + There are many parameters in hf_state_dict, whose name start with "model.layers.[number]." + This function extracts all strings like "model.layers.[number]." + that are starting strings of keys in hf_state_dict. + """ + all_layer_prefixes = set() + for param_key in hf_state_dict.keys(): + layer_prefix_pat = "model.layers.\d+." + m = re.match(layer_prefix_pat, param_key) + if m is not None: + all_layer_prefixes.add(m.group()) + return all_layer_prefixes + + +def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False): + """ + Replaces params from TE TransformerLayer state_dict with corresponding parameters + from HuggingFace GemmaModel state_dict. + """ + all_layer_prefixes: List[str] = _get_all_layer_prefixes_to_update(hf_state_dict) + + for layer_prefix in all_layer_prefixes: + + def copy_from_ht_to_te(te_name, hf_name, start=None, end=None): + te_state_dict[layer_prefix + te_name].data[start:end].copy_( + hf_state_dict[layer_prefix + hf_name] + ) + + copy_from_ht_to_te( + "self_attention.layernorm_qkv.layer_norm_weight", "input_layernorm.weight" + ) + copy_from_ht_to_te("self_attention.proj.weight", "self_attn.o_proj.weight") + copy_from_ht_to_te("layernorm_mlp.layer_norm_weight", "post_attention_layernorm.weight") + copy_from_ht_to_te("layernorm_mlp.fc2_weight", "mlp.down_proj.weight") + copy_from_ht_to_te( + "layernorm_mlp.fc1_weight", "mlp.gate_proj.weight", end=config.intermediate_size + ) + copy_from_ht_to_te( + "layernorm_mlp.fc1_weight", "mlp.up_proj.weight", start=config.intermediate_size + ) + + if qkv_fused_and_interleaved: + """ + When qkv_fused_and_interleaved=True, key, query and value layers are on one tensor + in TE TransformerLayer. Moreover they are interleaved within each head. + Let q_i, k_i and v_i be query, key and value layers for i-th head respectively. + Then TE stores weight tensor in the form: + [q1 k1 v1 q2 k2 v2 ...] + This is done to maximally optimize performance time. + """ + te_qkv_layer = te_state_dict[layer_prefix + "self_attention.layernorm_qkv.weight"] + + def copy_interleave(hf_name, idx): + src = hf_state_dict[layer_prefix + hf_name] + for head_nr in range(config.num_attention_heads): + dst_offset = head_nr * config.head_dim * 3 + dst_slice = slice( + dst_offset + idx * config.head_dim, dst_offset + (idx + 1) * config.head_dim + ) + src_slice = slice( + head_nr * config.head_dim, head_nr * config.head_dim + config.head_dim + ) + te_qkv_layer[dst_slice, :] = src[src_slice, :] + + copy_interleave("self_attn.q_proj.weight", 0) + copy_interleave("self_attn.k_proj.weight", 1) + copy_interleave("self_attn.v_proj.weight", 2) + else: + copy_from_ht_to_te( + "self_attention.layernorm_qkv.query_weight", "self_attn.q_proj.weight" + ) + copy_from_ht_to_te("self_attention.layernorm_qkv.key_weight", "self_attn.k_proj.weight") + copy_from_ht_to_te( + "self_attention.layernorm_qkv.value_weight", "self_attn.v_proj.weight" + ) + + return all_layer_prefixes diff --git a/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb new file mode 100644 index 0000000000..7875ffc9f3 --- /dev/null +++ b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb @@ -0,0 +1,314 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Accelerating a Hugging Face Gemma model finetuning with Transformer Engine" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the previous [tutorial](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb), we demonstrated how to accelerate HF Llama models using the Transformer Engine library. We replaced `LlamaDecoderLayer` with `TransformerLayer` from the Transformer Engine, achieving a speedup. Furthermore, we conducted the finetuning in FP8 precision, which yielded an additional speedup.\n", + "\n", + "Now, we will undertake a similar enhancement for the Google's [Gemma](https://blog.google/technology/developers/gemma-open-models/) model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dependencies for this tutorial\n", + "\n", + "Following files and media are necessary to effectively run this tutorial:\n", + "\n", + "1. `te_gemma.py`\n", + " - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n", + "2. `utils.py`\n", + " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", + "3. `requirements.txt`\n", + " - This file contains necessary Python packages for this tutorial.\n", + "4. `media/`\n", + " - This directory contains the images used in the following tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -r requirements.txt\n", + "\n", + "import torch\n", + "cudnn_version = torch.backends.cudnn.version()\n", + "assert cudnn_version >= 90100, \"cuDNN version >= 9.1.0 is needed to run this tutorial.\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Differences between Llama and Gemma" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Thr Llama and the Gemma are very similar models - both are based on Transformer Decoder architecture. The most important architectural differences between them are the following:\n", + "\n", + "\n", + "| Feature | Llama | Gemma |\n", + "|----------------------------------------------|------------------------------------|--------------------------------------------|\n", + "| **Norm Layer** | Standard RMSNorm
$y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * \\gamma + \\beta$ | RMSNorm with zero centered gamma parameter
$y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * (\\textcolor{red}{1 +} \\gamma) + \\beta$ |\n", + "| **Embedding Dimension/Head Dimension** | 4096/4096 | 3072/4096 |\n", + "| **Activation Function** | SwiGlu | GeGlu |\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [Baseline] Running HF `GemmaModel` (Precision: `BF16`)\n", + "\n", + "Similarly to the Llama tutorial, we begin the experiments by running baseline Hugging Face Gemma model finetuning in BF16 precision.\n", + "\n", + "
\n", + "\n", + "Note\n", + " \n", + "This tutorial loads and trains a Gemma 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n", + "\n", + "If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n", + "\n", + "
\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10 finetuning steps complete!\n", + "\n", + "Average time taken per step: \n", + "298 \n", + "milliseconds\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.mixed_precision = \"bf16\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_baseline_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n", + "\n", + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 298 | 1 |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [Improvement 1] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n", + "\n", + "We replace *GemmaDecoderLayer* with the highly tuned *TransformerLayer*, similarly to our approach in the [Llama tutorial](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb). Let's observe the impact this change has on the model's speed." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10 finetuning steps complete!\n", + "\n", + "Average time taken per step: \n", + "257 \n", + "milliseconds\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.mixed_precision = \"bf16\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `GemmaDecoderLayer` gives a speedup of **16%** even when using only BF16 precision!\n", + "\n", + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 298 | 1 |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 257 | 1.16 |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [Improvement 2] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n", + "\n", + "The last improvement is about enabling FP8 precision. Let's see how it works." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10 finetuning steps complete!\n", + "\n", + "Average time taken per step: \n", + "214 \n", + "milliseconds\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "#restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.mixed_precision = \"fp8\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 298 | 1 |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 257 | 1.16 |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 214 | 1.39 |\n", + "\n", + "\n", + "After turning on FP8 precision, we get even more speedup of almost **39%**!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "As shown in the [Llama tutorial](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb), using the `TransformerLayer` module from Transformer Engine to replace Hugging Face's `GemmaDecoderLayer` results in a speedup compared to Hugging Face's native Gemma implementation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## See more\n", + "\n", + "We also prepared [tutorial](./tutorial_generation_gemma_with_te.ipynb) in which we will show how to speedup the Gemma model generation using Transformer Engine." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb new file mode 100644 index 0000000000..1948a1481b --- /dev/null +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -0,0 +1,874 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "40364db7", + "metadata": {}, + "source": [ + "# Accelerating token generation of the Hugging Face Gemma Model with Transformer Engine\n", + "\n", + "Generative AI has made remarkable strides in recent years, with Large Language Models (LLMs) like ChatGPT at the forefront. These models have revolutionized how we interact with machine-generated content, providing capabilities that range from writing assistance to complex decision support. The core functionality of these models is the generation process, which involves predicting the next token in a sequence based on the preceding text. This task is critical for applications such as automated content creation, translation, and more, emphasizing the importance of efficient implementation.\n", + "\n", + "\n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Animation 1: Hugging Face Gemma model token generation.\n", + "
\n", + "
\n", + "\n", + "For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n", + "\n", + "In the previous tutorials on [Llama](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb) and [Gemma](./tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb), it was demonstrated how finetuning can be accelerated using the Transformer Engine's `TransformerLayer`. Building on this foundation, the current objective is to enhance the generation speed of the Gemma model.\n", + "\n", + "This tutorial will introduce and explain several advanced features of the Transformer Engine that contribute to this goal:\n", + "\n", + "###### **1. THD Attention Layout.**\n", + "\n", + "Addressing the challenge of computing attention for sequences with varying lengths, a common method is to pad these sequences and apply an attention mask. The Transformer Engine, however, offers a more optimized approach—by specifying the lengths and offsets of the sequences, attention can be computed directly. Instead of passing the tensor with shape `[b, s, h, d]` and the attention mask, one can pass a tensor of the shape `[t, h, d]` along with tensors detailing cumulative sequence lengths and offsets to run the attention optimized for this case. This specific attention layout is referred to as the **THD layout**. \n", + "\n", + "\n", + "The letter `t` in the standard `[t, h, d]` layout is equal to the total length of the sequences, namely `t = s_1 + s_2 + ... + s_b`, where `s_i` denotes the length of sequence `i`. TransformerEngine supports a THD layout that incorporates gaps between these sequences - the lengths of the offsets need to be passed in the additional parameter.\n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 1: The difference between BSHD (default) and THD attention layouts is as follows: with BSHD, one needs to provide the attention mask, while with THD, one needs to provide cumulative sequence lengths and sequence offsets.\n", + "
\n", + "
\n", + "\n", + "###### **2. CUDA Graphs API.**\n", + "\n", + "The speed of GPUs is increasing at a rapid pace. It turns out that sometimes the runtime of kernels is shorter than the time it takes for the CPU to submit them, which can lead to significant overhead. CUDA Graphs can address this issue. When certain kernels are executed repeatedly, it allows us to record and replay them with less CPU involvement. This becomes particularly useful in applications like token generation, where a `TransformerLayer` is run for every token that needs to be generated.\n", + "\n", + "One can read more about CUDA Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).\n", + "\n", + "PyTorch exposes graphs via a raw `torch.cuda.CUDAGraph` class and two convenience wrappers: `torch.cuda.graph` and `torch.cuda.make_graphed_callables`. More information about the cuda graphs in Pytorch can be found [here](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/).\n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 2: CUDA Graphs reduce the overhead generated by the long time it takes to launch a single kernel. It enables the recording and replaying of subsequent launches, thus reducing the total time used by the CPU.\n", + "
\n", + "
\n", + "\n", + "\n", + "###### **3. FP8 Weights Calibration.**\n", + "\n", + "Assuming that the model is trained in FP32/BF16 precision and the goal is to execute it in FP8 precision, the process isn't straightforward due to the absence of appropriate FP8 scaling factors. In this scenario, FP8 calibration becomes essential. By conducting several forward passes on sample data, the FP8 scaling parameters can be computed. This calibration allows the model to operate correctly in FP8 precision.\n", + "\n", + "It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n", + "\n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 3:\n", + "If the model is trained in BF16/FP32, it does not include the computed FP8 scaling factors. When it is run under fp8_autocast(), the value of these scaling factors will default to their initial values, which can cause numerical errors. Weight calibration involves calculating FP8 scaling factors from higher precision forward passes. Once these factors are computed, the model becomes numerically stable. \n", + "
\n", + "
\n", + "\n", + "###### **4. FP8 Model Weights.**\n", + "\n", + "The typical approach is to store weights in higher precision and then cast them to fp8 before operations. This may prevent accuraccy drops in training. However, for inference, this level of precision is not necessary.\n", + "\n", + "The TransformerEngine includes a wrapper `fp8_model_​init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast from higher precision to BF16, saving time in this casting process. \n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 4: Model under fp8_autocast() stores weights in high precision by default, and casts them if needed. It can leads to slowdown and increased memory usage. Using fp8_model_init() results in storing weight in FP8.\n", + "
\n", + "
\n", + "\n", + "###### Benchmarking\n", + "\n", + "We'll evaluate the generation time across one benchmark: generation with context phase max sequence length = 128, batch size = 64 and number of generated tokens = 896 on random texts with random lengths.\n", + "\n", + "
\n", + "Note\n", + " \n", + "This tutorial focuses on showcasing the mentioned features of Transformer Engine in the context of token generation. It's important to note, however, that NVIDIA provides [TensorRT](https://developer.nvidia.com/tensorrt), which is optimized for inference tasks and should be considered for such use cases.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "b18f91a9", + "metadata": {}, + "source": [ + "## Dependencies for this tutorial" + ] + }, + { + "cell_type": "markdown", + "id": "e5201d77", + "metadata": {}, + "source": [ + "Following files and media are necessary to effectively run this tutorial:\n", + "\n", + "1. `te_gemma.py`\n", + " - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. It does also contain code for generation with THD attention, CUDA Graphs and weight calibration.\n", + "2. `te_gemma_loading_weights.py`\n", + " - This file contains logic of mapping the parameters from `GemmaDecoderLayer` into the `TransformerLayer`.\n", + "3. `utils.py`\n", + " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", + "4. `requirements.txt`\n", + " - This file contains necessary Python packages for this tutorial.\n", + "5. `media/`\n", + " - This directory contains the images used in the following tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "31390c76", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -r requirements.txt\n", + "\n", + "import torch\n", + "cudnn_version = torch.backends.cudnn.version()\n", + "assert cudnn_version >= 90100, \"cuDNN version >= 9.1.0 is needed to run this tutorial.\"" + ] + }, + { + "cell_type": "markdown", + "id": "e8dfabbf", + "metadata": {}, + "source": [ + "\n", + "|\n", + "## [Baseline] Running Hugging Face generation with Gemma model" + ] + }, + { + "cell_type": "markdown", + "id": "59560bff", + "metadata": {}, + "source": [ + "HuggingFace Transformers library offers generation API. \n", + "HuggingFace generation for the Gemma model will be used as a baseline." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "2803e0ec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "\n", + "Generated text:\n", + "\n", + "1. They are very good at doing the same thing over and over again.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why GPUs are so good at rendering video games. The second\n", + "============================== Generation example 2 ==============================\n", + "Prompt:\n", + "Some facts about NVIDIA:\n", + "\n", + "Generated text:\n", + "\n", + "* NVIDIA is a global technology leader in the design and manufacture of \n", + " advanced microprocessors for the PC and mobile computing markets.\n", + "* NVIDIA is a leading provider of graphics processing units (GPUs) for the PC and mobile computing markets.\n", + "*\n", + "============================== Benchmarking ==============================\n", + "Benchmarking for batch_size = 64 and max total tokens = 1024\n", + "Time: 87.68 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "# !!! `model_name` attr must point to the location of the model weights !!!\n", + "# Weights can be downloaded from: https://huggingface.co/google/gemma-7b.\n", + "# Weights should be in the *.safetensors HF format, not in the original format.\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "\n", + "model = init_baseline_model(hyperparams)\n", + "\n", + "print_sample_of_generated_texts(model)\n", + "benchmark_generation(model)" + ] + }, + { + "cell_type": "markdown", + "id": "b3698dc6", + "metadata": {}, + "source": [ + "Let's put this time into the table for later comparison.\n", + "\n", + "| Models | Time (s) | Speedup | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | 87.68 | 1 |" + ] + }, + { + "cell_type": "markdown", + "id": "8bb40f45", + "metadata": {}, + "source": [ + "## [Improvement 1] Using TransformerLayer from Transformer Engine instead of GemmaDecoderLayer." + ] + }, + { + "cell_type": "markdown", + "id": "263b40f2", + "metadata": {}, + "source": [ + "As in the [Gemma](./tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb) finetuning tutorial, a GemmaDecoderLayer is substituted by a tuned TransformerLayer from the Transformer Engine. Let's run it and compare the time with the baseline." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9dceef93", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "\n", + "Generated text:\n", + "\n", + "1. GPUs are very good at doing the same thing over and over again.\n", + "2. GPUs are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why GPUs are so good at graphics. The second fact is why\n", + "============================== Generation example 2 ==============================\n", + "Prompt:\n", + "Some facts about NVIDIA:\n", + "\n", + "Generated text:\n", + "\n", + "* NVIDIA is a global technology company that designs and develops high-performance computer\n", + "* graphics and video processing chips.\n", + "* The company was founded in 1993 by Jen-Hsun Huang, Chris Malachowsky, and Curtis Priem.\n", + "============================== Benchmarking ==============================\n", + "Benchmarking for batch_size = 64 and max total tokens = 1024\n", + "Time: 54.11 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "\n", + "model = init_te_gemma_model(hyperparams)\n", + "\n", + "print_sample_of_generated_texts(model)\n", + "benchmark_generation(model)" + ] + }, + { + "cell_type": "markdown", + "id": "b5d40836", + "metadata": {}, + "source": [ + "The speedup of **62%** was obtained." + ] + }, + { + "cell_type": "markdown", + "id": "006d18e8", + "metadata": {}, + "source": [ + "| Models | Time (s) | Speedup | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | 87.68 | 1 |\n", + "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 54.11 | 1.62 | " + ] + }, + { + "cell_type": "markdown", + "id": "2bbf3d47", + "metadata": {}, + "source": [ + "## [Improvement 2] Use of THD attention layout.\n", + "\n", + "Input sequences can have various lengths. Hugging Face generation – as can be seen in Animation 1 – pads the sequences and then uses attention mask. In the THD attention layout cumulative sequence lengths and offsets need to be provided, instead of attention mask. The THD attention layout is much more optimized than BSHD layout.\n", + "\n", + "The class `transformer_engine.pytorch.DotProductAttention` supports this format. One need to pass the following things as the arguments to the forward:\n", + "- `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` – offsets of the beginnings of the next sequences,\n", + "- `cu_seqlens_q`, `cu_seqlens_kv` – cumulative sum of the lengths of the sequences of query and values,\n", + "- `max_seqlen_q` – maximum sequence length in query layer,\n", + "- `max_seqlen_kv` – maximum sequence length in key-value layer.\n", + "\n", + "
\n", + "Note\n", + "\n", + "Currently, the THD attention for `TransformerLayer` is supported only for token generation.\n", + "
\n", + "\n", + "Let's look how using TransformerEngine with THD attention impacts the speed of token generation:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "4fc5e1cd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "\n", + "Generated text:\n", + "\n", + "1. They are very good at doing the same thing over and over again.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why GPUs are so good at rendering video games. The second fact\n", + "============================== Generation example 2 ==============================\n", + "Prompt:\n", + "Some facts about NVIDIA:\n", + "\n", + "Generated text:\n", + "\n", + "* NVIDIA is a global technology company that designs and develops high-performance computing \n", + " and graphics processing units (GPUs) for the gaming, professional visualization, and data center markets.\n", + "* The company was founded in 1993 and is headquartered\n", + "============================== Benchmarking ==============================\n", + "Benchmarking for batch_size = 64 and max total tokens = 1024\n", + "Time: 28.22 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.qkv_format = \"thd\"\n", + "\n", + "model = init_te_gemma_model(hyperparams)\n", + "\n", + "print_sample_of_generated_texts(model)\n", + "benchmark_generation(model)" + ] + }, + { + "cell_type": "markdown", + "id": "8e397a65", + "metadata": {}, + "source": [ + "By using THD attention, the following speedup was obtained:\n", + "\n", + "| Models | Time (s) | Speedup | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | 87.68 | 1 |\n", + "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 54.11 | 1.62 | \n", + "| TE + THD attention | 28.22 | 3.11 | " + ] + }, + { + "cell_type": "markdown", + "id": "21a89d9c", + "metadata": {}, + "source": [ + "## [Improvement 3] Speeding up generation with CUDA Graphs" + ] + }, + { + "cell_type": "markdown", + "id": "e2d53e7b", + "metadata": {}, + "source": [ + "TransformerEngine includes a function `transformer_engine.pytorch.make_graphed_callables`, which functions similarly to the corresponding feature in PyTorch. It is capable of recording any modules from the Transformer Engine. Below is a code excerpt from `te_gemma.py` from class `TEGemmaForCausalLMCudaGraphs`:\n", + "```\n", + " def __init__(self, config : GemmaConfig):\n", + " (...)\n", + " \n", + " # Here \"the trick\" happens. We override methods from TEGemmaForCausalLM\n", + " # with their recorded version. After invocation of each of them,\n", + " # captured graph will be replayed with minimal usage of CPU,\n", + " # what will lead to huge speedup.\n", + " (...)\n", + " self._model_context_phase = \n", + " self.record_graph(self._model_context_phase, self.hidden_states_buffer) # CUDA Graphs recording\n", + "\n", + " (...) \n", + " self._model_generation_phase = \n", + " self.record_graph(self._model_generation_phase, self.generation_buffer) # CUDA Graphs recording\n", + "\n", + " @torch.no_grad()\n", + " def record_graph(self, function, input_tensor):\n", + " (...)\n", + " # function is invoked on argument (self.hidden_states,) and all kernels are recorded.\n", + " # record_graph() returns captured function, which can be run later with minimal use of th CPU.\n", + " fp8_format = Format.HYBRID\n", + " fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo=\"max\")\n", + " with autocast(dtype=torch.bfloat16, cache_enabled=False):\n", + " graphed_function = te.pytorch.make_graphed_callables(\n", + " function, \n", + " (input_tensor,), \n", + " fp8_enabled=True, \n", + " fp8_recipe=fp8_recipe, \n", + " allow_unused_input=True,\n", + " num_warmup_iters=3\n", + " )\n", + " return graphed_function\n", + "```\n", + "\n", + "It is strongly reccomended to review the entire code of the class `TEGemmaForCausalLMCudaGraphs`. Let's now proceed to evaluate the performance improvement offered by CUDA Graphs." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "31a3a8a3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Two facts about GPUs:\n", + "\n", + "1. They are very good at doing the same thing over and over again.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "This is why they are so good at rendering graphics.\n", + "\n", + "The first fact is the\n", + "============================== Generation example 2 ==============================\n", + "Two facts about NVIDIA:\n", + "\n", + "1. It is the world’s largest manufacturer of graphics processing units (GPUs) for the gaming industry.\n", + "2. It is the world’s largest manufacturer of GPUs for the data center industry.\n", + "\n", + "The company’s stock price has\n", + "============================== Benchmarking ==============================\n", + "Benchmarking for batch_size = 64 and max total tokens = 1024\n", + "Time: 16.75 s.\n" + ] + } + ], + "source": [ + "#Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.qkv_format = \"thd\"\n", + "\n", + "hyperparams.generation_cuda_graphs = True\n", + "\n", + "# It is necessary to preallocate a static buffer.\n", + "# CUDA graphs require static input tensors for every kernel.\n", + "# This approach may result in a slight increase in memory consumption;\n", + "# however, the substantial speedup achieved makes it worthwhile.\n", + "hyperparams.cuda_graphs_static_batch_size = 64\n", + "hyperparams.cuda_graphs_static_max_seq_len = 1024\n", + "hyperparams.cuda_graphs_static_max_context_len = 128\n", + "model = init_te_gemma_model(hyperparams)\n", + "\n", + "print_sample_of_generated_texts(model)\n", + "benchmark_generation(model)" + ] + }, + { + "cell_type": "markdown", + "id": "53bb430f", + "metadata": {}, + "source": [ + "The **5.23x** speedup was obtained.\n", + "\n", + "| Models | Time (s) | Speedup | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | 87.68 | 1 |\n", + "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 54.11 | 1.62 | \n", + "| TE + THD attention | 28.22 | 3.11 | \n", + "| TE + THD attention + CUDA Graphs | 16.75 | 5.23 | \n" + ] + }, + { + "cell_type": "markdown", + "id": "0a11b75c", + "metadata": {}, + "source": [ + "Let's look at the screenshots from *NVIDIA Nsight System* profiler to see where this speedup comes from:\n", + "\n", + "
\n", + "\n", + "
\n", + "Figure 5: Without CUDA Graphs. One can see that GPU (blue) is idle for big portion of the time.\n", + "
\n", + "
\n", + "\n", + "
\n", + "\n", + "
\n", + "Figure 6: With CUDA Graphs. One can see that GPU (orange) is fully utilized.\n", + "
\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "e6b171a0", + "metadata": {}, + "source": [ + "## [Improvement 4] Running generation in FP8 of the model trained in higher precision " + ] + }, + { + "cell_type": "markdown", + "id": "1a80288b", + "metadata": {}, + "source": [ + "Implementing FP8 generation with the Gemma model is not straightforward, because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing. Running the model at this lower precision without proper scaling could lead to significant errors and incorrect results.\n", + "\n", + "It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the necessity of scaling.\n", + "\n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 8: The FP8 scaling factors are incorrect and that leads to numerical errors. The weight calibration allows us to compute FP8 metadata during the forwards in higher precision.\n", + "
\n", + "
\n", + "\n", + "### Weight Calibration\n", + "\n", + "To address the issue outlined above, weight calibration will be used. This involves running several forward iterations at BF16 precision within the context `te.fp8_autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the FP8 scaling well.\n", + "\n", + "The code below outlines the steps to initialize the BF16 model and conduct several forward iterations within the specified context. After these iterations, the model is saved, and these weights will be utilized in subsequent chapters." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "aecee0e1", + "metadata": {}, + "outputs": [], + "source": [ + "#Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "import transformer_engine.pytorch as te\n", + "\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.fuse_qkv_params = True # This is needed by the last improvement.\n", + "\n", + "model = init_te_gemma_model(hyperparams)\n", + "\n", + "# Calibration\n", + "with te.fp8_autocast(enabled=False, calibrating=True), \\\n", + " torch.autocast(device_type='cuda', dtype=torch.bfloat16):\n", + " model.train()\n", + " run_forward_pass(model, hyperparams, num_iters=512)\n", + "\n", + "# Compute scale_fwd with enabled fp8 autocast\n", + "with te.fp8_autocast(enabled=True), \\\n", + " torch.autocast(device_type='cuda', dtype=torch.bfloat16):\n", + " run_forward_pass(model, hyperparams, 1)\n", + "\n", + "# Some parameters are in pointing to the same tensors, double save is avoided here.\n", + "dict_to_save = {k: v for k, v in model.state_dict().items() \\\n", + " if (\"_context_phase\" not in k and \"_generation_phase\" not in k)}\n", + "torch.save(dict_to_save, '') # <== Add path to save calibrated weights." + ] + }, + { + "cell_type": "markdown", + "id": "b6dcd135", + "metadata": {}, + "source": [ + "|\n", + "### Generation in FP8\n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 8: After the weight calibration FP8 scaling factors are correct and prevent numerical errors.\n", + "
\n", + "
\n", + "\n", + "Now FP8 inference is ready to be run." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a913f54d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Two facts about GPUs:\n", + "\n", + "1. They are exorbitantly expensive.\n", + "2. They are exorbitantly powerful.\n", + "\n", + "The first fact is a bummer, but the second fact is a boon. GPUs are exorbitantly powerful \n", + "because they are exorbitantly expensive. GPUs are exorbitantly expensive\n", + "============================== Generation example 2 ==============================\n", + "Two facts about NVIDIA:\n", + "\n", + "1. NVIDIA is a company that makes graphics cards for computers.\n", + "2. NVIDIA is a company that makes graphics cards for computers.\n", + "\n", + "The first fact is true. The second fact is true.\n", + "\n", + "

NVIDIA is a company that makes graphics cards\n", + "============================== Benchmarking ==============================\n", + "Benchmarking for batch_size = 64 and max total tokens = 1024\n", + "Time: 19.31 s.\n" + ] + } + ], + "source": [ + "#Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.qkv_format = \"thd\"\n", + "hyperparams.fuse_qkv_params = True # This is needed by the last improvement.\n", + "\n", + "hyperparams.fp8 = True\n", + "# Calibrated fp8 weights are loaded directly from the file.\n", + "\n", + "hyperparams.fp8_model_weights_filename = \"\" # <== Add calibrated weights location here.\n", + "\n", + "hyperparams.generation_cuda_graphs = True\n", + "hyperparams.cuda_graphs_static_batch_size = 64\n", + "hyperparams.cuda_graphs_static_max_seq_len = 1024\n", + "hyperparams.cuda_graphs_static_max_context_len = 128\n", + "model = init_te_gemma_model(hyperparams)\n", + "\n", + "print_sample_of_generated_texts(model)\n", + "benchmark_generation(model)" + ] + }, + { + "cell_type": "markdown", + "id": "8cdbb56c", + "metadata": {}, + "source": [ + "One can observe that the outputs are coherent; however, the generation time has increased. Why is this the case?\n", + "\n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 9: Running the model at higher precision involves only one GEMM operation. However, when the model operates in FP8, it requires not just the low-precision GEMM but also weight casting.\n", + "
\n", + "
\n", + "\n", + "Running the model in FP8 does not imply that all weights are stored in FP8. By default, they are stored in higher precision and are cast to FP8, using saved scaling factors, before operations such as GEMMs.\n", + "\n", + "This approach is beneficial during training: one can perform one cast for both backward and forward passes, leading to speedups. However, performing a single cast for each forward pass introduces too much overhead to achieve a speedup. This issue will be addressed in the next section of the tutorial." + ] + }, + { + "cell_type": "markdown", + "id": "8d3945e3", + "metadata": {}, + "source": [ + "### Use of only FP8 model weights" + ] + }, + { + "cell_type": "markdown", + "id": "2dd0cba9", + "metadata": {}, + "source": [ + "TransformerEngine stores parameters in higher precision and only casts them to FP8. It may be necessary to maintain accucacy during training. However, high precision is not needed when doing inference. \n", + "\n", + "Transformer Engine supports maintaining only FP8 weights with `fp8_model_init` decorator. Let's see an example\n", + "```\n", + "linear = te.Linear(1024, 1024) # this module is initialized with full precision weights\n", + "with te.fp8_model_init(enabled=True):\n", + " linear_fp8 = te.Linear(1024, 1024) # this module is initialized only with fp8 weights\n", + "\n", + "assert type(linear.weight.data) is torch.Tensor\n", + "assert type(linear_fp8.weight.data) is te.float8_tensor.Float8Tensor\n", + "```\n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 10: Using fp8_model_init stores the weights directly in FP8 format, which reduces both time and memory usage.\n", + "
\n", + "
\n", + "\n", + "Let's run the code with `fp8_model_init`:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "96264b9c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt:\n", + "Here are the two facts about GPUs:\n", + "\n", + "Generated text:\n", + "\n", + "1. GPUs are exorbitantly expensive.\n", + "2. GPUs are exorbitantly powerful.\n", + "\n", + "The first fact frustrates me. The second excites me.\n", + "\n", + "I’ve been using GPUs for a while now, and I’ve been using them for\n", + "============================== Generation example 2 ==============================\n", + "Prompt:\n", + "Some facts about NVIDIA:\n", + "\n", + "Generated text:\n", + "\n", + "* NVIDIA is a global technology company that designs and manufactures graphics processing units (GPUs)\n", + " for the gaming, professional visualization, and data center markets.\n", + "* NVIDIA is headquartered in Santa Clara, California, and has offices in more than 25\n", + "============================== Benchmarking ==============================\n", + "Benchmarking for batch_size = 64 and max total tokens = 1024\n", + "Time: 12.13 s.\n" + ] + } + ], + "source": [ + "#Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.fuse_qkv_params = True # Needed for fp8_model_init().\n", + "hyperparams.qkv_format = \"thd\"\n", + "\n", + "hyperparams.fp8 = True\n", + "hyperparams.fp8_model_init = True # This will result in storing only fp8 weights.\n", + "hyperparams.fp8_model_weights_filename = \"\" # <== Add calibrated weights location here.\n", + "\n", + "hyperparams.generation_cuda_graphs = True\n", + "hyperparams.cuda_graphs_static_batch_size = 64\n", + "hyperparams.cuda_graphs_static_max_seq_len = 1024\n", + "hyperparams.cuda_graphs_static_max_context_len = 128\n", + "model = init_te_gemma_model(hyperparams)\n", + "\n", + "print_sample_of_generated_texts(model)\n", + "benchmark_generation(model)" + ] + }, + { + "cell_type": "markdown", + "id": "3e30ca5a", + "metadata": {}, + "source": [ + "| Models | Time (s) | Speedup | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | 87.68 | 1 |\n", + "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 54.11 | 1.62 | \n", + "| TE + THD attention | 28.22 | 3.11 | \n", + "| TE + THD attention + CUDA Graphs | 16.75 | 5.23 | \n", + "| TE + THD attention + FP8 | 12.13 | 7.23 | \n", + "\n", + "The final speedup is **7.23x**." + ] + }, + { + "cell_type": "markdown", + "id": "c6e87275", + "metadata": {}, + "source": [ + "## Conclusions" + ] + }, + { + "cell_type": "markdown", + "id": "7bb2452d", + "metadata": {}, + "source": [ + "\n", + "
\n", + "\n", + "
\n", + " Figure 11: Times obtained with optimizations using TransformerEngine (seconds).\n", + "
\n", + "
\n", + "\n", + "In this tutorial, we've explored three features of the Transformer Engine:\n", + "1. Support for the THD attention layout,\n", + "2. Integration with CUDA Graphs,\n", + "3. FP8 weights calibration,\n", + "4. Models containing only FP8 version of their parameters.\n", + "\n", + "Each of these features can be applied in various contexts, such as fast token generation. It's important to note that the fastest possible inference speeds can be achieved using NVIDIA's inference-optimized [TensorRT](https://developer.nvidia.com/tensorrt) library." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py new file mode 100644 index 0000000000..292a452f42 --- /dev/null +++ b/docs/examples/te_gemma/utils.py @@ -0,0 +1,320 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import time +import sys +import IPython +import random +import string + +from te_gemma_loading_weights import load_te_model + +import torch +from torch.optim import AdamW +from torch.utils.data import DataLoader + +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + get_linear_schedule_with_warmup, + AutoConfig, +) +from transformers import DataCollatorForLanguageModeling +from datasets import load_dataset +from accelerate import Accelerator +from accelerate.utils.dataclasses import FP8RecipeKwargs + + +from te_gemma import TEGemmaForCausalLM, TEGemmaForCausalLMCudaGraphs + + +class HyperParameters: + def __init__(self): + self.mixed_precision = "bf16" + self.model_name = None + + self.fp8 = False + + # Weights in fp8 + self.fp8_model_weights_filename = None + self.fp8_model_init = False + + # Cuda graphs + self.generation_cuda_graphs = False + self.cuda_graphs_static_batch_size = 16 + self.cuda_graphs_static_max_seq_len = 256 + self.cuda_graphs_static_max_context_len = 16 + + # Finetuning settings. + self.dataset_name = "timdettmers/openassistant-guanaco" + self.dataset_text_field = "text" + self.learning_rate = 1.41e-5 + self.batch_size = 8 + self.max_seq_length = 256 + self.gradient_accumulation_steps = 1 + self.num_warmup_steps = 5 + self.num_training_steps = 10 + + # QKV format. + self.fuse_qkv_params = False + self.qkv_format = "bshd" + + +hyperparams = HyperParameters() + +assert ( + torch.backends.cudnn.version() >= 9100 +), "cuDNN version >= 9.1.0 is needed to run this tutorial." + + +def get_dataloaders(accelerator: Accelerator, hyperparams): + dataset = load_dataset(hyperparams.dataset_name, split="train") + tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) + + def tokenize(element): + outputs = tokenizer( + element["text"], + truncation=True, + padding=False, + max_length=hyperparams.max_seq_length, + return_overflowing_tokens=False, + return_length=False, + ) + return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} + + with accelerator.main_process_first(): + dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names) + + # Simply pad to the multiple of 16 for both FP8 and BF16 precision + pad_to_multiple_of = 16 + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + pad_to_multiple_of=pad_to_multiple_of, + ) + + dataloader_params = { + "batch_size": hyperparams.batch_size, + "collate_fn": data_collator, + "drop_last": True, + } + train_dataloader = DataLoader(dataset, **dataloader_params) + return train_dataloader + + +def init_baseline_model(hyperparams): + # Init the model + config = AutoConfig.from_pretrained(hyperparams.model_name) + # make sure to use flash_attention to do iso comparison with TEGemmaModel + config._attn_implementation = "flash_attention_2" + model = AutoModelForCausalLM.from_pretrained( + hyperparams.model_name, + config=config, + torch_dtype=torch.bfloat16, + ) + return model.cuda() + + +def init_te_gemma_model(hyperparams): + cls = TEGemmaForCausalLMCudaGraphs if hyperparams.generation_cuda_graphs else TEGemmaForCausalLM + config = AutoConfig.from_pretrained(hyperparams.model_name) + config._attn_implementation = "flash_attention_2" + # Adding all params from the hyperparams to the config to make the code simpler. + for key, value in hyperparams.__dict__.items(): + setattr(config, key, value) + model = load_te_model(cls, config) + if hyperparams.generation_cuda_graphs: + model.record() + return model.cuda() + + +def wrap_with_accelerator(model, hyperparams): + # Create FP8 kwarg handler if required + fp8_kwarg_handler = ( + [FP8RecipeKwargs(backend="te")] if hyperparams.mixed_precision == "fp8" else None + ) + + # Init HF accelerator that's used for training + accelerator = Accelerator( + log_with="wandb", + gradient_accumulation_steps=hyperparams.gradient_accumulation_steps, + mixed_precision=hyperparams.mixed_precision, + kwargs_handlers=fp8_kwarg_handler, + ) + # accelerator.print(f'State: {accelerator.state}') + train_dataloader = get_dataloaders(accelerator, hyperparams) + + # Wrap model, optimizer/scheduler, dataloaders in accelerate + optimizer = AdamW(params=model.parameters(), lr=hyperparams.learning_rate, fused=True) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=hyperparams.num_training_steps, + ) + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + return accelerator, model, optimizer, train_dataloader, lr_scheduler + + +def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler): + model.train() + optimizer.zero_grad() + train_dataloader = enumerate(train_dataloader) + + def run_iters(num_iters): + for _ in range(num_iters): + _, batch = next(train_dataloader) + with accelerator.accumulate(model): + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + run_iters(hyperparams.num_warmup_steps) # Warmup iters + + # Get the timers ready + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + + start.record() + run_iters(hyperparams.num_training_steps) # Training iters + torch.cuda.synchronize() + end.record() + accelerator.end_training() + + print( + f"""{hyperparams.num_training_steps} finetuning steps complete!\n + Average time taken per step: + {(start.elapsed_time(end)/hyperparams.num_training_steps):.0f} + milliseconds""" + ) + + +def restart_jupyter_notebook(): + # Try restarting the Jupyter kernel + IPython.Application.instance().kernel.do_shutdown(True) + + # Check whether the device memory has been flushed + if torch.cuda.memory_allocated() != 0: + import warnings + + warnings.warn("The device memory hasn't been flushed, trying with a second method!") + + # Try restarting the Jupyter kernel another way + # Restart the kernel + from IPython.core.display import HTML + + HTML("") + + if torch.cuda.memory_allocated() != 0: + print( + "The device memory hasn't been flushed, try manually restarting the Jupyter kernel!" + ) + + # Suppress the warnings + if not sys.warnoptions: + import warnings + + warnings.simplefilter("ignore") + torch.set_warn_always(False) + + +@torch.no_grad() +def run_forward_pass(model, hyperparams, num_iters): + """ + It runs num_iters forward passes with sample data. + """ + accelerator = Accelerator( + log_with="wandb", + gradient_accumulation_steps=hyperparams.gradient_accumulation_steps, + mixed_precision="no", + ) + train_dataloader = get_dataloaders(accelerator, hyperparams) + + model.train() + train_dataloader = enumerate(train_dataloader) + + for _ in range(num_iters): + _, batch = next(train_dataloader) + batch["input_ids"] = batch["input_ids"].cuda() + model(batch["input_ids"]) + + +""" + Benchmarking and example generation functions. +""" + + +def print_sample_of_generated_texts(model): + tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) + prompts = ["Here are the two facts about GPUs:", "Some facts about NVIDIA:"] + inputs = tokenizer(prompts * 32, return_tensors="pt", padding=True) + + max_length = inputs["input_ids"].size(1) + new_length = ((max_length + 63) // 64) * 128 + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], (new_length - max_length, 0), value=tokenizer.pad_token_id + ) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (new_length - max_length, 0), value=0 + ) + + inputs["input_ids"] = inputs["input_ids"].cuda() + inputs["attention_mask"] = inputs["attention_mask"].cuda() + + outputs = model.generate(**inputs, max_new_tokens=50) + generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + print("=" * 30 + " Generation example 1 " + "=" * 30) + print("Prompt:") + print(generated_texts[0][: len(prompts[0])]) + print("Generated text:") + print(generated_texts[0][len(prompts[0]) :]) + print("=" * 30 + " Generation example 2 " + "=" * 30) + print("Prompt:") + print(generated_texts[1][: len(prompts[1])]) + print("") + print("Generated text:") + print(generated_texts[1][len(prompts[1]) :]) + + +def _generate_random_words(num_words, max_word_length): + words = [] + for _ in range(num_words): + word_length = random.randint(1, max_word_length) + word = "".join(random.choices(string.ascii_lowercase, k=word_length)) + words.append(word) + return words + + +def benchmark_generation(model): + batch_size = 64 + context_length = 128 + max_new_tokens = 1024 - 128 + print("=" * 30 + " Benchmarking " + "=" * 30) + print( + f"Benchmarking for batch_size = {batch_size} and max total tokens =" + f" {context_length + max_new_tokens}" + ) + + input_str = _generate_random_words(batch_size, context_length) + + tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) + inputs = tokenizer(input_str, return_tensors="pt", padding=True) + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + + model.generate(inputs["input_ids"].cuda(), max_new_tokens=max_new_tokens) + torch.cuda.synchronize() + end.record() + + print(f"Time: {start.elapsed_time(end)/1000:.2f} s.") diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb similarity index 99% rename from docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb rename to docs/examples/te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb index 57c1bf6601..0d3ada8a12 100644 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb @@ -5,7 +5,7 @@ "id": "6a5b2993", "metadata": {}, "source": [ - "# Accelerating a Hugging Face Llama 2 and Llama 3 models with Transformer Engine\n", + "# Accelerating a Hugging Face Llama 2 and Llama 3 models finetuning with Transformer Engine\n", "\n", "
\n", "\n", diff --git a/docs/index.rst b/docs/index.rst index d64cebbfa2..316c2ded59 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -44,7 +44,9 @@ Transformer Engine documentation examples/fp8_primer.ipynb examples/advanced_optimizations.ipynb - examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb + examples/te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb + examples/te_gemma/tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb + examples/te_gemma/tutorial_generation_gemma_with_te.ipynb .. toctree:: :hidden: diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 90c5e499f3..985f92cedd 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -22,5 +22,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py +pytest -v -s $TE_PATH/tests/pytorch/test_generation.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index d6ba66cbbc..a2ce84293c 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -11,7 +11,7 @@ def apply_rotary_pos_emb_thd( - t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor + t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor, start_positions: torch.Tensor ) -> torch.Tensor: """A baseline implementation of applying RoPE for `thd` format. @@ -20,14 +20,106 @@ def apply_rotary_pos_emb_thd( cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and dtype torch.int32. freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] + start_positions (Tensor): Tensor of shape [b] determining the beginning offsets + of frequeuncies applied to sequences. Returns: Tensor: Shape [t, h, d]. The input tensor after applying RoPE. """ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - return torch.cat( - [apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)]) for x in torch.split(t, seqlens)] - ).squeeze(1) + if start_positions is None: + return torch.cat( + [ + apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)]) + for x in torch.split(t, seqlens) + ] + ).squeeze(1) + else: + return torch.cat( + [ + apply_rotary_pos_emb( + x.unsqueeze(1), freqs[start_positions[i] : (x.size(0) + start_positions[i])] + ) + for i, x in enumerate(torch.split(t, seqlens)) + ] + ).squeeze(1) + + +def apply_rotary_pos_emb_with_start_positions( + t: torch.Tensor, + freqs: torch.Tensor, + tensor_format: str = "sbhd", + start_positions: Union[torch.Tensor, None] = None, +) -> torch.Tensor: + """ + Apply rotary positional embedding tensor to the input tensor. + This is non-fused version which supports start_positions parameters. + Non-fused implementation with start_positions is slow, thus it is not included in the + Transformer Engine directly. + + Parameters + ---------- + t: torch.Tensor + Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which + rotary positional embedding will be applied. + freqs: torch.Tensor + Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', + with `s2 >= s` and `d2 <= d`. + tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' + start_positions: torch.Tensor, default = None. + We may not want begin all the sequences from the 0 embedding. + This tensor argument allows that. + """ + + def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """ + change sign so the last dimension becomes [-odd, +even] + """ + x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2))) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + if start_positions is None: + return apply_rotary_pos_emb(t, freqs, tensor_format=tensor_format) + + max_seq_len = freqs.shape[0] + cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] + + # Only apply the rotary embeddings up to the sequence length of the running + # input. + assert ( + cur_seq_len <= max_seq_len + ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" + + if tensor_format == "bshd": + t = t.transpose(0, 1) + # cos/sin first then dtype conversion for better precision + cos_ = torch.cos(freqs).to(t.dtype) + sin_ = torch.sin(freqs).to(t.dtype) + + rot_dim = freqs.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + # shifted_sin, shifted_cos will have the same shape as t. They will contain + # scaling factors shifted for each sequence by the corresponding start_positions offset. + + shifted_sin = sin_[:cur_seq_len].expand(t.shape).clone() + shifted_cos = cos_[:cur_seq_len].expand(t.shape).clone() + + for b in range(start_positions.shape[0]): + assert max_seq_len >= start_positions[b] + shifted_freq = slice(start_positions[b], (start_positions[b] + cur_seq_len)) + shifted_sin[:, b, :] = sin_[shifted_freq, 0, ...] + shifted_cos[:, b, :] = cos_[shifted_freq, 0, ...] + + t = (t * shifted_cos) + (_rotate_half(t) * shifted_sin) + out = torch.cat((t, t_pass), dim=-1) + + if tensor_format == "bshd": + out = out.transpose(0, 1).contiguous() + + return out def get_tol(dtype: torch.dtype) -> Dict: @@ -54,8 +146,9 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: @pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) @pytest.mark.parametrize("margin", [0, 10]) +@pytest.mark.parametrize("start_positions", [True, False]) @pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)]) -@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"]) +@pytest.mark.parametrize("tensor_format", ["bshd", "sbhd"]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) def test_fused_rope( dtype: torch.dtype, @@ -63,6 +156,7 @@ def test_fused_rope( hidden_size: int, rotary_percent: float, margin: int, + start_positions: bool, transpose: Union[Tuple, None], tensor_format: str, loss_func: Callable, @@ -80,11 +174,24 @@ def test_fused_rope( t = t.transpose(*transpose).contiguous().transpose(*transpose) t.requires_grad = True + if margin == 0 and start_positions == True: + # If sequence to encode has the same length as length of encoding + # there is no space left for starting with positions >0. + pytest.skip("Skipping test with margin=0 and start_positions=True") + + start_positions = ( + torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device) + if start_positions + else None + ) + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) emb = rotary_pos_emb(seq_length) # unfused - output_unfused = apply_rotary_pos_emb(t, emb, tensor_format=tensor_format, fused=False) + output_unfused = apply_rotary_pos_emb_with_start_positions( + t, emb, tensor_format=tensor_format, start_positions=start_positions + ) loss_unfused = loss_func(output_unfused) loss_unfused.backward() grad_unfused = t.grad.detach().clone() @@ -92,10 +199,7 @@ def test_fused_rope( # fused output_fused = apply_rotary_pos_emb( - t, - emb, - tensor_format=tensor_format, - fused=True, + t, emb, tensor_format=tensor_format, fused=True, start_positions=start_positions ) loss_fused = loss_func(output_fused) loss_fused.backward() @@ -112,12 +216,14 @@ def test_fused_rope( @pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) @pytest.mark.parametrize("transpose", [None, (1, 2)]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) +@pytest.mark.parametrize("start_positions", [True, False]) def test_fused_rope_thd( dtype: torch.dtype, hidden_size: int, rotary_percent: float, transpose: Union[Tuple, None], loss_func: Callable, + start_positions: bool, ) -> None: device = torch.device("cuda:0") batch_size, head_num = 2, 64 @@ -135,11 +241,17 @@ def test_fused_rope_thd( t = t.transpose(*transpose).contiguous().transpose(*transpose) t.requires_grad = True + start_positions = ( + torch.randint(0, 20, (cu_seqlens.shape[-1],), dtype=torch.int32, device=device) + if start_positions + else None + ) + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) emb = rotary_pos_emb(cu_seqlens[-1]) # unfused - output_unfused = apply_rotary_pos_emb_thd(t, cu_seqlens, emb) + output_unfused = apply_rotary_pos_emb_thd(t, cu_seqlens, emb, start_positions=start_positions) loss_unfused = loss_func(output_unfused) loss_unfused.backward() grad_unfused = t.grad.detach().clone() @@ -147,7 +259,12 @@ def test_fused_rope_thd( # fused output_fused = apply_rotary_pos_emb( - t, emb, fused=True, tensor_format="thd", cu_seqlens=cu_seqlens + t, + emb, + fused=True, + tensor_format="thd", + cu_seqlens=cu_seqlens, + start_positions=start_positions, ) loss_fused = loss_func(output_fused) loss_fused.backward() diff --git a/tests/pytorch/test_generation.py b/tests/pytorch/test_generation.py new file mode 100644 index 0000000000..343dd4db1d --- /dev/null +++ b/tests/pytorch/test_generation.py @@ -0,0 +1,210 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch + +import transformer_engine.pytorch as te + + +class TestInferenceParams: + def test_setup_before_new_input_bshd(self): + inference_params = te.attention.InferenceParams(64, 128, qkv_format="bshd") + + inference_params.setup_before_new_input(length=16) + # Offset before first sequence is equal to 0. + assert inference_params.sequence_len_offset == 0 + + # Offset before second sequence is equal to 16. + inference_params.setup_before_new_input(length=4) + assert inference_params.sequence_len_offset == 16 + + def test_setup_before_new_input_thd(self): + inference_params = te.attention.InferenceParams(4, 128, qkv_format="thd") + + inference_params.setup_before_new_input( + lengths_tensor=torch.Tensor([1, 0, 2, 4]).cuda(), max_input_length=20 + ) + + assert torch.equal( + inference_params.cached_sequence_lengths, torch.Tensor([0, 0, 0, 0]).cuda() + ) + assert torch.equal( + inference_params.input_sequence_lengths, torch.Tensor([1, 0, 2, 4]).cuda() + ) + assert inference_params.max_incoming_seq_len == 20 + + inference_params.setup_before_new_input( + lengths_tensor=torch.Tensor([2, 3, 5, 1]).cuda(), max_input_length=10 + ) + assert torch.equal( + inference_params.cached_sequence_lengths, torch.Tensor([1, 0, 2, 4]).cuda() + ) + assert torch.equal( + inference_params.input_sequence_lengths, torch.Tensor([2, 3, 5, 1]).cuda() + ) + assert inference_params.max_incoming_seq_len == 10 + + @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) + @pytest.mark.parametrize("batch_size", [64, 128, 256]) + @pytest.mark.parametrize("max_seq_len", [128, 256, 512]) + @pytest.mark.parametrize("max_input_len", [32, 128]) + def test_save_to_kv_cache_thd(self, batch_size, max_seq_len, max_input_len, dtype): + h, d = 16, 256 + + inference_params = te.attention.InferenceParams(batch_size, max_seq_len, qkv_format="thd") + inference_params.allocate_memory_for_kv_cache_if_empty(1, h, d, dtype) + + t = batch_size * max_input_len + key_layer = torch.randn((t, h, d)).cuda().to(dtype) + value_layer = torch.randn((t, h, d)).cuda().to(dtype) + + sequence_lengths = [1, 2] * (batch_size // 2) + + # We save the same sequences two time, which should result in sequences of lentgh 2 and 4 + # in the cache + inference_params.reset() + inference_params.setup_before_new_input( + lengths_tensor=torch.tensor(sequence_lengths).cuda(), max_input_length=max_input_len + ) + inference_params.save_to_kv_cache(1, key_layer, value_layer) + + inference_params.setup_before_new_input( + lengths_tensor=torch.tensor(sequence_lengths).cuda(), max_input_length=max_input_len + ) + inference_params.save_to_kv_cache(1, key_layer, value_layer) + + key_memory, value_memory = inference_params.key_value_memory_dict[1] + + # Chcek whether the sequences were copied properly. + + def check(memory, layer, b, idx1, idx2): + # Check if sequence idx in batch b in memory corresponds + # to the sequence idx2 in batch b in layer. + assert torch.equal(memory[b * max_seq_len + idx1], layer[b * max_input_len + idx2, :]) + + # even indices + for b in range(0, batch_size, 2): + check(key_memory, key_layer, b, 0, 0) + check(key_memory, key_layer, b, 1, 0) + assert (key_memory[b * max_seq_len + 2 : ((b + 1) * max_seq_len)] == 0).all() + + check(value_memory, value_layer, b, 0, 0) + check(value_memory, value_layer, b, 1, 0) + assert (value_memory[b * max_seq_len + 2 : ((b + 1) * max_seq_len)] == 0).all() + + # odd indices + for b in range(1, batch_size, 2): + check(key_memory, key_layer, b, 0, 0) + check(key_memory, key_layer, b, 1, 1) + check(key_memory, key_layer, b, 2, 0) + check(key_memory, key_layer, b, 3, 1) + assert (key_memory[b * max_seq_len + 4 : ((b + 1) * max_seq_len)] == 0).all() + + check(value_memory, value_layer, b, 0, 0) + check(value_memory, value_layer, b, 1, 1) + check(value_memory, value_layer, b, 2, 0) + check(value_memory, value_layer, b, 3, 1) + assert (value_memory[b * max_seq_len + 4 : ((b + 1) * max_seq_len)] == 0).all() + + @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) + @pytest.mark.parametrize("batch_size", [64, 128, 256]) + @pytest.mark.parametrize("max_seq_len", [128, 256, 512]) + def test_save_to_kv_cache_bshd(self, batch_size, max_seq_len, dtype): + # This test checks if key_layer and value_layer are copied to cache. + # Cache size is equal to the size of one key/value layer. + h, d = 16, 256 + + inference_params = te.attention.InferenceParams(batch_size, max_seq_len, qkv_format="bshd") + + inference_params.allocate_memory_for_kv_cache_if_empty(1, h, d, dtype) + key_layer = torch.randn((max_seq_len, batch_size, h, d)).cuda().to(dtype) + value_layer = torch.randn((max_seq_len, batch_size, h, d)).cuda().to(dtype) + + inference_params.setup_before_new_input(length=0) + inference_params.save_to_kv_cache(1, key_layer, value_layer) + + key_memory, value_memory = inference_params.key_value_memory_dict[1] + + assert torch.equal(key_memory, key_layer) + assert torch.equal(value_memory, value_layer) + + @pytest.mark.parametrize("layer_number", [1, 100]) + @pytest.mark.parametrize("batch_size", [1, 128]) + @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) + def test_allocate_memory_for_kv_cache_if_empty(self, layer_number, batch_size, dtype): + nr_heads = 16 + head_dim = 256 + max_sequence_len = 128 + inference_params = te.attention.InferenceParams( + batch_size, max_sequence_len, qkv_format="bshd" + ) + + assert layer_number not in inference_params.key_value_memory_dict + + inference_params.allocate_memory_for_kv_cache_if_empty( + layer_number, nr_heads, head_dim, dtype + ) + + key_memory, value_memory = inference_params.key_value_memory_dict[layer_number] + + assert key_memory.shape == (max_sequence_len, batch_size, nr_heads, head_dim) + assert value_memory.shape == (max_sequence_len, batch_size, nr_heads, head_dim) + + # Should not allocate new buffers. + inference_params.allocate_memory_for_kv_cache_if_empty(layer_number, 100, 100, dtype) + + assert key_memory.shape == (max_sequence_len, batch_size, nr_heads, head_dim) + assert value_memory.shape == (max_sequence_len, batch_size, nr_heads, head_dim) + + def test_set_params_to_thd_attention(self): + # This test check whether parameteres needed to run thd attention + # are computed correcly. This parameters are passed to the fused_attn_fwd(..) + # to indicate which parts of the key/query/value layers are sequences and + # which of them are offsets. + batch_size = 4 + channels = 1024 + max_sequence_len = 128 + max_input_len = 20 + inference_params = te.attention.InferenceParams( + batch_size, max_sequence_len, qkv_format="thd" + ) + + inference_params.setup_before_new_input( + lengths_tensor=torch.Tensor([1, 1, 1, 1]).cuda(), max_input_length=max_input_len + ) + inference_params.setup_before_new_input( + lengths_tensor=torch.Tensor([1, 0, 2, 4]).cuda(), max_input_length=max_input_len + ) + + buffers = [torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") for _ in range(6)] + max_q_len, max_kv_len, buffers = inference_params.set_params_to_thd_attention( + buffers, channels + ) + + cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = ( + buffers + ) + + assert max_q_len == max_input_len + assert max_kv_len == max_sequence_len + assert torch.equal(cu_seqlens_q, torch.tensor([0, 1, 1, 3, 7]).cuda()) + assert torch.equal(cu_seqlens_kv, torch.tensor([0, 2, 3, 6, 11]).cuda()) + + assert torch.equal( + seq_offsets_q, + torch.tensor([k * max_input_len * channels for k in range(batch_size + 1)]).cuda(), + ) + assert torch.equal( + seq_offsets_k, + torch.tensor([k * max_sequence_len * channels for k in range(batch_size + 1)]).cuda(), + ) + assert torch.equal( + seq_offsets_v, + torch.tensor([k * max_sequence_len * channels for k in range(batch_size + 1)]).cuda(), + ) + assert torch.equal( + seq_offsets_o, + torch.tensor([k * max_input_len * channels for k in range(batch_size + 1)]).cuda(), + ) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 7eed97a0ca..8e20957384 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -3,8 +3,9 @@ # See LICENSE for license information. import math +import functools import os -from typing import Dict, List, Optional +from typing import Dict, List, Tuple, Optional import pytest import copy @@ -12,6 +13,8 @@ import torch.nn as nn from torch.nn import Parameter +import transformer_engine.pytorch.cpp_extensions as ext + from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_model_init from transformer_engine.pytorch.utils import ( init_method_normal, @@ -40,6 +43,22 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +@functools.cache +def _cudnn_version() -> Tuple[int, int, int]: + """Runtime cuDNN version (major, minor, patch)""" + encoded_version = ext.get_cudnn_version() + major_version_magnitude = 1000 if encoded_version < 90000 else 10000 + major, encoded_version = divmod(encoded_version, major_version_magnitude) + minor, patch = divmod(encoded_version, 100) + return (major, minor, patch) + + +def get_device_compute_capability() -> Tuple[int, int]: + """CUDA compute capability of current GPU""" + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + return (props.major, props.minor) + + seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -1682,6 +1701,139 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, assert_allclose(full_output, incremental_output, atol[dtype]) +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model_key", model_configs_inference.keys()) +@pytest.mark.parametrize("use_RoPE", all_boolean) +@pytest.mark.parametrize("module", module_inference) +@pytest.mark.skipif( + get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+." +) +@pytest.mark.skipif(_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.") +def test_kv_cache_accuracy_thd(dtype, bs, model_key, use_RoPE, module): + """ + In thd attention sequences can have various lengths, + different that 's' dimension of input to the Transformer Layer. + + The test contains of: + - one context phase when sequences with various lengths(!) are passed through the model, + - 2 phases when sequences with length 1 are passed through the model. + + The output is compared with the case when all this sequences are passed at one. + """ + if dtype == torch.float32: + pytest.skip("torch.float32 does not support thd") + + fused_attn_env = os.environ["NVTE_FUSED_ATTN"] + os.environ["NVTE_FUSED_ATTN"] = "1" # Only fused attention supports thd. + + if not fp8_available: + pytest.skip(reason_for_no_fp8) + + config = model_configs_inference[model_key] + + S = config.seq_len + B = bs + H = config.num_attention_heads + D = config.hidden_size + G = 2 # generation phase length + S_max = S + G + head_size = config.embed + + layer_number = 1 + rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda") + + # Tensors have shapes [b, s, h, d] and the seqlens are the tensor of shapes [b] + # which indicate the length of sequences - sequences starts from the begining. + # This function copies sequences from tensor into dst_tensor. + # dst_tensor should be big enough to fit this sequences. + def _concat_thd(dst_tensor, dst_seqlens, tensor, seqlens): + for b in range(B): + dst_tensor[b, dst_seqlens[b] : (dst_seqlens[b] + seqlens[b]), :] = tensor[ + b, : seqlens[b], : + ] + dst_seqlens.copy_(dst_seqlens + seqlens) + + if module == "TransformerLayer": + model = TransformerLayer( + hidden_size=D, + ffn_hidden_size=4 * D, + num_attention_heads=H, + attn_input_format="thd", + self_attn_mask_type="padding_causal", + layer_number=layer_number, + params_dtype=dtype, + device="cuda", + ).eval() + attn_name = "self_attn_mask_type" + else: + model = ( + MultiheadAttention( + hidden_size=D, + num_attention_heads=H, + qkv_format="thd", + layer_number=layer_number, + params_dtype=dtype, + attn_mask_type="padding_causal", + ) + .cuda() + .eval() + ) + attn_name = "attn_mask_type" + + inference_params = InferenceParams(B, S_max, qkv_format="thd") + + kwargs = { + "inference_params": inference_params, + "rotary_pos_emb": rotary_freqs if use_RoPE else None, + } + + total_sequence_lengths = torch.zeros((B,)).cuda().to(torch.int32) + total_tensor = torch.zeros((B, S_max, D)).cuda().to(dtype) + + # Sequences split into chunks. + + # context phase + sequence_lengths = torch.randint(1, S, (B,)).cuda().to(torch.int32) + chunk = torch.randn((B, S, D)).cuda().to(dtype) + inference_params.setup_before_new_input(max_input_length=S, lengths_tensor=sequence_lengths) + model( + chunk, inference_params=inference_params, rotary_pos_emb=rotary_freqs if use_RoPE else None + ) + _concat_thd(total_tensor, total_sequence_lengths, chunk, sequence_lengths) + + # generation phase + for _ in range(G): + sequence_lengths = torch.ones((B,)).cuda().to(torch.int32) + chunk = torch.randn((B, 1, D)).cuda().to(dtype) + inference_params.setup_before_new_input(max_input_length=1, lengths_tensor=sequence_lengths) + # we need to remove 'causal' from mask + # otherwise tokens we add will be considered as a first in the sequence, + # but they need to interact with all tokens from key-value cache. + # after removing this line, tests should fail + kwargs[attn_name] = "padding" + output = model(chunk, **kwargs) + _concat_thd(total_tensor, total_sequence_lengths, chunk, sequence_lengths) + incremental_logits = output[:, -1, :] # last element of each seq. + + # Sequences passed in one, concatenated chunk. + + kwargs[attn_name] = "padding_causal" # add 'causal' back to the mask + inference_params.reset() + inference_params.setup_before_new_input( + max_input_length=S_max, lengths_tensor=total_sequence_lengths + ) + full_output = model(total_tensor, **kwargs) + full_logits = full_output[ + torch.arange(0, B), total_sequence_lengths - 1, : + ] # last element of each seq. + + # Final result should be close. + torch.testing.assert_close(full_logits, incremental_logits, atol=1e-2, rtol=1e-2) + + os.environ["NVTE_FUSED_ATTN"] = fused_attn_env + + @pytest.mark.parametrize( "shape", [ diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index e7cf940a57..560b7b55d3 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -15,11 +15,11 @@ namespace transformer_engine { template __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst, - const int offset_block, const int offset_block_dst, - const int h, const int d, const int d2, const int stride_h, - const int stride_d, const int o_stride_h, - const int o_stride_d) { - int s_id = blockIdx.x; + const int begin_offset, const int offset_block, + const int offset_block_dst, const int h, const int d, + const int d2, const int stride_h, const int stride_d, + const int o_stride_h, const int o_stride_d) { + int s_id = blockIdx.x + begin_offset; #pragma unroll for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { float v_cos, v_sin; @@ -52,11 +52,11 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs template __device__ void fused_rope_block_backward(const scalar_t *src, const float *freqs, scalar_t *dst, - const int offset_block, const int offset_block_dst, - const int h, const int d, const int d2, - const int stride_h, const int stride_d, + const int begin_offset, const int offset_block, + const int offset_block_dst, const int h, const int d, + const int d2, const int stride_h, const int stride_d, const int o_stride_h, const int o_stride_d) { - int s_id = blockIdx.x; + int s_id = blockIdx.x + begin_offset; #pragma unroll for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { float v_cos = cosf(freqs[s_id * d2 + d_id]); @@ -88,68 +88,75 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq } template -__global__ void fused_rope_forward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst, - const int h, const int d, const int d2, - const int stride_s, const int stride_b, - const int stride_h, const int stride_d, - const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d) { +__global__ void fused_rope_forward_kernel(const scalar_t *src, const float *freqs, + const int *start_positions, scalar_t *dst, const int h, + const int d, const int d2, const int stride_s, + const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, + const int o_stride_b, const int o_stride_h, + const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; + int begin_offset = (start_positions == 0) ? 0 : start_positions[b_id]; int offset_block = s_id * stride_s + b_id * stride_b; int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; - fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, - stride_d, o_stride_h, o_stride_d); + fused_rope_block_forward(src, freqs, dst, begin_offset, offset_block, offset_block_dst, h, d, d2, + stride_h, stride_d, o_stride_h, o_stride_d); } template -__global__ void fused_rope_backward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst, - const int h, const int d, const int d2, - const int stride_s, const int stride_b, - const int stride_h, const int stride_d, - const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d) { +__global__ void fused_rope_backward_kernel(const scalar_t *src, const float *freqs, + const int *start_positions, scalar_t *dst, const int h, + const int d, const int d2, const int stride_s, + const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, + const int o_stride_b, const int o_stride_h, + const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; + int begin_offset = (start_positions == 0) ? 0 : start_positions[b_id]; int offset_block = s_id * stride_s + b_id * stride_b; int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; - fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, - stride_d, o_stride_h, o_stride_d); + fused_rope_block_backward(src, freqs, dst, begin_offset, offset_block, offset_block_dst, h, d, d2, + stride_h, stride_d, o_stride_h, o_stride_d); } template __global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu_seqlens, - const float *freqs, scalar_t *dst, const int h, - const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, - const int o_stride_d) { + const float *freqs, const int *start_positions, + scalar_t *dst, const int h, const int d, const int d2, + const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; int t_id = s_id + cu_seqlens[b_id]; if (t_id >= cu_seqlens[b_id + 1]) return; int offset_block = t_id * stride_t; int offset_block_dst = t_id * o_stride_t; - fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, - stride_d, o_stride_h, o_stride_d); + int begin_offset = (start_positions == 0) ? 0 : start_positions[b_id]; + fused_rope_block_forward(src, freqs, dst, begin_offset, offset_block, offset_block_dst, h, d, d2, + stride_h, stride_d, o_stride_h, o_stride_d); } template __global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *cu_seqlens, - const float *freqs, scalar_t *dst, const int h, - const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, - const int o_stride_d) { + const float *freqs, const int *start_positions, + scalar_t *dst, const int h, const int d, + const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; int t_id = s_id + cu_seqlens[b_id]; if (t_id >= cu_seqlens[b_id + 1]) return; int offset_block = t_id * stride_t; int offset_block_dst = t_id * o_stride_t; - fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, - stride_d, o_stride_h, o_stride_d); + int begin_offset = (start_positions == 0) ? 0 : start_positions[b_id]; + fused_rope_block_backward(src, freqs, dst, begin_offset, offset_block, offset_block_dst, h, d, d2, + stride_h, stride_d, o_stride_h, o_stride_d); } template -void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, scalar_t *output, - const int s, const int b, const int h, const int d, const int d2, +void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, + const int *start_positions, scalar_t *output, const int s, + const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b, const int o_stride_h, const int o_stride_d, cudaStream_t stream) { @@ -158,115 +165,123 @@ void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, scal dim3 threads(THREADS_PER_WARP, warps_per_block); fused_rope_forward_kernel<<>>( - input, freqs, output, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, - o_stride_b, o_stride_h, o_stride_d); + input, freqs, start_positions, output, h, d, d2, stride_s, stride_b, stride_h, stride_d, + o_stride_s, o_stride_b, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } template void fused_rope_backward_launcher(const scalar_t *output_grads, const float *freqs, - scalar_t *input_grads, const int s, const int b, const int h, - const int d, const int d2, const int stride_s, const int stride_b, - const int stride_h, const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { + const int *start_positions, scalar_t *input_grads, const int s, + const int b, const int h, const int d, const int d2, + const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); fused_rope_backward_kernel<<>>( - output_grads, freqs, input_grads, h, d, d2, stride_s, stride_b, stride_h, stride_d, - o_stride_s, o_stride_b, o_stride_h, o_stride_d); + output_grads, freqs, start_positions, input_grads, h, d, d2, stride_s, stride_b, stride_h, + stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } template void fused_rope_thd_forward_launcher(const scalar_t *input, const int *cu_seqlens, - const float *freqs, scalar_t *output, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { + const float *freqs, const int *start_positions, + scalar_t *output, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, + const int stride_h, const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d, + cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(max_s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); - fused_rope_thd_forward_kernel<<>>(input, cu_seqlens, freqs, output, h, - d, d2, stride_t, stride_h, stride_d, - o_stride_t, o_stride_h, o_stride_d); + fused_rope_thd_forward_kernel<<>>( + input, cu_seqlens, freqs, start_positions, output, h, d, d2, stride_t, stride_h, stride_d, + o_stride_t, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } template void fused_rope_thd_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, - const float *freqs, scalar_t *input_grads, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { + const float *freqs, const int *start_positions, + scalar_t *input_grads, const int max_s, const int b, + const int h, const int d, const int d2, const int stride_t, + const int stride_h, const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d, + cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(max_s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); fused_rope_thd_backward_kernel<<>>( - output_grads, cu_seqlens, freqs, input_grads, h, d, d2, stride_t, stride_h, stride_d, - o_stride_t, o_stride_h, o_stride_d); + output_grads, cu_seqlens, freqs, start_positions, input_grads, h, d, d2, stride_t, stride_h, + stride_d, o_stride_t, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } -void fused_rope_forward(const Tensor &input, const Tensor &freqs, Tensor *output, const int s, - const int b, const int h, const int d, const int d2, const int stride_s, - const int stride_b, const int stride_h, const int stride_d, - const int o_stride_s, const int o_stride_b, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { +void fused_rope_forward(const Tensor &input, const Tensor &freqs, const Tensor &start_positions, + Tensor *output, const int s, const int b, const int h, const int d, + const int d2, const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, scalar_t, fused_rope_forward_launcher(reinterpret_cast(input.data.dptr), reinterpret_cast(freqs.data.dptr), + reinterpret_cast(start_positions.data.dptr), reinterpret_cast(output->data.dptr), s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream);); } -void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, Tensor *input_grads, - const int s, const int b, const int h, const int d, const int d2, - const int stride_s, const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d, cudaStream_t stream) { +void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, + const Tensor &start_positions, Tensor *input_grads, const int s, + const int b, const int h, const int d, const int d2, const int stride_s, + const int stride_b, const int stride_h, const int stride_d, + const int o_stride_s, const int o_stride_b, const int o_stride_h, + const int o_stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( output_grads.data.dtype, scalar_t, fused_rope_backward_launcher(reinterpret_cast(output_grads.data.dptr), reinterpret_cast(freqs.data.dptr), + reinterpret_cast(start_positions.data.dptr), reinterpret_cast(input_grads->data.dptr), s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream);); } void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs, - Tensor *output, const int max_s, const int b, const int h, const int d, - const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { + const Tensor &start_positions, Tensor *output, const int max_s, + const int b, const int h, const int d, const int d2, const int stride_t, + const int stride_h, const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, scalar_t, fused_rope_thd_forward_launcher(reinterpret_cast(input.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), + reinterpret_cast(start_positions.data.dptr), reinterpret_cast(output->data.dptr), max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);); } void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlens, - const Tensor &freqs, Tensor *input_grads, const int max_s, const int b, - const int h, const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d, cudaStream_t stream) { + const Tensor &freqs, const Tensor &start_positions, + Tensor *input_grads, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( output_grads.data.dtype, scalar_t, fused_rope_thd_backward_launcher(reinterpret_cast(output_grads.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), + reinterpret_cast(start_positions.data.dptr), reinterpret_cast(input_grads->data.dptr), max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);); @@ -274,58 +289,62 @@ void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlen } // end namespace transformer_engine -void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, - const int s, const int b, const int h, const int d, const int d2, +void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, + const NVTETensor start_positions, NVTETensor output, const int s, + const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b, const int o_stride_h, const int o_stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_forward); using namespace transformer_engine; fused_rope_forward(*reinterpret_cast(input), - *reinterpret_cast(freqs), reinterpret_cast(output), - s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, - o_stride_h, o_stride_d, stream); + *reinterpret_cast(freqs), + *reinterpret_cast(start_positions), + reinterpret_cast(output), s, b, h, d, d2, stride_s, stride_b, + stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream); } void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, - NVTETensor input_grads, const int s, const int b, const int h, - const int d, const int d2, const int stride_s, const int stride_b, - const int stride_h, const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { + const NVTETensor start_positions, NVTETensor input_grads, const int s, + const int b, const int h, const int d, const int d2, + const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_backward); using namespace transformer_engine; fused_rope_backward(*reinterpret_cast(output_grads), *reinterpret_cast(freqs), + *reinterpret_cast(start_positions), reinterpret_cast(input_grads), s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream); } void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor output, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { + const NVTETensor freqs, const NVTETensor start_positions, + NVTETensor output, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_thd_forward); using namespace transformer_engine; fused_rope_thd_forward( *reinterpret_cast(input), *reinterpret_cast(cu_seqlens), - *reinterpret_cast(freqs), reinterpret_cast(output), max_s, b, h, d, - d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); + *reinterpret_cast(freqs), *reinterpret_cast(start_positions), + reinterpret_cast(output), max_s, b, h, d, d2, stride_t, stride_h, stride_d, + o_stride_t, o_stride_h, o_stride_d, stream); } void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { + const NVTETensor freqs, const NVTETensor start_positions, + NVTETensor input_grads, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_thd_backward); using namespace transformer_engine; - fused_rope_thd_backward(*reinterpret_cast(output_grads), - *reinterpret_cast(cu_seqlens), - *reinterpret_cast(freqs), - reinterpret_cast(input_grads), max_s, b, h, d, d2, stride_t, - stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); + fused_rope_thd_backward( + *reinterpret_cast(output_grads), + *reinterpret_cast(cu_seqlens), *reinterpret_cast(freqs), + *reinterpret_cast(start_positions), reinterpret_cast(input_grads), + max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); } diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index b92de88eca..01305c1e6d 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -17,6 +17,7 @@ extern "C" { * * \param[in] input Input tensor for fused rope. * \param[in] freqs The freqs tensor. + * \param[in] start_positions The beginning offsets. * \param[out] output Output tensor. * \param[in] s Length of the s dimension of input. * \param[in] b Length of the b dimension of input. @@ -33,8 +34,9 @@ extern "C" { * \param[in] o_stride_d Stride of the d dimension of output. * \param[in] stream CUDA stream used for the operation. */ -void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, - const int s, const int b, const int h, const int d, const int d2, +void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, + const NVTETensor start_positions, NVTETensor output, const int s, + const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b, const int o_stride_h, const int o_stride_d, cudaStream_t stream); @@ -43,6 +45,7 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVT * * \param[in] output_grads Incoming gradient tensor for backward. * \param[in] freqs The freqs tensor. + * \param[in] start_positions The tensor with positions of first tokens in sequences. * \param[out] input_grads Input gradient tensor to calculate. * \param[in] s Length of the s dimension of output_grads. * \param[in] b Length of the b dimension of output_grads. @@ -60,43 +63,45 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVT * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, - NVTETensor input_grads, const int s, const int b, const int h, - const int d, const int d2, const int stride_s, const int stride_b, - const int stride_h, const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, const int o_stride_d, - cudaStream_t stream); + const NVTETensor start_positions, NVTETensor input_grads, const int s, + const int b, const int h, const int d, const int d2, + const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, cudaStream_t stream); /*! \brief Apply rotary positional embedding to the input tensor in thd format. * - * \param[in] input Input tensor for fused rope. - * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. - * \param[in] freqs The freqs tensor. - * \param[out] output Output tensor. - * \param[in] max_s Max sequence length. - * \param[in] b Batch size. - * \param[in] h Length of the h dimension of input. - * \param[in] d Length of the d dimension of input. - * \param[in] d2 Length of the d dimension of freqs. - * \param[in] stride_t Stride of the t dimension of input. - * \param[in] stride_h Stride of the h dimension of input. - * \param[in] stride_d Stride of the d dimension of input. - * \param[in] o_stride_t Stride of the t dimension of output. - * \param[in] o_stride_h Stride of the h dimension of output. - * \param[in] o_stride_d Stride of the d dimension of output. - * \param[in] stream CUDA stream used for the operation. + * \param[in] input Input tensor for fused rope. + * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. + * \param[in] freqs The freqs tensor. + * \param[in] start_positions The tensor with positions of first tokens in sequences. + * \param[out] output Output tensor. + * \param[in] max_s Max sequence length. + * \param[in] b Batch size. + * \param[in] h Length of the h dimension of input. + * \param[in] d Length of the d dimension of input. + * \param[in] d2 Length of the d dimension of freqs. + * \param[in] stride_t Stride of the t dimension of input. + * \param[in] stride_h Stride of the h dimension of input. + * \param[in] stride_d Stride of the d dimension of input. + * \param[in] o_stride_t Stride of the t dimension of output. + * \param[in] o_stride_h Stride of the h dimension of output. + * \param[in] o_stride_d Stride of the d dimension of output. + * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor output, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, const int o_stride_d, - cudaStream_t stream); + const NVTETensor freqs, NVTETensor start_positions, + NVTETensor output, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, cudaStream_t stream); /*! \brief Compute the backward of the fused rope in thd format. * * \param[in] output_grads Incoming gradient tensor for backward. * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. * \param[in] freqs The freqs tensor. + * \param[in] start_positions The beginning offsets. * \param[out] input_grads Input gradient to calculate. * \param[in] max_s Max sequence length. * \param[in] b Batch size. @@ -112,11 +117,11 @@ void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seq * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, const int max_s, - const int b, const int h, const int d, const int d2, - const int stride_t, const int stride_h, const int stride_d, - const int o_stride_t, const int o_stride_h, const int o_stride_d, - cudaStream_t stream); + const NVTETensor freqs, NVTETensor start_positions, + NVTETensor input_grads, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index fa72ecfa33..7430027335 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -703,18 +703,43 @@ class InferenceParams: # pylint: disable=too-few-public-methods Parameters ---------- - max_batch_size : int + max_batch_size: int maximum batch size during inference. - max_sequence_length : int - maximum sequence length during inference. + max_sequence_length: int + maximum sequence length during inference. + qkv_format: str + Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. + `s` stands for the sequence length dimension, + `b` batch size, `h` the number of attention heads, + `d` head size, and `t` the total number of sequences in a batch, i.e. + `t = sum(s_i) for i = 0...b-1`. """ - def __init__(self, max_batch_size, max_sequence_length): + def __init__(self, max_batch_size, max_sequence_length, qkv_format="bshd"): + assert qkv_format in ["bshd", "sbhd", "thd"] + self.max_sequence_length = max_sequence_length self.max_batch_size = max_batch_size - self.sequence_len_offset = 0 - self.batch_size_offset = 0 + + # self.key_value_memory_dict[layer number] = (key_cache, value_cache) + # if qkv_format in ["bshd", "sbhd"]: (key/value)_cache.shape = [b/s, s/b, h, d] + # # if qkv_format = "thd": (key/value)_cache.shape = [t, h, d] self.key_value_memory_dict = {} + self.qkv_format = qkv_format + + if qkv_format == "thd": + # In thd attention layout input sequences can have different lenghts. + # self.input_sequence_lengths stores tensor of shape [b] with lengths of input sequences + # and self.cached_sequence_lengths is the sum of all previous input lengths tensors - + # equivalently it contains total lengths of cached sequences. + self.cached_sequence_lengths = torch.zeros( + (max_batch_size,), device="cuda", dtype=torch.int32) + self.input_sequence_lengths = torch.zeros( + (max_batch_size,), device="cuda", dtype=torch.int32) + else: + self.sequence_len_offset = 0 + self.batch_size_offset = 0 + self.input_sequence_length = None def swap_key_value_dict(self, batch_indices): """ @@ -742,6 +767,214 @@ def swap_key_value_dict(self, batch_indices): ) + def setup_before_new_input(self, lengths_tensor=None, max_input_length=None, length=None): + """ + Updates parameters representing incoming sequence lengths and lengths + of sequences in the cache. Should be called before every forward pass in the inference. + + Parameters + ---------- + lengths_tensor: torch.Tensor + 1d tensor with sequence lengths in new input. + Should be used only when self.qkv_format = "thd". + max_input_length: int + Should be used only when self.qkv_format = "thd". + If the incoming sequences tensor has shape [b * s, h, d], + this should be equal to s. + length: int + Length of the incoming sequences. + Should be used only when self.qkv_format in ["bshd", "sbhd"]. + """ + if self.qkv_format == "thd": + assert lengths_tensor is not None and max_input_length is not None, \ + "lengths_tensor and max_input_length should not be none for qkv_format = \"thd\"" + torch.add( + self.cached_sequence_lengths, + self.input_sequence_lengths, + out=self.cached_sequence_lengths) + self.input_sequence_lengths.copy_(lengths_tensor) + self.max_incoming_seq_len = max_input_length + + else: + assert length is not None, \ + "length should not be none for qkv_format in [\"bshd\", \"sbhd\"]" + if self.input_sequence_length is not None: + self.sequence_len_offset += self.input_sequence_length + self.input_sequence_length = length + + def reset(self): + """ + Resets the parameters to allow the use of this object in a new generation iteration. + This method does not reallocate buffers, + making it more efficient than creating a new InferenceParams object. + Moreover, reusing the same object with the same buffers is compatible + with the CUDA Graphs. + """ + if self.qkv_format == "thd": + self.cached_sequence_lengths.zero_() + self.input_sequence_lengths.zero_() + else: + self.input_sequence_length = None + self.sequence_len_offset = 0 + + def save_to_kv_cache(self, layer_number, key_layer, value_layer): + """ + Saves key_layer and value_layer in the cache. + + Parameters + ---------- + layer_number: input + layer number of the current `TransformerLayer` when multiple such modules are + concatenated to form a transformer block. + key_layer: torch.Tensor + Tensor - of the format corresponding to the self.qkv_format - + representing key_layer. + Notice: if self.qkv_format in ["bshd", "sbhd"] then both layers are in format sbhd + Notice: if self.qkv_format = "thd", we assume that offsets of the sequences + are of the form k * self.max_incoming_seq_len for k = 0, ..., batch_size-1. + value_layer: int + Tensor - of the format corresponding to the self.qkv_format - + representing value_layer. + Notice: if self.qkv_format in ["bshd", "sbhd"] both layers are in format sbhd + Notice: if self.qkv_format = "thd", we assume that offsets of the sequences + are of the form k * self.max_incoming_seq_len for k = 0, ..., batch_size-1. + """ + # Current kernels work only with contiguous tensors, it can be made faster in the future. + key_layer, value_layer = key_layer.contiguous(), value_layer.contiguous() + inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number] + if self.qkv_format == "thd": + channels = inference_key_memory.shape[1] * inference_key_memory.shape[2] # h * d + # This kernels copies kernels from input layers into cache, + # taking into account the thd format and sequence lengths. + tex.attention_copy( + inference_key_memory, + self.cached_sequence_lengths, + self.input_sequence_lengths, + key_layer, + self.max_incoming_seq_len, + self.max_sequence_length, + self.max_batch_size, + channels) + + tex.attention_copy( + inference_value_memory, + self.cached_sequence_lengths, + self.input_sequence_lengths, + value_layer, + self.max_incoming_seq_len, + self.max_sequence_length, + self.max_batch_size, + channels) + key_layer, value_layer = inference_key_memory, inference_value_memory + else: + assert self.qkv_format in ["bshd", "sbhd"], \ + "Attention format not supported by the inference." + batch_start = self.batch_size_offset + batch_end = batch_start + key_layer.size(1) + assert batch_end <= inference_key_memory.size(1) + + sequence_start = self.sequence_len_offset + sequence_end = sequence_start + key_layer.size(0) + assert sequence_end <= inference_key_memory.size(0) + + # Copy keys and values into KV-cache + seq_offsets = slice(sequence_start, sequence_end) + batch_offsets = slice(batch_start, batch_end) + inference_key_memory[seq_offsets, batch_offsets, ...] = key_layer + inference_value_memory[seq_offsets, batch_offsets, ...] = value_layer + key_layer = inference_key_memory[:sequence_end, batch_offsets, ...] + value_layer = inference_value_memory[:sequence_end, batch_offsets, ...] + return key_layer, value_layer + + def allocate_memory_for_kv_cache_if_empty( + self, + layer_number, + num_gqa_groups_per_partition, + hidden_size_per_attention_head, + dtype): + """ + Allocates memory for kv_cache for given layer, if it hasn't been alocated before. + + Parameters + ---------- + layer_number: input + layer number of the current `TransformerLayer` when multiple such modules are + concatenated to form a transformer block. + num_gqa_groups_per_partition: torch.Tensor + This will be third dimension of cache tensor. + hidden_size_per_attention_head: int + This will be fourth dimension of cache tensor. + """ + + if layer_number in self.key_value_memory_dict: + return # Already allocated + + b, s = self.max_batch_size, self.max_sequence_length + + def _allocate_memory(dims): + return torch.zeros( + *dims, + num_gqa_groups_per_partition, + hidden_size_per_attention_head, + dtype=dtype, + device=torch.cuda.current_device(), + ) + + if self.qkv_format == "thd": + inference_key_memory = _allocate_memory((b * s,)) + inference_value_memory = _allocate_memory((b * s,)) + else: + inference_key_memory = _allocate_memory((s, b)) + inference_value_memory = _allocate_memory((s, b)) + self.key_value_memory_dict[layer_number] = ( + inference_key_memory, + inference_value_memory, + ) + + def set_params_to_thd_attention(self, buffers, channels): + """ + Fused attention with q/k/v of thd layout with offsets needs some parameters informing + about sequence lengths. This function computes them and + saves them into the provided buffers. + + Parameters + ---------- + buffers: List[torch.Tensor] + buffers of size [batch_size + 1] for the parameters: + cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, + seq_offsets_k, seq_offsets_v, seq_offsets_o + respectively. + channels: int + value of num_heads * hidden_dim_for_each_head. + + Returns + ---------- + max_seqlen_q: int + Maximal value of query sequence length. + max_seqlen_kv: int + Maximal value of key/value sequence length. + buffers: torch.Tensor + Tensor with filled buffers. + """ + max_seqlen_q, max_seqlen_kv = self.max_incoming_seq_len, self.max_sequence_length + + cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = \ + buffers + + torch.cumsum(self.input_sequence_lengths, dim=0, out=cu_seqlens_q[1:]) + torch.cumsum( + self.cached_sequence_lengths + self.input_sequence_lengths, + dim=0, out=cu_seqlens_kv[1:]) + # If layer has shape [b * s_layer, h, d] + # offsets are of the form [k * s_layer * h * d for k = 0, ..., batch_size] + seq_offsets_q.copy_( + torch.arange(0, self.max_batch_size + 1, device="cuda") * channels * max_seqlen_q) + seq_offsets_k.copy_( + torch.arange(0, self.max_batch_size + 1, device="cuda") * channels * max_seqlen_kv) + seq_offsets_v.copy_(seq_offsets_k) + seq_offsets_o.copy_(seq_offsets_q) + + return max_seqlen_q, max_seqlen_kv, buffers @torch.no_grad() def get_swa_mask( window_size: Tuple[int, int], @@ -2460,33 +2693,44 @@ def forward( freqs: torch.Tensor, tensor_format: str = "sbhd", cu_seqlens: Union[torch.Tensor, None] = None, + beginning_offsets: Union[torch.Tensor, None] = None, ) -> torch.Tensor: + if beginning_offsets is None: + # Each sequence will start from positional encoding corresponding to 0. + # Otherwise sequence i will start from positional encoding + # corresponding to beginning_offsets[i]. + beginning_offsets = torch.Tensor() if freqs.dtype != torch.float32: freqs = freqs.float() if tensor_format == "sbhd": - output = tex.fused_rope_forward(t, freqs, False) + output = tex.fused_rope_forward(t, freqs, beginning_offsets, False) elif tensor_format == "bshd": - output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1) + output = tex.fused_rope_forward( + t.transpose(0, 1), freqs, beginning_offsets, True + ).transpose(0, 1) elif tensor_format == "thd": - output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs) + output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, beginning_offsets) else: raise ValueError(f"Unsupported tensor_format: {tensor_format}.") - ctx.save_for_backward(freqs, cu_seqlens) + ctx.save_for_backward(freqs, cu_seqlens, beginning_offsets) ctx.tensor_format = tensor_format return output @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - freqs, cu_seqlens = ctx.saved_tensors + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + freqs, cu_seqlens, start_positions = ctx.saved_tensors if ctx.tensor_format == "sbhd": - grad_input = tex.fused_rope_backward(grad_output, freqs, False) + grad_input = tex.fused_rope_backward(grad_output, freqs, start_positions, False) elif ctx.tensor_format == "bshd": grad_input = tex.fused_rope_backward( - grad_output.transpose(0, 1), freqs, True + grad_output.transpose(0, 1), freqs, start_positions, True ).transpose(0, 1) elif ctx.tensor_format == "thd": - grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs) + grad_input = tex.fused_rope_thd_backward( + grad_output, cu_seqlens, freqs, start_positions) else: raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.") @@ -2508,6 +2752,7 @@ def apply_rotary_pos_emb( tensor_format: str = "sbhd", fused: bool = False, cu_seqlens: Union[torch.Tensor, None] = None, + start_positions: Union[torch.Tensor, None] = None, ) -> torch.Tensor: """ Apply rotary positional embedding tensor to the input tensor. @@ -2528,12 +2773,18 @@ def apply_rotary_pos_emb( cu_seqlens: torch.Tensor, default = None. Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and dtype torch.int32. Only valid when `tensor_format` is 'thd'. + start_positions: torch.Tensor, default = None. + Token i from sequence s have position encoding corresponding to + position start_positions[i]. If start_positions=None, then this token has position i. """ + assert not (start_positions is not None and not fused), \ + """start_positions != None and fused=False is not supported""" + if fused: assert ( tensor_format != "thd" or cu_seqlens is not None ), "cu_seqlens must not be None when tensor_format is 'thd'." - return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens) + return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, start_positions) assert tensor_format in ("sbhd", "bshd"), ( "Only formats `sbhd` or `bshd` are supported for input tensor `t` " @@ -5121,6 +5372,7 @@ def __init__( self.cp_group = cp_group self.cp_global_ranks = cp_global_ranks self.cp_stream = cp_stream + self.channels = kv_channels * num_attention_heads self.hidden_size_per_attention_head = kv_channels @@ -5210,6 +5462,16 @@ def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unuse self.register_load_state_dict_post_hook(remove_extra_states_check) + self._allocator = StaticBufferAllocator() + + + def alloc(self, size, dtype, device): + """ + Allocated the buffer and works correctly with CUDA Graphs. + """ + return self._allocator(size, dtype, device) + + def _checkpointed_attention_forward( self, attention_func: Callable, @@ -5413,21 +5675,7 @@ def forward( first microbatch (since it is the first gradient being produced) """ - with self.prepare_forward( - query_layer, - is_first_microbatch, - num_gemms=3, - allow_non_contiguous=True, - ) as query_layer: - - if self.fp8: - if self.fp8_meta["recipe"].fp8_mha: - if not self.fp8_meta["recipe"].fp8_dpa: - self.fp8_meta["recipe"].fp8_dpa = True - self.logger.WARNING( - """Forcing fp8_meta["recipe"].fp8_dpa=True due to """ - """fp8_meta["recipe"].fp8_mha=True""" - ) + batch_size = key_layer.shape[0] if self.fp8 and self.fp8_meta["recipe"].fp8_dpa: forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True) @@ -5484,28 +5732,26 @@ def forward( key_layer = key_layer.transpose(0, 1) value_layer = value_layer.transpose(0, 1) - ( - inference_key_memory, - inference_value_memory, - ) = inference_params.key_value_memory_dict[self.layer_number] + key_layer, value_layer = inference_params.save_to_kv_cache( + self.layer_number, key_layer, value_layer + ) - batch_start = inference_params.batch_size_offset - batch_end = batch_start + key_layer.size(1) - assert batch_end <= inference_key_memory.size(1) + if qkv_format == "thd": + # Allocation of buffers, it works correctly with CUDA Graphs. + NR_BUFFERS = 6 + buffers = [ + self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") + for _ in range(NR_BUFFERS) + ] - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + key_layer.size(0) - assert sequence_end <= inference_key_memory.size(0) + max_seqlen_q, max_seqlen_kv, buffers = \ + inference_params.set_params_to_thd_attention(buffers, self.channels) + cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, \ + seq_offsets_k, seq_offsets_v, seq_offsets_o = buffers - # Copy keys and values into KV-cache - inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = ( - key_layer - ) - inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = ( - value_layer - ) - key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] - value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] + # query_layer is reshaped to the format [t, h, d] + # and make contiguous - needed by the THD attention + query_layer = query_layer.view(-1, *query_layer.shape[2:]).contiguous() if qkv_format == "bshd": key_layer = key_layer.transpose(0, 1) @@ -5618,18 +5864,55 @@ def forward( assert ( core_attention_bias is None ), "core_attention_bias must be None when core_attention_bias_type is alibi!" - if ( - _alibi_cache["_num_heads"] != query_layer.shape[-2] - or _alibi_cache["_max_seqlen_q"] != max_seqlen_q - or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv - or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment - or _alibi_cache["_alibi_slopes"] is None - ): - _alibi_cache["_alibi_slopes_require_update"] = True - _alibi_cache["_alibi_bias_require_update"] = True + if (_alibi_cache["_num_heads"] != query_layer.shape[-2] + or _alibi_cache["_max_seqlen_q"] != max_seqlen_q + or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv + or _alibi_cache["_alibi_slopes"] is None): + _alibi_cache["_alibi_slopes_require_update"] = True + _alibi_cache["_alibi_bias_require_update"] = True + + if core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias is not None: + use_flash_attention = False - context_parallel = ( - self.cp_group is not None and get_distributed_world_size(self.cp_group) != 1 + fu_core_attention_bias_type = core_attention_bias_type + fu_core_attention_bias = core_attention_bias + if core_attention_bias_type == "alibi" and use_fused_attention and alibi_slopes is not None: + fu_core_attention_bias_type = "post_scale_bias" + _, fu_core_attention_bias = get_alibi( + query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes, + bias_dtype=query_layer.dtype) + if (use_fused_attention + and fu_core_attention_bias_type == "post_scale_bias" + and (fu_core_attention_bias.shape[0] != 1 + or fu_core_attention_bias.shape[1] != query_layer.shape[-2])): + if fu_core_attention_bias.requires_grad: + # remove this line when cuDNN adds bwd support for + # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s] + use_fused_attention = False + else: + # max512 backend will only support [1, h, s, s] + os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" + + if query_layer.shape[-1] == 256 and query_layer.requires_grad: + # Fused attention is not supported for backward with head_dim = 256. + # to do (cyang): move it to the tex.get_fused_attn_backend + use_fused_attention = False + + if use_fused_attention: + fused_attention_backend = tex.get_fused_attn_backend( + TE_DType[query_layer.dtype] + if not isinstance(query_layer, Float8Tensor) else query_layer._fp8_dtype, + TE_DType[key_layer.dtype] + if not isinstance(key_layer, Float8Tensor) else key_layer._fp8_dtype, + QKVLayout[qkv_layout], + AttnBiasType[fu_core_attention_bias_type], + AttnMaskType[attn_mask_type], + self.attention_dropout, + query_layer.shape[-2], # num_attn_heads + key_layer.shape[-2], # num_gqa_groups + max_seqlen_q, + max_seqlen_kv, + query_layer.shape[-1], # head_dim ) core_attention_bias_shape = None @@ -5663,87 +5946,33 @@ def forward( and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv) ) - attention_params = AttentionParams( - qkv_type=type(query_layer), - qkv_dtype=query_layer.dtype, - qkv_layout=qkv_layout, - batch_size=batch_size, - num_heads=query_layer.shape[-2], - num_gqa_groups=key_layer.shape[-2], - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, - head_dim=query_layer.shape[-1], - attn_mask_type=attn_mask_type, - window_size=window_size, - alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias_shape=core_attention_bias_shape, - core_attention_bias_requires_grad=( - core_attention_bias.requires_grad if core_attention_bias is not None else False - ), - pad_between_seqs=pad_between_seqs, - attention_dropout=self.attention_dropout, - context_parallel=context_parallel, - deterministic=self.deterministic, - is_training=self.training, - fp8=self.fp8, - fp8_meta=self.fp8_meta, - ) - global _attention_backends - if ( - _attention_backends["attention_params"] is None - or attention_params != _attention_backends["attention_params"] - ): - _attention_backends["attention_params"] = attention_params - _attention_backends["backend_selection_requires_update"] = True - if _attention_backends["backend_selection_requires_update"]: - ( - use_flash_attention, - use_fused_attention, - fused_attention_backend, - use_unfused_attention, - _, - ) = get_attention_backend(attention_params) - if use_flash_attention: - self.logger.info("Running with FlashAttention backend") - elif use_fused_attention: - self.logger.info( - "Running with FusedAttention backend (sub-backend %s)", - int(fused_attention_backend), - ) - elif use_unfused_attention: - self.logger.info("Running with UnfusedDotProductAttention backend") - else: - use_flash_attention = _attention_backends["use_flash_attention"] - use_fused_attention = _attention_backends["use_fused_attention"] - fused_attention_backend = _attention_backends["fused_attention_backend"] - use_unfused_attention = _attention_backends["use_unfused_attention"] - - if use_flash_attention: - if core_attention_bias_type == "alibi": - alibi_slopes, _ = get_alibi( - query_layer.shape[-2], - max_seqlen_q, - max_seqlen_kv, - alibi_slopes=alibi_slopes, - ) - return self.flash_attention( - query_layer, - key_layer, - value_layer, - attention_mask=attention_mask, - qkv_layout=qkv_layout, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - attn_mask_type=attn_mask_type, - window_size=window_size, - alibi_slopes=alibi_slopes, - cp_group=self.cp_group, - cp_global_ranks=self.cp_global_ranks, - cp_stream=self.cp_stream, - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, - ) + if self.attention_type == "self": + if self.qkv_format == "bshd" and query_layer.shape[1] != value_layer.shape[1] or \ + self.qkv_format == "sbhd" and query_layer.shape[0] != value_layer.shape[0]: + # Flash attention does not self-support max_seqlen_q != max_seqlen_kv + use_flash_attention = False + + if use_flash_attention: + if _NVTE_DEBUG: + print("[DotProductAttention]: using flash-attn",_flash_attn_version) + if core_attention_bias_type == "alibi": + alibi_slopes, _ = get_alibi( + query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes) + return self.flash_attention(query_layer, + key_layer, + value_layer, + attention_mask=attention_mask, + qkv_layout=qkv_layout, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + attn_mask_type=attn_mask_type, + window_size=window_size, + alibi_slopes=alibi_slopes, + cp_group=self.cp_group, + cp_global_ranks=self.cp_global_ranks, + cp_stream=self.cp_stream, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv) if use_fused_attention: fu_core_attention_bias_type = core_attention_bias_type @@ -5845,15 +6074,26 @@ def forward( query_layer, key_layer, value_layer, - qkv_layout=qkv_layout, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - attn_mask_type=attn_mask_type, - attention_mask=attention_mask, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - alibi_slopes=alibi_slopes, - ) + qkv_layout = qkv_layout, + cu_seqlens_q = cu_seqlens_q, + cu_seqlens_kv = cu_seqlens_kv, + attn_mask_type = attn_mask_type, + attention_mask = attention_mask, + core_attention_bias_type = core_attention_bias_type, + core_attention_bias = core_attention_bias, + alibi_slopes = alibi_slopes) + + return self.unfused_attention(query_layer, + key_layer, + value_layer, + qkv_layout = qkv_layout, + cu_seqlens_q = cu_seqlens_q, + cu_seqlens_kv = cu_seqlens_kv, + attn_mask_type = attn_mask_type, + attention_mask = attention_mask, + core_attention_bias_type = core_attention_bias_type, + core_attention_bias = core_attention_bias, + alibi_slopes = alibi_slopes) raise Exception("No dot product attention support for the provided inputs!") @@ -6206,17 +6446,13 @@ def __init__( **common_gemm_kwargs, ) - def _allocate_memory( - self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype - ) -> torch.Tensor: - return torch.empty( - inference_max_sequence_len, - batch_size, - self.num_gqa_groups_per_partition, - self.hidden_size_per_attention_head, - dtype=dtype, - device=torch.cuda.current_device(), - ) + self._allocator = StaticBufferAllocator() + + def alloc(self, size, dtype, device): + """ + Allocated the buffer and works correctly with CUDA Graphs. + """ + return self._allocator(size, dtype, device) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ @@ -6360,25 +6596,13 @@ def forward( # Pre-allocate memory for key-values for inference # ================================================= - if inference_params and self.layer_number is not None: - if self.layer_number not in inference_params.key_value_memory_dict: - inf_max_seq_len = inference_params.max_sequence_length - inf_max_batch_size = inference_params.max_batch_size - inference_key_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size, hidden_states.dtype - ) - inference_value_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size, hidden_states.dtype - ) - inference_params.key_value_memory_dict[self.layer_number] = ( - inference_key_memory, - inference_value_memory, - ) - else: - ( - inference_key_memory, - inference_value_memory, - ) = inference_params.key_value_memory_dict[self.layer_number] + if inference_params is not None: + inference_params.allocate_memory_for_kv_cache_if_empty( + self.layer_number, + self.num_gqa_groups_per_partition, + self.hidden_size_per_attention_head, + hidden_states.dtype + ) # ====================== # Query, Key, and Value @@ -6538,21 +6762,42 @@ def forward( q_pos_emb, k_pos_emb = rotary_pos_emb - # adjust key and value for inference - if inference_params is not None: - if self.qkv_format == "sbhd": - sequence_length = key_layer.size(0) - elif self.qkv_format == "bshd": - sequence_length = key_layer.size(1) + if self.qkv_format == "thd" and inference_params is not None: + # For thd attention incoming tokens can be on different positions, + # so we need to copy different positional encoding freqency + # for every sequence in a batch. + # + # For example if sequence lengths in context phase are: 2 and 5 (batch size=2), + # in first generation phase key_layer have shape [2, 1, d]. + # key_layer[0, :] corresponds to the token with position 3 = 2 + 1, + # and key_layer [1, :] corresponds to the token with position 6 = 5 + 1. + + query_layer = apply_rotary_pos_emb( + query_layer, q_pos_emb, "bshd", fused=True, + start_positions=inference_params.cached_sequence_lengths) + key_layer = apply_rotary_pos_emb( + key_layer, k_pos_emb, "bshd", fused=True, + start_positions=inference_params.cached_sequence_lengths) + + else: + # adjust key and value for inference + if inference_params is not None: + if self.qkv_format == "sbhd": + sequence_length = key_layer.size(0) + elif self.qkv_format == "bshd": + sequence_length = key_layer.size(1) + + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + sequence_length - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + sequence_length + q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] + k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] - q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] - k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] + query_layer = apply_rotary_pos_emb( + query_layer, q_pos_emb, self.qkv_format, fused=True) + key_layer = apply_rotary_pos_emb( + key_layer, k_pos_emb, self.qkv_format, fused=True) - query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True) - key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True) # =========================== # Core attention computation @@ -6576,6 +6821,12 @@ def forward( inference_params=inference_params, ) + if self.qkv_format == "thd": + # [b * sq, h] -> [qs, b, h] + context_layer = context_layer.view( + (inference_params.max_batch_size, -1, context_layer.shape[1]) + ).contiguous() + # =================== # Output. [sq, b, h] # =================== @@ -6596,3 +6847,20 @@ def forward( if self.input_layernorm and self.return_layernorm_output: outputs += (layernorm_output,) return outputs if len(outputs) > 1 else outputs[0] + + +class StaticBufferAllocator(torch.nn.Module): + """ + This class is used when we use te.make_graphed_callable(). + CUDA Graphs require all tensors to be static. Neverthless, + torch API make_graphed_callable() takes care of output of torch modules, + and makes them static. Thus by wrapping allocation of memory into + torch.nn.Module, we can greatly simplify our code. + """ + + # pylint: disable=no-self-use + def forward(self, size, dtype, device): + """ + Return buffer of given size, dtype and device. + """ + return torch.zeros(size, dtype=dtype, device=device) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f06b0cb197..40ec6959d2 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -357,16 +357,18 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio **************************************************************************************************/ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, + const at::Tensor &start_positions, const bool transpose_output_memory); at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, + const at::Tensor &start_positions, const bool transpose_output_memory); at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, - const at::Tensor &freqs); + const at::Tensor &freqs, const at::Tensor &start_positions); at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, - const at::Tensor &freqs); + const at::Tensor &freqs, const at::Tensor &start_positions); /*************************************************************************************************** * Miscellaneous @@ -376,6 +378,17 @@ size_t get_cublasLt_version(); size_t get_cudnn_version(); +bool userbuf_comm_available(); + +void placeholder(); + +/*************************************************************************************************** + * Generation + **************************************************************************************************/ + +void attention_copy(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len, + torch::Tensor B, int max_incoming_seq_len, int max_seq_len, int b, int s); + /*************************************************************************************************** * Support THD format for Context Parallel **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cu b/transformer_engine/pytorch/csrc/extensions/apply_rope.cu index c58ba91d5e..8dc0545e26 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cu +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cu @@ -7,6 +7,7 @@ #include "extensions.h" at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, + const at::Tensor &start_positions, const bool transpose_output_memory) { using namespace transformer_engine; TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); @@ -55,16 +56,19 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, auto input_cu = makeTransformerEngineTensor(input); auto freqs_cu = makeTransformerEngineTensor(freqs); + auto start_positions_cu = makeTransformerEngineTensor(start_positions); auto output_cu = makeTransformerEngineTensor(output); - nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), output_cu.data(), s, b, h, d, d2, - stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, - o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); + nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), start_positions_cu.data(), + output_cu.data(), s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, + o_stride_s, o_stride_b, o_stride_h, o_stride_d, + at::cuda::getCurrentCUDAStream()); return output; } at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, + const at::Tensor &start_positions, const bool transpose_output_memory) { using namespace transformer_engine; TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); @@ -111,17 +115,19 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor auto output_grads_cu = makeTransformerEngineTensor(output_grads); auto freqs_cu = makeTransformerEngineTensor(freqs); + auto start_positions_cu = makeTransformerEngineTensor(start_positions); auto input_grads_cu = makeTransformerEngineTensor(input_grads); - nvte_fused_rope_backward(output_grads_cu.data(), freqs_cu.data(), input_grads_cu.data(), s, b, h, - d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, - o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); + nvte_fused_rope_backward(output_grads_cu.data(), freqs_cu.data(), start_positions_cu.data(), + input_grads_cu.data(), s, b, h, d, d2, stride_s, stride_b, stride_h, + stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, + at::cuda::getCurrentCUDAStream()); return input_grads; } at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, - const at::Tensor &freqs) { + const at::Tensor &freqs, const at::Tensor &start_positions) { using namespace transformer_engine; TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); @@ -163,16 +169,18 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_ auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); auto freqs_cu = makeTransformerEngineTensor(freqs); auto output_cu = makeTransformerEngineTensor(output); + auto start_positions_cu = makeTransformerEngineTensor(start_positions); nvte_fused_rope_thd_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - output_cu.data(), max_s, b, h, d, d2, stride_t, stride_h, stride_d, - o_stride_t, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); + start_positions_cu.data(), output_cu.data(), max_s, b, h, d, d2, + stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, + at::cuda::getCurrentCUDAStream()); return output; } at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, - const at::Tensor &freqs) { + const at::Tensor &freqs, const at::Tensor &start_positions) { using namespace transformer_engine; TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); @@ -212,10 +220,11 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); auto freqs_cu = makeTransformerEngineTensor(freqs); auto input_grads_cu = makeTransformerEngineTensor(input_grads); + auto start_positions_cu = makeTransformerEngineTensor(start_positions); nvte_fused_rope_thd_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - input_grads_cu.data(), max_s, b, h, d, d2, stride_t, stride_h, - stride_d, o_stride_t, o_stride_h, o_stride_d, + start_positions_cu.data(), input_grads_cu.data(), max_s, b, h, d, d2, + stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); return input_grads; diff --git a/transformer_engine/pytorch/csrc/extensions/generation.cu b/transformer_engine/pytorch/csrc/extensions/generation.cu new file mode 100644 index 0000000000..5a162f1af6 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/generation.cu @@ -0,0 +1,55 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" + +// Kernel used to update KV chache when attention layout is "thd". +template +__global__ void attention_copy_kernel(scalar_t* cache_tensor, int* seq_len, int* incoming_seq_len, + scalar_t* hidden_tensor, int max_incoming_seq_len, + int max_seq_len, int b, int s) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int to_copy = s * incoming_seq_len[batch_idx]; + int offset = seq_len[batch_idx]; + + scalar_t* begin_cache_copy = cache_tensor + max_seq_len * s * batch_idx + s * offset; + scalar_t* begin_hidden_copy = hidden_tensor + s * batch_idx * max_incoming_seq_len; + + for (int i = threadIdx.x; i < to_copy; i += blockDim.x) { + *(begin_cache_copy + i) = *(begin_hidden_copy + i); + } + } +} + +template +void attention_copy_launcher(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len, + torch::Tensor B, int max_incoming_seq_len, int max_seq_len, int b, + int s) { + attention_copy_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(A.data_ptr()), seq_len.data_ptr(), + incoming_seq_len.data_ptr(), reinterpret_cast(B.data_ptr()), + max_incoming_seq_len, max_seq_len, b, s); +} + +void attention_copy(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len, + torch::Tensor B, int max_incoming_seq_len, int max_seq_len, int b, int s) { + if (A.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + attention_copy_launcher(A, seq_len, incoming_seq_len, B, max_incoming_seq_len, + max_seq_len, b, s); + + } else if (A.scalar_type() == at::ScalarType::BFloat16) { + using dtype = at::BFloat16; + attention_copy_launcher(A, seq_len, incoming_seq_len, B, max_incoming_seq_len, + max_seq_len, b, s); + } else if (A.scalar_type() == at::ScalarType::Float) { + using dtype = float; + attention_copy_launcher(A, seq_len, incoming_seq_len, B, max_incoming_seq_len, + max_seq_len, b, s); + } else { + NVTE_ERROR("Unsupported dtype of out\n"); + } +} diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 89bce77ded..d250ce4484 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -155,6 +155,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams); + // Generation + m.def("attention_copy", &attention_copy, "attention_copy"); + // Support THD format for Context Parallel m.def("thd_read_half_tensor", &thd_read_half_tensor, "Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD " diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 130cf91f0e..3e077a4c07 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -184,6 +184,10 @@ class TransformerLayer(torch.nn.Module): head size. Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules. + Notion: The experimental version of the 'thd' attention is supported + when :attr:`inference_params` is passed to the forward function. + + Parallelism parameters ---------------------- @@ -280,6 +284,9 @@ def __init__( ) -> None: super().__init__() + if ub_tp_comm_overlap: + assert tex.userbuf_comm_available(), "Userbuffer communication backend not available." + self.self_attn_mask_type = self_attn_mask_type self.window_size = check_set_window_size(self_attn_mask_type, window_size) self.enc_dec_attn_mask_type = enc_dec_attn_mask_type