diff --git a/src/instructlab/dolomite/enums.py b/src/instructlab/dolomite/enums.py index 3105551..e7aebaf 100644 --- a/src/instructlab/dolomite/enums.py +++ b/src/instructlab/dolomite/enums.py @@ -11,3 +11,83 @@ class ParamsGroupMethod(Enum): class GradientCheckpointingMethod(Enum): block = "block" + + +class LRDecaySchedule(str, Enum): + constant = "constant" + cosine = "cosine" + exponential = "exponential" + linear = "linear" + power = "power" + + +class AttentionImplementation(Enum): + """ + Enum class for attention implementation + """ + + eager = "eager" + sdpa = "sdpa" + flash_attention_2 = "flash_attention_2" + + +class MoEImplementation(Enum): + """ + Enum class for MoE implementation + """ + + eager = "eager" + scattermoe = "scattermoe" + + +class DatasetSplit(str, Enum): + """dataset split""" + + train = "train" + val = "val" + test = "test" + + +class Mode(str, Enum): + """training / inference mode""" + + training = "training" + inference = "inference" + unsharding = "unsharding" + distillation = "distillation" + + +class TuningMethod(str, Enum): + """training method""" + + pretraining = "pretraining" + full_finetuning = "full_finetuning" + prompt_tuning = "prompt_tuning" + lora = "lora" + distillation = "distillation" + + +class FP8Backend(str, Enum): + msamp = "msamp" + nvte = "nvte" + + +class LossMask(str, Enum): + """Type of loss masking method""" + + output_only = "output_only" + no_mask = "no_mask" + + +class KLDivergenceMethod(str, Enum): + """Type of KL divergence""" + + forward = "forward" + backward = "backward" + + +class ExperimentsTrackerName(str, Enum): + """Experiment tracker to use""" + + aim = "aim" + wandb = "wandb" diff --git a/src/instructlab/dolomite/hf_models/__init__.py b/src/instructlab/dolomite/hf_models/__init__.py index 66b024c..86cc443 100644 --- a/src/instructlab/dolomite/hf_models/__init__.py +++ b/src/instructlab/dolomite/hf_models/__init__.py @@ -2,7 +2,7 @@ # Extracted from https://github.com/ibm-granite/dolomite-engine # ---------------------------------------------------------------- # Local -from .config import GPTDolomiteConfig +from .models.gpt_dolomite.config import GPTDolomiteConfig from .model_conversion import export_to_huggingface, import_from_huggingface from .models import GPTDolomiteForCausalLM, GPTDolomiteModel from .register_hf import register_model_classes diff --git a/src/instructlab/dolomite/hf_models/config.py b/src/instructlab/dolomite/hf_models/config.py index 0258489..538dc34 100644 --- a/src/instructlab/dolomite/hf_models/config.py +++ b/src/instructlab/dolomite/hf_models/config.py @@ -1,15 +1,9 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Third Party from transformers import PretrainedConfig -# Local -from .enums import AttentionHeadType, PositionEmbeddingType +from .enums import AttentionHeadType, InitMethod, PositionEmbeddingType -class GPTDolomiteConfig(PretrainedConfig): - model_type = "gpt_dolomite" +class CommonConfig(PretrainedConfig): keys_to_ignore_at_inference = ["past_key_values"] attribute_map = { "hidden_size": "n_embd", @@ -18,11 +12,6 @@ class GPTDolomiteConfig(PretrainedConfig): "num_hidden_layers": "n_layer", } - # NOTE: initializer range is kept for backward compatiblity - # but it is not used anymore - # : also rope_scaling is not used anymore but kept for - # same reason. - def __init__( self, vocab_size: int = 50257, @@ -30,8 +19,8 @@ def __init__( n_embd: int = 768, n_layer: int = 12, n_head: int = 12, - num_key_value_heads: int = None, - n_inner: int = None, + num_key_value_heads: int | None = None, + n_inner: int | None = None, activation_function: str = "gelu_pytorch_tanh", attention_head_type: str = "mqa", resid_pdrop: float = 0.1, @@ -41,20 +30,19 @@ def __init__( layer_norm_epsilon: float = 1e-5, initializer_range: float = 0.02, scale_attn_weights: bool = True, - attention_multiplier: float = None, + attention_multiplier: float | None = None, use_cache: bool = True, bos_token_id: int = 50256, eos_token_id: int = 50256, pad_token_id: int = 50256, attention_softmax_in_fp32: bool = True, - scale_attention_softmax_in_fp32: bool = True, add_bias: bool = True, position_embedding_type: str = "learned_absolute", rope_theta: int = 10000, - rope_scaling: dict = None, - m_emb: float = None, - m_width: float = None, - m_residual: float = None, + rope_scaling: dict | None = None, + m_emb: float | None = None, + m_width: float | None = None, + m_residual: float | None = None, init_method: str = "normal", upcast_logits_for_loss: bool = False, **kwargs, @@ -78,7 +66,6 @@ def __init__( self.attention_multiplier = attention_multiplier self.use_cache = use_cache self.attention_softmax_in_fp32 = attention_softmax_in_fp32 - self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32 self.position_embedding_type = position_embedding_type self.add_bias = add_bias self.rope_theta = rope_theta @@ -93,6 +80,7 @@ def __init__( assert self.scale_attn_weights # check if enums are valid + init_method = InitMethod(init_method) attention_head_type = AttentionHeadType(attention_head_type) position_embedding_type = PositionEmbeddingType(position_embedding_type) @@ -110,9 +98,7 @@ def __init__( if self.num_key_value_heads is None: self.num_key_value_heads = 1 - assert ( - self.num_key_value_heads == 1 - ), "MultiQueryAttention should have 1 head for keys and values" + assert self.num_key_value_heads == 1, "MultiQueryAttention should have 1 head for keys and values" elif attention_head_type == AttentionHeadType.gqa: assert ( self.num_key_value_heads is not None @@ -122,9 +108,4 @@ def __init__( self.n_head % self.num_key_value_heads == 0 ), "GroupedQueryAttention should have more than 1 head for keys and values" - super().__init__( - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, - **kwargs, - ) + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs) diff --git a/src/instructlab/dolomite/hf_models/defaults.py b/src/instructlab/dolomite/hf_models/defaults.py new file mode 100644 index 0000000..2428033 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/defaults.py @@ -0,0 +1 @@ +DEFAULT_NORMALIZATION_IMPLEMENTATION = "torch" diff --git a/src/instructlab/dolomite/hf_models/enums.py b/src/instructlab/dolomite/hf_models/enums.py index bc5d592..5055bcf 100644 --- a/src/instructlab/dolomite/hf_models/enums.py +++ b/src/instructlab/dolomite/hf_models/enums.py @@ -1,10 +1,11 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Standard from enum import Enum +class InitMethod(Enum): + normal = "normal" + mup = "mup" + + class PositionEmbeddingType(Enum): """ Enum class for position embeddings diff --git a/src/instructlab/dolomite/hf_models/mixins/__init__.py b/src/instructlab/dolomite/hf_models/mixins/__init__.py new file mode 100644 index 0000000..c4f9102 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/mixins/__init__.py @@ -0,0 +1,4 @@ +from .dense import BaseModelMixin, CausalLMModelMixin, PreTrainedModelMixin +#from .dense_TP import BaseModelMixin_TP, CausalLMModelMixin_TP, PreTrainedModelMixin_TP +from .moe import BaseMoEModelMixin, CausalLMMoEModelMixin, PreTrainedMoEModelMixin +#from .moe_TP import BaseMoEModelMixin_TP, CausalLMMoEModelMixin_TP, PreTrainedMoEModelMixin_TP diff --git a/src/instructlab/dolomite/hf_models/mixins/dense/__init__.py b/src/instructlab/dolomite/hf_models/mixins/dense/__init__.py new file mode 100644 index 0000000..0ee5d10 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/mixins/dense/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseModelMixin, PreTrainedModelMixin +from .main import CausalLMModelMixin diff --git a/src/instructlab/dolomite/hf_models/mixins/dense/base.py b/src/instructlab/dolomite/hf_models/mixins/dense/base.py new file mode 100644 index 0000000..3298682 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/mixins/dense/base.py @@ -0,0 +1,584 @@ +import warnings + +import torch +import torch.nn as nn +from transformers import DynamicCache, PreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPast + +from ...config import CommonConfig +from ...defaults import DEFAULT_NORMALIZATION_IMPLEMENTATION +from ...enums import AttentionHeadType, PositionEmbeddingType +from ...modeling_utils import Alibi, ParameterizedEmbedding, RoPE, YaRNScaledRoPE, get_normalization_function +from ...utils import convert_padding_free_lists_to_tensors, divide_if_divisible + + +class PreTrainedModelMixin(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = None + layer_class = None + base_model_prefix = "transformer" + causal = True + _no_split_modules = None + _skip_keys_device_placement = "past_key_values" + _supports_sdpa = True + _supports_flash_attn_2 = True + + def __init__(self, config: CommonConfig, *args, **kwargs) -> None: + super().__init__(config, *args, **kwargs) + + assert self.config_class is not None + + self.normalization_implementation = kwargs.get( + "normalization_implementation", DEFAULT_NORMALIZATION_IMPLEMENTATION + ) + + self.attention_implementation = self.config._attn_implementation + self._use_eager_attention = self.attention_implementation == "eager" + self._use_sdpa = self.attention_implementation == "sdpa" + self._use_flash_attention_2 = self.attention_implementation == "flash_attention_2" + self._use_padding_free_transformer = kwargs.get("use_padding_free_transformer", False) + + self._tied_word_embeddings = config.tie_word_embeddings + + if self._use_padding_free_transformer: + assert self._use_flash_attention_2, "padding free transformer only works with flash attention" + + def _init_weights(self, module: nn.Module) -> None: + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + def prepare_inputs_for_model( + self, + input_ids: torch.Tensor | list[list[int]] | None, + inputs_embeds: torch.Tensor | list[list[float]] | None, + position_ids: torch.Tensor | list[list[int]] | None, + token_type_ids: torch.Tensor | list[list[int]] | None, + labels: torch.Tensor | list[list[int]] | None, + cu_seqlens: torch.Tensor | None, + max_seqlen: torch.Tensor | None, + past_key_values: tuple[tuple[torch.Tensor]], + attention_mask: torch.Tensor | None, + use_cache: bool, + output_attentions: bool, + ) -> tuple[torch.Tensor]: + if self._use_padding_free_transformer: + if isinstance(input_ids, list) or isinstance(inputs_embeds, list): + # this is managed internally + error_message = ( + "{variable} should not be passed for flash attention when using List[List[int]] " + "input types attention mask logic is handled internally" + ) + assert cu_seqlens is None, error_message.format(variable="cu_seqlens") + assert max_seqlen is None, error_message.format(variable="max_seqlen") + assert attention_mask is None, error_message.format(variable="attention_mask") + + input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = ( + convert_padding_free_lists_to_tensors( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + labels=labels, + device=torch.cuda.current_device(), + ) + ) + else: + assert ( + cu_seqlens is not None + ), "cu_seqlens needs to be specified when using tensor inputs with padding_free transformer" + assert position_ids is not None, "max_seqlen needs to be specified when specifying cu_seqlens" + assert max_seqlen is not None, "max_seqlen needs to be specified when specifying cu_seqlens" + assert attention_mask is None, "attention_mask should not be passed when specifying cu_seqlens" + + if use_cache or past_key_values is not None: + raise NotImplementedError("KV caching is not supported with padding_free transformer") + + assert not output_attentions + + return input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen + + +class BaseModelMixin(PreTrainedModelMixin): + mask_value = None + + def __init__(self, config: CommonConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + self._init_model(config, **kwargs) + + def _init_model(self, config: CommonConfig, **kwargs) -> None: + """this function purely exists because I have no clue how multiple inheritance works + + Args: + config (CommonConfig): a config object + """ + + self.attention_head_type = AttentionHeadType(config.attention_head_type) + self.embed_dim = config.n_embd + self.num_heads = config.n_head + self.m_emb = config.m_emb + self.initializer_range = config.initializer_range + + self.head_dim = divide_if_divisible( + self.embed_dim, + self.num_heads, + f"`embed_dim` ({self.embed_dim}) must be divisible by `num_heads` ({self.num_heads})", + ) + + self.wte = ParameterizedEmbedding(config.vocab_size, self.embed_dim, std=self.initializer_range) + + self.drop = nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList( + [ + self.layer_class( + config, + normalization_implementation=self.normalization_implementation, + attention_implementation=self.attention_implementation, + use_padding_free_transformer=self._use_padding_free_transformer, + layer_idx=i, + ) + for i in range(config.n_layer) + ] + ) + self.ln_f = get_normalization_function( + config.normalization_function, + self.embed_dim, + eps=config.layer_norm_epsilon, + normalization_implementation=self.normalization_implementation, + ) + + self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) + self._setup_positional_encoding() + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> ParameterizedEmbedding: + return self.wte + + def set_input_embeddings(self, new_embeddings: ParameterizedEmbedding) -> None: + self.wte = new_embeddings + + def forward( + self, + input_ids: torch.Tensor | None = None, + past_key_values: DynamicCache | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + use_cache: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool = True, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + ) -> tuple | BaseModelOutputWithPast: + ( + output_hidden_states, + use_cache, + hidden_states, + attention_mask, + position_ids, + rope_cos_sin, + past_key_values, + ) = self._prepare_a_bunch_of_stuff( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + # ========================================================================================== + # padding_free: + # attention_mask -> None + # flash: + # attention_mask -> (batch_size, key_length) + # else: + # attention_mask -> (batch_size, 1, query_length, key_length) + # ========================================================================================== + + past_key_values = DynamicCache() if use_cache and past_key_values is None else past_key_values + all_hidden_states = () if output_hidden_states else None + for block in self.h: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states = block( + hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + rope_cos_sin=rope_cos_sin, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + hidden_states = self.ln_f(hidden_states) + + # Add last hidden state + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + ) + + def _get_position_ids( + self, attention_mask: torch.Tensor, past_length: int, query_length: int, key_length: int, device: torch.device + ) -> torch.Tensor: + if attention_mask is not None and len(attention_mask.shape) == 2: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 0) + if past_length > 0: + position_ids = position_ids[:, past_length:key_length:] + else: + position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, query_length) + + return position_ids + + def _get_alibi_bias( + self, + attention_mask: torch.Tensor, + batch_size: int, + query_length: int, + key_length: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + if self.position_embedding_type != PositionEmbeddingType.alibi: + return None + + alibi_bias = self.alibi(attention_mask, batch_size, key_length, device, dtype) + + # ========================================================================================== + # alibi_bias -> (batch_size, num_heads, key_length) + # ========================================================================================== + + alibi_bias = alibi_bias.unsqueeze(2) + if query_length != 1: + alibi_bias = alibi_bias.expand(-1, -1, query_length, -1) + + # ========================================================================================== + # alibi_bias -> (batch_size, num_heads, query_length, key_length) + # ========================================================================================== + + return alibi_bias + + def _get_rope_cos_sin( + self, key_length: int, position_ids: torch.Tensor, dtype: torch.dtype, device: torch.device + ) -> torch.Tensor: + if self.position_embedding_type == PositionEmbeddingType.rope: + cos, sin = self.rope(key_length, dtype=dtype, device=device) + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + return cos, sin + + def _prepare_causal_attention_mask( + self, + attention_mask: torch.Tensor | None, + batch_size: int, + query_length: int, + key_length: int, + device: torch.device, + ) -> torch.Tensor: + past_length = key_length - query_length + + # ========================================================================================== + # attention_mask -> (batch_size, key_length) + # ========================================================================================== + + if query_length > 1: + # (query_length, key_length) + causal_mask = torch.empty((query_length, key_length), dtype=torch.bool, device=device) + causal_mask[:, past_length:] = torch.tril( + torch.ones(query_length, query_length, dtype=torch.bool, device=device) + ) + + if past_length > 0: + causal_mask[:, :past_length] = True + + # (query_length, key_length) -> (1, query_length, key_length) + causal_mask = causal_mask.unsqueeze(0) + + if attention_mask is None: + # (1, query_length, key_length) -> (batch_size, query_length, key_length) + causal_mask = causal_mask.expand(batch_size, -1, -1) + else: + # (1, query_length, key_length) & (batch_size, 1, key_length) -> (batch_size, query_length, key_length) + causal_mask = causal_mask & attention_mask.unsqueeze(1).to(torch.bool) + else: + if attention_mask is None: + # (batch_size, query_length, key_length) + causal_mask = torch.ones(batch_size, query_length, key_length, dtype=torch.bool, device=device) + else: + # (batch_size, query_length, key_length) + causal_mask = attention_mask.unsqueeze(1).to(dtype=torch.bool, device=device) + + # ========================================================================================== + # attention_mask -> (batch_size, query_length, key_length) + # ========================================================================================== + + causal_mask = causal_mask.unsqueeze(1) + + # ========================================================================================== + # attention_mask -> (batch_size, 1, query_length, key_length) + # ========================================================================================== + + return causal_mask + + def _get_initial_hidden_state( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor | None, + position_ids: torch.Tensor | None, + token_type_ids: torch.Tensor | None, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + if self.position_embedding_type == PositionEmbeddingType.learned_absolute: + inputs_embeds = inputs_embeds + self.wpe(position_ids) + + if token_type_ids is not None: + inputs_embeds = inputs_embeds + self.wte(token_type_ids) + + inputs_embeds = self.drop(inputs_embeds) + + if self.m_emb is not None: + inputs_embeds = inputs_embeds * self.m_emb + + return inputs_embeds + + def _prepare_a_bunch_of_stuff( + self, + input_ids: torch.Tensor | None = None, + past_key_values: DynamicCache | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + use_cache: bool | None = None, + output_hidden_states: bool | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + ) -> tuple[ + bool, + bool, + bool, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + tuple[torch.Tensor], + ]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if use_cache is None: + use_cache = False if self._use_padding_free_transformer else self.config.use_cache + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + + # special handling for padding free transformer with list inputs + if self._use_padding_free_transformer: + # for flash attention, there is no padding and we do packing + # so, input_ids is of shape (s1 + s2 + ... + sb) + batch_size = cu_seqlens.shape[0] - 1 + else: + batch_size = input_shape[0] + elif inputs_embeds is not None: + # TODO special handling for padding free transformer needed here if we support inputs_embeds argument + input_shape = inputs_embeds.size()[:-1] + batch_size = input_shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if self._use_padding_free_transformer: + assert position_ids is not None, ( + "GPTDolomiteModel needs position_ids from outside when using flash attention with List[List[int]] " + "inputs" + ) + else: + if self.position_embedding_type == PositionEmbeddingType.alibi: + if position_ids is not None: + warnings.warn("`position_ids` have no functionality with Alibi.", FutureWarning) + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + # ========================================================================================== + # padding_free: + # input_ids -> (total_q) + # attention_mask -> None + # position_ids -> (total_q) + # else: + # input_ids -> (batch_size, query_length) + # attention_mask -> None or (batch_size, key_length) + # position_ids -> None or (batch_size, key_length) + # ========================================================================================== + + past_length = None + query_length = None + key_length = None + if self._use_padding_free_transformer: + key_length = max_seqlen.item() + else: + past_length = 0 if past_key_values is None else past_key_values.get_seq_length() + query_length = input_shape[-1] + key_length = past_length + query_length + + if position_ids is None: + position_ids = self._get_position_ids(attention_mask, past_length, query_length, key_length, device) + + # ========================================================================================== + # padding_free: + # input_ids -> (total_q) + # attention_mask -> None + # position_ids -> (total_q) + # else: + # input_ids -> (batch_size, query_length) + # attention_mask -> None or (batch_size, key_length) + # position_ids -> (batch_size, query_length) + # ========================================================================================== + + hidden_states = self._get_initial_hidden_state(input_ids, inputs_embeds, position_ids, token_type_ids) + + # ========================================================================================== + # padding_free: + # hidden_states -> (total_q, num_heads * head_dim) + # else: + # hidden_states -> (batch_size, query_length, num_heads * head_dim) + # ========================================================================================== + + alibi_bias = self._get_alibi_bias( + attention_mask, batch_size, query_length, key_length, device, hidden_states.dtype + ) + + # ========================================================================================== + # alibi_bias -> (batch_size, num_heads, query_length, key_length) + # ========================================================================================== + + rope_cos_sin = self._get_rope_cos_sin( + key_length, position_ids, dtype=hidden_states.dtype, device=hidden_states.device + ) + + # ========================================================================================== + # padding_free: + # rope_cos_sin -> 2 * (max_seqlen, head_dim) + # else: + # rope_cos_sin -> 2 * (key_length, head_dim) + # ========================================================================================== + + attention_mask = self._get_maybe_causal_mask( + attention_mask, alibi_bias, batch_size, query_length, key_length, hidden_states.dtype, device + ) + + return ( + output_hidden_states, + use_cache, + hidden_states, + attention_mask, + position_ids, + rope_cos_sin, + past_key_values, + ) + + def _setup_positional_encoding(self) -> None: + max_position_embeddings = self.config.max_position_embeddings + + if self.position_embedding_type == PositionEmbeddingType.learned_absolute: + self.wpe = ParameterizedEmbedding(max_position_embeddings, self.embed_dim, std=self.initializer_range) + elif self.position_embedding_type == PositionEmbeddingType.alibi: + assert not self._use_flash_attention_2, "alibi is not implemented with FlashAttention" + + self.alibi = Alibi(self.num_heads) + elif self.position_embedding_type == PositionEmbeddingType.rope: + if self.config.rope_scaling is None: + self.rope = RoPE( + self.head_dim, + max_position_embeddings=max_position_embeddings, + base=self.config.rope_theta, + ) + else: + self.rope = YaRNScaledRoPE( + self.head_dim, + max_position_embeddings=max_position_embeddings, + base=self.config.rope_theta, + scale=self.config.rope_scaling["factor"], + original_max_position_embeddings=self.config.rope_scaling["original_max_position_embeddings"], + ) + elif self.position_embedding_type == PositionEmbeddingType.nope: + pass + else: + raise NotImplementedError() + + def _get_mask_value(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + # torch.where expects a tensor. We use a cache to avoid recreating it every time. + if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: + self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) + return self.mask_value + + def _get_maybe_causal_mask( + self, + attention_mask: torch.Tensor | None, + alibi_bias: torch.Tensor | None, + batch_size: int, + query_length: int, + key_length: int, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + if self._use_sdpa: + # we use the causal/non-causal argument of SDPA for attention in this case + if attention_mask is not None: + attention_mask = self._prepare_causal_attention_mask( + attention_mask, batch_size, query_length, key_length, device + ) + + attention_mask = torch.where( + attention_mask, + ~attention_mask if alibi_bias is None else alibi_bias, + self._get_mask_value(attention_mask.device, dtype), + ) + + # this is needed to prevent NaN since SDPA + # see issue: https://github.com/pytorch/pytorch/issues/110213 + attention_mask = attention_mask * ~torch.all( + attention_mask == self._get_mask_value(attention_mask.device, dtype), dim=-1, keepdim=True + ) + elif self._use_eager_attention: + attention_mask = self._prepare_causal_attention_mask( + attention_mask, batch_size, query_length, key_length, device + ) + + attention_mask = torch.where( + attention_mask, + ~attention_mask if alibi_bias is None else alibi_bias, + self._get_mask_value(attention_mask.device, dtype), + ) + + return attention_mask diff --git a/src/instructlab/dolomite/hf_models/mixins/dense/main.py b/src/instructlab/dolomite/hf_models/mixins/dense/main.py new file mode 100644 index 0000000..b03b9ed --- /dev/null +++ b/src/instructlab/dolomite/hf_models/mixins/dense/main.py @@ -0,0 +1,198 @@ +import torch +import torch.nn.functional as F +from transformers import DynamicCache +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast + +from ...config import CommonConfig +from ...modeling_utils import ParameterizedEmbedding, ParameterizedLinear +from .base import PreTrainedModelMixin + + +class CausalLMModelMixin(PreTrainedModelMixin): + _tied_weights_keys = ["lm_head.weight"] + base_model_class = None + + def __init__(self, config: CommonConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + self._init_model(config, **kwargs) + + def _init_model(self, config: CommonConfig, **kwargs) -> None: + self.transformer = self.base_model_class(config, **kwargs) + + if not self._tied_word_embeddings: + self.lm_head = ParameterizedLinear( + config.n_embd, config.vocab_size, bias=False, std=config.initializer_range + ) + + self.m_width = config.m_width + self.upcast_logits_for_loss = config.upcast_logits_for_loss + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> ParameterizedEmbedding: + return self.transformer.wte + + def set_input_embeddings(self, value: ParameterizedEmbedding) -> None: + self.transformer.wte = value + + def get_output_embeddings(self) -> ParameterizedLinear: + if not self._tied_word_embeddings: + return self.lm_head + + def set_output_embeddings(self, new_embeddings: ParameterizedLinear) -> None: + if not self._tied_word_embeddings: + self.lm_head = new_embeddings + + def prepare_inputs_for_generation( + self, + input_ids: torch.Tensor, + past_key_values: DynamicCache | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> dict: + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values.get_seq_length() + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask: torch.Tensor = kwargs.get("attention_mask", None) + position_ids: torch.Tensor = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 0) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + def forward( + self, + input_ids: torch.Tensor | list[list[int]] | None = None, + past_key_values: DynamicCache | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | list[list[int]] | None = None, + position_ids: torch.Tensor | list[list[int]] | None = None, + inputs_embeds: torch.Tensor | list[list[float]] | None = None, + labels: torch.Tensor | list[list[int]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool = True, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + ) -> tuple | CausalLMOutputWithPast: + input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + labels=labels, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # ========================================================================================== + # padding_free: + # input_ids -> (total_q) + # attention_mask -> None + # position_ids -> (total_q) + # else: + # input_ids -> (batch_size, query_length) + # attention_mask -> None or (batch_size, key_length) + # position_ids -> None or (batch_size, key_length) + # ========================================================================================== + + transformer_outputs: BaseModelOutputWithPast = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + lm_logits = self.get_lm_logits(transformer_outputs.last_hidden_state) + + if self.m_width is not None: + lm_logits = lm_logits / self.m_width + + loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens) + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def get_lm_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + return ( + F.linear(hidden_states, self.transformer.wte.weight) + if self._tied_word_embeddings + else self.lm_head(hidden_states) + ) + + def get_autoregressive_language_modeling_loss( + self, lm_logits: torch.Tensor, labels: torch.Tensor | None, cu_seqlens: torch.Tensor + ) -> torch.Tensor: + if labels is None: + return None + + if self._use_padding_free_transformer: + shift_logits = lm_logits[:-1, :] + shift_labels = labels[1:].to(shift_logits.device) + + # this is needed so that the last token of current example doesn't predict first token of next example + drop_loss_positions = cu_seqlens[1:-1] - 1 + shift_labels[drop_loss_positions] = -100 + else: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous().to(shift_logits.device) + + if self.upcast_logits_for_loss: + shift_logits = shift_logits.float() + + loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + return loss diff --git a/src/instructlab/dolomite/hf_models/mixins/dense_TP/__init__.py b/src/instructlab/dolomite/hf_models/mixins/dense_TP/__init__.py new file mode 100644 index 0000000..cbbb640 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/mixins/dense_TP/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseModelMixin_TP, PreTrainedModelMixin_TP +from .main import CausalLMModelMixin_TP diff --git a/src/instructlab/dolomite/hf_models/mixins/dense_TP/base.py b/src/instructlab/dolomite/hf_models/mixins/dense_TP/base.py new file mode 100644 index 0000000..801bd72 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/mixins/dense_TP/base.py @@ -0,0 +1,104 @@ +import torch.nn as nn + +from ....utils import ProcessGroupManager +from ...config import CommonConfig +from ...enums import AttentionHeadType, PositionEmbeddingType +from ...modeling_utils import RoPE, YaRNScaledRoPE +from ...modeling_utils_TP import Alibi_TP, Dropout_TP, Embedding_TP, get_normalization_function_TP +from ..dense import BaseModelMixin, PreTrainedModelMixin + + +class PreTrainedModelMixin_TP(PreTrainedModelMixin): + def __init__(self, config: CommonConfig, *args, **kwargs): + self.tensor_parallel_word_embeddings = kwargs.get("tensor_parallel_word_embeddings", False) + self.sequence_parallel = kwargs.get("sequence_parallel", False) + + super().__init__(config, *args, **kwargs) + + +class BaseModelMixin_TP(PreTrainedModelMixin_TP, BaseModelMixin): + def _init_model(self, config: CommonConfig, **kwargs) -> None: + self.attention_head_type = AttentionHeadType(config.attention_head_type) + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.max_position_embeddings = config.max_position_embeddings + self.m_emb = config.m_emb + self.initializer_range = config.initializer_range + self.head_dim = self.embed_dim // self.num_heads + + self.tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() + self.wte = Embedding_TP( + config.vocab_size, + self.embed_dim, + std=self.initializer_range, + tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings, + use_padding_free_transformer=self._use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) + + self.drop = ( + nn.Identity() + if config.embd_pdrop == 0 + else Dropout_TP( + config.embd_pdrop, + use_padding_free_transformer=self._use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) + ) + self.h = nn.ModuleList( + [ + self.layer_class( + config, + normalization_implementation=self.normalization_implementation, + attention_implementation=self.attention_implementation, + use_padding_free_transformer=self._use_padding_free_transformer, + layer_idx=i, + sequence_parallel=self.sequence_parallel, + ) + for i in range(config.num_hidden_layers) + ] + ) + self.ln_f = get_normalization_function_TP( + config.normalization_function, + self.embed_dim, + eps=config.layer_norm_epsilon, + normalization_implementation=self.normalization_implementation, + use_padding_free_transformer=self._use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) + + self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) + self._setup_positional_encoding() + + # Initialize weights and apply final processing + self.post_init() + + def _setup_positional_encoding(self) -> None: + max_position_embeddings = self.config.max_position_embeddings + + if self.position_embedding_type == PositionEmbeddingType.learned_absolute: + self.wpe = Embedding_TP( + max_position_embeddings, + self.embed_dim, + std=self.initializer_range, + tensor_parallel_word_embeddings=False, + use_padding_free_transformer=self._use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) + elif self.position_embedding_type == PositionEmbeddingType.alibi: + self.alibi = Alibi_TP(self.num_heads) + elif self.position_embedding_type == PositionEmbeddingType.rope: + if self.config.rope_scaling is None: + self.rope = RoPE( + self.head_dim, max_position_embeddings=max_position_embeddings, base=self.config.rope_theta + ) + else: + self.rope = YaRNScaledRoPE( + self.head_dim, + max_position_embeddings=max_position_embeddings, + base=self.config.rope_theta, + scale=self.config.rope_scaling["factor"], + original_max_position_embeddings=self.config.rope_scaling["original_max_position_embeddings"], + ) + else: + raise NotImplementedError() diff --git a/src/instructlab/dolomite/hf_models/mixins/dense_TP/main.py b/src/instructlab/dolomite/hf_models/mixins/dense_TP/main.py new file mode 100644 index 0000000..4505921 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/mixins/dense_TP/main.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +from contextlib import nullcontext + +import torch +import torch.nn.functional as F +from torch.distributed._tensor.placement_types import Replicate, Shard +from torch.distributed.tensor.parallel import loss_parallel +from transformers import DynamicCache +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast + +from ....utils import ProcessGroupManager, SafeTensorsWeightsManager +from ...config import CommonConfig +from ...enums import PositionEmbeddingType +from ...modeling_utils_TP import LMHead_TP, dtensor_to_tensor, tensor_to_dtensor +from ..dense import CausalLMModelMixin +from .base import PreTrainedModelMixin_TP + + +class CausalLMModelMixin_TP(PreTrainedModelMixin_TP, CausalLMModelMixin): + tensor_parallel_state_dict_function = None + + def _init_model(self, config: CommonConfig, **kwargs) -> None: + self.vocab_size = config.vocab_size + self.transformer = self.base_model_class(config, **kwargs) + + if not self._tied_word_embeddings: + self.lm_head = LMHead_TP( + self.vocab_size, + config.n_embd, + std=config.initializer_range, + tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings, + sequence_parallel=self.sequence_parallel, + ) + + self.m_width = config.m_width + self.upcast_logits_for_loss = config.upcast_logits_for_loss + + self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: torch.Tensor | list[list[int]] | None = None, + past_key_values: DynamicCache | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | list[list[int]] | None = None, + position_ids: torch.Tensor | list[list[int]] | None = None, + inputs_embeds: torch.Tensor | list[list[float]] | None = None, + labels: torch.Tensor | list[list[int]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool = True, + output_parallel_lm_logits: bool = False, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + ) -> tuple | CausalLMOutputWithPast: + input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + labels=labels, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + transformer_outputs: BaseModelOutputWithPast = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + lm_logits = self.get_lm_logits(transformer_outputs.last_hidden_state) + + if self.m_width is not None: + lm_logits = lm_logits / self.m_width + + loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens) + + if output_parallel_lm_logits: + assert self.tensor_parallel_word_embeddings + else: + if self.tensor_parallel_word_embeddings: + # all gather + lm_logits = tensor_to_dtensor(lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1)) + lm_logits = dtensor_to_tensor(lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate()) + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def get_lm_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + return ( + LMHead_TP.compute_with_weight( + hidden_states, + weight=self.transformer.wte.weight, + tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings, + use_padding_free_transformer=self._use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + tp_mesh=self.tp_mesh, + ) + if self._tied_word_embeddings + else self.lm_head(hidden_states) + ) + + def get_autoregressive_language_modeling_loss( + self, lm_logits: torch.Tensor, labels: torch.Tensor | None, cu_seqlens: torch.Tensor + ) -> torch.Tensor: + if labels is None: + return None + + if self._use_padding_free_transformer: + shift_logits = lm_logits[:-1, :] + shift_labels = labels[1:].to(shift_logits.device) + + # this is needed so that the last token of current example doesn't predict first token of next example + drop_loss_positions = cu_seqlens[1:-1] - 1 + shift_labels[drop_loss_positions] = -100 + else: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous().to(shift_logits.device) + + shift_logits = tensor_to_dtensor( + shift_logits, + device_mesh=self.tp_mesh, + current_placement=Shard(-1) if self.tensor_parallel_word_embeddings else Replicate(), + ) + shift_labels = tensor_to_dtensor(shift_labels, device_mesh=self.tp_mesh, current_placement=Replicate()) + + if self.upcast_logits_for_loss: + shift_logits = shift_logits.float() + + loss_context = loss_parallel if self.tensor_parallel_word_embeddings else nullcontext + with loss_context(): + loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + return loss + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + torch_dtype: torch.dtype = torch.float32, + tensor_parallel_word_embeddings: bool = False, + **kwargs, + ) -> CausalLMModelMixin_TP: + config: CommonConfig = cls.config_class.from_pretrained(pretrained_model_name_or_path) + + # use dummy tensors to avoid initializing model here + with torch.device("meta"): + # try sharding vocab matrices if really struggling for memory + model = cls._from_config(config, tensor_parallel_word_embeddings=tensor_parallel_word_embeddings, **kwargs) + model = model.to(dtype=torch_dtype) + + # copy to device without copying storage + model = model.to_empty(device=torch.cuda.current_device()) + model.load_from_safetensors_weights_manager(SafeTensorsWeightsManager(pretrained_model_name_or_path)) + + return model + + def load_from_safetensors_weights_manager(self, safetensors_weights_manager: SafeTensorsWeightsManager) -> None: + with torch.device(torch.cuda.current_device()): + position_embedding_type = PositionEmbeddingType(self.config.position_embedding_type) + + if position_embedding_type == PositionEmbeddingType.alibi: + self.transformer.alibi.reset_parameters() + elif position_embedding_type == PositionEmbeddingType.rope: + self.transformer.rope.reset_parameters() + + state_dict = self.__class__.tensor_parallel_state_dict_function( + config=self.config, + safetensors_weights_manager=safetensors_weights_manager, + tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings, + ) + + self.load_state_dict(state_dict) diff --git a/src/instructlab/dolomite/hf_models/mixins/moe/__init__.py b/src/instructlab/dolomite/hf_models/mixins/moe/__init__.py new file mode 100644 index 0000000..12b6465 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/mixins/moe/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseMoEModelMixin, MoeModelOutputWithPastAndAuxLoss, PreTrainedMoEModelMixin +from .main import CausalLMMoEModelMixin diff --git a/src/instructlab/dolomite/hf_models/mixins/moe/base.py b/src/instructlab/dolomite/hf_models/mixins/moe/base.py new file mode 100644 index 0000000..54ed982 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/mixins/moe/base.py @@ -0,0 +1,205 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn +from transformers import DynamicCache +from transformers.modeling_outputs import MoeModelOutputWithPast + +from ...config import CommonConfig +from ...enums import AttentionHeadType, PositionEmbeddingType +from ...modeling_utils import ParameterizedEmbedding, get_normalization_function +from ..dense import BaseModelMixin, PreTrainedModelMixin + + +@dataclass +class MoeModelOutputWithPastAndAuxLoss(MoeModelOutputWithPast): + aux_loss: torch.Tensor | None = None + + +class PreTrainedMoEModelMixin(PreTrainedModelMixin): + def __init__(self, config: CommonConfig, *args, **kwargs) -> None: + self.moe_implementation = kwargs.get("moe_implementation", "eager") + assert self.moe_implementation in ["eager", "scattermoe"] + + super().__init__(config, *args, **kwargs) + + +class BaseMoEModelMixin(BaseModelMixin): + def _init_model(self, config: CommonConfig, **kwargs) -> None: + self.attention_head_type = AttentionHeadType(config.attention_head_type) + self.embed_dim = config.n_embd + self.num_heads = config.n_head + self.m_emb = config.m_emb + self.initializer_range = config.initializer_range + self.mask_value = None + + assert ( + self.embed_dim % self.num_heads == 0 + ), f"`embed_dim` ({self.embed_dim}) must be divisible by `num_heads` ({self.num_heads})" + + self.head_dim = self.embed_dim // self.num_heads + + self.wte = ParameterizedEmbedding(config.vocab_size, self.embed_dim, std=self.initializer_range) + + self.drop = nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList( + [ + self.layer_class( + config, + normalization_implementation=self.normalization_implementation, + attention_implementation=self.attention_implementation, + use_padding_free_transformer=self._use_padding_free_transformer, + moe_implementation=self.moe_implementation, + layer_idx=i, + ) + for i in range(config.n_layer) + ] + ) + self.ln_f = get_normalization_function( + config.normalization_function, + self.embed_dim, + eps=config.layer_norm_epsilon, + normalization_implementation=self.normalization_implementation, + ) + + self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) + self._setup_positional_encoding() + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: torch.Tensor | None = None, + past_key_values: DynamicCache | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + use_cache: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool = True, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + output_router_logits: bool | None = None, + output_aux_loss: bool = True, + ) -> tuple | MoeModelOutputWithPastAndAuxLoss: + ( + output_hidden_states, + use_cache, + hidden_states, + attention_mask, + position_ids, + rope_cos_sin, + past_key_values, + output_router_logits, + ) = self._prepare_a_bunch_of_stuff( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_router_logits=output_router_logits, + ) + + # ========================================================================================== + # padding_free: + # attention_mask -> None + # flash: + # attention_mask -> (batch_size, key_length) + # else: + # attention_mask -> (batch_size, 1, query_length, key_length) + # ========================================================================================== + + past_key_values = DynamicCache() if use_cache and past_key_values is None else past_key_values + all_hidden_states = () if output_hidden_states else None + all_router_logits = () if output_router_logits else None + total_aux_loss = 0 + + for block in self.h: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = block( + hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + rope_cos_sin=rope_cos_sin, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_router_logits=output_router_logits, + output_aux_loss=output_aux_loss, + ) + + hidden_states = outputs[0] + outputs = outputs[1:] + + if output_router_logits: + all_router_logits += (outputs[0],) + outputs = outputs[1:] + + if output_aux_loss: + aux_loss = outputs[0] + total_aux_loss = total_aux_loss + aux_loss + + hidden_states = self.ln_f(hidden_states) + + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return MoeModelOutputWithPastAndAuxLoss( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + router_logits=all_router_logits, + aux_loss=total_aux_loss, + ) + + def _prepare_a_bunch_of_stuff( + self, + input_ids: torch.Tensor | None = None, + past_key_values: list[torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + use_cache: bool | None = None, + output_hidden_states: bool | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + output_router_logits: bool = False, + ) -> tuple[ + bool, + bool, + bool, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + tuple[torch.Tensor], + ]: + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + return super()._prepare_a_bunch_of_stuff( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + (output_router_logits,) diff --git a/src/instructlab/dolomite/hf_models/mixins/moe/main.py b/src/instructlab/dolomite/hf_models/mixins/moe/main.py new file mode 100644 index 0000000..89e9632 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/mixins/moe/main.py @@ -0,0 +1,95 @@ +import torch +from transformers import DynamicCache +from transformers.modeling_outputs import MoeCausalLMOutputWithPast + +from ...config import CommonConfig +from ..dense import CausalLMModelMixin +from .base import MoeModelOutputWithPastAndAuxLoss + + +class CausalLMMoEModelMixin(CausalLMModelMixin): + def __init__(self, config: CommonConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + + def forward( + self, + input_ids: torch.Tensor | list[list[int]] | None = None, + past_key_values: DynamicCache | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | list[list[int]] | None = None, + position_ids: torch.Tensor | list[list[int]] | None = None, + inputs_embeds: torch.Tensor | list[list[float]] | None = None, + labels: torch.Tensor | list[list[int]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool = True, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + output_router_logits: bool | None = None, + ) -> tuple | MoeCausalLMOutputWithPast: + input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + labels=labels, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # ========================================================================================== + # padding_free: + # input_ids -> (total_q) + # attention_mask -> None + # position_ids -> (total_q) + # else: + # input_ids -> (batch_size, query_length) + # attention_mask -> None or (batch_size, key_length) + # position_ids -> None or (batch_size, key_length) + # ========================================================================================== + + transformer_outputs: MoeModelOutputWithPastAndAuxLoss = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_router_logits=output_router_logits, + ) + + lm_logits = self.get_lm_logits(transformer_outputs.last_hidden_state) + + if self.m_width is not None: + lm_logits = lm_logits / self.m_width + + lm_loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens) + aux_loss = transformer_outputs.aux_loss + + if lm_loss is None: + loss = None + else: + loss = lm_loss + self.router_aux_loss_coef * aux_loss + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + router_logits=transformer_outputs.router_logits, + ) diff --git a/src/instructlab/dolomite/hf_models/mixins/moe_TP/__init__.py b/src/instructlab/dolomite/hf_models/mixins/moe_TP/__init__.py new file mode 100644 index 0000000..e4e90ab --- /dev/null +++ b/src/instructlab/dolomite/hf_models/mixins/moe_TP/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseMoEModelMixin_TP, PreTrainedMoEModelMixin_TP +from .main import CausalLMMoEModelMixin_TP diff --git a/src/instructlab/dolomite/hf_models/mixins/moe_TP/base.py b/src/instructlab/dolomite/hf_models/mixins/moe_TP/base.py new file mode 100644 index 0000000..55b09de --- /dev/null +++ b/src/instructlab/dolomite/hf_models/mixins/moe_TP/base.py @@ -0,0 +1,75 @@ +import torch.nn as nn + +from ....utils import ProcessGroupManager +from ...config import CommonConfig +from ...enums import AttentionHeadType, PositionEmbeddingType +from ...modeling_utils_TP import Dropout_TP, Embedding_TP, get_normalization_function_TP +from ..dense_TP import BaseModelMixin_TP, PreTrainedModelMixin_TP +from ..moe import BaseMoEModelMixin, PreTrainedMoEModelMixin + + +class PreTrainedMoEModelMixin_TP(PreTrainedMoEModelMixin, PreTrainedModelMixin_TP): + def __init__(self, config: CommonConfig, *args, **kwargs): + self.tensor_parallel_word_embeddings = kwargs.get("tensor_parallel_word_embeddings", False) + self.sequence_parallel = kwargs.get("sequence_parallel", False) + + super().__init__(config, *args, **kwargs) + + +class BaseMoEModelMixin_TP(BaseMoEModelMixin, BaseModelMixin_TP): + def _init_model(self, config: CommonConfig, **kwargs) -> None: + self.attention_head_type = AttentionHeadType(config.attention_head_type) + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.max_position_embeddings = config.max_position_embeddings + self.m_emb = config.m_emb + self.initializer_range = config.initializer_range + self.head_dim = self.embed_dim // self.num_heads + + self.tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() + self.wte = Embedding_TP( + config.vocab_size, + self.embed_dim, + std=self.initializer_range, + tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings, + use_padding_free_transformer=self._use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) + + self.drop = ( + nn.Identity() + if config.embd_pdrop == 0 + else Dropout_TP( + config.embd_pdrop, + use_padding_free_transformer=self._use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) + ) + self.h = nn.ModuleList( + [ + self.layer_class( + config, + normalization_implementation=self.normalization_implementation, + attention_implementation=self.attention_implementation, + use_padding_free_transformer=self._use_padding_free_transformer, + moe_implementation=self.moe_implementation, + layer_idx=i, + sequence_parallel=self.sequence_parallel, + ) + for i in range(config.num_hidden_layers) + ] + ) + self.ln_f = get_normalization_function_TP( + config.normalization_function, + self.embed_dim, + eps=config.layer_norm_epsilon, + normalization_implementation=self.normalization_implementation, + use_padding_free_transformer=self._use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) + + self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) + self._setup_positional_encoding() + + # Initialize weights and apply final processing + self.post_init() diff --git a/src/instructlab/dolomite/hf_models/mixins/moe_TP/main.py b/src/instructlab/dolomite/hf_models/mixins/moe_TP/main.py new file mode 100644 index 0000000..8f5de69 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/mixins/moe_TP/main.py @@ -0,0 +1,89 @@ +import torch +from torch.distributed._tensor.placement_types import Replicate, Shard +from transformers import DynamicCache +from transformers.modeling_outputs import MoeCausalLMOutputWithPast + +from ...modeling_utils_TP import dtensor_to_tensor, tensor_to_dtensor +from ..dense_TP import CausalLMModelMixin_TP +from ..moe import CausalLMMoEModelMixin, MoeModelOutputWithPastAndAuxLoss + + +class CausalLMMoEModelMixin_TP(CausalLMMoEModelMixin, CausalLMModelMixin_TP): + def forward( + self, + input_ids: torch.Tensor | list[list[int]] | None = None, + past_key_values: DynamicCache | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | list[list[int]] | None = None, + position_ids: torch.Tensor | list[list[int]] | None = None, + inputs_embeds: torch.Tensor | list[list[float]] | None = None, + labels: torch.Tensor | list[list[int]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool = True, + output_parallel_lm_logits: bool = False, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + output_router_logits: bool | None = None, + ) -> tuple | MoeCausalLMOutputWithPast: + input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + labels=labels, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + transformer_outputs: MoeModelOutputWithPastAndAuxLoss = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_router_logits=output_router_logits, + ) + + lm_logits = self.get_lm_logits(transformer_outputs.last_hidden_state) + + if self.m_width is not None: + lm_logits = lm_logits / self.m_width + + lm_loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens) + aux_loss = tensor_to_dtensor( + transformer_outputs.aux_loss, device_mesh=self.tp_mesh, current_placement=Replicate() + ) + + if lm_loss is None: + loss = None + else: + loss = lm_loss + self.router_aux_loss_coef * aux_loss + + if output_parallel_lm_logits: + assert self.tensor_parallel_word_embeddings + else: + if self.tensor_parallel_word_embeddings: + # all gather + lm_logits = tensor_to_dtensor(lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1)) + lm_logits = dtensor_to_tensor(lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate()) + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + router_logits=transformer_outputs.router_logits, + ) diff --git a/src/instructlab/dolomite/hf_models/model_conversion/__init__.py b/src/instructlab/dolomite/hf_models/model_conversion/__init__.py index d39217e..0ddd148 100644 --- a/src/instructlab/dolomite/hf_models/model_conversion/__init__.py +++ b/src/instructlab/dolomite/hf_models/model_conversion/__init__.py @@ -1,15 +1,13 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Third Party from transformers import AutoConfig -# Local from .bigcode import export_to_huggingface_bigcode, import_from_huggingface_bigcode +from .granite import export_to_huggingface_granite, import_from_huggingface_granite from .llama import export_to_huggingface_llama, import_from_huggingface_llama + _MODEL_IMPORT_FUNCTIONS = { "gpt_bigcode": import_from_huggingface_bigcode, + "granite": import_from_huggingface_granite, "llama": import_from_huggingface_llama, } @@ -19,9 +17,7 @@ def import_from_huggingface(pretrained_model_name_or_path: str, save_path: str) model_type = config.model_type if model_type not in _MODEL_IMPORT_FUNCTIONS: - raise NotImplementedError( - f"the current model_type ({model_type}) is not yet supported" - ) + raise NotImplementedError(f"the current model_type ({model_type}) is not yet supported") import_function = _MODEL_IMPORT_FUNCTIONS[model_type] import_function(pretrained_model_name_or_path, save_path) @@ -29,17 +25,14 @@ def import_from_huggingface(pretrained_model_name_or_path: str, save_path: str) _MODEL_EXPORT_FUNCTIONS = { "gpt_bigcode": export_to_huggingface_bigcode, + "granite": export_to_huggingface_granite, "llama": export_to_huggingface_llama, } -def export_to_huggingface( - pretrained_model_name_or_path: str, save_path: str, model_type: str -) -> None: +def export_to_huggingface(pretrained_model_name_or_path: str, save_path: str, model_type: str) -> None: if model_type not in _MODEL_EXPORT_FUNCTIONS: - raise NotImplementedError( - f"the current model_type ({model_type}) is not yet supported" - ) + raise NotImplementedError(f"the current model_type ({model_type}) is not yet supported") export_function = _MODEL_EXPORT_FUNCTIONS[model_type] export_function(pretrained_model_name_or_path, save_path) diff --git a/src/instructlab/dolomite/hf_models/model_conversion/bigcode.py b/src/instructlab/dolomite/hf_models/model_conversion/bigcode.py index 9bfd4da..9ee9339 100644 --- a/src/instructlab/dolomite/hf_models/model_conversion/bigcode.py +++ b/src/instructlab/dolomite/hf_models/model_conversion/bigcode.py @@ -1,20 +1,12 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Standard import shutil -# Third Party -from transformers import AutoConfig, AutoTokenizer, GenerationConfig, GPTBigCodeConfig +from transformers import AutoConfig, AutoTokenizer, GenerationConfig, GPTBigCodeConfig, GPTBigCodeForCausalLM -# Local -from ..config import GPTDolomiteConfig from ..enums import AttentionHeadType, PositionEmbeddingType +from ..models import GPTDolomiteConfig -def import_from_huggingface_bigcode( - pretrained_model_name_or_path: str, save_path: str -) -> None: +def import_from_huggingface_bigcode(pretrained_model_name_or_path: str, save_path: str) -> None: shutil.copytree(pretrained_model_name_or_path, save_path) original_config: GPTBigCodeConfig = AutoConfig.from_pretrained(save_path) @@ -27,13 +19,11 @@ def import_from_huggingface_bigcode( try: tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) tokenizer.save_pretrained(save_path, legacy_format=False) - except: # pylint: disable=bare-except + except: pass -def _import_config_from_huggingface( - original_config: GPTBigCodeConfig, -) -> GPTDolomiteConfig: +def _import_config_from_huggingface(original_config: GPTBigCodeConfig) -> GPTDolomiteConfig: assert original_config.activation_function in ["gelu_pytorch_tanh", "gelu"] config = GPTDolomiteConfig( @@ -62,9 +52,7 @@ def _import_config_from_huggingface( return config -def export_to_huggingface_bigcode( - pretrained_model_name_or_path: str, save_path: str -) -> None: +def export_to_huggingface_bigcode(pretrained_model_name_or_path: str, save_path: str) -> None: shutil.copytree(pretrained_model_name_or_path, save_path) config: GPTDolomiteConfig = AutoConfig.from_pretrained(save_path) @@ -77,21 +65,19 @@ def export_to_huggingface_bigcode( try: tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) tokenizer.save_pretrained(save_path, legacy_format=False) - except: # pylint: disable=bare-except + except: pass def _export_config_to_huggingface(config: GPTDolomiteConfig) -> GPTBigCodeConfig: assert config.activation_function == "gelu_pytorch_tanh" assert config.normalization_function == "layernorm" - assert AttentionHeadType(config.attention_head_type) in [ - AttentionHeadType.mha, - AttentionHeadType.mqa, - ] - assert ( - PositionEmbeddingType(config.position_embedding_type) - == PositionEmbeddingType.learned_absolute - ) + assert AttentionHeadType(config.attention_head_type) in [AttentionHeadType.mha, AttentionHeadType.mqa] + assert PositionEmbeddingType(config.position_embedding_type) == PositionEmbeddingType.learned_absolute + assert config.m_emb is None + assert config.m_residual is None + assert config.m_width is None + assert config.attention_multiplier is None original_config = GPTBigCodeConfig( vocab_size=config.vocab_size, @@ -109,12 +95,12 @@ def _export_config_to_huggingface(config: GPTDolomiteConfig) -> GPTBigCodeConfig scale_attn_weights=config.scale_attn_weights, use_cache=config.use_cache, attention_softmax_in_fp32=config.attention_softmax_in_fp32, - scale_attention_softmax_in_fp32=config.scale_attention_softmax_in_fp32, multi_query=config.multi_query, tie_word_embeddings=config.tie_word_embeddings, bos_token_id=config.bos_token_id, eos_token_id=config.eos_token_id, pad_token_id=config.pad_token_id, + architectures=[GPTBigCodeForCausalLM.__name__], ) return original_config diff --git a/src/instructlab/dolomite/hf_models/model_conversion/granite.py b/src/instructlab/dolomite/hf_models/model_conversion/granite.py new file mode 100644 index 0000000..c9af0d6 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/model_conversion/granite.py @@ -0,0 +1,143 @@ +from transformers import AutoConfig, AutoTokenizer, GenerationConfig + +from ...utils import SafeTensorsWeightsManager, download_repo +from ..enums import AttentionHeadType +from ..models import GPTDolomiteConfig +from .llama import _export_state_dict_to_huggingface, _import_state_dict_from_huggingface + + +try: + from transformers import GraniteConfig, GraniteForCausalLM +except: + GraniteConfig = None + + +def import_from_huggingface_granite(pretrained_model_name_or_path: str, save_path: str) -> None: + original_config, tokenizer, downloaded_model_path = download_repo(pretrained_model_name_or_path) + config = _import_config_from_huggingface(original_config) + + safetensors_weights_manager = SafeTensorsWeightsManager(downloaded_model_path) + state_dict = _import_state_dict_from_huggingface( + safetensors_weights_manager, + config.n_layer, + config.n_head, + config.num_key_value_heads, + config.n_embd // config.n_head, + AttentionHeadType(config.attention_head_type), + ) + + SafeTensorsWeightsManager.save_state_dict(state_dict, save_path) + config.save_pretrained(save_path) + + generation_config = GenerationConfig.from_model_config(config) + generation_config.save_pretrained(save_path) + + if tokenizer is not None: + tokenizer.save_pretrained(save_path, legacy_format=False) + + +def _import_config_from_huggingface(original_config: GraniteConfig) -> GPTDolomiteConfig: + assert original_config.hidden_act == "silu" + + if original_config.num_attention_heads == original_config.num_key_value_heads: + attention_head_type = "mha" + elif original_config.num_key_value_heads == 1: + attention_head_type = "mqa" + elif original_config.num_attention_heads > original_config.num_key_value_heads: + attention_head_type = "gqa" + + assert original_config.mlp_bias == original_config.attention_bias + + config = GPTDolomiteConfig( + vocab_size=original_config.vocab_size, + n_positions=original_config.max_position_embeddings, + n_embd=original_config.hidden_size, + n_layer=original_config.num_hidden_layers, + n_head=original_config.num_attention_heads, + num_key_value_heads=original_config.num_key_value_heads, + attention_head_type=attention_head_type, + position_embedding_type="rope", + n_inner=original_config.intermediate_size, + activation_function="swiglu", + normalization_function="rmsnorm", + layer_norm_epsilon=original_config.rms_norm_eps, + use_cache=original_config.use_cache, + add_bias=original_config.attention_bias, + tie_word_embeddings=original_config.tie_word_embeddings, + initializer_range=original_config.initializer_range, + rope_theta=original_config.rope_theta, + rope_scaling=original_config.rope_scaling, + attn_pdrop=original_config.attention_dropout, + bos_token_id=original_config.bos_token_id, + eos_token_id=original_config.eos_token_id, + pad_token_id=original_config.pad_token_id, + m_emb=None if original_config.embedding_multiplier == 1 else original_config.embedding_multiplier, + m_residual=None if original_config.residual_multiplier == 1 else original_config.residual_multiplier, + m_width=None if original_config.logits_scaling == 1 else original_config.logits_scaling, + attention_multiplier=original_config.attention_multiplier, + ) + + return config + + +def export_to_huggingface_granite(pretrained_model_name_or_path: str, save_path: str) -> None: + config: GPTDolomiteConfig = AutoConfig.from_pretrained(pretrained_model_name_or_path) + original_config = _export_config_to_huggingface(config) + + safetensors_weights_manager = SafeTensorsWeightsManager(pretrained_model_name_or_path) + state_dict = _export_state_dict_to_huggingface( + safetensors_weights_manager, + config.n_layer, + config.n_head, + config.num_key_value_heads, + config.n_embd // config.n_head, + AttentionHeadType(config.attention_head_type), + ) + + SafeTensorsWeightsManager.save_state_dict(state_dict, save_path) + original_config.save_pretrained(save_path) + + original_generation_config = GenerationConfig.from_model_config(original_config) + original_generation_config.save_pretrained(save_path) + + try: + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) + tokenizer.save_pretrained(save_path, legacy_format=False) + except: + pass + + +def _export_config_to_huggingface(config: GPTDolomiteConfig) -> GraniteConfig: + assert config.activation_function == "swiglu" + assert config.normalization_function == "rmsnorm" + assert config.position_embedding_type == "rope" + + original_config = GraniteConfig( + vocab_size=config.vocab_size, + max_position_embeddings=config.n_positions, + hidden_size=config.n_embd, + num_hidden_layers=config.n_layer, + num_attention_heads=config.n_head, + num_key_value_heads=config.num_key_value_heads, + intermediate_size=4 * config.n_embd if config.n_inner is None else config.n_inner, + hidden_act="silu", + rms_norm_eps=config.layer_norm_epsilon, + use_cache=config.use_cache, + attention_bias=config.add_bias, + tie_word_embeddings=config.tie_word_embeddings, + initializer_range=config.initializer_range, + rope_theta=config.rope_theta, + rope_scaling=config.rope_scaling, + attention_dropout=config.attn_pdrop, + mlp_bias=config.add_bias, + bos_token_id=config.bos_token_id, + eos_token_id=config.eos_token_id, + pad_token_id=config.pad_token_id, + embedding_multiplier=1 if config.m_emb is None else config.m_emb, + residual_multiplier=1 if config.m_residual is None else config.m_residual, + logits_scaling=1 if config.m_width is None else config.m_width, + attention_multiplier=config.attention_multiplier, + architectures=[GraniteForCausalLM.__name__], + ) + + return original_config diff --git a/src/instructlab/dolomite/hf_models/model_conversion/granitemoe.py b/src/instructlab/dolomite/hf_models/model_conversion/granitemoe.py new file mode 100644 index 0000000..478abac --- /dev/null +++ b/src/instructlab/dolomite/hf_models/model_conversion/granitemoe.py @@ -0,0 +1,277 @@ +import torch +from transformers import AutoConfig, AutoTokenizer, GenerationConfig + +from ...utils import SafeTensorsWeightsManager, download_repo +from ..enums import AttentionHeadType +from ..modeling_utils import ( + interleave_query_key_value_tensor_for_attention, + split_query_key_value_tensor_for_attention, +) +from ..models import MoEDolomiteConfig + + +try: + from transformers import GraniteMoeConfig, GraniteMoeForCausalLM +except: + GraniteMoeConfig = None + + +def import_from_huggingface_granitemoe(pretrained_model_name_or_path: str, save_path: str) -> None: + original_config, tokenizer, downloaded_model_path = download_repo(pretrained_model_name_or_path) + config = _import_config_from_huggingface(original_config) + + safetensors_weights_manager = SafeTensorsWeightsManager(downloaded_model_path) + state_dict = _import_state_dict_from_huggingface( + safetensors_weights_manager, + config.n_layer, + config.num_experts, + config.n_head, + config.num_key_value_heads, + config.n_embd // config.n_head, + AttentionHeadType(config.attention_head_type), + ) + + SafeTensorsWeightsManager.save_state_dict(state_dict, save_path) + config.save_pretrained(save_path) + + generation_config = GenerationConfig.from_model_config(config) + generation_config.save_pretrained(save_path) + + if tokenizer is not None: + tokenizer.save_pretrained(save_path, legacy_format=False) + + +def _import_config_from_huggingface(original_config: GraniteMoeConfig) -> MoEDolomiteConfig: + assert original_config.hidden_act == "silu" + + if original_config.num_attention_heads == original_config.num_key_value_heads: + attention_head_type = "mha" + elif original_config.num_key_value_heads == 1: + attention_head_type = "mqa" + elif original_config.num_attention_heads > original_config.num_key_value_heads: + attention_head_type = "gqa" + + assert not original_config.attention_bias + + config = MoEDolomiteConfig( + vocab_size=original_config.vocab_size, + n_positions=original_config.max_position_embeddings, + n_embd=original_config.hidden_size, + n_layer=original_config.num_hidden_layers, + n_head=original_config.num_attention_heads, + num_key_value_heads=original_config.num_key_value_heads, + attention_head_type=attention_head_type, + position_embedding_type="rope", + n_inner=original_config.intermediate_size, + activation_function="swiglu", + normalization_function="rmsnorm", + layer_norm_epsilon=original_config.rms_norm_eps, + use_cache=original_config.use_cache, + add_bias=original_config.attention_bias, + tie_word_embeddings=original_config.tie_word_embeddings, + initializer_range=original_config.initializer_range, + rope_theta=original_config.rope_theta, + rope_scaling=original_config.rope_scaling, + attn_pdrop=original_config.attention_dropout, + num_experts=original_config.num_local_experts, + num_experts_per_tok=original_config.num_experts_per_tok, + output_router_logits=original_config.output_router_logits, + router_aux_loss_coef=original_config.router_aux_loss_coef, + bos_token_id=original_config.bos_token_id, + eos_token_id=original_config.eos_token_id, + pad_token_id=original_config.pad_token_id, + m_emb=None if original_config.embedding_multiplier == 1 else original_config.embedding_multiplier, + m_residual=None if original_config.residual_multiplier == 1 else original_config.residual_multiplier, + m_width=None if original_config.logits_scaling == 1 else original_config.logits_scaling, + attention_multiplier=original_config.attention_multiplier, + ) + + return config + + +def _import_state_dict_from_huggingface( + safetensors_weights_manager: SafeTensorsWeightsManager, + num_layers: int, + num_experts: int, + num_heads: int, + num_key_value_heads: int, + head_dim: int, + attention_head_type: AttentionHeadType, +) -> None: + state_dict = { + "transformer.wte.weight": safetensors_weights_manager.get_tensor("model.embed_tokens.weight"), + "transformer.ln_f.weight": safetensors_weights_manager.get_tensor("model.norm.weight"), + } + + if safetensors_weights_manager.has_tensor("lm_head.weight"): + state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight") + + for layer_idx in range(num_layers): + state_dict[f"transformer.h.{layer_idx}.ln_1.weight"] = safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.input_layernorm.weight" + ) + state_dict[f"transformer.h.{layer_idx}.ln_2.weight"] = safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.post_attention_layernorm.weight" + ) + + state_dict[f"transformer.h.{layer_idx}.moe.gate.weight"] = safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.block_sparse_moe.router.layer.weight" + ).T.contiguous() + + state_dict[f"transformer.h.{layer_idx}.moe.c_fc.weight"] = ( + _split_and_reorder_for_glu( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.block_sparse_moe.input_linear.weight" + ) + ) + .transpose(0, 1) + .contiguous() + ) + state_dict[f"transformer.h.{layer_idx}.moe.c_proj.weight"] = ( + safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.block_sparse_moe.output_linear.weight") + .transpose(0, 1) + .contiguous() + ) + + state_dict[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = interleave_query_key_value_tensor_for_attention( + safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.q_proj.weight"), + safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.k_proj.weight"), + safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.v_proj.weight"), + num_heads, + num_key_value_heads, + head_dim, + attention_head_type, + ) + state_dict[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.self_attn.o_proj.weight" + ) + + return state_dict + + +def export_to_huggingface_granitemoe(pretrained_model_name_or_path: str, save_path: str) -> None: + config: MoEDolomiteConfig = AutoConfig.from_pretrained(pretrained_model_name_or_path) + original_config = _export_config_to_huggingface(config) + + safetensors_weights_manager = SafeTensorsWeightsManager(pretrained_model_name_or_path) + state_dict = _export_state_dict_to_huggingface( + safetensors_weights_manager, + config.n_layer, + config.num_experts, + config.n_head, + config.num_key_value_heads, + config.n_embd // config.n_head, + AttentionHeadType(config.attention_head_type), + ) + + SafeTensorsWeightsManager.save_state_dict(state_dict, save_path) + original_config.save_pretrained(save_path) + + original_generation_config = GenerationConfig.from_model_config(original_config) + original_generation_config.save_pretrained(save_path) + + try: + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) + tokenizer.save_pretrained(save_path, legacy_format=False) + except: + pass + + +def _export_config_to_huggingface(config: MoEDolomiteConfig) -> GraniteMoeConfig: + assert config.activation_function == "swiglu" + assert config.normalization_function == "rmsnorm" + assert config.position_embedding_type == "rope" + assert not config.add_bias + + original_config = GraniteMoeConfig( + vocab_size=config.vocab_size, + max_position_embeddings=config.n_positions, + hidden_size=config.n_embd, + num_hidden_layers=config.n_layer, + num_attention_heads=config.n_head, + num_key_value_heads=config.num_key_value_heads, + intermediate_size=4 * config.n_embd if config.n_inner is None else config.n_inner, + hidden_act="silu", + rms_norm_eps=config.layer_norm_epsilon, + use_cache=config.use_cache, + attention_bias=config.add_bias, + tie_word_embeddings=config.tie_word_embeddings, + initializer_range=config.initializer_range, + rope_theta=config.rope_theta, + rope_scaling=config.rope_scaling, + attention_dropout=config.attn_pdrop, + num_local_experts=config.num_experts, + num_experts_per_tok=config.num_experts_per_tok, + output_router_logits=config.output_router_logits, + router_aux_loss_coef=config.router_aux_loss_coef, + bos_token_id=config.bos_token_id, + eos_token_id=config.eos_token_id, + pad_token_id=config.pad_token_id, + embedding_multiplier=1 if config.m_emb is None else config.m_emb, + residual_multiplier=1 if config.m_residual is None else config.m_residual, + logits_scaling=1 if config.m_width is None else config.m_width, + attention_multiplier=config.attention_multiplier, + architectures=[GraniteMoeForCausalLM.__name__], + ) + + return original_config + + +def _export_state_dict_to_huggingface( + safetensors_weights_manager: SafeTensorsWeightsManager, + num_layers: int, + num_experts: int, + num_heads: int, + num_key_value_heads: int, + head_dim: int, + attention_head_type: AttentionHeadType, +) -> None: + state_dict = { + "model.embed_tokens.weight": safetensors_weights_manager.get_tensor("transformer.wte.weight"), + "model.norm.weight": safetensors_weights_manager.get_tensor("transformer.ln_f.weight"), + } + + if safetensors_weights_manager.has_tensor("lm_head.weight"): + state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight") + + for layer_idx in range(num_layers): + state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.ln_1.weight" + ) + state_dict[f"model.layers.{layer_idx}.post_attention_layernorm.weight"] = ( + safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.ln_2.weight") + ) + + state_dict[f"model.layers.{layer_idx}.block_sparse_moe.router.layer.weight"] = ( + safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.moe.gate.weight") + ).T.contiguous() + + state_dict[f"model.layers.{layer_idx}.block_sparse_moe.input_linear.weight"] = _split_and_reorder_for_glu( + safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.moe.c_fc.weight").transpose(0, 1) + ).contiguous() + state_dict[f"model.layers.{layer_idx}.block_sparse_moe.output_linear.weight"] = ( + safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.moe.c_proj.weight").transpose(0, 1) + ).contiguous() + + query_weight, key_weight, value_weight = split_query_key_value_tensor_for_attention( + safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.attn.c_attn.weight"), + num_heads, + num_key_value_heads, + head_dim, + attention_head_type, + ) + state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.weight"] = query_weight + state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.weight"] = key_weight + state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.weight"] = value_weight + + state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.attn.c_proj.weight" + ) + + return state_dict + + +def _split_and_reorder_for_glu(weight: torch.Tensor) -> torch.Tensor: + x, y = weight.chunk(2, dim=1) + weight = torch.cat([y, x], dim=1) + return weight diff --git a/src/instructlab/dolomite/hf_models/model_conversion/llama.py b/src/instructlab/dolomite/hf_models/model_conversion/llama.py index cf028ab..dee5dd4 100644 --- a/src/instructlab/dolomite/hf_models/model_conversion/llama.py +++ b/src/instructlab/dolomite/hf_models/model_conversion/llama.py @@ -1,34 +1,22 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Third Party -from transformers import AutoConfig, AutoTokenizer, GenerationConfig, LlamaConfig +from transformers import AutoConfig, AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM -# Local from ...utils import SafeTensorsWeightsManager, download_repo -from ..config import GPTDolomiteConfig from ..enums import AttentionHeadType from ..modeling_utils import ( interleave_query_key_value_tensor_for_attention, split_query_key_value_tensor_for_attention, ) -from ..models.gpt_dolomite import ( - interleave_up_gate_tensor_for_mlp, - split_up_gate_tensor_for_mlp, -) +from ..models import GPTDolomiteConfig +from ..models.gpt_dolomite import interleave_up_gate_tensor_for_mlp, split_up_gate_tensor_for_mlp -def import_from_huggingface_llama( - pretrained_model_name_or_path: str, save_path: str -) -> None: - original_config, tokenizer, downloaded_model_path = download_repo( - pretrained_model_name_or_path - ) +def import_from_huggingface_llama(pretrained_model_name_or_path: str, save_path: str) -> None: + original_config, tokenizer, downloaded_model_path = download_repo(pretrained_model_name_or_path) config = _import_config_from_huggingface(original_config) - safetensors_weight_manager = SafeTensorsWeightsManager(downloaded_model_path) + safetensors_weights_manager = SafeTensorsWeightsManager(downloaded_model_path) state_dict = _import_state_dict_from_huggingface( - safetensors_weight_manager, + safetensors_weights_manager, config.n_layer, config.n_head, config.num_key_value_heads, @@ -87,7 +75,7 @@ def _import_config_from_huggingface(original_config: LlamaConfig) -> GPTDolomite def _import_state_dict_from_huggingface( - safetensors_weight_manager: SafeTensorsWeightsManager, + safetensors_weights_manager: SafeTensorsWeightsManager, num_layers: int, num_heads: int, num_key_value_heads: int, @@ -95,97 +83,54 @@ def _import_state_dict_from_huggingface( attention_head_type: AttentionHeadType, ) -> None: state_dict = { - "transformer.wte.weight": safetensors_weight_manager.get_tensor( - "model.embed_tokens.weight" - ), - "transformer.ln_f.weight": safetensors_weight_manager.get_tensor( - "model.norm.weight" - ), + "transformer.wte.weight": safetensors_weights_manager.get_tensor("model.embed_tokens.weight"), + "transformer.ln_f.weight": safetensors_weights_manager.get_tensor("model.norm.weight"), } - if safetensors_weight_manager.has_tensor("lm_head.weight"): - state_dict["lm_head.weight"] = safetensors_weight_manager.get_tensor( - "lm_head.weight" - ) + if safetensors_weights_manager.has_tensor("lm_head.weight"): + state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight") for layer_idx in range(num_layers): - state_dict[f"transformer.h.{layer_idx}.ln_1.weight"] = ( - safetensors_weight_manager.get_tensor( - f"model.layers.{layer_idx}.input_layernorm.weight" - ) + state_dict[f"transformer.h.{layer_idx}.ln_1.weight"] = safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.input_layernorm.weight" ) - state_dict[f"transformer.h.{layer_idx}.ln_2.weight"] = ( - safetensors_weight_manager.get_tensor( - f"model.layers.{layer_idx}.post_attention_layernorm.weight" - ) + state_dict[f"transformer.h.{layer_idx}.ln_2.weight"] = safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.post_attention_layernorm.weight" ) - state_dict[f"transformer.h.{layer_idx}.mlp.c_fc.weight"] = ( - interleave_up_gate_tensor_for_mlp( - safetensors_weight_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.up_proj.weight" - ), - safetensors_weight_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.gate_proj.weight" - ), - ) + state_dict[f"transformer.h.{layer_idx}.mlp.c_fc.weight"] = interleave_up_gate_tensor_for_mlp( + safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.up_proj.weight"), + safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.gate_proj.weight"), ) - if f"model.layers.{layer_idx}.mlp.up_proj.bias" in safetensors_weight_manager: - state_dict[f"transformer.h.{layer_idx}.mlp.c_fc.bias"] = ( - interleave_up_gate_tensor_for_mlp( - safetensors_weight_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.up_proj.bias" - ), - safetensors_weight_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.gate_proj.bias" - ), - ) + if f"model.layers.{layer_idx}.mlp.up_proj.bias" in safetensors_weights_manager: + state_dict[f"transformer.h.{layer_idx}.mlp.c_fc.bias"] = interleave_up_gate_tensor_for_mlp( + safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.up_proj.bias"), + safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.gate_proj.bias"), ) - state_dict[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = ( - safetensors_weight_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.down_proj.weight" - ) + state_dict[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.mlp.down_proj.weight" ) - if f"model.layers.{layer_idx}.mlp.down_proj.bias" in safetensors_weight_manager: - state_dict[f"transformer.h.{layer_idx}.mlp.c_proj.bias"] = ( - safetensors_weight_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.down_proj.bias" - ) + if f"model.layers.{layer_idx}.mlp.down_proj.bias" in safetensors_weights_manager: + state_dict[f"transformer.h.{layer_idx}.mlp.c_proj.bias"] = safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.mlp.down_proj.bias" ) - state_dict[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = ( - interleave_query_key_value_tensor_for_attention( - safetensors_weight_manager.get_slice( - f"model.layers.{layer_idx}.self_attn.q_proj.weight" - ), - safetensors_weight_manager.get_slice( - f"model.layers.{layer_idx}.self_attn.k_proj.weight" - ), - safetensors_weight_manager.get_slice( - f"model.layers.{layer_idx}.self_attn.v_proj.weight" - ), - num_heads, - num_key_value_heads, - head_dim, - attention_head_type, - ) + state_dict[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = interleave_query_key_value_tensor_for_attention( + safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.q_proj.weight"), + safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.k_proj.weight"), + safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.v_proj.weight"), + num_heads, + num_key_value_heads, + head_dim, + attention_head_type, ) - if ( - f"model.layers.{layer_idx}.self_attn.q_proj.bias" - in safetensors_weight_manager - ): + if f"model.layers.{layer_idx}.self_attn.q_proj.bias" in safetensors_weights_manager: state_dict[f"transformer.h.{layer_idx}.attn.c_attn.bias"] = ( interleave_query_key_value_tensor_for_attention( - safetensors_weight_manager.get_slice( - f"model.layers.{layer_idx}.self_attn.q_proj.bias" - ), - safetensors_weight_manager.get_slice( - f"model.layers.{layer_idx}.self_attn.k_proj.bias" - ), - safetensors_weight_manager.get_slice( - f"model.layers.{layer_idx}.self_attn.v_proj.bias" - ), + safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.q_proj.bias"), + safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.k_proj.bias"), + safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.v_proj.bias"), num_heads, num_key_value_heads, head_dim, @@ -193,37 +138,24 @@ def _import_state_dict_from_huggingface( ) ) - state_dict[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = ( - safetensors_weight_manager.get_tensor( - f"model.layers.{layer_idx}.self_attn.o_proj.weight" - ) + state_dict[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.self_attn.o_proj.weight" ) - if ( - f"model.layers.{layer_idx}.self_attn.o_proj.bias" - in safetensors_weight_manager - ): - state_dict[f"transformer.h.{layer_idx}.attn.c_proj.bias"] = ( - safetensors_weight_manager.get_tensor( - f"model.layers.{layer_idx}.self_attn.o_proj.bias" - ) + if f"model.layers.{layer_idx}.self_attn.o_proj.bias" in safetensors_weights_manager: + state_dict[f"transformer.h.{layer_idx}.attn.c_proj.bias"] = safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.self_attn.o_proj.bias" ) return state_dict -def export_to_huggingface_llama( - pretrained_model_name_or_path: str, save_path: str -) -> None: - config: GPTDolomiteConfig = AutoConfig.from_pretrained( - pretrained_model_name_or_path - ) +def export_to_huggingface_llama(pretrained_model_name_or_path: str, save_path: str) -> None: + config: GPTDolomiteConfig = AutoConfig.from_pretrained(pretrained_model_name_or_path) original_config = _export_config_to_huggingface(config) - safetensors_weight_manager = SafeTensorsWeightsManager( - pretrained_model_name_or_path - ) + safetensors_weights_manager = SafeTensorsWeightsManager(pretrained_model_name_or_path) state_dict = _export_state_dict_to_huggingface( - safetensors_weight_manager, + safetensors_weights_manager, config.n_layer, config.n_head, config.num_key_value_heads, @@ -240,7 +172,7 @@ def export_to_huggingface_llama( try: tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) tokenizer.save_pretrained(save_path, legacy_format=False) - except: # pylint: disable=bare-except + except: pass @@ -248,18 +180,19 @@ def _export_config_to_huggingface(config: GPTDolomiteConfig) -> LlamaConfig: assert config.activation_function == "swiglu" assert config.normalization_function == "rmsnorm" assert config.position_embedding_type == "rope" + assert config.m_emb is None + assert config.m_residual is None + assert config.m_width is None + assert config.attention_multiplier is None original_config = LlamaConfig( - architectures=config.architectures, vocab_size=config.vocab_size, max_position_embeddings=config.n_positions, hidden_size=config.n_embd, num_hidden_layers=config.n_layer, num_attention_heads=config.n_head, num_key_value_heads=config.num_key_value_heads, - intermediate_size=4 * config.n_embd - if config.n_inner is None - else config.n_inner, + intermediate_size=4 * config.n_embd if config.n_inner is None else config.n_inner, hidden_act="silu", rms_norm_eps=config.layer_norm_epsilon, use_cache=config.use_cache, @@ -273,13 +206,14 @@ def _export_config_to_huggingface(config: GPTDolomiteConfig) -> LlamaConfig: bos_token_id=config.bos_token_id, eos_token_id=config.eos_token_id, pad_token_id=config.pad_token_id, + architectures=[LlamaForCausalLM.__name__], ) return original_config def _export_state_dict_to_huggingface( - safetensors_weight_manager: SafeTensorsWeightsManager, + safetensors_weights_manager: SafeTensorsWeightsManager, num_layers: int, num_heads: int, num_key_value_heads: int, @@ -287,101 +221,71 @@ def _export_state_dict_to_huggingface( attention_head_type: AttentionHeadType, ) -> None: state_dict = { - "model.embed_tokens.weight": safetensors_weight_manager.get_tensor( - "transformer.wte.weight" - ), - "model.norm.weight": safetensors_weight_manager.get_tensor( - "transformer.ln_f.weight" - ), + "model.embed_tokens.weight": safetensors_weights_manager.get_tensor("transformer.wte.weight"), + "model.norm.weight": safetensors_weights_manager.get_tensor("transformer.ln_f.weight"), } - if safetensors_weight_manager.has_tensor("lm_head.weight"): - state_dict["lm_head.weight"] = safetensors_weight_manager.get_tensor( - "lm_head.weight" - ) + if safetensors_weights_manager.has_tensor("lm_head.weight"): + state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight") for layer_idx in range(num_layers): - state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = ( - safetensors_weight_manager.get_tensor( - f"transformer.h.{layer_idx}.ln_1.weight" - ) + state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.ln_1.weight" ) state_dict[f"model.layers.{layer_idx}.post_attention_layernorm.weight"] = ( - safetensors_weight_manager.get_tensor( - f"transformer.h.{layer_idx}.ln_2.weight" - ) + safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.ln_2.weight") ) up_weight, gate_weight = split_up_gate_tensor_for_mlp( - safetensors_weight_manager.get_tensor( - f"transformer.h.{layer_idx}.mlp.c_fc.weight" - ) + safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp.c_fc.weight") ) state_dict[f"model.layers.{layer_idx}.mlp.up_proj.weight"] = up_weight state_dict[f"model.layers.{layer_idx}.mlp.gate_proj.weight"] = gate_weight - if f"transformer.h.{layer_idx}.mlp.c_fc.bias" in safetensors_weight_manager: + if f"transformer.h.{layer_idx}.mlp.c_fc.bias" in safetensors_weights_manager: up_bias, gate_bias = split_up_gate_tensor_for_mlp( - safetensors_weight_manager.get_tensor( - f"transformer.h.{layer_idx}.mlp.c_fc.bias" - ) + safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp.c_fc.bias") ) state_dict[f"model.layers.{layer_idx}.mlp.up_proj.bias"] = up_bias state_dict[f"model.layers.{layer_idx}.mlp.gate_proj.bias"] = gate_bias - state_dict[f"model.layers.{layer_idx}.mlp.down_proj.weight"] = ( - safetensors_weight_manager.get_tensor( - f"transformer.h.{layer_idx}.mlp.c_proj.weight" - ) + state_dict[f"model.layers.{layer_idx}.mlp.down_proj.weight"] = safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.mlp.c_proj.weight" ) - if f"transformer.h.{layer_idx}.mlp.c_proj.bias" in safetensors_weight_manager: - state_dict[f"model.layers.{layer_idx}.mlp.down_proj.bias"] = ( - safetensors_weight_manager.get_tensor( - f"transformer.h.{layer_idx}.mlp.c_proj.bias" - ) + if f"transformer.h.{layer_idx}.mlp.c_proj.bias" in safetensors_weights_manager: + state_dict[f"model.layers.{layer_idx}.mlp.down_proj.bias"] = safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.mlp.c_proj.bias" ) - query_weight, key_weight, value_weight = ( - split_query_key_value_tensor_for_attention( - safetensors_weight_manager.get_tensor( - f"transformer.h.{layer_idx}.attn.c_attn.weight" - ), - num_heads, - num_key_value_heads, - head_dim, - attention_head_type, - ) + query_weight, key_weight, value_weight = split_query_key_value_tensor_for_attention( + safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.attn.c_attn.weight"), + num_heads, + num_key_value_heads, + head_dim, + attention_head_type, ) state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.weight"] = query_weight state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.weight"] = key_weight state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.weight"] = value_weight - if f"transformer.h.{layer_idx}.attn.c_attn.bias" in safetensors_weight_manager: - query_bias, key_bias, value_bias = ( - split_query_key_value_tensor_for_attention( - safetensors_weight_manager.get_tensor( - f"transformer.h.{layer_idx}.attn.c_attn.bias" - ), - num_heads, - num_key_value_heads, - head_dim, - attention_head_type, - ) + if f"transformer.h.{layer_idx}.attn.c_attn.bias" in safetensors_weights_manager: + query_bias, key_bias, value_bias = split_query_key_value_tensor_for_attention( + safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.attn.c_attn.bias"), + num_heads, + num_key_value_heads, + head_dim, + attention_head_type, ) state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.bias"] = query_bias state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.bias"] = key_bias state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.bias"] = value_bias - state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = ( - safetensors_weight_manager.get_tensor( - f"transformer.h.{layer_idx}.attn.c_proj.weight" - ) + state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.attn.c_proj.weight" ) - if f"transformer.h.{layer_idx}.attn.c_proj.bias" in safetensors_weight_manager: - state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.bias"] = ( - safetensors_weight_manager.get_tensor( - f"transformer.h.{layer_idx}.attn.c_proj.bias" - ) + if f"transformer.h.{layer_idx}.attn.c_proj.bias" in safetensors_weights_manager: + state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.bias"] = safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.attn.c_proj.bias" ) return state_dict diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/__init__.py index 1ad7d54..92aea83 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/__init__.py @@ -1,7 +1,3 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Local from .activations import get_activation_function, is_glu from .attention import ( SDPA, @@ -9,10 +5,11 @@ FlashAttention2, PaddingFreeAttention, get_attention_module, - get_unpad_data, interleave_query_key_value_tensor_for_attention, repeat_key_value, split_query_key_value_tensor_for_attention, ) -from .normalization import RMSNorm, get_normalization_function -from .position_embedding import Alibi, RoPE, apply_rotary_pos_emb +from .embedding import ParameterizedEmbedding +from .linear import ParameterizedLinear, ParameterizedTransposedLinear +from .normalization import get_normalization_function +from .position_embedding import Alibi, RoPE, YaRNScaledRoPE, apply_rotary_pos_emb diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/activations/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/activations/__init__.py index 052666a..478c5dd 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/activations/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/activations/__init__.py @@ -1,13 +1,8 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Third Party -import torch +import torch.nn as nn -# Local from .base import get_base_activation from .glu import get_glu_activation, is_glu -def get_activation_function(name: str) -> torch.nn.Module: +def get_activation_function(name: str) -> nn.Module: return get_glu_activation(name) if is_glu(name) else get_base_activation(name) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/activations/base.py b/src/instructlab/dolomite/hf_models/modeling_utils/activations/base.py index 9dd6b68..3a8d155 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/activations/base.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/activations/base.py @@ -1,25 +1,18 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# pylint: disable=consider-using-from-import -# Third Party -from transformers.activations import ACT2CLS, ClassInstantier import torch.nn as nn +from transformers.activations import ACT2CLS, ClassInstantier -# Local -from .math_gelu import MathGELU _BASE_ACTIVATIONS = { "celu": nn.modules.CELU, "elu": nn.modules.ELU, "gelu": nn.modules.GELU, "gelu_pytorch_tanh": (nn.modules.GELU, {"approximate": "tanh"}), - "gelu_math_tanh": MathGELU, "selu": nn.modules.SELU, "hard_shrink": nn.modules.Hardshrink, "hard_sigmoid": nn.modules.Hardsigmoid, "hard_swish": nn.modules.Hardswish, "hard_tanh": nn.modules.Hardtanh, + "identity": nn.modules.Identity, "laplace": ACT2CLS["laplace"], "leaky_reLU": nn.modules.LeakyReLU, "log_sigmoid": nn.modules.LogSigmoid, diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/activations/glu.py b/src/instructlab/dolomite/hf_models/modeling_utils/activations/glu.py index 0fc3d5e..1419488 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/activations/glu.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/activations/glu.py @@ -1,12 +1,9 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Third Party import torch +import torch.nn as nn -# Local from .base import get_base_activation + _GLU_BASE_MAPPING = { "ceglu": "celu", "eglu": "elu", @@ -21,8 +18,8 @@ } -class GLUActivation(torch.nn.Module): - def __init__(self, base_activation: torch.nn.Module) -> None: +class GLUActivation(nn.Module): + def __init__(self, base_activation: nn.Module) -> None: super().__init__() self.base_activation = base_activation @@ -31,10 +28,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x[0] * self.base_activation(x[1]) -def get_glu_activation(name: str) -> torch.nn.Module: +def get_glu_activation(name: str) -> nn.Module: # for glu and sigmoid_glu, we directly return the pytorch's GLU if name in ["glu", "sigmoid_glu"]: - activation_function = torch.nn.modules.GLU() + activation_function = nn.modules.GLU() else: if name in _GLU_BASE_MAPPING: name = _GLU_BASE_MAPPING[name] diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/activations/math_gelu.py b/src/instructlab/dolomite/hf_models/modeling_utils/activations/math_gelu.py deleted file mode 100644 index 5c72e4e..0000000 --- a/src/instructlab/dolomite/hf_models/modeling_utils/activations/math_gelu.py +++ /dev/null @@ -1,37 +0,0 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Third Party -import torch - - -@torch.compile -def _gelu_forward(x: torch.Tensor) -> torch.Tensor: - return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) - - -@torch.compile -def _gelu_backward(gradient: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - ff = 0.5 * x * ( - (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) - ) + 0.5 * (1 + tanh_out) - return ff * gradient - - -class _MathGELU(torch.autograd.Function): - @staticmethod - def forward(ctx, input: torch.Tensor) -> torch.Tensor: - ctx.save_for_backward(input) - return _gelu_forward(input) - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: - (input,) = ctx.saved_tensors - tmp = _gelu_backward(grad_output, input) - return tmp - - -class MathGELU(torch.nn.Module): - def forward(self, input: torch.Tensor) -> torch.Tensor: - return _MathGELU.apply(input) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/__init__.py index fca22e7..c743985 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/__init__.py @@ -1,22 +1,14 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Standard -from typing import Tuple import inspect -# Third Party import torch -# Local -from ...config import GPTDolomiteConfig +from ...config import CommonConfig from ...enums import AttentionHeadType from .base import Attention from .flash import FlashAttention2 from .padding_free import PaddingFreeAttention from .sdpa import SDPA from .utils import ( - get_unpad_data, interleave_query_key_value_tensor_for_gqa, interleave_query_key_value_tensor_for_mha, interleave_query_key_value_tensor_for_mqa, @@ -26,6 +18,7 @@ split_query_key_value_tensor_for_mqa, ) + _ATTENTION_MODULES = { "eager": Attention, "sdpa": SDPA, @@ -48,7 +41,7 @@ def get_attention_module( - config: GPTDolomiteConfig, + config: CommonConfig, causal: bool, attention_implementation: str, use_padding_free_transformer: bool, @@ -76,9 +69,7 @@ def interleave_query_key_value_tensor_for_attention( ) -> torch.Tensor: if attention_head_type.value in _INTERLEAVE_FUNCTIONS: interleave_function = _INTERLEAVE_FUNCTIONS[attention_head_type.value] - interleave_function_parameters = inspect.signature( - interleave_function - ).parameters.keys() + interleave_function_parameters = inspect.signature(interleave_function).parameters.keys() parameters_to_pass = {} this_function_parameters = locals() @@ -98,7 +89,7 @@ def split_query_key_value_tensor_for_attention( num_key_value_heads: int, head_dim: int, attention_head_type: AttentionHeadType, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if attention_head_type.value in _SPLIT_FUNCTIONS: split_function = _SPLIT_FUNCTIONS[attention_head_type.value] split_function_parameters = inspect.signature(split_function).parameters.keys() diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/base.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/base.py index 4c17941..51903f1 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/base.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/base.py @@ -1,26 +1,20 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Standard -from typing import Tuple - -# Third Party -from torch.nn import Linear # replaces ParameterizedLinear -from transformers import DynamicCache +import math + import torch +import torch.nn as nn import torch.nn.functional as F +from transformers import DynamicCache -# Local -from ...config import GPTDolomiteConfig -from ...enums import AttentionHeadType, PositionEmbeddingType +from ...config import CommonConfig +from ...enums import AttentionHeadType, InitMethod, PositionEmbeddingType +from ...utils import divide_if_divisible +from ..linear import ParameterizedLinear from ..position_embedding import apply_rotary_pos_emb from .utils import repeat_key_value -class Attention(torch.nn.Module): - def __init__( - self, config: GPTDolomiteConfig, causal: bool, layer_idx: int = None - ) -> None: +class Attention(nn.Module): + def __init__(self, config: CommonConfig, causal: bool, layer_idx: int | None = None) -> None: super().__init__() self.causal = causal @@ -29,24 +23,25 @@ def __init__( self.num_key_value_heads = config.num_key_value_heads self.add_bias = config.add_bias - assert ( - self.hidden_size % self.num_heads == 0 - ), f"`hidden_size` ({self.hidden_size}) must be divisible by `num_heads` ({self.num_heads})" + initializer_range = config.initializer_range + m_width = config.m_width + n_layer = config.n_layer + init_method = InitMethod(config.init_method) + + self.head_dim = divide_if_divisible( + self.hidden_size, + self.num_heads, + f"`hidden_size` ({self.hidden_size}) must be divisible by `num_heads` ({self.num_heads})", + ) - self.head_dim = self.hidden_size // self.num_heads self.attention_head_type = AttentionHeadType(config.attention_head_type) - self.position_embedding_type = PositionEmbeddingType( - config.position_embedding_type - ) + self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) self.scale_attn_weights = config.scale_attn_weights self.attention_multiplier = config.attention_multiplier self.layer_idx = layer_idx self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 - self.scale_attention_softmax_in_fp32 = ( - config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32 - ) if self.attention_head_type == AttentionHeadType.mha: if self.num_key_value_heads is None: @@ -60,49 +55,43 @@ def __init__( self.num_key_value_heads is not None ), "`num_key_value_heads` needs to be specified with GroupedQueryAttention" - assert self.num_heads % self.num_key_value_heads == 0, ( - f"`num_heads` ({self.num_heads}) should be a multiple of `num_key_value_heads` " - f"({self.num_key_value_heads})" + divide_if_divisible( + self.num_heads, + self.num_key_value_heads, + f"`num_heads` ({self.num_heads}) should be a multiple of `num_key_value_heads` ({self.num_key_value_heads})", ) elif self.attention_head_type == AttentionHeadType.mqa: if self.num_key_value_heads is None: self.num_key_value_heads = 1 - assert ( - self.num_key_value_heads == 1 - ), f"{self.__class__.__name__} should have 1 head for keys and values" + assert self.num_key_value_heads == 1, f"{self.__class__.__name__} should have 1 head for keys and values" else: - raise ValueError( - f"unexpected attention_head_type ({self.attention_head_type})" - ) + raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})") # note that the actual layout is different for the output and depends on whether we are using MHA, MQA or GQA # (self.hidden_size + 2 * self.num_key_value_heads * self.head_dim) is just the actual number output features - self.c_attn = Linear( + std = initializer_range + if init_method == InitMethod.mup: + std /= math.sqrt(m_width) + self.c_attn = ParameterizedLinear( self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=self.add_bias, + std=std, ) - self.c_proj = Linear(self.hidden_size, self.hidden_size, bias=self.add_bias) + std = initializer_range / math.sqrt(2 * n_layer) + if init_method == InitMethod.mup: + std /= math.sqrt(m_width) + self.c_proj = ParameterizedLinear(self.hidden_size, self.hidden_size, bias=self.add_bias, std=std) self.attn_pdrop = config.attn_pdrop self.resid_pdrop = config.resid_pdrop - self.attn_dropout = ( - torch.nn.Identity() - if self.attn_pdrop == 0 - else torch.nn.Dropout(self.attn_pdrop) - ) - self.resid_dropout = ( - torch.nn.Identity() - if self.resid_pdrop == 0 - else torch.nn.Dropout(self.resid_pdrop) - ) + self.attn_dropout = nn.Identity() if self.attn_pdrop == 0 else nn.Dropout(self.attn_pdrop) + self.resid_dropout = nn.Identity() if self.resid_pdrop == 0 else nn.Dropout(self.resid_pdrop) - def _prepare_qkv_for_forward( - self, hidden_states: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def _prepare_qkv_for_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # ========================================================================================== # hidden_states -> (batch_size, query_length, num_heads * head_dim) # ========================================================================================== @@ -122,9 +111,7 @@ def _prepare_qkv_for_forward( elif self.attention_head_type == AttentionHeadType.mqa: query, key, value = self._prepare_qkv_for_forward_mqa(hidden_states) else: - raise ValueError( - f"unexpected attention_head_type ({self.attention_head_type})" - ) + raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})") # ========================================================================================== # query -> (batch_size, num_heads, query_length, head_dim) @@ -136,7 +123,7 @@ def _prepare_qkv_for_forward( def _prepare_qkv_for_forward_mha( self, hidden_states: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, query_length = hidden_states.shape[:-1] hidden_states = hidden_states.view(batch_size, query_length, self.num_heads, -1) @@ -148,20 +135,13 @@ def _prepare_qkv_for_forward_mha( def _prepare_qkv_for_forward_gqa( self, hidden_states: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, query_length = hidden_states.shape[:-1] - hidden_states = hidden_states.view( - batch_size, query_length, self.num_key_value_heads, -1 - ) + hidden_states = hidden_states.view(batch_size, query_length, self.num_key_value_heads, -1) query, key, value = hidden_states.split( - ( - (self.num_heads // self.num_key_value_heads) * self.head_dim, - self.head_dim, - self.head_dim, - ), - dim=-1, + ((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1 ) # this needs to be a reshape instead of view sadly @@ -175,12 +155,10 @@ def _prepare_qkv_for_forward_gqa( def _prepare_qkv_for_forward_mqa( self, hidden_states: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, query_length = hidden_states.shape[:-1] - query, key, value = hidden_states.split( - (self.hidden_size, self.head_dim, self.head_dim), dim=-1 - ) + query, key, value = hidden_states.split((self.hidden_size, self.head_dim, self.head_dim), dim=-1) query = query.view(batch_size, query_length, self.num_heads, -1) @@ -193,11 +171,11 @@ def _prepare_qkv_for_forward_mqa( def forward( self, hidden_states: torch.Tensor, - past_key_values: DynamicCache = None, - attention_mask: torch.Tensor = None, - rope_cos_sin: torch.Tensor = None, - cu_seqlens: torch.Tensor = None, - max_seqlen: torch.Tensor = None, + past_key_values: DynamicCache | None = None, + attention_mask: torch.Tensor | None = None, + rope_cos_sin: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, ) -> torch.Tensor: # ========================================================================================== # hidden_states -> (batch_size, query_length, num_heads * head_dim) @@ -255,20 +233,16 @@ def forward( if attention_mask is None: attn_weights = torch.empty( - (batch_size * self.num_heads, query_length, key_length), - device=query.device, - dtype=query.dtype, + (batch_size * self.num_heads, query_length, key_length), device=query.device, dtype=query.dtype ) beta = 0 else: - attn_weights = attention_mask.expand(-1, self.num_heads, -1, -1).reshape( - -1, query_length, key_length - ) + attn_weights = attention_mask.expand(-1, self.num_heads, -1, -1).reshape(-1, query_length, key_length) beta = 1 - attn_weights = torch.baddbmm( - attn_weights, query, key, beta=beta, alpha=self._get_softmax_scale(False) - ).view(batch_size, self.num_heads, query_length, key_length) + attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=self._get_softmax_scale(False)).view( + batch_size, self.num_heads, query_length, key_length + ) # ========================================================================================== # attn_weights -> (batch_size, num_heads, query_length, key_length) @@ -289,9 +263,7 @@ def forward( # ========================================================================================== attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape( - batch_size, -1, self.num_heads * self.head_dim - ) + attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim) # ========================================================================================== # attn_output -> (batch_size, query_length, num_heads * head_dim) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/flash.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/flash.py index 05c7189..26bac53 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/flash.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/flash.py @@ -1,32 +1,21 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Third Party -from transformers import DynamicCache import torch +from transformers import DynamicCache +from transformers.modeling_flash_attention_utils import _flash_attention_forward -# Local -from ....utils import is_flash_attention_available from ...enums import AttentionHeadType, PositionEmbeddingType from ..position_embedding import apply_rotary_pos_emb from .base import Attention -from .utils import get_unpad_data - -if is_flash_attention_available(): - # Third Party - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input - from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func class FlashAttention2(Attention): def forward( self, hidden_states: torch.Tensor, - past_key_values: DynamicCache = None, - attention_mask: torch.Tensor = None, - rope_cos_sin: torch.Tensor = None, - cu_seqlens: torch.Tensor = None, - max_seqlen: torch.Tensor = None, + past_key_values: DynamicCache | None = None, + attention_mask: torch.Tensor | None = None, + rope_cos_sin: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, ) -> torch.Tensor: # ========================================================================================== # hidden_states -> (batch_size, query_length, num_heads * head_dim) @@ -73,81 +62,16 @@ def forward( batch_size, query_length = query.shape[:2] - if attention_mask is None: - attn_output = flash_attn_func( - query, - key, - value, - dropout_p=dropout_p, - softmax_scale=softmax_scale, - causal=self.causal, - ) - else: - key_length = key.shape[1] - - indices_k, cu_seqlens_k, max_seqlen_k = get_unpad_data(attention_mask) - - key = index_first_axis( - key.reshape( - batch_size * key_length, self.num_key_value_heads, self.head_dim - ), - indices_k, - ) - value = index_first_axis( - value.reshape( - batch_size * key_length, self.num_key_value_heads, self.head_dim - ), - indices_k, - ) - - if query_length == key_length: - query = index_first_axis( - query.reshape( - batch_size * key_length, self.num_heads, self.head_dim - ), - indices_k, - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_q = max_seqlen_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query = query.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input( - query, attention_mask - ) - - # ========================================================================================== - # query -> (total_q, num_heads, head_dim) - # key -> (total_q, num_heads, head_dim) - # value -> (total_q, num_heads, head_dim) - # ========================================================================================== - - attn_output = flash_attn_varlen_func( - query, - key, - value, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - dropout_p=dropout_p, - softmax_scale=softmax_scale, - causal=self.causal, - ) - - # ========================================================================================== - # attn_output -> (total_q, num_heads, head_dim) - # ========================================================================================== - - attn_output = pad_input(attn_output, indices_q, batch_size, query_length) + attn_output = _flash_attention_forward( + query_states=query, + key_states=key, + value_states=value, + attention_mask=attention_mask, + query_length=query_length, + is_causal=self.causal, + dropout=dropout_p, + softmax_scale=softmax_scale, + ) attn_output = attn_output.view(batch_size, query_length, -1) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/padding_free.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/padding_free.py index 5d633ee..9b07a51 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/padding_free.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/padding_free.py @@ -1,21 +1,13 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Standard -from typing import Tuple - -# Third Party -from transformers import DynamicCache import torch +from transformers import DynamicCache -# Local from ....utils import is_flash_attention_available from ...enums import PositionEmbeddingType from ..position_embedding import apply_rotary_pos_emb from .base import Attention + if is_flash_attention_available(): - # Third Party from flash_attn.flash_attn_interface import flash_attn_varlen_func @@ -23,11 +15,11 @@ class PaddingFreeAttention(Attention): def forward( self, hidden_states: torch.Tensor, - past_key_values: DynamicCache = None, - attention_mask: torch.Tensor = None, - rope_cos_sin: torch.Tensor = None, - cu_seqlens: torch.Tensor = None, - max_seqlen: torch.Tensor = None, + past_key_values: DynamicCache | None = None, + attention_mask: torch.Tensor | None = None, + rope_cos_sin: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, ) -> torch.Tensor: assert past_key_values is None @@ -86,7 +78,7 @@ def forward( def _prepare_qkv_for_forward_mha( self, hidden_states: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: total_q = hidden_states.shape[0] hidden_states = hidden_states.view(total_q, self.num_key_value_heads, -1) @@ -96,18 +88,13 @@ def _prepare_qkv_for_forward_mha( def _prepare_qkv_for_forward_gqa( self, hidden_states: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: total_q = hidden_states.shape[0] hidden_states = hidden_states.view(total_q, self.num_key_value_heads, -1) query, key, value = hidden_states.split( - ( - (self.num_heads // self.num_key_value_heads) * self.head_dim, - self.head_dim, - self.head_dim, - ), - dim=-1, + ((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1 ) # this needs to be a reshape instead of view sadly @@ -117,12 +104,10 @@ def _prepare_qkv_for_forward_gqa( def _prepare_qkv_for_forward_mqa( self, hidden_states: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: total_q = hidden_states.shape[0] - query, key, value = hidden_states.split( - (self.hidden_size, self.head_dim, self.head_dim), dim=-1 - ) + query, key, value = hidden_states.split((self.hidden_size, self.head_dim, self.head_dim), dim=-1) query = query.view(total_q, self.num_heads, -1) key = key.unsqueeze(1) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/sdpa.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/sdpa.py index dba33f1..ad3290e 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/sdpa.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/sdpa.py @@ -1,12 +1,7 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Third Party -from transformers import DynamicCache import torch import torch.nn.functional as F +from transformers import DynamicCache -# Local from ...enums import PositionEmbeddingType from ..position_embedding import apply_rotary_pos_emb from .base import Attention @@ -17,11 +12,11 @@ class SDPA(Attention): def forward( self, hidden_states: torch.Tensor, - past_key_values: DynamicCache = None, - attention_mask: torch.Tensor = None, - rope_cos_sin: torch.Tensor = None, - cu_seqlens: torch.Tensor = None, - max_seqlen: torch.Tensor = None, + past_key_values: DynamicCache | None = None, + attention_mask: torch.Tensor | None = None, + rope_cos_sin: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, ) -> torch.Tensor: # ========================================================================================== # hidden_states -> (batch_size, query_length, num_heads * head_dim) @@ -76,9 +71,7 @@ def forward( batch_size = attn_output.shape[0] attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape( - batch_size, -1, self.num_heads * self.head_dim - ) + attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim) # ========================================================================================== # attn_output -> (batch_size, query_length, num_heads * head_dim) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/utils.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/utils.py index a894acf..ca60ca9 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/utils.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/utils.py @@ -1,27 +1,4 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Standard -from typing import Tuple - -# Third Party import torch -import torch.nn.functional as F - - -# Copied from transformers.models.llama.modeling_llama._get_unpad_data -def get_unpad_data( - attention_mask: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) def interleave_query_key_value_tensor_for_mha( @@ -45,7 +22,7 @@ def interleave_query_key_value_tensor_for_mha( def split_query_key_value_tensor_for_mha( query_key_value_weight: torch.Tensor, num_heads: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: original_shape = query_key_value_weight.shape query_key_value_weight = query_key_value_weight.view(num_heads, -1) @@ -84,21 +61,14 @@ def interleave_query_key_value_tensor_for_gqa( def split_query_key_value_tensor_for_gqa( - query_key_value_weight: torch.Tensor, - num_heads: int, - num_key_value_heads: int, - head_dim: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + query_key_value_weight: torch.Tensor, num_heads: int, num_key_value_heads: int, head_dim: int +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: query_heads_per_group = num_heads // num_key_value_heads original_shape = query_key_value_weight.shape - query_key_value_weight = query_key_value_weight.view( - num_key_value_heads, (query_heads_per_group + 2), -1 - ) + query_key_value_weight = query_key_value_weight.view(num_key_value_heads, (query_heads_per_group + 2), -1) - query_weight, key_weight, value_weight = query_key_value_weight.split( - (query_heads_per_group, 1, 1), 1 - ) + query_weight, key_weight, value_weight = query_key_value_weight.split((query_heads_per_group, 1, 1), 1) query_weight = query_weight.reshape(-1, *original_shape[1:]) key_weight = key_weight.reshape(-1, *original_shape[1:]) @@ -118,13 +88,11 @@ def interleave_query_key_value_tensor_for_mqa( def split_query_key_value_tensor_for_mqa( query_key_value_weight: torch.Tensor, num_heads: int, head_dim: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return query_key_value_weight.split((num_heads * head_dim, head_dim, head_dim)) -def repeat_key_value( - x: torch.Tensor, num_heads: int, num_key_value_heads: int -) -> torch.Tensor: +def repeat_key_value(x: torch.Tensor, num_heads: int, num_key_value_heads: int) -> torch.Tensor: num_groups = num_heads // num_key_value_heads if num_groups == 1: diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/embedding.py b/src/instructlab/dolomite/hf_models/modeling_utils/embedding.py new file mode 100644 index 0000000..3cff32e --- /dev/null +++ b/src/instructlab/dolomite/hf_models/modeling_utils/embedding.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn + + +class ParameterizedEmbedding(nn.Embedding): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int | None = None, + max_norm: float | None = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: torch.Tensor | None = None, + _freeze: bool = False, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + std: float | None = None, + ) -> None: + self.std = std + super().__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + _freeze, + device, + dtype, + ) + + @torch.no_grad() + def reset_parameters(self) -> None: + if self.std is None: + super().reset_parameters() + else: + # nn.init.trunc_normal_(self.weight, mean=0, std=self.std) + self.weight.data.normal_(mean=0, std=self.std) + if self.padding_idx is not None: + self.weight.data[self.padding_idx].zero_() diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/linear.py b/src/instructlab/dolomite/hf_models/modeling_utils/linear.py new file mode 100644 index 0000000..560e100 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/modeling_utils/linear.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn + + +class ParameterizedLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + std: float | None = None, + ) -> None: + self.std = std + super().__init__(in_features, out_features, bias, device, dtype) + + @torch.no_grad() + def reset_parameters(self) -> None: + if self.std is None: + super().reset_parameters() + else: + nn.init.normal_(self.weight, mean=0, std=self.std) + if hasattr(self, "bias") and self.bias is not None: + self.bias.zero_() + + +class ParameterizedTransposedLinear(ParameterizedLinear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + std: float | None = None, + ) -> None: + self.std = std + + if bias: + raise NotImplementedError("bias is not supported with TransposedLinear yet") + + # pass in_features as out_features and vice-versa + super().__init__(out_features, in_features, bias, device, dtype) + + # invert them now to print the module correctly + self.in_features, self.out_features = self.out_features, self.in_features + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return input @ self.weight diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/__init__.py index eb68644..b4bf746 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/__init__.py @@ -1,11 +1,8 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Third Party -import torch +import torch.nn as nn + +from .layernorm import get_layernorm +from .rmsnorm import get_rmsnorm -# Local -from .norms import RMSNorm, get_layernorm, get_rmsnorm _NORMALIZATION_FUNCTIONS = { "layernorm": get_layernorm, @@ -18,14 +15,10 @@ def get_normalization_function( normalized_shape: int, eps: float = 1e-5, normalization_implementation: str = "torch", -) -> torch.nn.LayerNorm: +) -> nn.LayerNorm: if name in _NORMALIZATION_FUNCTIONS: return _NORMALIZATION_FUNCTIONS[name]( - normalized_shape, - eps=eps, - normalization_implementation=normalization_implementation, + normalized_shape, eps=eps, normalization_implementation=normalization_implementation ) - raise ValueError( - f"unexpected `normalization_implementation` {normalization_implementation}" - ) + raise ValueError(f"unexpected `normalization_implementation` {normalization_implementation}") diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/__init__.py new file mode 100644 index 0000000..915c7ca --- /dev/null +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/__init__.py @@ -0,0 +1,22 @@ +import torch.nn as nn + +from .apex import ApexLayerNorm +from .apex_persistent import ApexPersistentLayerNorm + + +_LAYERNORM_MODULES = { + "torch": nn.LayerNorm, + "apex": ApexLayerNorm, + "apex_persistent": ApexPersistentLayerNorm, +} + + +def get_layernorm( + normalized_shape: int, + eps: float, + normalization_implementation: str = "torch", +) -> nn.LayerNorm: + if normalization_implementation in _LAYERNORM_MODULES: + return _LAYERNORM_MODULES[normalization_implementation](normalized_shape=normalized_shape, eps=eps) + + raise ValueError(f"unexpected `normalization_implementation` {normalization_implementation}") diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex.py new file mode 100644 index 0000000..763ad7f --- /dev/null +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn + + +def is_apex_layernorm_available() -> bool: + try: + from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction + + return True + except ImportError: + return False + + +if is_apex_layernorm_available(): + from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction + + +def apex_layernorm( + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, memory_efficient: bool +) -> torch.Tensor: + normalized_shape = (input.shape[-1],) + return FusedLayerNormAffineFunction.apply(input, weight, bias, normalized_shape, eps, memory_efficient) + + +class ApexLayerNorm(nn.LayerNorm): + def __init__(self, normalized_shape: int, eps: float = 0.00001) -> None: + if not is_apex_layernorm_available(): + raise ImportError("build apex from source") + + super().__init__(normalized_shape, eps=eps) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return apex_layernorm(input, self.weight, self.bias, self.eps, True) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex_persistent.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex_persistent.py new file mode 100644 index 0000000..e3ac497 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex_persistent.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn + + +def is_apex_persistent_layernorm_available() -> bool: + try: + from apex.contrib.layer_norm.layer_norm import FastLayerNormFN + + return True + except ImportError: + return False + + +if is_apex_persistent_layernorm_available(): + from apex.contrib.layer_norm.layer_norm import FastLayerNormFN + + +_PERSISTENT_LAYERNORM_ALLOWED_HIDDEN_STATES = [ + 1024, + 1536, + 2048, + 2304, + 3072, + 3840, + 4096, + 5120, + 6144, + 8192, + 10240, + 12288, + 12800, + 15360, + 16384, + 18432, + 20480, + 24576, + 25600, + 30720, + 32768, + 40960, + 49152, + 65536, +] + + +def apex_persistent_layernorm( + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, memory_efficient +) -> torch.Tensor: + return FastLayerNormFN.apply(input, weight, bias, eps, memory_efficient) + + +class ApexPersistentLayerNorm(nn.LayerNorm): + def __init__(self, normalized_shape: int, eps: float = 0.00001) -> None: + if not is_apex_persistent_layernorm_available(): + raise ImportError("build apex from source with --fast_layer_norm") + + super().__init__(normalized_shape, eps=eps) + + assert ( + self.normalized_shape[0] in _PERSISTENT_LAYERNORM_ALLOWED_HIDDEN_STATES + ), "persistent layernorm kernel is not avilable for the specified hidden dimension" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return apex_persistent_layernorm(input, self.weight, self.bias, self.eps, True) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/norms.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/norms.py deleted file mode 100644 index f752a6a..0000000 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/norms.py +++ /dev/null @@ -1,81 +0,0 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- - -# Standard -import numbers - -# Third Party -import torch - -# ---------------- LayerNorm --------------- - -_LAYERNORM_MODULES = { - "torch": torch.nn.LayerNorm, -} - - -def get_layernorm( - normalized_shape: int, - eps: float, - normalization_implementation: str = "torch", -) -> torch.nn.LayerNorm: - if normalization_implementation in _LAYERNORM_MODULES: - return _LAYERNORM_MODULES[normalization_implementation]( - normalized_shape=normalized_shape, eps=eps - ) - - raise ValueError( - f"unexpected `normalization_implementation` {normalization_implementation}" - ) - - -# --------------- RMS Norm --------------- -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- - - -class RMSNorm(torch.nn.Module): - def __init__(self, normalized_shape: int, eps: float = 1e-6) -> None: - super().__init__() - - self.weight = torch.nn.Parameter(torch.ones(normalized_shape)) - self.eps = eps - - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = (normalized_shape,) - self.normalized_shape = normalized_shape - - def forward(self, input: torch.Tensor) -> torch.Tensor: - input_dtype = input.dtype - - input = input.to(torch.float32) - variance = input.pow(2).mean(-1, keepdim=True) - input = input * torch.rsqrt(variance + self.eps) - - return self.weight * input.to(input_dtype) - - def extra_repr(self) -> str: - return f"{self.normalized_shape}, eps={self.eps}" - - def reset_parameters(self) -> None: - torch.nn.init.ones_(self.weight) - - -_RMSNORM_MODULES = {"torch": RMSNorm} - - -def get_rmsnorm( - normalized_shape: int, - eps: float, - normalization_implementation: str = "torch", -) -> torch.nn.LayerNorm: - if normalization_implementation in _RMSNORM_MODULES: - return _RMSNORM_MODULES[normalization_implementation]( - normalized_shape=normalized_shape, eps=eps - ) - - raise ValueError( - f"unexpected `normalization_implementation` {normalization_implementation}" - ) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/__init__.py new file mode 100644 index 0000000..42a64c3 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/__init__.py @@ -0,0 +1,19 @@ +import torch.nn as nn + +from .apex import ApexRMSNorm +from .base import RMSNorm +#from .torchtitan import TorchTitanRMSNorm + +# Removing TorchTitanRMSNorm to avoid unecessary imports and checks +_RMSNORM_MODULES = {"torch": RMSNorm, "apex": ApexRMSNorm}#, "torchtitan": TorchTitanRMSNorm} + + +def get_rmsnorm( + normalized_shape: int, + eps: float, + normalization_implementation: str = "torch", +) -> nn.LayerNorm: + if normalization_implementation in _RMSNORM_MODULES: + return _RMSNORM_MODULES[normalization_implementation](normalized_shape=normalized_shape, eps=eps) + + raise ValueError(f"unexpected `normalization_implementation` {normalization_implementation}") diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/apex.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/apex.py new file mode 100644 index 0000000..c91f4e7 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/apex.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn + + +def is_apex_rmsnorm_available() -> bool: + try: + from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction + + return True + except ImportError: + return False + + +if is_apex_rmsnorm_available(): + from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction + + +def apex_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float, memory_efficient: bool) -> torch.Tensor: + normalized_shape = (input.shape[-1],) + return FusedRMSNormAffineMixedDtypesFunction.apply(input, weight, normalized_shape, eps, memory_efficient) + + +class ApexRMSNorm(nn.RMSNorm): + def __init__(self, normalized_shape: int, eps: float = 1e-6) -> None: + if not is_apex_rmsnorm_available(): + raise ImportError("build apex from source") + + super().__init__(normalized_shape, eps=eps) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return apex_rmsnorm(input, self.weight, self.eps, True) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/base.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/base.py new file mode 100644 index 0000000..82dd4a2 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/base.py @@ -0,0 +1,13 @@ +import torch +import torch.nn as nn + + +class RMSNorm(nn.RMSNorm): + def forward(self, input: torch.Tensor) -> torch.Tensor: + input_dtype = input.dtype + + input = input.to(torch.float32) + variance = input.pow(2).mean(-1, keepdim=True) + input = input * torch.rsqrt(variance + self.eps) + + return self.weight * input.to(input_dtype) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/torchtitan.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/torchtitan.py new file mode 100644 index 0000000..c5fd754 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/torchtitan.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Credit +# Tri Dao's Triton LayerNorm: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py +# Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + +"""Code taken from torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py""" + + +import math + +import torch +import torch.nn as nn + +from .....utils import is_triton_available + + +if is_triton_available(): + import triton + import triton.language as tl + + @triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N"], + ) + @triton.jit + def _rms_norm_fwd_kernel( + X, + stride_x, + Y, + stride_y, + W, + Rstd, + eps, + M, # num rows + N, # num cols + block_N: tl.constexpr, + ): + row = tl.program_id(0) + cols = tl.arange(0, block_N) + + # Load input data and weights + mask = cols < N + x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) + + # Compute mean and variance + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + + # Store the reciprocal standard deviation + tl.store(Rstd + row, rstd) + + # Normalize and apply linear transformation + x_hat = x * rstd + y = x_hat * w + + # Write output + tl.store(Y + row * stride_y + cols, y, mask=mask) + + @triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N"], + ) + @triton.jit + def _rms_norm_bwd_kernel_sm( + X, + stride_x, + W, + DY, + stride_dy, + DX, + stride_dx, + Rstd, + DW, + eps, + M, # num rows + N, # num cols + rows_per_program, + block_N: tl.constexpr, + ): + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, block_N) + mask = cols < N + + # Load weights + w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) + + # Accumulate gradients for weights + dw = tl.zeros((block_N,), dtype=tl.float32) + + row_end = min(row_start + rows_per_program, M) + for row in range(row_start, row_end): + # Load input, output gradient, and reciprocal standard deviation + x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32) + rstd = tl.load(Rstd + row) + + # Compute normalized input and gradients + x_hat = x * rstd + wdy = w * dy + dw += dy * x_hat + c1 = tl.sum(x_hat * wdy, axis=0) / N + dx = (wdy - x_hat * c1) * rstd + + # Store input gradient + tl.store(DX + row * stride_dx + cols, dx, mask=mask) + + # Store weight gradients + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + + +class _TorchTitanRMSNorm(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: + x_shape_start = x.shape + + # Flatten input + x = x.view(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if weight.stride(-1) != 1: + weight = weight.contiguous() + + M, N = x.shape + y = torch.empty_like(x) + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + + max_size = 65536 // x.element_size() + block_N = min(max_size, triton.next_power_of_2(N)) + + if N > block_N: + raise ValueError(f"N {N} must be <= {block_N=}") + + grid = lambda meta: (M,) + _rms_norm_fwd_kernel[grid](x, x.stride(0), y, y.stride(0), weight, rstd, eps, M, N, block_N) + + ctx.eps = eps + ctx.save_for_backward(x, weight, rstd) + ctx.x_shape_start = x_shape_start + + y = y.reshape(x_shape_start) + return y + + @staticmethod + def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None]: + x, weight, rstd = ctx.saved_tensors + eps = ctx.eps + x_shape_start = ctx.x_shape_start + + # Flatten input and output gradients + dy = dy.view(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + + M, N = dy.shape + dx = torch.empty_like(x) + dw = torch.empty_like(weight) + + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + + max_size = 65536 // x.element_size() + block_N = min(max_size, triton.next_power_of_2(N)) + rows_per_sm = math.ceil(M / sm_count) + + if N > block_N: + raise ValueError(f"N {N} must be <= {block_N=}") + + grid = lambda meta: (sm_count,) + _rms_norm_bwd_kernel_sm[grid]( + x, x.stride(0), weight, dy, dy.stride(0), dx, dx.stride(0), rstd, _dw, eps, M, N, rows_per_sm, block_N + ) + dw = _dw.sum(0).to(weight.dtype) + dx = dx.view(x_shape_start) + return dx, dw, None + + +def torchtitan_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: + return _TorchTitanRMSNorm.apply(input, weight, eps) + + +class TorchTitanRMSNorm(nn.RMSNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torchtitan_rmsnorm(x, self.weight, self.eps) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/__init__.py index a16ee80..e82f7cf 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/__init__.py @@ -1,6 +1,2 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Local from .alibi import Alibi -from .rope import RoPE, apply_rotary_pos_emb +from .rope import RoPE, YaRNScaledRoPE, apply_rotary_pos_emb diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/alibi.py b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/alibi.py index d41e207..3f49177 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/alibi.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/alibi.py @@ -1,16 +1,10 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -"""copied from BLOOM's code with some minor changes""" - -# Standard import math -# Third Party import torch +import torch.nn as nn -class Alibi(torch.nn.Module): +class Alibi(nn.Module): def __init__(self, num_heads: int) -> None: super().__init__() self.num_heads = num_heads @@ -19,73 +13,32 @@ def __init__(self, num_heads: int) -> None: def forward( self, - attention_mask: torch.Tensor, + attention_mask: torch.Tensor | None, batch_size: int, key_length: int, device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: - """ - Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it - relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value - `softmax(l+a) = softmax(l)`. Based on - https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 - TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. - - Args: - attention_mask (torch.Tensor): attention_mask tensor of shape (`batch_size`, `key_length`) - num_heads (int): `num_heads` for the model - batch_size (int): `batch_size` - key_length (int): `key_length` - device (torch.device): device for the tensors - dtype (torch.dtype): dtype to use for the tensors - - Returns: - torch.Tensor: alibi tensor of shape (`batch_size`, `num_heads`, `key_length`) - """ - - # Note: alibi will added to the attention bias that will be applied to the query, key product of attention - # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) - # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) - # => the query_length dimension will then be broadcasted correctly - # This is more or less identical to T5's relative position bias: - # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 if attention_mask is None: arange_tensor = ( - torch.arange(key_length, device=device) - .unsqueeze(0) - .unsqueeze(0) - .expand(batch_size, -1, -1) + torch.arange(key_length, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, -1, -1) ) else: - arange_tensor = ( - (attention_mask.cumsum(dim=-1) - 1) - .masked_fill_(attention_mask == 0, 0) - .unsqueeze(1) - ) + arange_tensor = (attention_mask.cumsum(dim=-1) - 1).masked_fill_(attention_mask == 0, 0).unsqueeze(1) alibi = self.slopes.unsqueeze(1) * arange_tensor return alibi.to(dtype) def reset_parameters(self) -> None: closest_power_of_2 = 2 ** math.floor(math.log2(self.num_heads)) - base = torch.tensor( - 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32 - ) + base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != self.num_heads: - extra_base = torch.tensor( - 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), - dtype=torch.float32, - ) - num_remaining_heads = min( - closest_power_of_2, self.num_heads - closest_power_of_2 - ) - extra_powers = torch.arange( - 1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32 - ) + extra_base = torch.tensor(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32) + num_remaining_heads = min(closest_power_of_2, self.num_heads - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) self.register_buffer("slopes", slopes, persistent=False) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py index 1dd5bd6..71c5916 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py @@ -1,16 +1,12 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- """Logic is copied from transformers.models.llama.modeling_utils with slight modifications""" -# Standard -from typing import Tuple +import math -# Third Party import torch +import torch.nn as nn -class RoPE(torch.nn.Module): +class RoPE(nn.Module): def __init__( self, head_dim: int, @@ -26,9 +22,7 @@ def __init__( self.reset_parameters() - def forward( - self, seq_len: int, dtype: torch.dtype, device: torch.device - ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, seq_len: int, dtype: torch.dtype, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype) @@ -38,40 +32,78 @@ def forward( return cos, sin def reset_parameters(self) -> None: - inv_freq = 1.0 / ( - self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=self.max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype(), - ) + self._set_cos_sin_cache(seq_len=self.max_position_embeddings, device=None, dtype=torch.float32) @torch.no_grad() - def _set_cos_sin_cache( - self, seq_len: int, device: torch.device, dtype: torch.dtype - ) -> None: + def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None: self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) - freqs = torch.outer(t, self.inv_freq) + inv_freq = self._get_inv_freq(device) + t = torch.arange(self.max_seq_len_cached, dtype=torch.float32, device=device) + + freqs = torch.outer(t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - "cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False + self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False) + self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False) + + def _get_inv_freq(self, device: torch.device) -> torch.Tensor: + return 1.0 / ( + self.base ** (torch.arange(0, self.head_dim, 2, dtype=torch.float32, device=device) / self.head_dim) ) - self.register_buffer( - "sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False + + +class YaRNScaledRoPE(RoPE): + def __init__( + self, + head_dim: int, + max_position_embeddings: int = 2048, + base: int = 10000, + scale: float = 1, + original_max_position_embeddings: int = 2048, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + nn.Module.__init__(self) + + self.head_dim = head_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.scale = scale + self.original_max_position_embeddings = original_max_position_embeddings + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + + # Get n-d magnitude scaling corrected for interpolation + self.mscale = _yarn_get_mscale(self.scale) * self.attn_factor + + self.reset_parameters() + + def _get_inv_freq(self, device: torch.device) -> torch.Tensor: + pos_freqs = self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (self.scale * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, self.beta_slow, self.head_dim, self.base, self.original_max_position_embeddings ) + inv_freq_mask = ( + 1 - _yarn_linear_ramp_mask(low, high, self.head_dim // 2).float() + ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation + inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + + return inv_freq def apply_rotary_pos_emb( - x: torch.Tensor, cos_sin: Tuple[torch.Tensor, torch.Tensor] -) -> Tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, cos_sin: tuple[torch.Tensor, torch.Tensor] +) -> tuple[torch.Tensor, torch.Tensor]: cos, sin = cos_sin x = (x * cos) + (_rotate_half(x) * sin) return x @@ -80,3 +112,34 @@ def apply_rotary_pos_emb( def _rotate_half(x: torch.Tensor) -> torch.Tensor: x1, x2 = torch.chunk(x, 2, dim=-1) return torch.cat((-x2, x1), dim=-1) + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim( + num_rotations: int, dim: int, base: int = 10000, max_position_embeddings: int = 2048 +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: int, high_rot: int, dim: int, base: int = 10000, max_position_embeddings: int = 2048 +) -> int: + low = math.floor(_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(_yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask(min: float, max: float, dim: int) -> torch.Tensor: + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def _yarn_get_mscale(scale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 diff --git a/src/instructlab/dolomite/hf_models/models/__init__.py b/src/instructlab/dolomite/hf_models/models/__init__.py index 684111f..871910e 100644 --- a/src/instructlab/dolomite/hf_models/models/__init__.py +++ b/src/instructlab/dolomite/hf_models/models/__init__.py @@ -2,4 +2,4 @@ # Extracted from https://github.com/ibm-granite/dolomite-engine # ---------------------------------------------------------------- # Local -from .gpt_dolomite import GPTDolomiteForCausalLM, GPTDolomiteModel +from .gpt_dolomite import GPTDolomiteForCausalLM, GPTDolomiteModel, GPTDolomiteConfig diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/__init__.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/__init__.py index d121b1e..347102e 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/__init__.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/__init__.py @@ -1,7 +1,4 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Local from .base import GPTDolomiteModel, GPTDolomitePreTrainedModel +from .config import GPTDolomiteConfig from .main import GPTDolomiteForCausalLM from .mlp import interleave_up_gate_tensor_for_mlp, split_up_gate_tensor_for_mlp diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/base.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/base.py index e9753bb..c9bee9d 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/base.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/base.py @@ -1,739 +1,12 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Standard -from typing import List, Tuple, Union -import warnings - -# Third Party -from transformers import DynamicCache, PreTrainedModel -from transformers.modeling_outputs import BaseModelOutputWithPast -import torch - -# Local -from ...config import GPTDolomiteConfig -from ...enums import AttentionHeadType, PositionEmbeddingType -from ...modeling_utils import Alibi, RMSNorm, RoPE, get_normalization_function -from ...utils import check_list_type, flatten_and_convert_to_tensors +from ...mixins import BaseModelMixin, PreTrainedModelMixin +from .config import GPTDolomiteConfig from .layer import GPTDolomiteBlock -DEFAULT_NORMALIZATION_IMPLEMENTATION = "torch" - - -class GPTDolomitePreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ +class GPTDolomitePreTrainedModel(PreTrainedModelMixin): config_class = GPTDolomiteConfig - base_model_prefix = "transformer" - causal = True + layer_class = GPTDolomiteBlock _no_split_modules = ["GPTDolomiteBlock"] - _skip_keys_device_placement = "past_key_values" - _supports_sdpa = True - _supports_flash_attn_2 = True - - def __init__(self, config: GPTDolomiteConfig, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - - self.normalization_implementation = kwargs.get( - "normalization_implementation", DEFAULT_NORMALIZATION_IMPLEMENTATION - ) - - self.attention_implementation = self.config._attn_implementation - self._use_eager_attention = self.attention_implementation == "eager" - self._use_sdpa = self.attention_implementation == "sdpa" - self._use_flash_attention_2 = ( - self.attention_implementation == "flash_attention_2" - ) - self._use_padding_free_transformer = kwargs.get( - "use_padding_free_transformer", False - ) - - self._tied_word_embeddings = config.tie_word_embeddings - - if self._use_padding_free_transformer: - assert ( - self._use_flash_attention_2 - ), "padding free transformer only works with flash attention" - - assert any( - [ - self._use_eager_attention, - self._use_sdpa, - self._use_flash_attention_2, - self._use_padding_free_transformer, - ] - ) and not all( - [ - self._use_eager_attention, - self._use_sdpa, - self._use_flash_attention_2, - self._use_padding_free_transformer, - ] - ) - - self.upcast_logits_for_loss = config.upcast_logits_for_loss - - def _init_weights(self, module: torch.nn.Module) -> None: - if isinstance( - module, - ( - torch.nn.Embedding, - torch.nn.Linear, - torch.nn.LayerNorm, - RMSNorm, - Alibi, - RoPE, - ), - ): - module.reset_parameters() - - def get_autoregressive_language_modeling_loss( - self, lm_logits: torch.Tensor, labels: torch.Tensor, cu_seqlens: torch.Tensor - ) -> torch.Tensor: - if labels is None: - return None - - if self._use_padding_free_transformer: - shift_logits = lm_logits[:-1, :] - shift_labels = labels[1:].to(shift_logits.device) - - # this is needed so that the last token of current example doesn't predict first token of next example - drop_loss_positions = cu_seqlens[1:-1] - 1 - shift_labels[drop_loss_positions] = -100 - else: - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous().to(shift_logits.device) - - # Flatten the tokens - loss_fct = torch.nn.CrossEntropyLoss() - if self.upcast_logits_for_loss: - shift_logits = shift_logits.float() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) - ) - - return loss - - def prepare_inputs_for_model( - self, - input_ids: Union[torch.Tensor, List[List[int]]], - inputs_embeds: Union[torch.Tensor, List[List[float]]], - position_ids: Union[torch.Tensor, List[List[int]]], - token_type_ids: Union[torch.Tensor, List[List[int]]], - labels: Union[torch.Tensor, List[List[int]]], - cu_seqlens: torch.Tensor, - max_seqlen: torch.Tensor, - past_key_values: Tuple[Tuple[torch.Tensor]], - attention_mask: torch.Tensor, - use_cache: bool, - output_attentions: bool, - ) -> Tuple[torch.Tensor, torch.Tensor]: - if self._use_padding_free_transformer: - if isinstance(input_ids, list) or isinstance(inputs_embeds, list): - device = torch.cuda.current_device() - - # check input types are correct - error_message = "{variable} should be of type List[List[{dtype}]]" - check_list_type( - input_ids, error_message.format(variable="input_ids", dtype="int") - ) - check_list_type( - inputs_embeds, - error_message.format(variable="inputs_embeds", dtype="float"), - ) - check_list_type( - position_ids, - error_message.format(variable="position_ids", dtype="int"), - ) - check_list_type( - token_type_ids, - error_message.format(variable="token_type_ids", dtype="int"), - ) - check_list_type( - labels, error_message.format(variable="labels", dtype="int") - ) - - # this is managed internally - error_message = ( - "{variable} should not be passed for flash attention when using List[List[int]] " - "input types attention mask logic is handled internally" - ) - assert cu_seqlens is None, error_message.format(variable="cu_seqlens") - assert max_seqlen is None, error_message.format(variable="max_seqlen") - assert attention_mask is None, error_message.format( - variable="attention_mask" - ) - - # prepare inputs for the model - seqlens = torch.tensor([0] + [len(x) for x in input_ids]) - cu_seqlens = seqlens.cumsum(dim=-1).to(device, torch.int32) - max_seqlen = seqlens.max().to(device) - - if position_ids is None: - position_ids = [list(range(len(x))) for x in input_ids] - position_ids = flatten_and_convert_to_tensors(position_ids, device) - - input_ids = flatten_and_convert_to_tensors(input_ids, device) - - if inputs_embeds is not None: - inputs_embeds = flatten_and_convert_to_tensors( - inputs_embeds, device - ) - - if token_type_ids is not None: - token_type_ids = flatten_and_convert_to_tensors( - token_type_ids, device - ) - - if labels is not None: - labels = flatten_and_convert_to_tensors(labels, device) - else: - assert ( - cu_seqlens is not None - ), "cu_seqlens needs to be specified when using tensor inputs with padding_free transformer" - assert ( - position_ids is not None - ), "max_seqlen needs to be specified when specifying cu_seqlens" - assert ( - max_seqlen is not None - ), "max_seqlen needs to be specified when specifying cu_seqlens" - assert ( - attention_mask is None - ), "attention_mask should not be passed when specifying cu_seqlens" - - if use_cache or past_key_values is not None: - raise NotImplementedError( - "KV caching is not supported with padding_free transformer" - ) - - error_message = "{variable} is only supported with math attention" - - assert not output_attentions - - return input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen - - -class GPTDolomiteModel(GPTDolomitePreTrainedModel): - mask_value = None - - def __init__(self, config: GPTDolomiteConfig, **kwargs) -> None: - super().__init__(config, **kwargs) - - self.attention_head_type = AttentionHeadType(config.attention_head_type) - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.num_key_value_heads = config.num_key_value_heads - self.m_emb = config.m_emb - - assert ( - self.embed_dim % self.num_heads == 0 - ), f"`embed_dim` ({self.embed_dim}) must be divisible by `num_heads` ({self.num_heads})" - - self.head_dim = self.embed_dim // self.num_heads - - self.wte = torch.nn.Embedding(config.vocab_size, self.embed_dim) - - self.drop = ( - torch.nn.Identity() - if config.embd_pdrop == 0 - else torch.nn.Dropout(config.embd_pdrop) - ) - self.h = torch.nn.ModuleList( - [ - GPTDolomiteBlock( - config, - self.normalization_implementation, - self.attention_implementation, - self._use_padding_free_transformer, - layer_idx=i, - ) - for i in range(config.num_hidden_layers) - ] - ) - self.ln_f = get_normalization_function( - config.normalization_function, - self.embed_dim, - eps=config.layer_norm_epsilon, - normalization_implementation=self.normalization_implementation, - ) - - self.position_embedding_type = PositionEmbeddingType( - config.position_embedding_type - ) - self._setup_positional_encoding() - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> torch.nn.Embedding: - return self.wte - - def set_input_embeddings(self, new_embeddings: torch.nn.Embedding) -> None: - self.wte = new_embeddings - - def forward( - self, - input_ids: torch.Tensor = None, - past_key_values: DynamicCache = None, - attention_mask: torch.Tensor = None, - token_type_ids: torch.Tensor = None, - position_ids: torch.Tensor = None, - inputs_embeds: torch.Tensor = None, - use_cache: bool = None, - output_hidden_states: bool = None, - return_dict: bool = None, - cu_seqlens: torch.Tensor = None, - max_seqlen: torch.Tensor = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - ( - output_hidden_states, - use_cache, - return_dict, - input_shape, - hidden_states, - attention_mask, - position_ids, - rope_cos_sin, - past_key_values, - ) = self._prepare_a_bunch_of_stuff( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - - # ========================================================================================== - # padding_free: - # attention_mask -> None - # flash: - # attention_mask -> (batch_size, key_length) - # else: - # attention_mask -> (batch_size, 1, query_length, key_length) - # ========================================================================================== - - output_shape = input_shape + (hidden_states.size(-1),) - - past_key_values = ( - DynamicCache() if use_cache and past_key_values is None else past_key_values - ) - all_hidden_states = () if output_hidden_states else None - for block in self.h: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - hidden_states = block( - hidden_states, - past_key_values=past_key_values, - attention_mask=attention_mask, - rope_cos_sin=rope_cos_sin, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - - hidden_states = self.ln_f(hidden_states) - - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [hidden_states, past_key_values, all_hidden_states] - if v is not None - ) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=all_hidden_states, - ) - - def _get_position_ids( - self, - attention_mask: torch.Tensor, - past_length: int, - query_length: int, - key_length: int, - device: torch.device, - ) -> torch.Tensor: - if attention_mask is not None and len(attention_mask.shape) == 2: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 0) - if past_length > 0: - position_ids = position_ids[:, past_length:key_length:] - else: - position_ids = torch.arange( - past_length, key_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, query_length) - - return position_ids - - def _get_alibi_bias( - self, - attention_mask: torch.Tensor, - batch_size: int, - query_length: int, - key_length: int, - device: torch.device, - dtype: torch.dtype, - ) -> torch.Tensor: - if self.position_embedding_type != PositionEmbeddingType.alibi: - return None - - alibi_bias = self.alibi(attention_mask, batch_size, key_length, device, dtype) - - # ========================================================================================== - # alibi_bias -> (batch_size, num_heads, key_length) - # ========================================================================================== - - alibi_bias = alibi_bias.unsqueeze(2) - if query_length != 1: - alibi_bias = alibi_bias.expand(-1, -1, query_length, -1) - - # ========================================================================================== - # alibi_bias -> (batch_size, num_heads, query_length, key_length) - # ========================================================================================== - - return alibi_bias - - def _get_rope_cos_sin( - self, - key_length: int, - position_ids: torch.Tensor, - dtype: torch.dtype, - device: torch.device, - ) -> torch.Tensor: - if self.position_embedding_type == PositionEmbeddingType.rope: - cos, sin = self.rope(key_length, dtype=dtype, device=device) - cos = cos[position_ids].unsqueeze(1) - sin = sin[position_ids].unsqueeze(1) - return cos, sin - - def _prepare_causal_attention_mask( - self, - attention_mask: torch.Tensor, - batch_size: int, - query_length: int, - key_length: int, - device: torch.device, - ) -> torch.Tensor: - past_length = key_length - query_length - - # ========================================================================================== - # attention_mask -> (batch_size, key_length) - # ========================================================================================== - - if query_length > 1: - # (query_length, key_length) - causal_mask = torch.empty( - (query_length, key_length), dtype=torch.bool, device=device - ) - causal_mask[:, past_length:] = torch.tril( - torch.ones(query_length, query_length, dtype=torch.bool, device=device) - ) - - if past_length > 0: - causal_mask[:, :past_length] = True - - # (query_length, key_length) -> (1, query_length, key_length) - causal_mask = causal_mask.unsqueeze(0) - - if attention_mask is None: - # (1, query_length, key_length) -> (batch_size, query_length, key_length) - causal_mask = causal_mask.expand(batch_size, -1, -1) - else: - # (1, query_length, key_length) & (batch_size, 1, key_length) -> (batch_size, query_length, key_length) - causal_mask = causal_mask & attention_mask.unsqueeze(1).to(torch.bool) - else: - if attention_mask is None: - # (batch_size, query_length, key_length) - causal_mask = torch.ones( - batch_size, - query_length, - key_length, - dtype=torch.bool, - device=device, - ) - else: - # (batch_size, query_length, key_length) - causal_mask = attention_mask.unsqueeze(1).to( - dtype=torch.bool, device=device - ) - - # ========================================================================================== - # attention_mask -> (batch_size, query_length, key_length) - # ========================================================================================== - - causal_mask = causal_mask.unsqueeze(1) - - # ========================================================================================== - # attention_mask -> (batch_size, 1, query_length, key_length) - # ========================================================================================== - - return causal_mask - - def _get_initial_hidden_state( - self, - input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - position_ids: torch.Tensor, - token_type_ids: torch.Tensor, - ) -> torch.Tensor: - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - - if self.position_embedding_type == PositionEmbeddingType.learned_absolute: - inputs_embeds = inputs_embeds + self.wpe(position_ids) - - if token_type_ids is not None: - inputs_embeds = inputs_embeds + self.wte(token_type_ids) - - inputs_embeds = self.drop(inputs_embeds) - - if self.m_emb is not None: - inputs_embeds = inputs_embeds * self.m_emb - - return inputs_embeds - - def _prepare_a_bunch_of_stuff( - self, - input_ids: torch.Tensor = None, - past_key_values: DynamicCache = None, - attention_mask: torch.Tensor = None, - token_type_ids: torch.Tensor = None, - position_ids: torch.Tensor = None, - inputs_embeds: torch.Tensor = None, - use_cache: bool = None, - output_hidden_states: bool = None, - return_dict: bool = None, - cu_seqlens: torch.Tensor = None, - max_seqlen: torch.Tensor = None, - ) -> Tuple[ - bool, - bool, - bool, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - Union[Tuple[torch.Tensor], Tuple[Tuple[torch.Tensor, torch.Tensor]]], - ]: - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - - if use_cache is None: - use_cache = ( - False if self._use_padding_free_transformer else self.config.use_cache - ) - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) - if input_ids is not None: - input_shape = input_ids.size() - - # special handling for padding free transformer with list inputs - if self._use_padding_free_transformer: - # for flash attention, there is no padding and we do packing - # so, input_ids is of shape (s1 + s2 + ... + sb) - batch_size = cu_seqlens.shape[0] - 1 - else: - batch_size = input_shape[0] - elif inputs_embeds is not None: - # TODO special handling for padding free transformer needed here if we support inputs_embeds argument - input_shape = inputs_embeds.size()[:-1] - batch_size = input_shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if self._use_padding_free_transformer: - assert position_ids is not None, ( - "GPTDolomiteModel needs position_ids from outside when using flash attention with List[List[int]] " - "inputs" - ) - else: - if self.position_embedding_type == PositionEmbeddingType.alibi: - if position_ids is not None: - warnings.warn( - "`position_ids` have no functionality with Alibi.", - FutureWarning, - ) - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - - # ========================================================================================== - # padding_free: - # input_ids -> (total_q) - # attention_mask -> None - # position_ids -> (total_q) - # else: - # input_ids -> (batch_size, query_length) - # attention_mask -> None or (batch_size, key_length) - # position_ids -> None or (batch_size, key_length) - # ========================================================================================== - - past_length = None - query_length = None - key_length = None - if self._use_padding_free_transformer: - key_length = max_seqlen - else: - past_length = ( - 0 if past_key_values is None else past_key_values.get_seq_length() - ) - query_length = input_shape[-1] - key_length = past_length + query_length - - if position_ids is None: - position_ids = self._get_position_ids( - attention_mask, past_length, query_length, key_length, device - ) - - # ========================================================================================== - # padding_free: - # input_ids -> (total_q) - # attention_mask -> None - # position_ids -> (total_q) - # else: - # input_ids -> (batch_size, query_length) - # attention_mask -> None or (batch_size, key_length) - # position_ids -> (batch_size, query_length) - # ========================================================================================== - - hidden_states = self._get_initial_hidden_state( - input_ids, inputs_embeds, position_ids, token_type_ids - ) - - # ========================================================================================== - # padding_free: - # hidden_states -> (total_q, num_heads * head_dim) - # else: - # hidden_states -> (batch_size, query_length, num_heads * head_dim) - # ========================================================================================== - - alibi_bias = self._get_alibi_bias( - attention_mask, - batch_size, - query_length, - key_length, - device, - hidden_states.dtype, - ) - - # ========================================================================================== - # alibi_bias -> (batch_size, num_heads, query_length, key_length) - # ========================================================================================== - - rope_cos_sin = self._get_rope_cos_sin( - key_length, - position_ids, - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - # ========================================================================================== - # padding_free: - # rope_cos_sin -> 2 * (max_seqlen, head_dim) - # else: - # rope_cos_sin -> 2 * (key_length, head_dim) - # ========================================================================================== - - # prepare causal mask only if not using flash attention - if self._use_sdpa: - # we use the causal/non-causal argument of SDPA for attention in this case - if attention_mask is not None: - attention_mask = self._prepare_causal_attention_mask( - attention_mask, batch_size, query_length, key_length, device - ) - - attention_mask = torch.where( - attention_mask, - ~attention_mask if alibi_bias is None else alibi_bias, - self._get_mask_value(attention_mask.device, hidden_states.dtype), - ) - elif self._use_eager_attention: - attention_mask = self._prepare_causal_attention_mask( - attention_mask, batch_size, query_length, key_length, device - ) - - attention_mask = torch.where( - attention_mask, - ~attention_mask if alibi_bias is None else alibi_bias, - self._get_mask_value(attention_mask.device, hidden_states.dtype), - ) - - return ( - output_hidden_states, - use_cache, - return_dict, - input_shape, - hidden_states, - attention_mask, - position_ids, - rope_cos_sin, - past_key_values, - ) - - def _setup_positional_encoding(self) -> None: - max_position_embeddings = self.config.max_position_embeddings - - if self.position_embedding_type == PositionEmbeddingType.learned_absolute: - self.wpe = torch.nn.Embedding(max_position_embeddings, self.embed_dim) - elif self.position_embedding_type == PositionEmbeddingType.alibi: - assert ( - not self._use_flash_attention_2 - ), "alibi is not implemented with FlashAttention" - self.alibi = Alibi(self.num_heads) - elif self.position_embedding_type == PositionEmbeddingType.rope: - self.rope = RoPE( - self.head_dim, - max_position_embeddings=max_position_embeddings, - base=self.config.rope_theta, - ) - else: - raise NotImplementedError() - def _get_mask_value(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: - # torch.where expects a tensor. We use a cache to avoid recreating it every time. - if ( - self.mask_value is None - or self.mask_value.dtype != dtype - or self.mask_value.device != device - ): - self.mask_value = torch.full( - [], torch.finfo(torch.float16).min, dtype=dtype, device=device - ) - return self.mask_value +class GPTDolomiteModel(GPTDolomitePreTrainedModel, BaseModelMixin): ... diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/config.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/config.py new file mode 100644 index 0000000..8b83592 --- /dev/null +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/config.py @@ -0,0 +1,5 @@ +from ...config import CommonConfig + + +class GPTDolomiteConfig(CommonConfig): + model_type = "gpt_dolomite" diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/layer.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/layer.py index 34fc3b3..5fc15a5 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/layer.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/layer.py @@ -1,21 +1,14 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Standard -from typing import Tuple, Union - -# Third Party -from transformers import DynamicCache import torch +import torch.nn as nn +from transformers import DynamicCache -# Local -from ...config import GPTDolomiteConfig from ...enums import AttentionHeadType from ...modeling_utils import get_attention_module, get_normalization_function +from .config import GPTDolomiteConfig from .mlp import MLP -class GPTDolomiteBlock(torch.nn.Module): +class GPTDolomiteBlock(nn.Module): """ Layer implementation for the transformer block """ @@ -26,7 +19,7 @@ def __init__( normalization_implementation: str, attention_implementation: str, use_padding_free_transformer: bool, - layer_idx: int = None, + layer_idx: int | None = None, ) -> None: super().__init__() @@ -43,11 +36,7 @@ def __init__( normalization_implementation=normalization_implementation, ) self.attn = get_attention_module( - config, - True, - attention_implementation, - use_padding_free_transformer, - layer_idx, + config, True, attention_implementation, use_padding_free_transformer, layer_idx ) self.ln_2 = get_normalization_function( config.normalization_function, @@ -60,20 +49,16 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - past_key_values: DynamicCache = None, - attention_mask: torch.Tensor = None, - rope_cos_sin: torch.Tensor = None, - cu_seqlens: torch.Tensor = None, - max_seqlen: torch.Tensor = None, - ) -> Union[ - Tuple[torch.Tensor], - Tuple[torch.Tensor, torch.Tensor], - Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - ]: + past_key_values: DynamicCache | None = None, + attention_mask: torch.Tensor | None = None, + rope_cos_sin: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + ) -> tuple[torch.Tensor]: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_output = self.attn( + hidden_states = self.attn( hidden_states, past_key_values=past_key_values, attention_mask=attention_mask, @@ -83,20 +68,20 @@ def forward( ) if self.m_residual is not None: - attn_output = attn_output * self.m_residual + hidden_states = hidden_states * self.m_residual # residual connection - hidden_states = attn_output + residual + hidden_states = hidden_states + residual residual = hidden_states hidden_states = self.ln_2(hidden_states) - feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states) if self.m_residual is not None: - feed_forward_hidden_states = feed_forward_hidden_states * self.m_residual + hidden_states = hidden_states * self.m_residual # residual connection - hidden_states = residual + feed_forward_hidden_states + hidden_states = hidden_states + residual return hidden_states diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/main.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/main.py index 8d84c28..cba1599 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/main.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/main.py @@ -1,186 +1,6 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Standard -from typing import List, Optional, Tuple, Union - -# Third Party -from transformers import DynamicCache -from transformers.modeling_outputs import CausalLMOutputWithPast -import torch -import torch.nn.functional as F - -# Local -from ...config import GPTDolomiteConfig +from ...mixins import CausalLMModelMixin from .base import GPTDolomiteModel, GPTDolomitePreTrainedModel -class GPTDolomiteForCausalLM(GPTDolomitePreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config: GPTDolomiteConfig, **kwargs) -> None: - super().__init__(config, **kwargs) - self.transformer = GPTDolomiteModel(config, **kwargs) - - if not self._tied_word_embeddings: - self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False) - - self.m_width = config.m_width - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> torch.nn.Embedding: - return self.transformer.wte - - def set_input_embeddings(self, value: torch.nn.Embedding) -> None: - self.transformer.wte = value - - def get_output_embeddings(self) -> torch.nn.Linear: - if not self._tied_word_embeddings: - return self.lm_head - - def set_output_embeddings(self, new_embeddings: torch.nn.Linear) -> None: - if not self._tied_word_embeddings: - self.lm_head = new_embeddings - - # FIXME typing - def prepare_inputs_for_generation( - self, - input_ids: torch.Tensor, - past_key_values: Optional[DynamicCache] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> dict: - token_type_ids = kwargs.get("token_type_ids", None) - # Omit tokens covered by past_key_values - if past_key_values: - past_length = past_key_values.get_seq_length() - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -input_ids.shape[1] :] - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 0) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - else: - position_ids = None - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - } - ) - return model_inputs - - def forward( - self, - input_ids: Union[torch.Tensor, List[List[int]]] = None, - past_key_values: DynamicCache = None, - attention_mask: torch.Tensor = None, - token_type_ids: Union[torch.Tensor, List[List[int]]] = None, - position_ids: Union[torch.Tensor, List[List[int]]] = None, - inputs_embeds: Union[torch.Tensor, List[List[float]]] = None, - labels: Union[torch.Tensor, List[List[int]]] = None, - use_cache: bool = None, - output_attentions: bool = None, - output_hidden_states: bool = None, - return_dict: bool = None, - cu_seqlens: torch.Tensor = None, - max_seqlen: torch.Tensor = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = ( - self.prepare_inputs_for_model( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - token_type_ids=token_type_ids, - labels=labels, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - past_key_values=past_key_values, - attention_mask=attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - ) - - # ========================================================================================== - # padding_free: - # input_ids -> (total_q) - # attention_mask -> None - # position_ids -> (total_q) - # else: - # input_ids -> (batch_size, query_length) - # attention_mask -> None or (batch_size, key_length) - # position_ids -> None or (batch_size, key_length) - # ========================================================================================== - - # pylint: disable=duplicate-code - transformer_outputs = self.transformer( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - hidden_states = transformer_outputs[0] - - lm_logits = ( - F.linear(hidden_states, self.transformer.wte.weight) - if self._tied_word_embeddings - else self.lm_head(hidden_states) - ) - - if self.m_width is not None: - lm_logits = lm_logits / self.m_width - - loss = self.get_autoregressive_language_modeling_loss( - lm_logits, labels, cu_seqlens - ) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) +class GPTDolomiteForCausalLM(GPTDolomitePreTrainedModel, CausalLMModelMixin): + base_model_class = GPTDolomiteModel diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/mlp.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/mlp.py index 7e5b214..b94e41a 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/mlp.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/mlp.py @@ -1,18 +1,14 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Standard -from typing import Tuple +import math -# Third Party import torch +import torch.nn as nn -# Local -from ...config import GPTDolomiteConfig -from ...modeling_utils import get_activation_function, is_glu +from ...enums import InitMethod +from ...modeling_utils import ParameterizedLinear, get_activation_function, is_glu +from .config import GPTDolomiteConfig -class MLP(torch.nn.Module): +class MLP(nn.Module): def __init__(self, config: GPTDolomiteConfig) -> None: super().__init__() @@ -22,21 +18,29 @@ def __init__(self, config: GPTDolomiteConfig) -> None: add_bias = config.add_bias residual_dropout = config.resid_pdrop - self.c_fc = torch.nn.Linear( + init_method = InitMethod(config.init_method) + initializer_range = config.initializer_range + m_width = config.m_width + n_layer = config.n_layer + + std = initializer_range + if init_method == InitMethod.mup: + std /= math.sqrt(m_width) + self.c_fc = ParameterizedLinear( hidden_size, 2 * intermediate_size if is_glu(activation_function) else intermediate_size, bias=add_bias, + std=std, ) self.act = get_activation_function(activation_function) - self.c_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=add_bias) + std = initializer_range / math.sqrt(2 * n_layer) + if init_method == InitMethod.mup: + std /= math.sqrt(m_width) + self.c_proj = ParameterizedLinear(intermediate_size, hidden_size, bias=add_bias, std=std) - self.dropout = ( - torch.nn.Identity() - if residual_dropout == 0 - else torch.nn.Dropout(residual_dropout) - ) + self.dropout = nn.Identity() if residual_dropout == 0 else nn.Dropout(residual_dropout) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.c_fc(hidden_states) @@ -46,13 +50,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -def interleave_up_gate_tensor_for_mlp( - up_weight: torch.Tensor, gate_weight: torch.Tensor -) -> torch.Tensor: +def interleave_up_gate_tensor_for_mlp(up_weight: torch.Tensor, gate_weight: torch.Tensor) -> torch.Tensor: return torch.cat([up_weight, gate_weight]) -def split_up_gate_tensor_for_mlp( - c_fc_weight: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: +def split_up_gate_tensor_for_mlp(c_fc_weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return c_fc_weight.chunk(2) diff --git a/src/instructlab/dolomite/hf_models/register_hf.py b/src/instructlab/dolomite/hf_models/register_hf.py index d264fd8..e92e456 100644 --- a/src/instructlab/dolomite/hf_models/register_hf.py +++ b/src/instructlab/dolomite/hf_models/register_hf.py @@ -1,12 +1,11 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Third Party -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM + +from .models import ( + GPTDolomiteConfig, + GPTDolomiteForCausalLM, + GPTDolomiteModel, +) -# Local -from .config import GPTDolomiteConfig -from .models import GPTDolomiteForCausalLM, GPTDolomiteModel # (AutoConfig, AutoModel, AutoModelForCausalLM) _CUSTOM_MODEL_REGISTRY = [ @@ -17,11 +16,7 @@ def register_model_classes() -> None: - for ( - config_class, - auto_model_class, - auto_model_for_causal_lm_class, - ) in _CUSTOM_MODEL_REGISTRY: + for config_class, auto_model_class, auto_model_for_causal_lm_class in _CUSTOM_MODEL_REGISTRY: model_type = config_class.model_type AutoConfig.register(model_type, config_class) @@ -30,3 +25,7 @@ def register_model_classes() -> None: _CUSTOM_MODEL_TYPES.append(model_type) _CUSTOM_MODEL_CLASSES.append(auto_model_for_causal_lm_class) + + +def is_custom_model(model_class: type[AutoModelForCausalLM] | type[AutoModelForSeq2SeqLM], model_type: str) -> bool: + return model_class.__name__ in _CUSTOM_MODEL_CLASSES or model_type in _CUSTOM_MODEL_TYPES diff --git a/src/instructlab/dolomite/hf_models/utils.py b/src/instructlab/dolomite/hf_models/utils.py index e270ac4..d6ae749 100644 --- a/src/instructlab/dolomite/hf_models/utils.py +++ b/src/instructlab/dolomite/hf_models/utils.py @@ -1,16 +1,63 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Standard -from typing import List, Union - -# Third Party import torch -def check_list_type( - list_of_list: List[List[Union[int, float]]], error_message: str -) -> None: +def divide_if_divisible(dividend: int, divisor: int, msg: str) -> int: + """divide if divisible else raise an error + + Args: + dividend (int): dividend + divisor (int): divisor + msg (str): error message + + Returns: + int: result + """ + + assert dividend % divisor == 0, msg + return dividend // divisor + + +def convert_padding_free_lists_to_tensors( + input_ids: list[list[int]] | None = None, + inputs_embeds: list[list[float]] | None = None, + position_ids: list[list[int]] | None = None, + token_type_ids: list[list[int]] | None = None, + labels: list[list[int]] | None = None, + device: torch.device = None, +) -> tuple[torch.Tensor]: + + # check input types are correct + error_message = "{variable} should be of type List[List[{dtype}]]" + _check_list_type(input_ids, error_message.format(variable="input_ids", dtype="int")) + _check_list_type(inputs_embeds, error_message.format(variable="inputs_embeds", dtype="float")) + _check_list_type(position_ids, error_message.format(variable="position_ids", dtype="int")) + _check_list_type(token_type_ids, error_message.format(variable="token_type_ids", dtype="int")) + _check_list_type(labels, error_message.format(variable="labels", dtype="int")) + + # prepare inputs for the model + seqlens = torch.tensor([0] + [len(x) for x in input_ids], device=device) + cu_seqlens = seqlens.cumsum(dim=-1).to(torch.int32) + max_seqlen = seqlens.max().to(device) + + if position_ids is None: + position_ids = [list(range(len(x))) for x in input_ids] + position_ids = _flatten_and_convert_to_tensors(position_ids, device) + + input_ids = _flatten_and_convert_to_tensors(input_ids, device) + + if inputs_embeds is not None: + inputs_embeds = _flatten_and_convert_to_tensors(inputs_embeds, device) + + if token_type_ids is not None: + token_type_ids = _flatten_and_convert_to_tensors(token_type_ids, device) + + if labels is not None: + labels = _flatten_and_convert_to_tensors(labels, device) + + return input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen + + +def _check_list_type(list_of_list: list[list[int | float]] | None, error_message: str) -> None: if list_of_list is None: return @@ -18,7 +65,7 @@ def check_list_type( assert isinstance(list_of_list[0], list), error_message -def flatten_and_convert_to_tensors(x: List[int], device: torch.device) -> torch.Tensor: +def _flatten_and_convert_to_tensors(x: list[int], device: torch.device) -> torch.Tensor: y = [] for sequence in x: y.extend(sequence) diff --git a/src/instructlab/dolomite/utils/hf_hub.py b/src/instructlab/dolomite/utils/hf_hub.py index 30054e0..82d3431 100644 --- a/src/instructlab/dolomite/utils/hf_hub.py +++ b/src/instructlab/dolomite/utils/hf_hub.py @@ -1,17 +1,11 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Standard -from typing import Tuple import os -# Third Party from transformers import AutoConfig, AutoTokenizer from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, cached_file from transformers.utils.hub import get_checkpoint_shard_files -def download_repo(repo_name_or_path: str) -> Tuple[AutoConfig, AutoTokenizer, str]: +def download_repo(repo_name_or_path: str) -> tuple[AutoConfig | None, AutoTokenizer | None, str]: config = _download_config(repo_name_or_path) tokenizer = _download_tokenizer(repo_name_or_path) model_path = None @@ -23,33 +17,31 @@ def download_repo(repo_name_or_path: str) -> Tuple[AutoConfig, AutoTokenizer, st try: model_path = cached_file(repo_name_or_path, SAFE_WEIGHTS_NAME) model_path = os.path.dirname(model_path) - except: # pylint: disable=bare-except + except: # try downloading model weights if they are sharded try: - sharded_filename = cached_file( - repo_name_or_path, SAFE_WEIGHTS_INDEX_NAME - ) + sharded_filename = cached_file(repo_name_or_path, SAFE_WEIGHTS_INDEX_NAME) get_checkpoint_shard_files(repo_name_or_path, sharded_filename) model_path = os.path.dirname(sharded_filename) - except: # pylint: disable=bare-except + except: pass return config, tokenizer, model_path -def _download_config(repo_name_or_path: str) -> AutoConfig: +def _download_config(repo_name_or_path: str) -> AutoConfig | None: try: config = AutoConfig.from_pretrained(repo_name_or_path) - except: # pylint: disable=bare-except + except: config = None return config -def _download_tokenizer(repo_name_or_path: str) -> AutoTokenizer: +def _download_tokenizer(repo_name_or_path: str) -> AutoTokenizer | None: try: tokenizer = AutoTokenizer.from_pretrained(repo_name_or_path) - except: # pylint: disable=bare-except + except: tokenizer = None return tokenizer diff --git a/src/instructlab/dolomite/utils/safetensors.py b/src/instructlab/dolomite/utils/safetensors.py index cc19b75..a9ffd0b 100644 --- a/src/instructlab/dolomite/utils/safetensors.py +++ b/src/instructlab/dolomite/utils/safetensors.py @@ -1,19 +1,11 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Standard import json import os -# Third Party +import torch +from huggingface_hub import split_torch_state_dict_into_shards from safetensors import safe_open from safetensors.torch import save_file -from transformers.modeling_utils import ( - SAFE_WEIGHTS_INDEX_NAME, - SAFE_WEIGHTS_NAME, - shard_checkpoint, -) -import torch +from transformers.modeling_utils import SAFE_WEIGHTS_INDEX_NAME class SafeTensorsWeightsManager: @@ -41,7 +33,7 @@ def get_slice(self, tensor_name: str): return f.get_slice(tensor_name) def get_tensor( - self, tensor_name: str, dtype: torch.dtype = None, device: torch.device = None + self, tensor_name: str, dtype: torch.dtype | None = None, device: torch.device | None = None ) -> torch.Tensor: filename = self.tensor_filenames[tensor_name] f = self.file_handles[filename] @@ -59,8 +51,9 @@ def has_tensor(self, tensor_name: str) -> bool: def __len__(self) -> int: return len(self.tensor_filenames) - def __iter__(self) -> str: - yield from self.tensor_filenames + def __iter__(self): + for tensor_name in self.tensor_filenames: + yield tensor_name def __eq__(self, __value: object) -> bool: if not isinstance(__value, SafeTensorsWeightsManager): @@ -83,21 +76,22 @@ def state_dict(self) -> dict: @staticmethod def save_state_dict(state_dict: dict, save_path: str) -> None: - os.makedirs(save_path) + os.makedirs(save_path, exist_ok=True) - shards, index = shard_checkpoint( - state_dict, max_shard_size="5GB", weights_name=SAFE_WEIGHTS_NAME - ) - - for shard_file, shard in shards.items(): + state_dict_split = split_torch_state_dict_into_shards(state_dict) + for filename, tensors in state_dict_split.filename_to_tensors.items(): + shard = {tensor: state_dict[tensor] for tensor in tensors} save_file( - shard, os.path.join(save_path, shard_file), metadata={"format": "pt"} + shard, + os.path.join(save_path, filename), + metadata={"format": "pt"}, ) - if index is not None: + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + with open(os.path.join(save_path, SAFE_WEIGHTS_INDEX_NAME), "w") as f: - json.dump( - index, - f, - indent=4, - ) + f.write(json.dumps(index, indent=2))