From 97cb9a827d0bee9585775f883578817cfb1c1486 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 9 Jun 2024 01:55:18 +0000 Subject: [PATCH 01/88] start of layer dropout implementation --- torchtune/modules/__init__.py | 2 ++ torchtune/modules/layer_dropout.py | 38 ++++++++++++++++++++++++++++++ torchtune/modules/transformer.py | 7 +++++- 3 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 torchtune/modules/layer_dropout.py diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index d767956526..7ae955d9dd 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from .attention import CausalSelfAttention # noqa +from .layer_dropout import LayerDropout # noqa from .common_utils import reparametrize_as_dtype_state_dict_post_hook from .feed_forward import FeedForward # noqa from .kv_cache import KVCache # noqa @@ -24,4 +25,5 @@ "TransformerDecoderLayer", "TransformerClassifier", "reparametrize_as_dtype_state_dict_post_hook", + "LayerDropout", ] diff --git a/torchtune/modules/layer_dropout.py b/torchtune/modules/layer_dropout.py new file mode 100644 index 0000000000..fa71a8e993 --- /dev/null +++ b/torchtune/modules/layer_dropout.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable +import torch + +class LayerDropout(torch.nn.Module): + def __init__(self, prob=0.0, dim=0, disable_on_eval=True): + super().__init__() + self.prob: float = prob + self.dim = dim + self.disable_on_eval: bool = disable_on_eval + self.generator = torch.Generator(device="cpu") + self.inferred: float = None + + def forward(self, function: Callable, input: torch.Tensor, *args, **kwargs): + n = input.shape[self.dim] + + if self.prob == 0 or (self.disable_on_eval and self.training is False): + self.inferred = 1.0 + return function(input, *args, **kwargs) + + skip = torch.bernoulli(torch.Tensor((n) * [self.prob]), generator=self.generator).to(input.device).to(input.dtype) + self.inferred = 1 - torch.mean(skip) + ind_selected = (skip == 0).nonzero().squeeze().to(input.device) + + if ind_selected.numel() > 0: + x_selected = torch.index_select(input, self.dim, ind_selected) + out_selected = function(x_selected, *args, **kwargs) + + out = input.clone() + assert self.dim == 0, "Currently only supporting dropping elements along the 0th dimension" + if ind_selected.numel() > 0: + out[ind_selected] = out_selected + return out diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index bf0af8b797..0c3121007d 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -10,6 +10,7 @@ from torch import nn, Tensor from torchtune.modules import CausalSelfAttention, KVCache +from torchtune.modules import LayerDropout class TransformerDecoderLayer(nn.Module): @@ -121,6 +122,8 @@ class TransformerDecoder(nn.Module): before final MLP. output (nn.Linear): Callable that applies a linear transformation to the output of the decoder. + layer_dropout_prob (float): Probability of skipping samples in the transformer + layer. Note: Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1]) @@ -138,6 +141,7 @@ def __init__( head_dim: int, norm: nn.Module, output: nn.Linear, + layer_dropout_prob: float= 0.0, ) -> None: super().__init__() @@ -149,6 +153,7 @@ def __init__( self.num_heads = num_heads self.head_dim = head_dim self.causal_mask = None + self.layer_dropout = LayerDropout(layer_dropout_prob) def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: """Setup key value caches for attention calculation. @@ -242,7 +247,7 @@ def forward( for layer in self.layers: # shape: [b, s, d] - h = layer(h, mask=mask, input_pos=input_pos) + h = self.layer_dropout(layer, h, mask=mask, input_pos=input_pos) # shape: [b, s, d] h = self.norm(h) From 4a25c5bd9bbaca05f4ce5ace7ba987de8031ddda Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 9 Jun 2024 04:18:15 +0000 Subject: [PATCH 02/88] have different dropouts at different layers --- torchtune/modules/__init__.py | 3 +- torchtune/modules/layer_dropout.py | 49 ++++++++++++++++++++++++++++-- torchtune/modules/transformer.py | 9 +++--- 3 files changed, 54 insertions(+), 7 deletions(-) diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 7ae955d9dd..93aeeb10df 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from .attention import CausalSelfAttention # noqa -from .layer_dropout import LayerDropout # noqa +from .layer_dropout import LayerDropout, create_layer_dropout_modules # noqa from .common_utils import reparametrize_as_dtype_state_dict_post_hook from .feed_forward import FeedForward # noqa from .kv_cache import KVCache # noqa @@ -26,4 +26,5 @@ "TransformerClassifier", "reparametrize_as_dtype_state_dict_post_hook", "LayerDropout", + "create_layer_dropout_modules", ] diff --git a/torchtune/modules/layer_dropout.py b/torchtune/modules/layer_dropout.py index fa71a8e993..91595c60d0 100644 --- a/torchtune/modules/layer_dropout.py +++ b/torchtune/modules/layer_dropout.py @@ -4,11 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable +from enum import Enum +from typing import Callable, Optional +import math import torch class LayerDropout(torch.nn.Module): - def __init__(self, prob=0.0, dim=0, disable_on_eval=True): + def __init__(self, prob=0.0, dim=0, disable_on_eval=True, seed=None): super().__init__() self.prob: float = prob self.dim = dim @@ -16,6 +18,9 @@ def __init__(self, prob=0.0, dim=0, disable_on_eval=True): self.generator = torch.Generator(device="cpu") self.inferred: float = None + if seed is not None: + self.generator.manual_seed(seed) + def forward(self, function: Callable, input: torch.Tensor, *args, **kwargs): n = input.shape[self.dim] @@ -36,3 +41,43 @@ def forward(self, function: Callable, input: torch.Tensor, *args, **kwargs): if ind_selected.numel() > 0: out[ind_selected] = out_selected return out + +class ScaleType(str, Enum): + UNIFORM = "uniform" + EXP = "exp" + LINEAR = "linear" + LOG = "log" + SIN = "sin" + SIGMOID = "sigmoid" + STEP = "step" + +def get_scale(scale_type: ScaleType, scale_period: int, val: int): + if scale_period == 0: + return 1 + + # all the equations below aim to make scale = 0 when val=0, and scale = 1 when val=scale_period + return { + ScaleType.UNIFORM: 1, + ScaleType.EXP: math.exp(val * math.log(2) / scale_period) - 1, + ScaleType.LINEAR: val / scale_period, + ScaleType.LOG: math.log(val + 1) / math.log(scale_period + 1), + ScaleType.SIN: math.sin(0.5 * math.pi * val / scale_period), + ScaleType.SIGMOID: 1 / (1 + math.exp(-10 * (val / scale_period - 0.5))), + ScaleType.STEP: 0 if val < scale_period else 1 + }[scale_type] + +def create_layer_dropout_modules(num_layers: int, prob_max: float= 0.0, prob_layer_scale: ScaleType = ScaleType.EXP, prob_layer_scale_period: Optional[int] = None, disable_on_eval: bool = True): + layer_dropouts = torch.nn.ModuleList() + + for layer_id in range(num_layers): + prob = prob_max * get_scale( + scale_type = prob_layer_scale, + scale_period = num_layers - 1 if prob_layer_scale_period is None else prob_layer_scale_period, + val = layer_id, + ) + assert prob >= 0.0 and prob <= prob_max, f"prob={prob} should be between 0 and {prob_max}" + # We would like each layer to have a different seed, so that we don't have the same samples skipped across layers. Hence, we use the layer_id as a seed for each layer's dropout. + layer_dropout = LayerDropout(prob, disable_on_eval=disable_on_eval, seed=layer_id) + layer_dropouts.append(layer_dropout) + + return layer_dropouts diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 0c3121007d..3b5a2e9e45 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -10,7 +10,7 @@ from torch import nn, Tensor from torchtune.modules import CausalSelfAttention, KVCache -from torchtune.modules import LayerDropout +from torchtune.modules import LayerDropout, create_layer_dropout_modules class TransformerDecoderLayer(nn.Module): @@ -153,7 +153,8 @@ def __init__( self.num_heads = num_heads self.head_dim = head_dim self.causal_mask = None - self.layer_dropout = LayerDropout(layer_dropout_prob) + + self.layer_dropouts = create_layer_dropout_modules(num_layers, layer_dropout_prob, "exp") def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: """Setup key value caches for attention calculation. @@ -245,9 +246,9 @@ def forward( # in most cases input_pos_len should be 1 mask = self.causal_mask[None, input_pos] - for layer in self.layers: + for i, layer in enumerate(self.layers): # shape: [b, s, d] - h = self.layer_dropout(layer, h, mask=mask, input_pos=input_pos) + h = self.layer_dropouts[i](layer, h, mask=mask, input_pos=input_pos) # shape: [b, s, d] h = self.norm(h) From ac8ad0b0a99dc0e8e280510491e694b9f86ad42a Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 9 Jun 2024 04:49:43 +0000 Subject: [PATCH 03/88] add option to specify which layers to apply dropout --- torchtune/modules/common_utils.py | 27 +++++++++++++++++++++++++++ torchtune/modules/layer_dropout.py | 10 ++++++---- torchtune/modules/transformer.py | 4 +++- 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/torchtune/modules/common_utils.py b/torchtune/modules/common_utils.py index 9c588fabba..15a0a7d684 100644 --- a/torchtune/modules/common_utils.py +++ b/torchtune/modules/common_utils.py @@ -48,3 +48,30 @@ def reparametrize_as_dtype_state_dict_post_hook( state_dict[k] = v.to(dtype) if offload_to_cpu: state_dict[k] = state_dict[k].cpu() + +def slice_str_to_array(slice_str, length): + # Parse the slice string + parts = slice_str.split(':') + start, end, step = None, None, None + + if len(parts) == 1 and parts[0] != '': + start = int(parts[0]) + elif len(parts) == 2: + start = int(parts[0]) if parts[0] != '' else None + end = int(parts[1]) if parts[1] != '' else None + elif len(parts) == 3: + start = int(parts[0]) if parts[0] != '' else None + end = int(parts[1]) if parts[1] != '' else None + step = int(parts[2]) if parts[2] != '' else None + + # Create a boolean array based on the slice + result = [False] * length + slice_indices = range(start if start is not None else 0, + end if end is not None else length, + step if step is not None else 1) + + for i in slice_indices: + if 0 <= i < length: + result[i] = True + + return result diff --git a/torchtune/modules/layer_dropout.py b/torchtune/modules/layer_dropout.py index 91595c60d0..75b9163377 100644 --- a/torchtune/modules/layer_dropout.py +++ b/torchtune/modules/layer_dropout.py @@ -9,6 +9,8 @@ import math import torch +from .common_utils import slice_str_to_array + class LayerDropout(torch.nn.Module): def __init__(self, prob=0.0, dim=0, disable_on_eval=True, seed=None): super().__init__() @@ -63,18 +65,18 @@ def get_scale(scale_type: ScaleType, scale_period: int, val: int): ScaleType.LOG: math.log(val + 1) / math.log(scale_period + 1), ScaleType.SIN: math.sin(0.5 * math.pi * val / scale_period), ScaleType.SIGMOID: 1 / (1 + math.exp(-10 * (val / scale_period - 0.5))), - ScaleType.STEP: 0 if val < scale_period else 1 }[scale_type] -def create_layer_dropout_modules(num_layers: int, prob_max: float= 0.0, prob_layer_scale: ScaleType = ScaleType.EXP, prob_layer_scale_period: Optional[int] = None, disable_on_eval: bool = True): +def create_layer_dropout_modules(num_layers: int, prob_max: float= 0.0, prob_layer_scale: ScaleType = ScaleType.EXP, layers_str: Optional[str] = None, disable_on_eval: bool = True): layer_dropouts = torch.nn.ModuleList() + has_dropout = slice_str_to_array(layers_str, num_layers) if layers_str else [True] * num_layers for layer_id in range(num_layers): prob = prob_max * get_scale( scale_type = prob_layer_scale, - scale_period = num_layers - 1 if prob_layer_scale_period is None else prob_layer_scale_period, + scale_period = num_layers - 1, val = layer_id, - ) + ) if has_dropout[layer_id] else 0.0 assert prob >= 0.0 and prob <= prob_max, f"prob={prob} should be between 0 and {prob_max}" # We would like each layer to have a different seed, so that we don't have the same samples skipped across layers. Hence, we use the layer_id as a seed for each layer's dropout. layer_dropout = LayerDropout(prob, disable_on_eval=disable_on_eval, seed=layer_id) diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 3b5a2e9e45..8974599235 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -142,6 +142,8 @@ def __init__( norm: nn.Module, output: nn.Linear, layer_dropout_prob: float= 0.0, + layer_dropout_prob_layer_scale: str="exp", + layer_dropout_str: str = ":", ) -> None: super().__init__() @@ -154,7 +156,7 @@ def __init__( self.head_dim = head_dim self.causal_mask = None - self.layer_dropouts = create_layer_dropout_modules(num_layers, layer_dropout_prob, "exp") + self.layer_dropouts = create_layer_dropout_modules(num_layers, layer_dropout_prob, layer_dropout_prob_layer_scale, layer_dropout_str) def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: """Setup key value caches for attention calculation. From ae61c858950186a9956295404f1ae9ea91afb962 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Mon, 10 Jun 2024 04:08:28 +0000 Subject: [PATCH 04/88] start early exit loss --- recipes/full_finetune_distributed.py | 16 +++++++++++++++- torchtune/modules/transformer.py | 14 +++++++++++--- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 3ec99c021a..82fb106baa 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -516,7 +516,7 @@ def train(self) -> None: input_pos.to(self._device) if input_pos is not None else None ) - logits = self._model(tokens, mask=mask, input_pos=input_pos) + logits = self._model(tokens, mask=mask, input_pos=input_pos, output_hidden_states=True) # Shift so that tokens < n predict n logits = logits[..., :-1, :].contiguous() labels = labels[..., 1:].contiguous() @@ -524,6 +524,20 @@ def train(self) -> None: # Compute loss loss = self._loss_fn(logits, labels) + # Compute early exit loss + if self._model.output_hidden_states: + # TODO: calculate early_logits in one shot: + # logits_early = self._model.output(self._model.norm(torch.stack(tuple(self._model.output_hidden_states.values())))) + for layer_id, hidden_state in self._model.output_hidden_states.items(): + h_early = self._model.norm(hidden_state) + logits_early = self._model.output(h_early) + # Shift so that tokens < n predict n + logits_early = logits_early[..., :-1, :].contiguous() + logits_early = logits_early.transpose(1, 2) + # Compute early loss + loss_early = self._loss_fn(logits_early, labels) + loss += 0.1 / len(self._model.layers) * loss_early + loss = loss / self._gradient_accumulation_steps running_loss += loss loss.backward() diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 8974599235..8e021759c3 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -4,7 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import copy -from typing import Optional +from collections import OrderedDict +from typing import List, Optional, Union import torch from torch import nn, Tensor @@ -141,8 +142,8 @@ def __init__( head_dim: int, norm: nn.Module, output: nn.Linear, - layer_dropout_prob: float= 0.0, - layer_dropout_prob_layer_scale: str="exp", + layer_dropout_prob: float = 0.5, + layer_dropout_prob_layer_scale: str = "exp", layer_dropout_str: str = ":", ) -> None: super().__init__() @@ -157,6 +158,7 @@ def __init__( self.causal_mask = None self.layer_dropouts = create_layer_dropout_modules(num_layers, layer_dropout_prob, layer_dropout_prob_layer_scale, layer_dropout_str) + self.output_hidden_states = OrderedDict() # TODO: use tensordict? def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: """Setup key value caches for attention calculation. @@ -196,6 +198,7 @@ def forward( *, mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None, + output_hidden_states: Union[bool, List[bool]] = False, ) -> Tensor: """ Args: @@ -235,6 +238,9 @@ def forward( # shape: [b, s, d] h = self.tok_embeddings(tokens) + if isinstance(output_hidden_states, bool): + output_hidden_states = [output_hidden_states] * len(self.layers) + if self.causal_mask is not None: if input_pos is None: raise ValueError( @@ -251,6 +257,8 @@ def forward( for i, layer in enumerate(self.layers): # shape: [b, s, d] h = self.layer_dropouts[i](layer, h, mask=mask, input_pos=input_pos) + if output_hidden_states[i]: + self.output_hidden_states[i] = h # shape: [b, s, d] h = self.norm(h) From 735d2a8023531a4b306dc59e150e3f20d56d0d9c Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 19 Jun 2024 18:27:44 +0000 Subject: [PATCH 05/88] parallelize processing of early exit losses --- recipes/full_finetune_distributed.py | 42 ++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 82fb106baa..458076df1f 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy import sys import time @@ -503,6 +504,7 @@ def train(self) -> None: # Both are shape [b, s] tokens, labels = batch["tokens"], batch["labels"] + b, s = tokens.shape # Get the attention mask and position ids from the dataset if they # exist. Currently, only sample packing in PackedDataset returns these mask = batch.get("mask", None) # shape [b, s, s] @@ -526,17 +528,35 @@ def train(self) -> None: # Compute early exit loss if self._model.output_hidden_states: - # TODO: calculate early_logits in one shot: - # logits_early = self._model.output(self._model.norm(torch.stack(tuple(self._model.output_hidden_states.values())))) - for layer_id, hidden_state in self._model.output_hidden_states.items(): - h_early = self._model.norm(hidden_state) - logits_early = self._model.output(h_early) - # Shift so that tokens < n predict n - logits_early = logits_early[..., :-1, :].contiguous() - logits_early = logits_early.transpose(1, 2) - # Compute early loss - loss_early = self._loss_fn(logits_early, labels) - loss += 0.1 / len(self._model.layers) * loss_early + self._batch_loss_fn = copy.deepcopy(self._loss_fn) + self._batch_loss_fn.reduction = "none" + + e = len(self._model.output_hidden_states) + # List of e tensors with shape [b, s, d] + hidden_states = tuple(self._model.output_hidden_states.values()) + hidden_layer_ids = tuple(self._model.output_hidden_states.keys()) + # Shape: [e, b, s, d] + hidden_states_stacked = torch.stack(hidden_states) + # Shape: [e, b, s, out_dim] + logits_early = self._model.output(self._model.norm(hidden_states_stacked)) + logits_early = logits_early[..., :-1, :].contiguous() + # Shape: [e*b, s, out_dim] + logits_early = logits_early.flatten(0, 1) + logits_early = logits_early.transpose(1, 2) + # Shape: [e, b*s] + labels_repeated = labels.repeat(e, 1) + # Compute early losses: Shape: [e*b, s] + losses_early = self._batch_loss_fn(logits_early, labels_repeated) + # Shape: [e, b*s] + losses_early = losses_early.view(e, -1) + # Shape: [e] + s_unpadded = (labels != self._loss_fn.ignore_index).sum() + losses_early = losses_early.float().sum(-1) / s_unpadded + # Shape: [e] + losses_scales = 0.1 * torch.Tensor(hidden_layer_ids).to(losses_early) / len(self._model.layers) + + val1 = torch.sum(losses_scales * losses_early) + loss += val1 loss = loss / self._gradient_accumulation_steps running_loss += loss From be912a6f871af9efd2271b054a7ca43bde98efb9 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 19 Jun 2024 18:30:56 +0000 Subject: [PATCH 06/88] use absolute imports --- torchtune/modules/layer_dropout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/modules/layer_dropout.py b/torchtune/modules/layer_dropout.py index 75b9163377..4f619cba3a 100644 --- a/torchtune/modules/layer_dropout.py +++ b/torchtune/modules/layer_dropout.py @@ -9,7 +9,7 @@ import math import torch -from .common_utils import slice_str_to_array +from torchtune.modules.common_utils import slice_str_to_array class LayerDropout(torch.nn.Module): def __init__(self, prob=0.0, dim=0, disable_on_eval=True, seed=None): From 0686dd28d602fb50bf57cf117f88314d49f58af2 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 19 Jun 2024 18:52:20 +0000 Subject: [PATCH 07/88] remove unnecessary sync --- torchtune/modules/layer_dropout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/modules/layer_dropout.py b/torchtune/modules/layer_dropout.py index 4f619cba3a..a04e0d5b24 100644 --- a/torchtune/modules/layer_dropout.py +++ b/torchtune/modules/layer_dropout.py @@ -32,7 +32,7 @@ def forward(self, function: Callable, input: torch.Tensor, *args, **kwargs): skip = torch.bernoulli(torch.Tensor((n) * [self.prob]), generator=self.generator).to(input.device).to(input.dtype) self.inferred = 1 - torch.mean(skip) - ind_selected = (skip == 0).nonzero().squeeze().to(input.device) + ind_selected = (skip == 0).nonzero().squeeze() if ind_selected.numel() > 0: x_selected = torch.index_select(input, self.dim, ind_selected) From 4e4783f12cb7a52fdca540097ed4138ca25ea6cd Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 19 Jun 2024 21:37:50 +0000 Subject: [PATCH 08/88] move early exit loss to separate file and add layers as arg --- recipes/full_finetune_distributed.py | 41 +++++++-------------------- torchtune/utils/early_exit.py | 42 ++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 31 deletions(-) create mode 100644 torchtune/utils/early_exit.py diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 458076df1f..cacd457424 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import copy import sys import time @@ -31,6 +30,8 @@ from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.utils.activations import apply_selective_activation_checkpointing +from torchtune.utils.early_exit import early_exit_loss +from torchtune.modules.common_utils import slice_str_to_array from tqdm import tqdm @@ -136,6 +137,8 @@ def __init__(self, cfg: DictConfig) -> None: self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 + self.early_exit_layers = cfg.get("early_exit_layers", None) + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ Extract the checkpoint state from file and validate. If resume_from_checkpoint @@ -486,6 +489,9 @@ def train(self) -> None: running_loss = 0 num_tokens = 0 + # Early exit loss settings + output_hidden_states = slice_str_to_array(self.early_exit_layers, len(self._model.layers)) + # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): @@ -518,7 +524,7 @@ def train(self) -> None: input_pos.to(self._device) if input_pos is not None else None ) - logits = self._model(tokens, mask=mask, input_pos=input_pos, output_hidden_states=True) + logits = self._model(tokens, mask=mask, input_pos=input_pos, output_hidden_states=output_hidden_states) # Shift so that tokens < n predict n logits = logits[..., :-1, :].contiguous() labels = labels[..., 1:].contiguous() @@ -527,36 +533,9 @@ def train(self) -> None: loss = self._loss_fn(logits, labels) # Compute early exit loss + # TODO: change condition to "if early_exit_loss" if self._model.output_hidden_states: - self._batch_loss_fn = copy.deepcopy(self._loss_fn) - self._batch_loss_fn.reduction = "none" - - e = len(self._model.output_hidden_states) - # List of e tensors with shape [b, s, d] - hidden_states = tuple(self._model.output_hidden_states.values()) - hidden_layer_ids = tuple(self._model.output_hidden_states.keys()) - # Shape: [e, b, s, d] - hidden_states_stacked = torch.stack(hidden_states) - # Shape: [e, b, s, out_dim] - logits_early = self._model.output(self._model.norm(hidden_states_stacked)) - logits_early = logits_early[..., :-1, :].contiguous() - # Shape: [e*b, s, out_dim] - logits_early = logits_early.flatten(0, 1) - logits_early = logits_early.transpose(1, 2) - # Shape: [e, b*s] - labels_repeated = labels.repeat(e, 1) - # Compute early losses: Shape: [e*b, s] - losses_early = self._batch_loss_fn(logits_early, labels_repeated) - # Shape: [e, b*s] - losses_early = losses_early.view(e, -1) - # Shape: [e] - s_unpadded = (labels != self._loss_fn.ignore_index).sum() - losses_early = losses_early.float().sum(-1) / s_unpadded - # Shape: [e] - losses_scales = 0.1 * torch.Tensor(hidden_layer_ids).to(losses_early) / len(self._model.layers) - - val1 = torch.sum(losses_scales * losses_early) - loss += val1 + loss += early_exit_loss(self._model, self._model.output_hidden_states, labels, self._loss_fn) loss = loss / self._gradient_accumulation_steps running_loss += loss diff --git a/torchtune/utils/early_exit.py b/torchtune/utils/early_exit.py new file mode 100644 index 0000000000..446e44ccba --- /dev/null +++ b/torchtune/utils/early_exit.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import torch + +def early_exit_loss(model, hidden_states_dict, labels, loss_fn): + # Pop last layer as we already calculated its loss + if len(model.layers) - 1 in hidden_states_dict: + hidden_states_dict.pop(len(model.layers) - 1) + + batch_loss_fn = copy.deepcopy(loss_fn) + batch_loss_fn.reduction = "none" + + e = len(hidden_states_dict) + # List of e tensors with shape [b, s, d] + hidden_states = tuple(hidden_states_dict.values()) + hidden_layer_ids = tuple(hidden_states_dict.keys()) + # Shape: [e, b, s, d] + hidden_states_stacked = torch.stack(hidden_states) + # Shape: [e, b, s, out_dim] + logits_early = model.output(model.norm(hidden_states_stacked)) + logits_early = logits_early[..., :-1, :].contiguous() + # Shape: [e*b, s, out_dim] + logits_early = logits_early.flatten(0, 1) + logits_early = logits_early.transpose(1, 2) + # Shape: [e, b*s] + labels_repeated = labels.repeat(e, 1) + # Compute early losses: Shape: [e*b, s] + losses_early = batch_loss_fn(logits_early, labels_repeated) + # Shape: [e, b*s] + losses_early = losses_early.view(e, -1) + # Shape: [e] + s_unpadded = (labels != loss_fn.ignore_index).sum() + losses_early = losses_early.float().sum(-1) / s_unpadded + # Shape: [e] + losses_scales = 0.1 * torch.Tensor(hidden_layer_ids).to(losses_early) / len(model.layers) + + return torch.sum(losses_scales * losses_early) From 268813e83334b0975e4fb0b1460b3047ff125339 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 19 Jun 2024 22:56:11 +0000 Subject: [PATCH 09/88] perform loss scaling every iteration --- recipes/full_finetune_distributed.py | 8 +++-- torchtune/modules/transformer.py | 2 +- torchtune/utils/early_exit.py | 44 ++++++++++++++++++++++++++-- 3 files changed, 48 insertions(+), 6 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index cacd457424..524d861044 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -490,7 +490,10 @@ def train(self) -> None: num_tokens = 0 # Early exit loss settings - output_hidden_states = slice_str_to_array(self.early_exit_layers, len(self._model.layers)) + if self.early_exit_layers: + output_hidden_states = slice_str_to_array(self.early_exit_layers, len(self._model.layers)) + else: + output_hidden_states = False # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): @@ -533,8 +536,7 @@ def train(self) -> None: loss = self._loss_fn(logits, labels) # Compute early exit loss - # TODO: change condition to "if early_exit_loss" - if self._model.output_hidden_states: + if self.early_exit_layers: loss += early_exit_loss(self._model, self._model.output_hidden_states, labels, self._loss_fn) loss = loss / self._gradient_accumulation_steps diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 8e021759c3..15a7e31459 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -142,7 +142,7 @@ def __init__( head_dim: int, norm: nn.Module, output: nn.Linear, - layer_dropout_prob: float = 0.5, + layer_dropout_prob: float = 0.0, layer_dropout_prob_layer_scale: str = "exp", layer_dropout_str: str = ":", ) -> None: diff --git a/torchtune/utils/early_exit.py b/torchtune/utils/early_exit.py index 446e44ccba..9aeae06923 100644 --- a/torchtune/utils/early_exit.py +++ b/torchtune/utils/early_exit.py @@ -5,9 +5,21 @@ # LICENSE file in the root directory of this source tree. import copy +import numpy as np import torch +from enum import Enum +from typing import List +import math -def early_exit_loss(model, hidden_states_dict, labels, loss_fn): +class LossScaleType(str, Enum): + ONE = "one" + L = "l" + SUM_L = "sum_l" + INV_L = "inv_l" + SQRT_L = "sqrt_l" + INV_SQRT_L = "inv_sqrt_l" + +def early_exit_loss(model, hidden_states_dict, labels, loss_fn, e_scale: float=0.1, loss_scale_type=LossScaleType.SUM_L): # Pop last layer as we already calculated its loss if len(model.layers) - 1 in hidden_states_dict: hidden_states_dict.pop(len(model.layers) - 1) @@ -37,6 +49,34 @@ def early_exit_loss(model, hidden_states_dict, labels, loss_fn): s_unpadded = (labels != loss_fn.ignore_index).sum() losses_early = losses_early.float().sum(-1) / s_unpadded # Shape: [e] - losses_scales = 0.1 * torch.Tensor(hidden_layer_ids).to(losses_early) / len(model.layers) + # losses_scales = 0.1 * torch.Tensor(hidden_layer_ids).to(losses_early) / len(model.layers) + losses_scales = layer_ids_to_loss_scales(torch.Tensor(hidden_layer_ids).to(losses_early), len(model.layers), loss_scale_type, e_scale) return torch.sum(losses_scales * losses_early) + +def layer_ids_to_loss_scales(layer_ids, n_layers, loss_scale_type: LossScaleType, e_scale: float): + match loss_scale_type: + case LossScaleType.ONE: + loss_scales = torch.ones(len(layer_ids)) + case LossScaleType.L: + loss_scales = torch.Tensor(layer_ids+1) + case LossScaleType.SUM_L: + # TODO: should we change to sum 0:i ? Perhaps create a new scale_type + loss_scales = torch.cumsum(layer_ids+1, dim=0) + case LossScaleType.SQRT_L: + loss_scales = torch.sqrt(layer_ids+1) + case LossScaleType.INV_L: + loss_scales = 1.0 / (layer_ids+1) + case LossScaleType.INV_SQRT_L: + loss_scales = 1.0 / torch.sqrt(layer_ids+1) + case _: + raise ValueError(f"Unsupported loss_scale type {loss_scale_type}") + + loss_scales = loss_scales * torch.where(loss_scales < n_layers - 1, e_scale, 1.0) + + # normalize loss scales to ensure that their sum is 1.0 + loss_scales = loss_scales / torch.sum(loss_scales) + assert torch.isclose(torch.sum(loss_scales), torch.Tensor([1.0]).to(loss_scales)) + + print(loss_scales) + return loss_scales From ccb4a5006ece5436505d8f7b1764734aa347d422 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 19 Jun 2024 23:06:00 +0000 Subject: [PATCH 10/88] return hidden states as an output rather than storing --- recipes/full_finetune_distributed.py | 7 +++++-- torchtune/modules/transformer.py | 12 +++++++++--- torchtune/utils/early_exit.py | 1 - 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 524d861044..cca36cf43e 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -527,7 +527,10 @@ def train(self) -> None: input_pos.to(self._device) if input_pos is not None else None ) - logits = self._model(tokens, mask=mask, input_pos=input_pos, output_hidden_states=output_hidden_states) + if self.early_exit_layers: + logits, hidden_states = self._model(tokens, mask=mask, input_pos=input_pos, output_hidden_states=output_hidden_states) + else: + logits = self._model(tokens, mask=mask, input_pos=input_pos, output_hidden_states=output_hidden_states) # Shift so that tokens < n predict n logits = logits[..., :-1, :].contiguous() labels = labels[..., 1:].contiguous() @@ -537,7 +540,7 @@ def train(self) -> None: # Compute early exit loss if self.early_exit_layers: - loss += early_exit_loss(self._model, self._model.output_hidden_states, labels, self._loss_fn) + loss += early_exit_loss(self._model, hidden_states, labels, self._loss_fn) loss = loss / self._gradient_accumulation_steps running_loss += loss diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 15a7e31459..f78f928a21 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -158,7 +158,6 @@ def __init__( self.causal_mask = None self.layer_dropouts = create_layer_dropout_modules(num_layers, layer_dropout_prob, layer_dropout_prob_layer_scale, layer_dropout_str) - self.output_hidden_states = OrderedDict() # TODO: use tensordict? def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: """Setup key value caches for attention calculation. @@ -254,15 +253,22 @@ def forward( # in most cases input_pos_len should be 1 mask = self.causal_mask[None, input_pos] + if any(output_hidden_states): + hidden_states = OrderedDict() # TODO: use tensordict? + for i, layer in enumerate(self.layers): # shape: [b, s, d] h = self.layer_dropouts[i](layer, h, mask=mask, input_pos=input_pos) if output_hidden_states[i]: - self.output_hidden_states[i] = h + hidden_states[i] = h # shape: [b, s, d] h = self.norm(h) # shape: [b, s, out_dim] - out_dim is usually the vocab size output = self.output(h).float() - return output + + if any(output_hidden_states): + return output, hidden_states + else: + return output diff --git a/torchtune/utils/early_exit.py b/torchtune/utils/early_exit.py index 9aeae06923..7f6ab8513e 100644 --- a/torchtune/utils/early_exit.py +++ b/torchtune/utils/early_exit.py @@ -78,5 +78,4 @@ def layer_ids_to_loss_scales(layer_ids, n_layers, loss_scale_type: LossScaleType loss_scales = loss_scales / torch.sum(loss_scales) assert torch.isclose(torch.sum(loss_scales), torch.Tensor([1.0]).to(loss_scales)) - print(loss_scales) return loss_scales From ff7d1570807c6c7308add08355cc8c6a49db43c2 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Thu, 20 Jun 2024 01:09:49 +0000 Subject: [PATCH 11/88] ensure last layer is always included --- recipes/full_finetune_distributed.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index cca36cf43e..0fd3cbfa3a 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -492,6 +492,10 @@ def train(self) -> None: # Early exit loss settings if self.early_exit_layers: output_hidden_states = slice_str_to_array(self.early_exit_layers, len(self._model.layers)) + if True: # TODO: add cli option + if len(self._model.layers) - 1 not in output_hidden_states: + # ensure we include last layer + output_hidden_states[len(self._model.layers) - 1] = True else: output_hidden_states = False From 5a23811ccf7a5459054460b406e7979972b12d2c Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Thu, 20 Jun 2024 01:10:32 +0000 Subject: [PATCH 12/88] return either last logits or hidden states --- recipes/full_finetune_distributed.py | 9 +++++---- torchtune/utils/early_exit.py | 4 ---- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 0fd3cbfa3a..ba3c0f1095 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -535,16 +535,17 @@ def train(self) -> None: logits, hidden_states = self._model(tokens, mask=mask, input_pos=input_pos, output_hidden_states=output_hidden_states) else: logits = self._model(tokens, mask=mask, input_pos=input_pos, output_hidden_states=output_hidden_states) + # Shift so that tokens < n predict n logits = logits[..., :-1, :].contiguous() labels = labels[..., 1:].contiguous() logits = logits.transpose(1, 2) - # Compute loss - loss = self._loss_fn(logits, labels) - # Compute early exit loss + # Compute loss if self.early_exit_layers: - loss += early_exit_loss(self._model, hidden_states, labels, self._loss_fn) + loss = early_exit_loss(self._model, hidden_states, labels, self._loss_fn) + else: + loss = self._loss_fn(logits, labels) loss = loss / self._gradient_accumulation_steps running_loss += loss diff --git a/torchtune/utils/early_exit.py b/torchtune/utils/early_exit.py index 7f6ab8513e..94c4ff9b87 100644 --- a/torchtune/utils/early_exit.py +++ b/torchtune/utils/early_exit.py @@ -20,10 +20,6 @@ class LossScaleType(str, Enum): INV_SQRT_L = "inv_sqrt_l" def early_exit_loss(model, hidden_states_dict, labels, loss_fn, e_scale: float=0.1, loss_scale_type=LossScaleType.SUM_L): - # Pop last layer as we already calculated its loss - if len(model.layers) - 1 in hidden_states_dict: - hidden_states_dict.pop(len(model.layers) - 1) - batch_loss_fn = copy.deepcopy(loss_fn) batch_loss_fn.reduction = "none" From e11aebafae7ccd1b6893e7cd3275f15c68a3af99 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Thu, 20 Jun 2024 01:10:43 +0000 Subject: [PATCH 13/88] fix scaling layers --- torchtune/utils/early_exit.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchtune/utils/early_exit.py b/torchtune/utils/early_exit.py index 94c4ff9b87..70d7e90b4d 100644 --- a/torchtune/utils/early_exit.py +++ b/torchtune/utils/early_exit.py @@ -68,8 +68,7 @@ def layer_ids_to_loss_scales(layer_ids, n_layers, loss_scale_type: LossScaleType case _: raise ValueError(f"Unsupported loss_scale type {loss_scale_type}") - loss_scales = loss_scales * torch.where(loss_scales < n_layers - 1, e_scale, 1.0) - + loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) # normalize loss scales to ensure that their sum is 1.0 loss_scales = loss_scales / torch.sum(loss_scales) assert torch.isclose(torch.sum(loss_scales), torch.Tensor([1.0]).to(loss_scales)) From f9e164f1dcc7c58ca08c8a3c120d2ebb9732941e Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Thu, 20 Jun 2024 02:45:17 +0000 Subject: [PATCH 14/88] rotational early exit curriculum --- recipes/full_finetune_distributed.py | 15 ++++++++-- torchtune/utils/early_exit.py | 44 +++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index ba3c0f1095..39fca659e0 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -30,7 +30,7 @@ from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.utils.activations import apply_selective_activation_checkpointing -from torchtune.utils.early_exit import early_exit_loss +from torchtune.utils.early_exit import early_exit_loss, build_early_exit_curriculum from torchtune.modules.common_utils import slice_str_to_array from tqdm import tqdm @@ -137,7 +137,11 @@ def __init__(self, cfg: DictConfig) -> None: self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - self.early_exit_layers = cfg.get("early_exit_layers", None) + self.early_exit = cfg.get("early_exit", None) + # TODO: create a "setup" function similar to setup_model? + if self.early_exit: + self.early_exit_layers = cfg.get("early_exit.layers", ":") + self.early_exit_curriculum = cfg.get("early_exit.curriculum", "none") def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ @@ -490,6 +494,7 @@ def train(self) -> None: num_tokens = 0 # Early exit loss settings + # TODO: move to _init_() or setup() if self.early_exit_layers: output_hidden_states = slice_str_to_array(self.early_exit_layers, len(self._model.layers)) if True: # TODO: add cli option @@ -498,6 +503,8 @@ def train(self) -> None: output_hidden_states[len(self._model.layers) - 1] = True else: output_hidden_states = False + if self.early_exit_curriculum: + self.early_exit_curriculum = build_early_exit_curriculum(self.early_exit_curriculum, output_hidden_states) # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): @@ -588,6 +595,10 @@ def train(self) -> None: num_tokens = 0 t0 = time.perf_counter() + if self.early_exit_curriculum: + self.early_exit_curriculum.step() + output_hidden_states = self.early_exit_curriculum.get() + self.epochs_run += 1 self.save_checkpoint(epoch=curr_epoch) diff --git a/torchtune/utils/early_exit.py b/torchtune/utils/early_exit.py index 70d7e90b4d..049ed1abca 100644 --- a/torchtune/utils/early_exit.py +++ b/torchtune/utils/early_exit.py @@ -9,7 +9,10 @@ import torch from enum import Enum from typing import List -import math + +from torchtune import utils + +log = utils.get_logger("DEBUG") class LossScaleType(str, Enum): ONE = "one" @@ -74,3 +77,42 @@ def layer_ids_to_loss_scales(layer_ids, n_layers, loss_scale_type: LossScaleType assert torch.isclose(torch.sum(loss_scales), torch.Tensor([1.0]).to(loss_scales)) return loss_scales + +class EarlyExitCurriculumType(str, Enum): + NONE = "none" + ROTATIONAL = "rot" + GRADUAL = "gradual" + +def build_early_exit_curriculum(early_exit_curriculum: EarlyExitCurriculumType, *args, **kwargs): + match early_exit_curriculum: + case EarlyExitCurriculumType.NONE: + return None + + case EarlyExitCurriculumType.ROTATIONAL: + return RotationalEarlyExitCurriculum(*args, **kwargs) + + case _: + raise ValueError(f"Unsupported early loss curriculum {early_exit_curriculum}.") + + +# TODO: create a base curriculum class that can be used for other aspects, e.g., dropout, datasets, etc. +class EarlyExitCurriculum(): + def __init__(self, output_hidden_states, verbose=True): + self._init_output_hidden_states = output_hidden_states + self.output_hidden_states = output_hidden_states + self.verbose = verbose + + def step(self): + pass + + def get(self): + return self.output_hidden_states + +class RotationalEarlyExitCurriculum(EarlyExitCurriculum): + def __init__(self, output_hidden_states, verbose=True): + super().__init__(output_hidden_states) + + def step(self): + self.output_hidden_states = torch.roll(self.output_hidden_states, -1) + if self.verbose: + log.info(f"Updating self.output_hidden_states to {self.output_hidden_states}.") From 069b6615217ae7fdba30c5689742fa6ce2f06277 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Mon, 22 Jul 2024 01:38:29 +0000 Subject: [PATCH 15/88] set early exit params from cli --- recipes/full_finetune_distributed.py | 14 ++++++++++---- torchtune/utils/early_exit.py | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 39fca659e0..c9b4b4b095 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -137,11 +137,17 @@ def __init__(self, cfg: DictConfig) -> None: self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - self.early_exit = cfg.get("early_exit", None) + cfg_early_exit = cfg.get("early_exit", None) # TODO: create a "setup" function similar to setup_model? - if self.early_exit: - self.early_exit_layers = cfg.get("early_exit.layers", ":") - self.early_exit_curriculum = cfg.get("early_exit.curriculum", "none") + # TODO: rename "early_exit" to "early_exit_loss" + if cfg_early_exit: + self.early_exit_layers = cfg_early_exit.get("layers", ":") + self.early_exit_curriculum = cfg_early_exit.get("curriculum", "none") + self.early_exit_scale = cfg_early_exit.get("scale", 1.0) + else: + self.early_exit_layers = None + self.early_exit_curriculum = None + self.early_exit_scale = None def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ diff --git a/torchtune/utils/early_exit.py b/torchtune/utils/early_exit.py index 049ed1abca..a71bc331bc 100644 --- a/torchtune/utils/early_exit.py +++ b/torchtune/utils/early_exit.py @@ -22,7 +22,7 @@ class LossScaleType(str, Enum): SQRT_L = "sqrt_l" INV_SQRT_L = "inv_sqrt_l" -def early_exit_loss(model, hidden_states_dict, labels, loss_fn, e_scale: float=0.1, loss_scale_type=LossScaleType.SUM_L): +def early_exit_loss(model, hidden_states_dict, labels, loss_fn, e_scale: float=1.0, loss_scale_type=LossScaleType.SUM_L): batch_loss_fn = copy.deepcopy(loss_fn) batch_loss_fn.reduction = "none" From 954d09712011fb4bc78890d0ddef26dfd92edae9 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Mon, 22 Jul 2024 01:40:05 +0000 Subject: [PATCH 16/88] ensure last layer loss is always calculated --- recipes/full_finetune_distributed.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index c9b4b4b095..7953370c2e 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -504,11 +504,10 @@ def train(self) -> None: if self.early_exit_layers: output_hidden_states = slice_str_to_array(self.early_exit_layers, len(self._model.layers)) if True: # TODO: add cli option - if len(self._model.layers) - 1 not in output_hidden_states: - # ensure we include last layer - output_hidden_states[len(self._model.layers) - 1] = True + output_hidden_states[len(self._model.layers) - 1] = True else: output_hidden_states = False + if self.early_exit_curriculum: self.early_exit_curriculum = build_early_exit_curriculum(self.early_exit_curriculum, output_hidden_states) @@ -604,6 +603,8 @@ def train(self) -> None: if self.early_exit_curriculum: self.early_exit_curriculum.step() output_hidden_states = self.early_exit_curriculum.get() + if True: # TODO: add cli option + output_hidden_states[len(self._model.layers) - 1] = True self.epochs_run += 1 self.save_checkpoint(epoch=curr_epoch) From 5789745f8faa5802e79e206ee3975607ee3fe330 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Mon, 22 Jul 2024 01:40:40 +0000 Subject: [PATCH 17/88] implement gradual early exit --- recipes/full_finetune_distributed.py | 2 +- torchtune/utils/early_exit.py | 31 ++++++++++++++++++++++++---- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 7953370c2e..6047831881 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -509,7 +509,7 @@ def train(self) -> None: output_hidden_states = False if self.early_exit_curriculum: - self.early_exit_curriculum = build_early_exit_curriculum(self.early_exit_curriculum, output_hidden_states) + self.early_exit_curriculum = build_early_exit_curriculum(self.early_exit_curriculum, output_hidden_states, self.total_epochs*self._steps_per_epoch) # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): diff --git a/torchtune/utils/early_exit.py b/torchtune/utils/early_exit.py index a71bc331bc..9ef3b0a915 100644 --- a/torchtune/utils/early_exit.py +++ b/torchtune/utils/early_exit.py @@ -91,16 +91,20 @@ def build_early_exit_curriculum(early_exit_curriculum: EarlyExitCurriculumType, case EarlyExitCurriculumType.ROTATIONAL: return RotationalEarlyExitCurriculum(*args, **kwargs) + case EarlyExitCurriculumType.GRADUAL: + return GradualEarlyExitCurriculum(*args, **kwargs) + case _: raise ValueError(f"Unsupported early loss curriculum {early_exit_curriculum}.") # TODO: create a base curriculum class that can be used for other aspects, e.g., dropout, datasets, etc. class EarlyExitCurriculum(): - def __init__(self, output_hidden_states, verbose=True): + def __init__(self, output_hidden_states, max_steps, verbose=False): self._init_output_hidden_states = output_hidden_states self.output_hidden_states = output_hidden_states self.verbose = verbose + self.max_steps = max_steps def step(self): pass @@ -109,10 +113,29 @@ def get(self): return self.output_hidden_states class RotationalEarlyExitCurriculum(EarlyExitCurriculum): - def __init__(self, output_hidden_states, verbose=True): - super().__init__(output_hidden_states) + def __init__(self, output_hidden_states, max_steps, verbose=False): + super().__init__(output_hidden_states, max_steps, verbose) def step(self): - self.output_hidden_states = torch.roll(self.output_hidden_states, -1) + self.output_hidden_states = np.roll(self.output_hidden_states, -1) if self.verbose: log.info(f"Updating self.output_hidden_states to {self.output_hidden_states}.") + +class GradualEarlyExitCurriculum(EarlyExitCurriculum): + def __init__(self, output_hidden_states, max_steps, verbose=False): + super().__init__(output_hidden_states, max_steps, verbose) + self._step = 0 + + def step(self): + percent_trained = self._step / self.max_steps + n_layers = len(self.output_hidden_states) + for layer_index in range(len(self.output_hidden_states)): + # TODO: replace 2 with an argument + should_train = (percent_trained * 2) >= ((n_layers - 1 - layer_index) / (n_layers - 1)) + self.output_hidden_states[layer_index] = should_train + + # TODO: move this to step() in parent class? + # TODO: how to ensure we always call parent step() in derived class? + self._step += 1 + if self.verbose: + log.info(f"Updating self.output_hidden_states to {self.output_hidden_states}.") \ No newline at end of file From c3534e67bb34f69d3e1f1d8258babf26435b43de Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Mon, 22 Jul 2024 01:40:55 +0000 Subject: [PATCH 18/88] get streaming to work --- torchtune/datasets/_text_completion.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/torchtune/datasets/_text_completion.py b/torchtune/datasets/_text_completion.py index 49f2ad84e7..1c9b8786a7 100644 --- a/torchtune/datasets/_text_completion.py +++ b/torchtune/datasets/_text_completion.py @@ -36,18 +36,28 @@ def __init__( source: str, column: str = "text", max_seq_len: Optional[int] = None, + num_samples: Optional[int] = None, **load_dataset_kwargs: Dict[str, Any], ) -> None: self._tokenizer = tokenizer self._data = load_dataset(source, **load_dataset_kwargs) self.max_seq_len = max_seq_len self._column = column + self._num_samples = num_samples + self._streaming = load_dataset_kwargs["streaming"] if "streaming" in load_dataset_kwargs else False + self._data_itr = iter(self._data) if self._streaming else None def __len__(self): - return len(self._data) + if self._num_samples is None or not self._streaming: + return len(self._data) + else: + return self._num_samples def __getitem__(self, index: int) -> Dict[str, List[int]]: - sample = self._data[index] + if self._streaming: + sample = next(self._data_itr) + else: + sample = self._data[index] return self._prepare_sample(sample) def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]: From 7849130b1ded5cbf46ad6e79a6e1926c265fc4cd Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 13 Nov 2024 05:30:54 +0000 Subject: [PATCH 19/88] add separate recipe for early exit --- .../dev/early_exit_finetune_distributed.py | 929 ++++++++++++++++++ recipes/full_finetune_distributed.py | 12 - torchtune/_recipe_registry.py | 8 + .../early_exit_loss.py} | 0 torchtune/modules/transformer.py | 9 - 5 files changed, 937 insertions(+), 21 deletions(-) create mode 100644 recipes/dev/early_exit_finetune_distributed.py rename torchtune/{utils/early_exit.py => modules/early_exit_loss.py} (100%) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py new file mode 100644 index 0000000000..f0c7eda610 --- /dev/null +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -0,0 +1,929 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time + +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.distributed import destroy_process_group, init_process_group + +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, training, utils +from torchtune.config._utils import _get_component_from_path +from torchtune.data import padded_collate_packed +from torchtune.datasets import ConcatDataset +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.training import DummyProfiler, PROFILER_KEY +from torchtune.training.activations import apply_selective_activation_checkpointing +from torchtune.training.lr_schedulers import get_lr + +from tqdm import tqdm + +log = utils.get_logger("DEBUG") + +class EarlyExitFinetuneRecipeDistributed(FTRecipeInterface): + """ + Full finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports + distributed training and can be run on a single node (1 to 8 GPUs). + + Features: + - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states + is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is + done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config + ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). + DDP is currently not supported. Training on CPU is not supported. + + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * number of GPUs * gradient accumulation steps. + + For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a + total batch size of 64. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. + + Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``left_pad_sequence`` is set as the data collator. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + self._dtype = training.get_dtype(cfg.dtype, device=self._device) + + if self._dtype == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + if ( + cfg.get("fsdp_cpu_offload", False) + and cfg.optimizer.get("fused", False) + and not utils.torch_version_ge("2.4.0") + ): + raise RuntimeError( + "Using fused optimizer on CPU is only supported in PyTorch nightly." + ) + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + if self._log_peak_memory_stats and self._device.type != "cuda": + log.info( + "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." + ) + self._log_peak_memory_stats = False + + # _is_rank_zero is used primarily for logging. In the future, the logger + # should directly take care of this + _, rank = training.get_world_size_and_rank() + self._is_rank_zero = rank == 0 + + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False) + self._clip_grad_norm = cfg.get("clip_grad_norm", None) + + # Optimizer in backward is not compatible with gradient accumulation or gradient clipping + if self._optimizer_in_bwd: + if self._clip_grad_norm is not None: + raise RuntimeError( + "Gradient clipping is not supported with optimizer in bwd." + "Please set clip_grad_norm=None, or optimizer_in_bwd=False." + ) + if self._gradient_accumulation_steps > 1: + raise RuntimeError( + "Gradient accumulation is not supported with optimizer in bwd." + "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." + ) + + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif ( + self._enable_activation_checkpointing + and cfg.checkpointer.model_type != "LLAMA3_VISION" + ): + utils.log_rank_zero( + log, + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further.", + ) + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = training.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + + cfg_early_exit = cfg.get("early_exit", None) + # TODO: create a "setup" function similar to setup_model? + # TODO: rename "early_exit" to "early_exit_loss" + if cfg_early_exit: + self.early_exit_layers = cfg_early_exit.get("layers", ":") + self.early_exit_curriculum = cfg_early_exit.get("curriculum", "none") + self.early_exit_scale = cfg_early_exit.get("scale", 1.0) + else: + self.early_exit_layers = None + self.early_exit_curriculum = None + self.early_exit_scale = None + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. If resume_from_checkpoint + is True, this also includes the recipe state. + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[training.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[training.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[training.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe. This includes training state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, sampler, and dataloader. + """ + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + + self._compile = cfg.get("compile", False) + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, + custom_sharded_layers=cfg.get("custom_sharded_layers", None), + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), + reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), + model_state_dict=checkpoint_dict[training.MODEL_KEY], + ac_mode=cfg.get("ac_mode", None), + ac_option=cfg.get("ac_option", None), + ) + self._tokenizer = config.instantiate(cfg.tokenizer) + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + optimizer_in_bwd=self._optimizer_in_bwd, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + # initialize loss + self._loss_fn = config.instantiate(cfg.loss) + + if self._compile: + training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) + + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + + if self._is_rank_zero: + log.info("Loss is initialized.") + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after both of these are initialized + collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + collate_fn=collate_name, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + # + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader, the max_steps_per_epoch param set by the user and the + # gradient_accumulation_steps param. This value is used for logging and tracking + # training state. The computation should happen after the dataloader has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + if self._is_rank_zero: + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + enable_activation_offloading: bool, + fsdp_cpu_offload: bool, + reshard_after_forward: bool, + model_state_dict: Dict[str, Any], + custom_sharded_layers: Optional[List[str]] = None, + ac_mode: Optional[str] = None, + ac_option: Optional[int] = None, + ) -> nn.Module: + """ + Model initialization has some important considerations: + a. To minimize GPU peak memory, we initialize the model on meta device with + the right dtype + b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since + full state dicts are loaded with ``torch.load(mmap=True)`` + """ + + if self._is_rank_zero: + log.info( + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." + ) + init_start = time.perf_counter() + + with training.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + + # We currently have two versions of activation checkpointing in this recipe + # for testing and BC purposes. ``enable_activation_checkpointing`` controls + # the older version of AC and this behavior is unchanged + # ac_mode and ac_option together control selective AC. This is only enabled + # when these are set AND ``enable_activation_checkpointing`` is set to False + # We'll clean this up as soon as testing of AC is complete + if (not enable_activation_checkpointing) and (ac_mode is not None): + apply_selective_activation_checkpointing( + model, + ac_mode, + ac_option, + ) + + # original activation checkpointing (full) - flip the condition above + if enable_activation_checkpointing and ac_mode is None: + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) + + # For FSDP sharding + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, + ) + ] + training.shard_model( + model=model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, + ) + + with training.set_default_dtype(self._dtype), self._device: + for m in model.modules(): + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + # This method will convert the full model state dict into a sharded state + # dict and load into the model + training.load_from_full_model_state_dict( + model, + model_state_dict, + self._device, + self._is_rank_zero, + strict=True, + cpu_offload=fsdp_cpu_offload, + ) + + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + + # Ensure no params and buffers are on meta device + training.validate_no_params_on_meta_device(model) + + if self._is_rank_zero: + log.info( + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" + ) + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + + # synchronize before training begins + torch.distributed.barrier() + + return model + + def _setup_optimizer( + self, + cfg_optimizer: DictConfig, + optimizer_in_bwd: bool = False, + opt_state_dict: Optional[Dict[str, Any]] = None, + ) -> Optional[Optimizer]: + if optimizer_in_bwd: + # Maintain a dict of optims for every parameter. + optim_dict = { + param: config.instantiate(cfg_optimizer, [param]) + for param in self._model.parameters() + } + + # Register optimizer step hooks on the model to run optimizer in backward. + training.register_optim_in_bwd_hooks( + model=self._model, optim_dict=optim_dict + ) + # Create a wrapper for checkpoint save/load of optimizer states when running in backward. + self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( + model=self._model, optim_dict=optim_dict + ) + # Load optimizer states for each param. If optimizer states are being restored in an optimizer in + # backward run, these need to have been saved with the same setting. Cannot restore from runs that + # did not use optimizer in backward. + if opt_state_dict is not None: + for param in opt_state_dict.keys(): + try: + training.load_from_full_optimizer_state_dict( + self._optim_ckpt_wrapper.state_dict()[param], + opt_state_dict[param], + self._device, + ) + except BaseException as e: + raise RuntimeError( + "Failed loading in-backward optimizer checkpoints." + "Please make sure run being restored from was using in-backward optimizer." + ) from e + if self._is_rank_zero: + log.info("In-backward optimizers are set up.") + return None + else: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + training.load_from_full_optimizer_state_dict( + optimizer, + opt_state_dict, + self._device, + ) + + if self._is_rank_zero: + log.info("Optimizer is initialized.") + return optimizer + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + collate_fn: str, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports the + DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, + iterable datasets and streaming datasets are not supported. + """ + world_size, rank = training.get_world_size_and_rank() + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + packed = False + else: + ds = config.instantiate(cfg_dataset, self._tokenizer) + packed = cfg_dataset.get("packed", False) + + # Instantiate collate_fn + if "left_pad_sequence" in collate_fn: + raise RuntimeError("left_pad_sequence collator is only for inference.") + collate_fn = _get_component_from_path(collate_fn) + + sampler = DistributedSampler( + ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 + ) + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + # dropping last avoids shape issues with compile + flex attention + drop_last=True, + collate_fn=( + partial( + collate_fn, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else padded_collate_packed + ), + ) + + if self._is_rank_zero: + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint( + self, + epoch: int, + ) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Model weights with key training.MODEL_KEY + - Relevant recipe state if training is not complete + + Checkpointer will save the model weights and recipe state in + different checkpoint files. To correctly resume training from an intermediate checkpoint, + the model weights and recipe state must be provided. + """ + # final dict passed onto the checkpointer + checkpoint_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + + if self._is_rank_zero: + log.info( + "Saving checkpoint. This may take some time. Retrieving full model state dict..." + ) + start = time.perf_counter() + + # To prevent GPU memory from spiking during checkpoint save, + # we consolidate the full model and optim state dicts on CPU for rank 0 + cpu_state_dict = training.gather_cpu_state_dict( + self._model.state_dict(), + self._is_rank_zero, + device=self._device, + ) + + if self._is_rank_zero: + log.info( + f"Getting full model state dict took {time.perf_counter() - start:.2f} secs" + ) + + if intermediate_checkpoint: + start = time.perf_counter() + if self._is_rank_zero: + log.info("Getting optimizer state dict...") + if not self._optimizer_in_bwd: + opt_state_dict = training.get_full_optimizer_state_dict( + self._optimizer, + self._is_rank_zero, + device=self._device, + ) + else: + opt_state_dict = {} + for param, opt in self._optim_ckpt_wrapper.optim_map.items(): + opt_state_dict[param] = training.get_full_optimizer_state_dict( + opt, self._is_rank_zero, device=self._device + ) + if self._is_rank_zero: + log.info( + f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs" + ) + else: + opt_state_dict = None + + # Now that we have the model and opt state dict, create the actual checkpoint dict + # to be sent to the checkpointer and ultimately written to file + + if self._is_rank_zero: + start = time.perf_counter() + checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict}) + + # if training is in-progress, checkpoint the optimizer state and recipe state + # as well. + if intermediate_checkpoint: + checkpoint_dict.update( + { + training.OPT_KEY: opt_state_dict, + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self.epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + self._checkpointer.save_checkpoint( + checkpoint_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + ) + log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs") + + torch.distributed.barrier() + + def train(self) -> None: + """ + The core training loop. + """ + # clean up before training begins + training.cleanup_before_training() + + world_size, rank = training.get_world_size_and_rank() + + # zero out the gradients before starting training + if not self._optimizer_in_bwd: + self._optimizer.zero_grad() + else: + for opt in self._optim_ckpt_wrapper.optim_map.values(): + opt.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + self._profiler.start() + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # Start tracking CUDA memory for active steps for just the first epoch + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + ): + torch.cuda.memory._record_memory_history() + + utils.batch_to_device(batch, self._device) + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens + + # Shape [b, s], needed for the loss not the model + labels = batch.pop("labels") + + with self.activations_handling_ctx: + logits = self._model(**batch) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + + # Compute loss + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients + current_loss = self._loss_fn(logits, labels) * current_num_tokens + + # free logits otherwise it peaks backward memory + del logits + + running_loss += current_loss + + # For optimizer in backward, we need to normalize before calling backward + # This case and gradient accumulation are mutually exclusive + if self._optimizer_in_bwd: + torch.distributed.all_reduce(num_tokens) + torch.distributed.all_reduce(running_loss) + current_loss = current_loss / num_tokens + + current_loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + if not self._optimizer_in_bwd: + # Get total number of tokens across all ranks to normalize gradients + torch.distributed.all_reduce(num_tokens) + # This will ensure that the logged loss matches what we're optimizing + torch.distributed.all_reduce(running_loss) + # Manually scale the gradients from unnormalized loss by total # of tokens + training.scale_grads(self._model, 1 / num_tokens) + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + # Update the number of steps when the weights are updated + self.global_step += 1 + + loss_to_log = running_loss.item() / num_tokens + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": get_lr( + ( + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper + ), + ), + "tokens_per_second_per_gpu": num_tokens + / (time_per_step * world_size), + } + if self._log_peak_memory_stats: + log_dict.update( + training.get_memory_stats(device=self._device) + ) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + # Stop tracking CUDA memory now that active steps are complete + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + ): + torch.cuda.memory._record_memory_history(enabled=None) + + # Step profiler + # Note that this is called within gradient accumulation block, hence + # will include multiple forward / backward passes if gradient accumulation > 1 + self._profiler.step() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + self._profiler.stop() + + def cleanup(self) -> None: + if self._is_rank_zero: + self._metric_logger.close() + destroy_process_group() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + if not training.is_distributed(): + raise RuntimeError( + "Distributed finetune recipe should be run via a distributed launcher." + "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" + ) + init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + if cfg.get("fsdp_cpu_offload", False): + # Utilize all available CPU cores for intra-op parallelism. This provides ~2x + # speed up when benchmarking fused AdamW on CPU + training.set_torch_num_threads() + + config.log_config(recipe_name="EarlyExitFinetuneRecipeDistributed", cfg=cfg) + + recipe = EarlyExitFinetuneRecipeDistributed(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 395677c08c..98d34b5f94 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -200,18 +200,6 @@ def __init__(self, cfg: DictConfig) -> None: self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - cfg_early_exit = cfg.get("early_exit", None) - # TODO: create a "setup" function similar to setup_model? - # TODO: rename "early_exit" to "early_exit_loss" - if cfg_early_exit: - self.early_exit_layers = cfg_early_exit.get("layers", ":") - self.early_exit_curriculum = cfg_early_exit.get("curriculum", "none") - self.early_exit_scale = cfg_early_exit.get("scale", 1.0) - else: - self.early_exit_layers = None - self.early_exit_curriculum = None - self.early_exit_scale = None - def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ Extract the checkpoint state from file and validate. If resume_from_checkpoint diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index c40e89184b..4742e7cdd3 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -411,6 +411,14 @@ class Recipe: ], supports_distributed=False, ), + Recipe( + name="dev/early_exit_finetune_distributed", + file_path="dev/early_exit_finetune_distributed.py", + configs=[ + Config(name="llama2/7B_full", file_path="llama2/7B_full.yaml"), + ], + supports_distributed=True, + ), Recipe( name="eleuther_eval", file_path="eleuther_eval.py", diff --git a/torchtune/utils/early_exit.py b/torchtune/modules/early_exit_loss.py similarity index 100% rename from torchtune/utils/early_exit.py rename to torchtune/modules/early_exit_loss.py diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index eff3c2eb46..97b50a21e5 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -14,8 +14,6 @@ from torchtune.modules.attention_utils import _MaskType from torchtune.utils._logging import deprecated -from torchtune.modules import LayerDropout, create_layer_dropout_modules - class TransformerSelfAttentionLayer(nn.Module): """ Transformer layer derived from the Llama2 model. Normalization is applied before the attention **and** FF layer. @@ -343,8 +341,6 @@ class TransformerDecoder(nn.Module): before final MLP. output (Union[nn.Linear, Callable]): Callable that applies a linear transformation to the output of the decoder. - layer_dropout_prob (float): Probability of skipping samples in the transformer - layer. num_layers (Optional[int]): Number of Transformer Decoder layers, only define when layers is not a list. output_hidden_states (Optional[List[int]]): List of layers (indices) to include in the output @@ -371,9 +367,6 @@ def __init__( output: Union[nn.Linear, Callable], num_layers: Optional[int] = None, output_hidden_states: Optional[List[int]] = None, - layer_dropout_prob: float = 0.0, - layer_dropout_prob_layer_scale: str = "exp", - layer_dropout_str: str = ":", ) -> None: super().__init__() if isinstance(layers, nn.ModuleList): @@ -402,8 +395,6 @@ def __init__( self.encoder_max_cache_seq_len = None self.decoder_max_cache_seq_len = None - self.layer_dropouts = create_layer_dropout_modules(num_layers, layer_dropout_prob, layer_dropout_prob_layer_scale, layer_dropout_str) - def set_num_output_chunks(self, num_output_chunks: int) -> None: """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. This should be called before the first forward pass, in the recipe.""" From df89c4f699fe004ced0412c180ee8714270f361e Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 13 Nov 2024 05:40:04 +0000 Subject: [PATCH 20/88] port early exit loss code from PR --- .../dev/early_exit_finetune_distributed.py | 32 +++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index f0c7eda610..ba6bbb16f2 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -28,6 +28,9 @@ from torchtune.training.activations import apply_selective_activation_checkpointing from torchtune.training.lr_schedulers import get_lr +from torchtune.modules.early_exit_loss import early_exit_loss, build_early_exit_curriculum +from torchtune.modules.common_utils import slice_str_to_array + from tqdm import tqdm log = utils.get_logger("DEBUG") @@ -740,6 +743,18 @@ def train(self) -> None: running_loss = 0 num_tokens = 0 + # Early exit loss settings + # TODO: move to _init_() or setup() + if self.early_exit_layers: + output_hidden_states = slice_str_to_array(self.early_exit_layers, len(self._model.layers)) + if True: # TODO: add cli option + output_hidden_states[len(self._model.layers) - 1] = True + else: + output_hidden_states = False + + if self.early_exit_curriculum: + self.early_exit_curriculum = build_early_exit_curriculum(self.early_exit_curriculum, output_hidden_states, self.total_epochs*self._steps_per_epoch) + self._profiler.start() # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): @@ -778,7 +793,10 @@ def train(self) -> None: labels = batch.pop("labels") with self.activations_handling_ctx: - logits = self._model(**batch) + if self.early_exit_layers: + logits, hidden_states = self._model(**batch) + else: + logits = self._model(**batch) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] @@ -793,7 +811,10 @@ def train(self) -> None: # Compute loss # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients - current_loss = self._loss_fn(logits, labels) * current_num_tokens + if self.early_exit_layers: + current_loss = early_exit_loss(self._model, hidden_states, labels, self._loss_fn) * current_num_tokens + else: + current_loss = self._loss_fn(logits, labels) * current_num_tokens # free logits otherwise it peaks backward memory del logits @@ -869,6 +890,13 @@ def train(self) -> None: num_tokens = 0 t0 = time.perf_counter() + # Update Early Exit Layers/Scales + if self.early_exit_curriculum: + self.early_exit_curriculum.step() + output_hidden_states = self.early_exit_curriculum.get() + if True: # TODO: add cli option + output_hidden_states[len(self._model.layers) - 1] = True + # Stop tracking CUDA memory now that active steps are complete if ( self._is_rank_zero From 6cedb19fbb6b630618f3090dc9488e0c94f13e2d Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 17 Nov 2024 17:39:57 +0000 Subject: [PATCH 21/88] convert boolean array to indices --- recipes/dev/early_exit_finetune_distributed.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index ba6bbb16f2..09b0c34a1a 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -746,14 +746,13 @@ def train(self) -> None: # Early exit loss settings # TODO: move to _init_() or setup() if self.early_exit_layers: - output_hidden_states = slice_str_to_array(self.early_exit_layers, len(self._model.layers)) + do_output_hidden_states = slice_str_to_array(self.early_exit_layers, len(self._model.layers)) if True: # TODO: add cli option - output_hidden_states[len(self._model.layers) - 1] = True - else: - output_hidden_states = False + do_output_hidden_states[len(self._model.layers) - 1] = True + self._model.output_hidden_states = [i for i in range(len(do_output_hidden_states)) if do_output_hidden_states[i]] if self.early_exit_curriculum: - self.early_exit_curriculum = build_early_exit_curriculum(self.early_exit_curriculum, output_hidden_states, self.total_epochs*self._steps_per_epoch) + self.early_exit_curriculum = build_early_exit_curriculum(self.early_exit_curriculum, self._model.output_hidden_states, self.total_epochs*self._steps_per_epoch) self._profiler.start() # self.epochs_run should be non-zero when we're resuming from a checkpoint @@ -893,9 +892,10 @@ def train(self) -> None: # Update Early Exit Layers/Scales if self.early_exit_curriculum: self.early_exit_curriculum.step() - output_hidden_states = self.early_exit_curriculum.get() + do_output_hidden_states = self.early_exit_curriculum.get() if True: # TODO: add cli option - output_hidden_states[len(self._model.layers) - 1] = True + do_output_hidden_states[len(self._model.layers) - 1] = True + self._model.output_hidden_states = do_output_hidden_states # Stop tracking CUDA memory now that active steps are complete if ( From a83da5a24ace390b1317fe9fd48019a09dc2271e Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 17 Nov 2024 17:40:49 +0000 Subject: [PATCH 22/88] decide on hidden outputs by member variable not forward pass --- recipes/dev/early_exit_finetune_distributed.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 09b0c34a1a..420db4e9ed 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -792,10 +792,12 @@ def train(self) -> None: labels = batch.pop("labels") with self.activations_handling_ctx: + outputs = self._model(**batch) if self.early_exit_layers: - logits, hidden_states = self._model(**batch) + logits = outputs.pop(-1) + hidden_states = {i:h for i,h in zip(self._model.output_hidden_states, outputs)} else: - logits = self._model(**batch) + logits = outputs # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] From 2a8791d2ff015871d17a87286e106ea537f30a91 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 17 Nov 2024 17:42:05 +0000 Subject: [PATCH 23/88] add early exit recipe config --- recipes/dev/7B_full_early_exit.yaml | 111 ++++++++++++++++++++++++++++ torchtune/_recipe_registry.py | 2 +- 2 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 recipes/dev/7B_full_early_exit.yaml diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml new file mode 100644 index 0000000000..7c41e3653b --- /dev/null +++ b/recipes/dev/7B_full_early_exit.yaml @@ -0,0 +1,111 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama2/7B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama2/7B_full checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# Single device full finetuning requires more memory optimizations. It's +# best to use 7B_full_single_device.yaml for those cases + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-7b-hf/tokenizer.model + max_seq_len: null + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama2.llama2_7b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] + recipe_checkpoint: null + output_dir: /tmp/Llama-2-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-llama2-finetune +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 + +# early exit loss +early_exit: + layers: "0:10" + curriculum: "gradual" + scale: "one" \ No newline at end of file diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 4742e7cdd3..d72b2ea69d 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -415,7 +415,7 @@ class Recipe: name="dev/early_exit_finetune_distributed", file_path="dev/early_exit_finetune_distributed.py", configs=[ - Config(name="llama2/7B_full", file_path="llama2/7B_full.yaml"), + Config(name="llama2/7B_full_early_exit", file_path="dev/7B_full_early_exit.yaml"), ], supports_distributed=True, ), From a326937aa56eb5e74912afabefa334abfc7e5413 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 17 Nov 2024 17:42:29 +0000 Subject: [PATCH 24/88] refactor unembedding --- torchtune/modules/early_exit_loss.py | 9 ++++++--- torchtune/modules/transformer.py | 12 +++++++++--- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index 9ef3b0a915..0f089f9797 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -33,11 +33,14 @@ def early_exit_loss(model, hidden_states_dict, labels, loss_fn, e_scale: float=1 # Shape: [e, b, s, d] hidden_states_stacked = torch.stack(hidden_states) # Shape: [e, b, s, out_dim] - logits_early = model.output(model.norm(hidden_states_stacked)) - logits_early = logits_early[..., :-1, :].contiguous() + logits_early = model.unembed(hidden_states_stacked) + if not isinstance(logits_early, list): + labels = labels.reshape(-1) + logits_early = logits_early.reshape(-1, logits_early.size(-1)) + ###### logits_early = logits_early[..., :-1, :].contiguous() # Shape: [e*b, s, out_dim] logits_early = logits_early.flatten(0, 1) - logits_early = logits_early.transpose(1, 2) + ###### logits_early = logits_early.transpose(1, 2) # Shape: [e, b*s] labels_repeated = labels.repeat(e, 1) # Compute early losses: Shape: [e*b, s] diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 97b50a21e5..6fe558d112 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -642,6 +642,15 @@ def forward( input_pos=input_pos, ) + # shape: [b, seq_len, out_dim] + output = self.unembed(h) + + # Output list if hidden states are requested, otherwise just the output + # TODO: always output a list to have a consistent output type + output = output if not hidden else [*hidden, output] + return output + + def unembed(self, h): # shape: [b, s, d] h = self.norm(h) @@ -651,9 +660,6 @@ def forward( # shape: [b, seq_len, out_dim] output = self.output(h).float() - # Output list if hidden states are requested, otherwise just the output - # TODO: always output a list to have a consistent output type - output = output if not hidden else [*hidden, output] return output From 8ba6ab46b4c7f69048a01cd1a1fb67aeda9a2e02 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Mon, 18 Nov 2024 16:09:38 +0000 Subject: [PATCH 25/88] got early exit loss to work --- recipes/dev/7B_full_early_exit.yaml | 2 +- .../dev/early_exit_finetune_distributed.py | 2 +- torchtune/modules/early_exit_loss.py | 46 +++++++++---------- 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml index 7c41e3653b..1741383a3d 100644 --- a/recipes/dev/7B_full_early_exit.yaml +++ b/recipes/dev/7B_full_early_exit.yaml @@ -55,7 +55,7 @@ optimizer: fused: True lr: 2e-5 loss: - _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + _component_: torch.nn.CrossEntropyLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 # Use to increase virtual batch size compile: False # pytorch compile, set to true for better perf/memory diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 420db4e9ed..901a986970 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -752,7 +752,7 @@ def train(self) -> None: self._model.output_hidden_states = [i for i in range(len(do_output_hidden_states)) if do_output_hidden_states[i]] if self.early_exit_curriculum: - self.early_exit_curriculum = build_early_exit_curriculum(self.early_exit_curriculum, self._model.output_hidden_states, self.total_epochs*self._steps_per_epoch) + self.early_exit_curriculum = build_early_exit_curriculum(self.early_exit_curriculum, do_output_hidden_states, self.total_epochs*self._steps_per_epoch) self._profiler.start() # self.epochs_run should be non-zero when we're resuming from a checkpoint diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index 0f089f9797..9a7d6331cc 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -22,6 +22,8 @@ class LossScaleType(str, Enum): SQRT_L = "sqrt_l" INV_SQRT_L = "inv_sqrt_l" +# TODO: create docstring using other functions as template +# TODO: add assert on type of loss_fn def early_exit_loss(model, hidden_states_dict, labels, loss_fn, e_scale: float=1.0, loss_scale_type=LossScaleType.SUM_L): batch_loss_fn = copy.deepcopy(loss_fn) batch_loss_fn.reduction = "none" @@ -34,16 +36,12 @@ def early_exit_loss(model, hidden_states_dict, labels, loss_fn, e_scale: float=1 hidden_states_stacked = torch.stack(hidden_states) # Shape: [e, b, s, out_dim] logits_early = model.unembed(hidden_states_stacked) - if not isinstance(logits_early, list): - labels = labels.reshape(-1) - logits_early = logits_early.reshape(-1, logits_early.size(-1)) - ###### logits_early = logits_early[..., :-1, :].contiguous() - # Shape: [e*b, s, out_dim] - logits_early = logits_early.flatten(0, 1) - ###### logits_early = logits_early.transpose(1, 2) - # Shape: [e, b*s] - labels_repeated = labels.repeat(e, 1) - # Compute early losses: Shape: [e*b, s] + # Shape: [e*b*s, out_dim] + logits_early = logits_early.reshape(-1, logits_early.size(-1)) + logits_early = logits_early.contiguous() + # Shape: [e*b*s] + labels_repeated = labels.repeat(e, 1).reshape(-1) + # Compute early losses: Shape: [e*b*s] losses_early = batch_loss_fn(logits_early, labels_repeated) # Shape: [e, b*s] losses_early = losses_early.view(e, -1) @@ -103,9 +101,9 @@ def build_early_exit_curriculum(early_exit_curriculum: EarlyExitCurriculumType, # TODO: create a base curriculum class that can be used for other aspects, e.g., dropout, datasets, etc. class EarlyExitCurriculum(): - def __init__(self, output_hidden_states, max_steps, verbose=False): - self._init_output_hidden_states = output_hidden_states - self.output_hidden_states = output_hidden_states + def __init__(self, do_output_hidden_states, max_steps, verbose=False): + self._init_do_output_hidden_states = do_output_hidden_states + self.do_output_hidden_states = do_output_hidden_states self.verbose = verbose self.max_steps = max_steps @@ -113,32 +111,32 @@ def step(self): pass def get(self): - return self.output_hidden_states + return self.do_output_hidden_states class RotationalEarlyExitCurriculum(EarlyExitCurriculum): - def __init__(self, output_hidden_states, max_steps, verbose=False): - super().__init__(output_hidden_states, max_steps, verbose) + def __init__(self, do_output_hidden_states, max_steps, verbose=False): + super().__init__(do_output_hidden_states, max_steps, verbose) def step(self): - self.output_hidden_states = np.roll(self.output_hidden_states, -1) + self.do_output_hidden_states = np.roll(self.do_output_hidden_states, -1) if self.verbose: - log.info(f"Updating self.output_hidden_states to {self.output_hidden_states}.") + log.info(f"Updating self.output_hidden_states to {self.do_output_hidden_states}.") class GradualEarlyExitCurriculum(EarlyExitCurriculum): - def __init__(self, output_hidden_states, max_steps, verbose=False): - super().__init__(output_hidden_states, max_steps, verbose) + def __init__(self, do_output_hidden_states, max_steps, verbose=False): + super().__init__(do_output_hidden_states, max_steps, verbose) self._step = 0 def step(self): percent_trained = self._step / self.max_steps - n_layers = len(self.output_hidden_states) - for layer_index in range(len(self.output_hidden_states)): + n_layers = len(self.do_output_hidden_states) + for layer_index in range(len(self.do_output_hidden_states)): # TODO: replace 2 with an argument should_train = (percent_trained * 2) >= ((n_layers - 1 - layer_index) / (n_layers - 1)) - self.output_hidden_states[layer_index] = should_train + self.do_output_hidden_states[layer_index] = should_train # TODO: move this to step() in parent class? # TODO: how to ensure we always call parent step() in derived class? self._step += 1 if self.verbose: - log.info(f"Updating self.output_hidden_states to {self.output_hidden_states}.") \ No newline at end of file + log.info(f"Updating self.do_output_hidden_states to {self.do_output_hidden_states}.") From 681e7cada8e21a339909c850a5f094818ed9a6c5 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Tue, 19 Nov 2024 03:59:17 +0000 Subject: [PATCH 26/88] add TopV2 instruction set --- recipes/dev/7B_full_early_exit.yaml | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml index 1741383a3d..c3c876cb18 100644 --- a/recipes/dev/7B_full_early_exit.yaml +++ b/recipes/dev/7B_full_early_exit.yaml @@ -25,9 +25,17 @@ tokenizer: max_seq_len: null # Dataset +# dataset: +# _component_: torchtune.datasets.alpaca_dataset +# packed: False # True increases speed dataset: - _component_: torchtune.datasets.alpaca_dataset - packed: False # True increases speed + _component_: torchtune.datasets.instruct_dataset + source: WillHeld/top_v2 + split: train + column_map: + input: utterance + output: semantic_parse + seed: null shuffle: True From 119ac7d234ae2d87f8b03daf17587f7da1c3fa04 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Tue, 19 Nov 2024 05:13:08 +0000 Subject: [PATCH 27/88] ensure all early exit loss params from cfg file are passed to code --- recipes/dev/7B_full_early_exit.yaml | 4 +++- recipes/dev/early_exit_finetune_distributed.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml index c3c876cb18..fabb9b0d4f 100644 --- a/recipes/dev/7B_full_early_exit.yaml +++ b/recipes/dev/7B_full_early_exit.yaml @@ -113,7 +113,9 @@ profiler: num_cycles: 1 # early exit loss +# TODO: rename this and variables to early exit loss early_exit: layers: "0:10" curriculum: "gradual" - scale: "one" \ No newline at end of file + scale_type: "one" + scale: 1.0 diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 901a986970..acb8058c4a 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -209,10 +209,12 @@ def __init__(self, cfg: DictConfig) -> None: self.early_exit_layers = cfg_early_exit.get("layers", ":") self.early_exit_curriculum = cfg_early_exit.get("curriculum", "none") self.early_exit_scale = cfg_early_exit.get("scale", 1.0) + self.early_exit_scale_type = cfg_early_exit.get("scale_type", "one") else: self.early_exit_layers = None self.early_exit_curriculum = None self.early_exit_scale = None + self.early_exit_scale_type = None def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ @@ -813,7 +815,7 @@ def train(self) -> None: # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients if self.early_exit_layers: - current_loss = early_exit_loss(self._model, hidden_states, labels, self._loss_fn) * current_num_tokens + current_loss = early_exit_loss(self._model, hidden_states, labels, self._loss_fn, self.early_exit_scale, self.early_exit_scale_type) * current_num_tokens else: current_loss = self._loss_fn(logits, labels) * current_num_tokens From 3ec9d23b70129eeab95f739051ed7419959a8586 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Tue, 19 Nov 2024 05:13:55 +0000 Subject: [PATCH 28/88] fix gradual early exit --- torchtune/modules/early_exit_loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index 9a7d6331cc..5adaf52bc2 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -123,16 +123,16 @@ def step(self): log.info(f"Updating self.output_hidden_states to {self.do_output_hidden_states}.") class GradualEarlyExitCurriculum(EarlyExitCurriculum): - def __init__(self, do_output_hidden_states, max_steps, verbose=False): + def __init__(self, do_output_hidden_states, max_steps, percent_scale=2, verbose=False): super().__init__(do_output_hidden_states, max_steps, verbose) self._step = 0 + self._percent_scale = percent_scale def step(self): percent_trained = self._step / self.max_steps n_layers = len(self.do_output_hidden_states) for layer_index in range(len(self.do_output_hidden_states)): - # TODO: replace 2 with an argument - should_train = (percent_trained * 2) >= ((n_layers - 1 - layer_index) / (n_layers - 1)) + should_train = (percent_trained * self._percent_scale) >= (n_layers - layer_index) / n_layers self.do_output_hidden_states[layer_index] = should_train # TODO: move this to step() in parent class? From 04a590f5a6255f8b87bf331436631d5547fb30e4 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Tue, 19 Nov 2024 05:14:26 +0000 Subject: [PATCH 29/88] add test cases for early exit loss --- .../torchtune/modules/test_early_exit_loss.py | 129 ++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 tests/torchtune/modules/test_early_exit_loss.py diff --git a/tests/torchtune/modules/test_early_exit_loss.py b/tests/torchtune/modules/test_early_exit_loss.py new file mode 100644 index 0000000000..f4e8c807c1 --- /dev/null +++ b/tests/torchtune/modules/test_early_exit_loss.py @@ -0,0 +1,129 @@ +import numpy as np +import pytest +import torch +import torch.nn as nn +from torchtune import utils +from torchtune.modules import TransformerDecoder +from torchtune.modules.early_exit_loss import ( + early_exit_loss, + layer_ids_to_loss_scales, + LossScaleType, + EarlyExitCurriculumType, + build_early_exit_curriculum, + RotationalEarlyExitCurriculum, + GradualEarlyExitCurriculum, +) + +# Mock components for TransformerDecoder +class MockLayer(nn.Module): + def forward(self, x, mask=None, encoder_input=None, encoder_mask=None, input_pos=None): + return x # Simply return the input for testing purposes + +@pytest.fixture +def mock_model(): + # Create mock components + tok_embeddings = nn.Embedding(1000, 512) # Example vocab size and embedding dim + layers = nn.ModuleList([MockLayer() for _ in range(12)]) # 12 mock layers + norm = nn.LayerNorm(512) # Example layer normalization + output = nn.Linear(512, 1000) # Example output layer + + # Create an instance of TransformerDecoder + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layers, + max_seq_len=512, + num_heads=8, + head_dim=64, + norm=norm, + output=output, + num_layers=12, + output_hidden_states=[0, 1, 2] # Example layers to output hidden states + ) + return model + +@pytest.fixture +def hidden_states_dict(): + return {i: torch.randn(4, 5, 512) for i in range(3)} # Adjusted embedding dim + +@pytest.fixture +def labels(): + return torch.randint(0, 1000, (4, 5)) # Adjusted vocab size + +@pytest.fixture +def loss_fn(): + return nn.CrossEntropyLoss(ignore_index=-1) + +def test_early_exit_loss(mock_model, hidden_states_dict, labels, loss_fn): + loss = early_exit_loss(mock_model, hidden_states_dict, labels, loss_fn) + assert isinstance(loss, torch.Tensor) + assert loss.item() >= 0 + +def test_layer_ids_to_loss_scales(): + layer_ids = torch.tensor([0, 1, 2]) + n_layers = 12 + scales = layer_ids_to_loss_scales(layer_ids, n_layers, LossScaleType.SUM_L, 1.0) + assert torch.isclose(scales.sum(), torch.tensor(1.0)) + +def test_build_early_exit_curriculum(): + curriculum = build_early_exit_curriculum(EarlyExitCurriculumType.ROTATIONAL, [True, False, True], 100) + assert isinstance(curriculum, RotationalEarlyExitCurriculum) + + curriculum = build_early_exit_curriculum(EarlyExitCurriculumType.GRADUAL, [True, False, True], 100) + assert isinstance(curriculum, GradualEarlyExitCurriculum) + +def test_rotational_early_exit_curriculum(): + curriculum = RotationalEarlyExitCurriculum([True, False, True], 100) + curriculum.step() + expected = np.array([False, True, True]) + assert np.array_equal(curriculum.get(), expected), f"Expected {expected}, but got {curriculum.get()}" + +def test_gradual_early_exit_curriculum(): + curriculum = GradualEarlyExitCurriculum([False, False, False, False], max_steps=4, percent_scale=1) + curriculum.step() + assert curriculum.get() == [False, False, False, False] + curriculum.step() + assert curriculum.get() == [False, False, False, True] + curriculum.step() + assert curriculum.get() == [False, False, True, True] + curriculum.step() + assert curriculum.get() == [False, True, True, True] + curriculum.step() + assert curriculum.get() == [True, True, True, True] + curriculum.step() + assert curriculum.get() == [True, True, True, True] + +@pytest.fixture +def hidden_states_dict(): + return {i: torch.randn(4, 5, 512) for i in range(3)} # Adjusted embedding dim + +@pytest.fixture +def labels(): + return torch.randint(0, 1000, (4, 5)) # Adjusted vocab size + +@pytest.fixture +def loss_fn(): + return nn.CrossEntropyLoss(ignore_index=-1) + +def test_early_exit_loss_vs_manual(mock_model, hidden_states_dict, labels, loss_fn): + # Convert to float32 for numeric equivalence + # Calculate early exit loss using the function + calculated_loss = early_exit_loss(mock_model, hidden_states_dict, labels, loss_fn, e_scale=1, loss_scale_type="one") + # Manually calculate the loss for each hidden state + total_loss = 0.0 + num_hidden_states = len(hidden_states_dict) + for i, hidden_state in hidden_states_dict.items(): + # Compute logits for the current hidden state + logits = mock_model.unembed(hidden_state) + logits = logits.view(-1, logits.size(-1)) # Flatten for loss computation + labels = labels.view(-1) + # Compute the loss for the current hidden state + loss = loss_fn(logits, labels) + total_loss += loss + # Average the losses across all hidden states + manual_loss = total_loss / num_hidden_states + # Compare the two losses + assert torch.isclose(calculated_loss, manual_loss, atol=1e-6), \ + f"Calculated loss: {calculated_loss}, Manual loss: {manual_loss}" + +if __name__ == "__main__": + pytest.main() From 9b5c96afdb40fbeef7ad0d148b029583047d2c20 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Tue, 19 Nov 2024 05:18:16 +0000 Subject: [PATCH 30/88] add more assertions for rotational early exit --- tests/torchtune/modules/test_early_exit_loss.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/torchtune/modules/test_early_exit_loss.py b/tests/torchtune/modules/test_early_exit_loss.py index f4e8c807c1..ab09646e04 100644 --- a/tests/torchtune/modules/test_early_exit_loss.py +++ b/tests/torchtune/modules/test_early_exit_loss.py @@ -75,7 +75,13 @@ def test_rotational_early_exit_curriculum(): curriculum = RotationalEarlyExitCurriculum([True, False, True], 100) curriculum.step() expected = np.array([False, True, True]) - assert np.array_equal(curriculum.get(), expected), f"Expected {expected}, but got {curriculum.get()}" + assert np.array_equal(curriculum.get(), expected) + curriculum.step() + expected = np.array([True, True, False]) + assert np.array_equal(curriculum.get(), expected) + curriculum.step() + expected = np.array([True, False, True]) + assert np.array_equal(curriculum.get(), expected) def test_gradual_early_exit_curriculum(): curriculum = GradualEarlyExitCurriculum([False, False, False, False], max_steps=4, percent_scale=1) From 3319ab00f510d7ac4e7c3345a3c3d98f5bba9175 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Tue, 19 Nov 2024 05:26:45 +0000 Subject: [PATCH 31/88] test to follow training code --- tests/torchtune/modules/test_early_exit_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/torchtune/modules/test_early_exit_loss.py b/tests/torchtune/modules/test_early_exit_loss.py index ab09646e04..ef828a5e02 100644 --- a/tests/torchtune/modules/test_early_exit_loss.py +++ b/tests/torchtune/modules/test_early_exit_loss.py @@ -120,8 +120,8 @@ def test_early_exit_loss_vs_manual(mock_model, hidden_states_dict, labels, loss_ for i, hidden_state in hidden_states_dict.items(): # Compute logits for the current hidden state logits = mock_model.unembed(hidden_state) - logits = logits.view(-1, logits.size(-1)) # Flatten for loss computation - labels = labels.view(-1) + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) # Compute the loss for the current hidden state loss = loss_fn(logits, labels) total_loss += loss From 619b3eb99e5a94724dea6dac11bd8832dda919c1 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 20 Nov 2024 22:41:03 +0000 Subject: [PATCH 32/88] fix curriculum update --- recipes/dev/early_exit_finetune_distributed.py | 2 +- torchtune/modules/early_exit_loss.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index acb8058c4a..eacdac8384 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -899,7 +899,7 @@ def train(self) -> None: do_output_hidden_states = self.early_exit_curriculum.get() if True: # TODO: add cli option do_output_hidden_states[len(self._model.layers) - 1] = True - self._model.output_hidden_states = do_output_hidden_states + self._model.output_hidden_states = [i for i in range(len(do_output_hidden_states)) if do_output_hidden_states[i]] # Stop tracking CUDA memory now that active steps are complete if ( diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index 5adaf52bc2..776d6b52bb 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -133,6 +133,7 @@ def step(self): n_layers = len(self.do_output_hidden_states) for layer_index in range(len(self.do_output_hidden_states)): should_train = (percent_trained * self._percent_scale) >= (n_layers - layer_index) / n_layers + # TODO: either handle if layers_str != ":", or add an assert statement layers_str == ":" self.do_output_hidden_states[layer_index] = should_train # TODO: move this to step() in parent class? From d376dddf47f5a3f1673b0a6dcf850a4ef1cddd12 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Thu, 21 Nov 2024 19:38:20 +0000 Subject: [PATCH 33/88] update recipe --- recipes/dev/7B_full_early_exit.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml index fabb9b0d4f..4850669a6d 100644 --- a/recipes/dev/7B_full_early_exit.yaml +++ b/recipes/dev/7B_full_early_exit.yaml @@ -115,7 +115,7 @@ profiler: # early exit loss # TODO: rename this and variables to early exit loss early_exit: - layers: "0:10" - curriculum: "gradual" - scale_type: "one" + layers: "0::8" + curriculum: "rot" + scale_type: "sum_l" scale: 1.0 From ff3977bf035e995bdc682a0f7f98af4b1878deb5 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Thu, 21 Nov 2024 19:40:42 +0000 Subject: [PATCH 34/88] reset changes to data loading --- torchtune/datasets/_text_completion.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/torchtune/datasets/_text_completion.py b/torchtune/datasets/_text_completion.py index f84e74ef98..5b5cc94299 100644 --- a/torchtune/datasets/_text_completion.py +++ b/torchtune/datasets/_text_completion.py @@ -56,16 +56,10 @@ def __init__( self._data = self._data.filter(filter_fn) def __len__(self): - if self._num_samples is None or not self._streaming: - return len(self._data) - else: - return self._num_samples + return len(self._data) def __getitem__(self, index: int) -> Dict[str, List[int]]: - if self._streaming: - sample = next(self._data_itr) - else: - sample = self._data[index] + sample = self._data[index] return self._prepare_sample(sample) def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]: From 75b2e013a89f1c3a8392c012574e11a03f369c91 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sat, 23 Nov 2024 04:50:00 +0000 Subject: [PATCH 35/88] code cleanup --- torchtune/modules/__init__.py | 2 +- torchtune/modules/early_exit_loss.py | 2 +- torchtune/modules/transformer.py | 19 +++++++++++++------ 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 9cc8570c1d..509c7f7e8c 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -21,6 +21,7 @@ from .rms_norm import RMSNorm # noqa from .tanh_gate import TanhGate # noqa from .tied_linear import TiedLinear # noqa +from .layer_dropout import LayerDropout, create_layer_dropout_modules # noqa from .transformer import ( # noqa TiedEmbeddingTransformerDecoder, TransformerCrossAttentionLayer, @@ -28,7 +29,6 @@ TransformerSelfAttentionLayer, ) from .vision_transformer import VisionTransformer -from .layer_dropout import LayerDropout, create_layer_dropout_modules # noqa __all__ = [ diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index 776d6b52bb..23ec776458 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -68,7 +68,7 @@ def layer_ids_to_loss_scales(layer_ids, n_layers, loss_scale_type: LossScaleType case LossScaleType.INV_L: loss_scales = 1.0 / (layer_ids+1) case LossScaleType.INV_SQRT_L: - loss_scales = 1.0 / torch.sqrt(layer_ids+1) + loss_scales = torch.reciprocal(torch.sqrt(layer_ids+1)) case _: raise ValueError(f"Unsupported loss_scale type {loss_scale_type}") diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 6fe558d112..b19f14b5ef 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import copy -from collections import OrderedDict from typing import Callable, Dict, List, Optional, Union import torch @@ -839,7 +838,6 @@ def forward( encoder_input: Optional[torch.Tensor] = None, encoder_mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, - output_hidden_states: Union[bool, List[bool]] = False, ) -> Union[torch.Tensor, List[torch.Tensor]]: """ Args: @@ -934,6 +932,18 @@ def forward( input_pos=input_pos, ) + # shape: [b, seq_len, out_dim] + output = self.unembed(h) + + # Output list if hidden states are requested, otherwise just the output + # TODO: always output a list to have a consistent output type + output = output if not hidden else [*hidden, output] + return output + + def unembed(self, h): + # shape: [b, s, d] + h = self.norm(h) + # shape: [b, s, d] h = self.norm(h) @@ -943,7 +953,4 @@ def forward( # shape: [b, seq_len, out_dim] output = F.linear(h, self.tok_embeddings.weight).float() - # Output list if hidden states are requested, otherwise just the output - # TODO: always output a list to have a consistent output type - output = output if not hidden else [*hidden, output] - return output + return output \ No newline at end of file From 33a95f57be7e585b5ad9534d71da89428112a5d1 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sat, 23 Nov 2024 06:02:38 +0000 Subject: [PATCH 36/88] rename early_exit to early_exit_loss --- recipes/dev/7B_full_early_exit.yaml | 5 ++--- recipes/dev/early_exit_finetune_distributed.py | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml index 4850669a6d..547eb9e3c2 100644 --- a/recipes/dev/7B_full_early_exit.yaml +++ b/recipes/dev/7B_full_early_exit.yaml @@ -112,9 +112,8 @@ profiler: active_steps: 2 num_cycles: 1 -# early exit loss -# TODO: rename this and variables to early exit loss -early_exit: +# Early Exit Loss +early_exit_loss: layers: "0::8" curriculum: "rot" scale_type: "sum_l" diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index eacdac8384..6f5bab9ba7 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -202,9 +202,8 @@ def __init__(self, cfg: DictConfig) -> None: self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - cfg_early_exit = cfg.get("early_exit", None) + cfg_early_exit = cfg.get("early_exit_loss", None) # TODO: create a "setup" function similar to setup_model? - # TODO: rename "early_exit" to "early_exit_loss" if cfg_early_exit: self.early_exit_layers = cfg_early_exit.get("layers", ":") self.early_exit_curriculum = cfg_early_exit.get("curriculum", "none") From 5d7e903a8f4fe8f15b066205e9f790d5bd0ae947 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sat, 23 Nov 2024 14:51:15 +0000 Subject: [PATCH 37/88] address some early exit TODOs --- recipes/dev/7B_full_early_exit.yaml | 4 +- .../dev/early_exit_finetune_distributed.py | 80 ++++++++++++------- 2 files changed, 52 insertions(+), 32 deletions(-) diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml index 547eb9e3c2..c2de2ce364 100644 --- a/recipes/dev/7B_full_early_exit.yaml +++ b/recipes/dev/7B_full_early_exit.yaml @@ -114,7 +114,7 @@ profiler: # Early Exit Loss early_exit_loss: - layers: "0::8" - curriculum: "rot" + layers: ":" + curriculum: "gradual" scale_type: "sum_l" scale: 1.0 diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 6f5bab9ba7..544a3b4295 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -28,7 +28,7 @@ from torchtune.training.activations import apply_selective_activation_checkpointing from torchtune.training.lr_schedulers import get_lr -from torchtune.modules.early_exit_loss import early_exit_loss, build_early_exit_curriculum +from torchtune.modules.early_exit_loss import early_exit_loss, build_early_exit_curriculum, EarlyExitCurriculum from torchtune.modules.common_utils import slice_str_to_array from tqdm import tqdm @@ -202,18 +202,15 @@ def __init__(self, cfg: DictConfig) -> None: self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - cfg_early_exit = cfg.get("early_exit_loss", None) - # TODO: create a "setup" function similar to setup_model? - if cfg_early_exit: - self.early_exit_layers = cfg_early_exit.get("layers", ":") - self.early_exit_curriculum = cfg_early_exit.get("curriculum", "none") - self.early_exit_scale = cfg_early_exit.get("scale", 1.0) - self.early_exit_scale_type = cfg_early_exit.get("scale_type", "one") + cfg_early_exit_loss = cfg.get("early_exit_loss", None) + if cfg_early_exit_loss: + self._do_early_exit_loss = True + self._early_exit_loss_scale = cfg_early_exit_loss.get("scale", 1.0) + self._early_exit_loss_scale_type = cfg_early_exit_loss.get("scale_type", "one") else: - self.early_exit_layers = None - self.early_exit_curriculum = None - self.early_exit_scale = None - self.early_exit_scale_type = None + self._do_early_exit_loss = False + self._early_exit_loss_scale = None + self._early_exit_loss_scale_type = None def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ @@ -356,6 +353,9 @@ def setup(self, cfg: DictConfig) -> None: (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device ) + # Setup early exit loss + self._do_output_hidden_states, self._early_exit_loss_curriculum = self._setup_early_exit_loss(cfg.get("early_exit_loss", None)) + def _setup_profiler( self, cfg_profiler: Optional[DictConfig] = None ) -> Union[torch.profiler.profile, DummyProfiler]: @@ -633,6 +633,33 @@ def _setup_data( return sampler, dataloader + def _setup_early_exit_loss( + self, + cfg_early_exit_loss: DictConfig, + ) -> Tuple[List[bool], EarlyExitCurriculum]: + """ + All early exit loss related setup happens here. + """ + do_output_hidden_states = None + early_exit_loss_curriculum = None + + if cfg_early_exit_loss: + do_output_hidden_states = slice_str_to_array(cfg_early_exit_loss.get("layers", ":"), len(self._model.layers)) + # TODO: add cli option + # TODO: move this statement to inside curriculum + if True: + do_output_hidden_states[len(self._model.layers) - 1] = True + + # TODO: rename build_early_exit_curriculum to setup_early_exit_loss_curriculum + if cfg_early_exit_loss.curriculum: + early_exit_loss_curriculum = build_early_exit_curriculum(cfg_early_exit_loss.curriculum, do_output_hidden_states, self.total_epochs*self._steps_per_epoch) + else: + early_exit_loss_curriculum = None + + # TODO: get initial do_output_hidden_states from curriculum + + return do_output_hidden_states, early_exit_loss_curriculum + def save_checkpoint( self, epoch: int, @@ -744,16 +771,9 @@ def train(self) -> None: running_loss = 0 num_tokens = 0 - # Early exit loss settings - # TODO: move to _init_() or setup() - if self.early_exit_layers: - do_output_hidden_states = slice_str_to_array(self.early_exit_layers, len(self._model.layers)) - if True: # TODO: add cli option - do_output_hidden_states[len(self._model.layers) - 1] = True - self._model.output_hidden_states = [i for i in range(len(do_output_hidden_states)) if do_output_hidden_states[i]] - - if self.early_exit_curriculum: - self.early_exit_curriculum = build_early_exit_curriculum(self.early_exit_curriculum, do_output_hidden_states, self.total_epochs*self._steps_per_epoch) + # Initialize output hidden states + if self._do_output_hidden_states: + self._model.output_hidden_states = [i for i in range(len(self._do_output_hidden_states)) if self._do_output_hidden_states[i]] self._profiler.start() # self.epochs_run should be non-zero when we're resuming from a checkpoint @@ -794,7 +814,7 @@ def train(self) -> None: with self.activations_handling_ctx: outputs = self._model(**batch) - if self.early_exit_layers: + if self._do_early_exit_loss: logits = outputs.pop(-1) hidden_states = {i:h for i,h in zip(self._model.output_hidden_states, outputs)} else: @@ -813,8 +833,8 @@ def train(self) -> None: # Compute loss # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients - if self.early_exit_layers: - current_loss = early_exit_loss(self._model, hidden_states, labels, self._loss_fn, self.early_exit_scale, self.early_exit_scale_type) * current_num_tokens + if self._do_early_exit_loss: + current_loss = early_exit_loss(self._model, hidden_states, labels, self._loss_fn, self._early_exit_loss_scale, self._early_exit_loss_scale_type) * current_num_tokens else: current_loss = self._loss_fn(logits, labels) * current_num_tokens @@ -893,12 +913,12 @@ def train(self) -> None: t0 = time.perf_counter() # Update Early Exit Layers/Scales - if self.early_exit_curriculum: - self.early_exit_curriculum.step() - do_output_hidden_states = self.early_exit_curriculum.get() + if self._early_exit_loss_curriculum: + self._early_exit_loss_curriculum.step() + self._do_output_hidden_states = self._early_exit_loss_curriculum.get() if True: # TODO: add cli option - do_output_hidden_states[len(self._model.layers) - 1] = True - self._model.output_hidden_states = [i for i in range(len(do_output_hidden_states)) if do_output_hidden_states[i]] + self._do_output_hidden_states[len(self._model.layers) - 1] = True + self._model.output_hidden_states = [i for i in range(len(self._do_output_hidden_states)) if self._do_output_hidden_states[i]] # Stop tracking CUDA memory now that active steps are complete if ( From 87f2ee0990f525792e1a2f959ff5d9ab7a2b5694 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sat, 23 Nov 2024 16:22:15 +0000 Subject: [PATCH 38/88] get layer dropout to work --- recipes/dev/7B_full_early_exit.yaml | 7 +++ .../dev/early_exit_finetune_distributed.py | 8 ++++ torchtune/modules/__init__.py | 2 +- torchtune/modules/layer_dropout.py | 48 +++++++++++++++---- 4 files changed, 56 insertions(+), 9 deletions(-) diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml index c2de2ce364..69e66d897c 100644 --- a/recipes/dev/7B_full_early_exit.yaml +++ b/recipes/dev/7B_full_early_exit.yaml @@ -118,3 +118,10 @@ early_exit_loss: curriculum: "gradual" scale_type: "sum_l" scale: 1.0 + +# Layer Dropout +layer_dropout: + prob: 0.5 + layers: ":" + layers_scale: "exp" + disable_on_eval: True diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 544a3b4295..2f1b466f16 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -29,6 +29,7 @@ from torchtune.training.lr_schedulers import get_lr from torchtune.modules.early_exit_loss import early_exit_loss, build_early_exit_curriculum, EarlyExitCurriculum +from torchtune.modules.layer_dropout import apply_layer_dropout_modules from torchtune.modules.common_utils import slice_str_to_array from tqdm import tqdm @@ -202,6 +203,7 @@ def __init__(self, cfg: DictConfig) -> None: self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 + # Early Exit Properties cfg_early_exit_loss = cfg.get("early_exit_loss", None) if cfg_early_exit_loss: self._do_early_exit_loss = True @@ -356,6 +358,12 @@ def setup(self, cfg: DictConfig) -> None: # Setup early exit loss self._do_output_hidden_states, self._early_exit_loss_curriculum = self._setup_early_exit_loss(cfg.get("early_exit_loss", None)) + # Layer Dropout Setup + # TODO: move to a setup function? + cfg_layer_dropout = cfg.get("layer_dropout", None) + if cfg_layer_dropout: + apply_layer_dropout_modules(self._model, prob_max=cfg_layer_dropout.get("prob", 0.0), prob_layer_scale=cfg_layer_dropout.get("layers_scale", "uniform"), layers_str=cfg_layer_dropout.get("layers", ":"), disable_on_eval=cfg_layer_dropout.get("disable_on_eval", True)) + def _setup_profiler( self, cfg_profiler: Optional[DictConfig] = None ) -> Union[torch.profiler.profile, DummyProfiler]: diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 509c7f7e8c..86a1de3ba8 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -21,7 +21,7 @@ from .rms_norm import RMSNorm # noqa from .tanh_gate import TanhGate # noqa from .tied_linear import TiedLinear # noqa -from .layer_dropout import LayerDropout, create_layer_dropout_modules # noqa +from .layer_dropout import LayerDropout, apply_layer_dropout_modules # noqa from .transformer import ( # noqa TiedEmbeddingTransformerDecoder, TransformerCrossAttentionLayer, diff --git a/torchtune/modules/layer_dropout.py b/torchtune/modules/layer_dropout.py index a04e0d5b24..599d534f0f 100644 --- a/torchtune/modules/layer_dropout.py +++ b/torchtune/modules/layer_dropout.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from enum import Enum -from typing import Callable, Optional +from typing import Any, Callable, Optional import math import torch @@ -44,6 +44,40 @@ def forward(self, function: Callable, input: torch.Tensor, *args, **kwargs): out[ind_selected] = out_selected return out +class ModuleLayerDropoutWrapper(torch.nn.Module): + def __init__(self, module: torch.nn.Module, dropout: LayerDropout): + super().__init__() + self.module = module + self.dropout = dropout + + def forward(self, input: torch.Tensor, *args, **kwargs): + return self.dropout(self.module, input, *args, **kwargs) + + def __getattr__(self, name: str) -> Any: + """Forward missing attributes to wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.module, name) # fallback to wrapped module + + def __setattr__(self, name: str, value: Any) -> Any: + """Forward missing attributes to wrapped module.""" + try: + return super().__setattr__(name, value) # defer to nn.Module's logic + except AttributeError: + return setattr(self.module, name, value) # fallback to wrapped module + + def __getitem__(self, key: int) -> Any: + """Forward indexing calls in case the module is a nn.Sequential.""" + return self.module.__getitem__(key) + + def state_dict(self, *args, **kwargs): + return self.module.state_dict(*args, **kwargs) + + def load_state_dict(self, state_dict, *args, **kwargs): + self.module.load_state_dict(state_dict, *args, **kwargs) + return + class ScaleType(str, Enum): UNIFORM = "uniform" EXP = "exp" @@ -67,11 +101,11 @@ def get_scale(scale_type: ScaleType, scale_period: int, val: int): ScaleType.SIGMOID: 1 / (1 + math.exp(-10 * (val / scale_period - 0.5))), }[scale_type] -def create_layer_dropout_modules(num_layers: int, prob_max: float= 0.0, prob_layer_scale: ScaleType = ScaleType.EXP, layers_str: Optional[str] = None, disable_on_eval: bool = True): - layer_dropouts = torch.nn.ModuleList() +# TODO: rename to prepare() just like quantizer()? +def apply_layer_dropout_modules(model, prob_max: float= 0.0, prob_layer_scale: ScaleType = ScaleType.EXP, layers_str: Optional[str] = None, disable_on_eval: bool = True): + num_layers = len(model.layers) has_dropout = slice_str_to_array(layers_str, num_layers) if layers_str else [True] * num_layers - - for layer_id in range(num_layers): + for layer_id in range(len(model.layers)): prob = prob_max * get_scale( scale_type = prob_layer_scale, scale_period = num_layers - 1, @@ -80,6 +114,4 @@ def create_layer_dropout_modules(num_layers: int, prob_max: float= 0.0, prob_lay assert prob >= 0.0 and prob <= prob_max, f"prob={prob} should be between 0 and {prob_max}" # We would like each layer to have a different seed, so that we don't have the same samples skipped across layers. Hence, we use the layer_id as a seed for each layer's dropout. layer_dropout = LayerDropout(prob, disable_on_eval=disable_on_eval, seed=layer_id) - layer_dropouts.append(layer_dropout) - - return layer_dropouts + model.layers[layer_id] = ModuleLayerDropoutWrapper(model.layers[layer_id], layer_dropout) From 1de0c2a489da8f30acbac7834fa3b6c7a8892201 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 24 Nov 2024 04:23:02 +0000 Subject: [PATCH 39/88] clean up early exit curriculum --- .../dev/early_exit_finetune_distributed.py | 13 ++++-------- .../torchtune/modules/test_early_exit_loss.py | 8 +++---- torchtune/modules/early_exit_loss.py | 21 ++++++++++++------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 2f1b466f16..f14d52bbc3 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -28,7 +28,7 @@ from torchtune.training.activations import apply_selective_activation_checkpointing from torchtune.training.lr_schedulers import get_lr -from torchtune.modules.early_exit_loss import early_exit_loss, build_early_exit_curriculum, EarlyExitCurriculum +from torchtune.modules.early_exit_loss import early_exit_loss, setup_early_exit_loss_curriculum, EarlyExitCurriculum from torchtune.modules.layer_dropout import apply_layer_dropout_modules from torchtune.modules.common_utils import slice_str_to_array @@ -359,7 +359,6 @@ def setup(self, cfg: DictConfig) -> None: self._do_output_hidden_states, self._early_exit_loss_curriculum = self._setup_early_exit_loss(cfg.get("early_exit_loss", None)) # Layer Dropout Setup - # TODO: move to a setup function? cfg_layer_dropout = cfg.get("layer_dropout", None) if cfg_layer_dropout: apply_layer_dropout_modules(self._model, prob_max=cfg_layer_dropout.get("prob", 0.0), prob_layer_scale=cfg_layer_dropout.get("layers_scale", "uniform"), layers_str=cfg_layer_dropout.get("layers", ":"), disable_on_eval=cfg_layer_dropout.get("disable_on_eval", True)) @@ -653,14 +652,12 @@ def _setup_early_exit_loss( if cfg_early_exit_loss: do_output_hidden_states = slice_str_to_array(cfg_early_exit_loss.get("layers", ":"), len(self._model.layers)) - # TODO: add cli option - # TODO: move this statement to inside curriculum - if True: + train_last_layer = cfg_early_exit_loss.get("include_last_layer", True) + if train_last_layer: do_output_hidden_states[len(self._model.layers) - 1] = True - # TODO: rename build_early_exit_curriculum to setup_early_exit_loss_curriculum if cfg_early_exit_loss.curriculum: - early_exit_loss_curriculum = build_early_exit_curriculum(cfg_early_exit_loss.curriculum, do_output_hidden_states, self.total_epochs*self._steps_per_epoch) + early_exit_loss_curriculum = setup_early_exit_loss_curriculum(cfg_early_exit_loss.curriculum, do_output_hidden_states, self.total_epochs*self._steps_per_epoch, train_last_layer) else: early_exit_loss_curriculum = None @@ -924,8 +921,6 @@ def train(self) -> None: if self._early_exit_loss_curriculum: self._early_exit_loss_curriculum.step() self._do_output_hidden_states = self._early_exit_loss_curriculum.get() - if True: # TODO: add cli option - self._do_output_hidden_states[len(self._model.layers) - 1] = True self._model.output_hidden_states = [i for i in range(len(self._do_output_hidden_states)) if self._do_output_hidden_states[i]] # Stop tracking CUDA memory now that active steps are complete diff --git a/tests/torchtune/modules/test_early_exit_loss.py b/tests/torchtune/modules/test_early_exit_loss.py index ef828a5e02..968b77b097 100644 --- a/tests/torchtune/modules/test_early_exit_loss.py +++ b/tests/torchtune/modules/test_early_exit_loss.py @@ -9,7 +9,7 @@ layer_ids_to_loss_scales, LossScaleType, EarlyExitCurriculumType, - build_early_exit_curriculum, + setup_early_exit_loss_curriculum, RotationalEarlyExitCurriculum, GradualEarlyExitCurriculum, ) @@ -64,11 +64,11 @@ def test_layer_ids_to_loss_scales(): scales = layer_ids_to_loss_scales(layer_ids, n_layers, LossScaleType.SUM_L, 1.0) assert torch.isclose(scales.sum(), torch.tensor(1.0)) -def test_build_early_exit_curriculum(): - curriculum = build_early_exit_curriculum(EarlyExitCurriculumType.ROTATIONAL, [True, False, True], 100) +def test_setup_early_exit_loss_curriculum(): + curriculum = setup_early_exit_loss_curriculum(EarlyExitCurriculumType.ROTATIONAL, [True, False, True], 100) assert isinstance(curriculum, RotationalEarlyExitCurriculum) - curriculum = build_early_exit_curriculum(EarlyExitCurriculumType.GRADUAL, [True, False, True], 100) + curriculum = setup_early_exit_loss_curriculum(EarlyExitCurriculumType.GRADUAL, [True, False, True], 100) assert isinstance(curriculum, GradualEarlyExitCurriculum) def test_rotational_early_exit_curriculum(): diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index 23ec776458..28eb1f154f 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -84,7 +84,7 @@ class EarlyExitCurriculumType(str, Enum): ROTATIONAL = "rot" GRADUAL = "gradual" -def build_early_exit_curriculum(early_exit_curriculum: EarlyExitCurriculumType, *args, **kwargs): +def setup_early_exit_loss_curriculum(early_exit_curriculum: EarlyExitCurriculumType, *args, **kwargs): match early_exit_curriculum: case EarlyExitCurriculumType.NONE: return None @@ -101,9 +101,10 @@ def build_early_exit_curriculum(early_exit_curriculum: EarlyExitCurriculumType, # TODO: create a base curriculum class that can be used for other aspects, e.g., dropout, datasets, etc. class EarlyExitCurriculum(): - def __init__(self, do_output_hidden_states, max_steps, verbose=False): + def __init__(self, do_output_hidden_states, max_steps, train_last_layer=True, verbose=False): self._init_do_output_hidden_states = do_output_hidden_states self.do_output_hidden_states = do_output_hidden_states + self.train_last_layer = train_last_layer self.verbose = verbose self.max_steps = max_steps @@ -114,17 +115,19 @@ def get(self): return self.do_output_hidden_states class RotationalEarlyExitCurriculum(EarlyExitCurriculum): - def __init__(self, do_output_hidden_states, max_steps, verbose=False): - super().__init__(do_output_hidden_states, max_steps, verbose) + def __init__(self, do_output_hidden_states, max_steps, train_last_layer=True, verbose=False): + super().__init__(do_output_hidden_states, max_steps, train_last_layer, verbose) def step(self): self.do_output_hidden_states = np.roll(self.do_output_hidden_states, -1) + if self.train_last_layer: + self.do_output_hidden_states[-1] = True if self.verbose: - log.info(f"Updating self.output_hidden_states to {self.do_output_hidden_states}.") + log.info(f"Updated self.output_hidden_states to {self.do_output_hidden_states}.") class GradualEarlyExitCurriculum(EarlyExitCurriculum): - def __init__(self, do_output_hidden_states, max_steps, percent_scale=2, verbose=False): - super().__init__(do_output_hidden_states, max_steps, verbose) + def __init__(self, do_output_hidden_states, max_steps, train_last_layer=True, percent_scale=2, verbose=False): + super().__init__(do_output_hidden_states, max_steps, train_last_layer, verbose) self._step = 0 self._percent_scale = percent_scale @@ -136,8 +139,10 @@ def step(self): # TODO: either handle if layers_str != ":", or add an assert statement layers_str == ":" self.do_output_hidden_states[layer_index] = should_train + if self.train_last_layer: + self.do_output_hidden_states[-1] = True # TODO: move this to step() in parent class? # TODO: how to ensure we always call parent step() in derived class? self._step += 1 if self.verbose: - log.info(f"Updating self.do_output_hidden_states to {self.do_output_hidden_states}.") + log.info(f"Updated self.do_output_hidden_states to {self.do_output_hidden_states}.") From 2b0cdd19974b90da893e10b9b5b94618306de149 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 24 Nov 2024 06:41:06 +0000 Subject: [PATCH 40/88] enable grad curriculum for subset of layers + clear hidden_states at start of grad curriculum --- .../dev/early_exit_finetune_distributed.py | 11 +++++----- torchtune/modules/early_exit_loss.py | 20 +++++++++++++++---- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index f14d52bbc3..976f4b9797 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -653,16 +653,17 @@ def _setup_early_exit_loss( if cfg_early_exit_loss: do_output_hidden_states = slice_str_to_array(cfg_early_exit_loss.get("layers", ":"), len(self._model.layers)) train_last_layer = cfg_early_exit_loss.get("include_last_layer", True) + verbose = cfg_early_exit_loss.get("verbose", False) + if train_last_layer: do_output_hidden_states[len(self._model.layers) - 1] = True if cfg_early_exit_loss.curriculum: - early_exit_loss_curriculum = setup_early_exit_loss_curriculum(cfg_early_exit_loss.curriculum, do_output_hidden_states, self.total_epochs*self._steps_per_epoch, train_last_layer) + early_exit_loss_curriculum = setup_early_exit_loss_curriculum(early_exit_curriculum=cfg_early_exit_loss.curriculum, do_output_hidden_states=do_output_hidden_states, max_steps=self.total_epochs*self._steps_per_epoch, train_last_layer=train_last_layer, verbose=verbose) + do_output_hidden_states = early_exit_loss_curriculum.get() else: early_exit_loss_curriculum = None - # TODO: get initial do_output_hidden_states from curriculum - return do_output_hidden_states, early_exit_loss_curriculum def save_checkpoint( @@ -819,7 +820,7 @@ def train(self) -> None: with self.activations_handling_ctx: outputs = self._model(**batch) - if self._do_early_exit_loss: + if self._model.output_hidden_states: logits = outputs.pop(-1) hidden_states = {i:h for i,h in zip(self._model.output_hidden_states, outputs)} else: @@ -838,7 +839,7 @@ def train(self) -> None: # Compute loss # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients - if self._do_early_exit_loss: + if self._model.output_hidden_states: current_loss = early_exit_loss(self._model, hidden_states, labels, self._loss_fn, self._early_exit_loss_scale, self._early_exit_loss_scale_type) * current_num_tokens else: current_loss = self._loss_fn(logits, labels) * current_num_tokens diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index 28eb1f154f..de14207275 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -61,7 +61,6 @@ def layer_ids_to_loss_scales(layer_ids, n_layers, loss_scale_type: LossScaleType case LossScaleType.L: loss_scales = torch.Tensor(layer_ids+1) case LossScaleType.SUM_L: - # TODO: should we change to sum 0:i ? Perhaps create a new scale_type loss_scales = torch.cumsum(layer_ids+1, dim=0) case LossScaleType.SQRT_L: loss_scales = torch.sqrt(layer_ids+1) @@ -117,32 +116,45 @@ def get(self): class RotationalEarlyExitCurriculum(EarlyExitCurriculum): def __init__(self, do_output_hidden_states, max_steps, train_last_layer=True, verbose=False): super().__init__(do_output_hidden_states, max_steps, train_last_layer, verbose) + self._initial_do_output_hidden_states = np.copy(do_output_hidden_states) def step(self): + # Rotate layer enablement one step forward self.do_output_hidden_states = np.roll(self.do_output_hidden_states, -1) + + # Ensure last layer is trained if self.train_last_layer: self.do_output_hidden_states[-1] = True + if self.verbose: log.info(f"Updated self.output_hidden_states to {self.do_output_hidden_states}.") class GradualEarlyExitCurriculum(EarlyExitCurriculum): def __init__(self, do_output_hidden_states, max_steps, train_last_layer=True, percent_scale=2, verbose=False): super().__init__(do_output_hidden_states, max_steps, train_last_layer, verbose) + self._final_do_output_hidden_states = np.copy(do_output_hidden_states) self._step = 0 self._percent_scale = percent_scale + # Initialize all layers to False + for i in range(len(self.do_output_hidden_states)): + self.do_output_hidden_states[i] = False + def step(self): percent_trained = self._step / self.max_steps n_layers = len(self.do_output_hidden_states) + # Enable each layer based on proportion of completed training steps for layer_index in range(len(self.do_output_hidden_states)): should_train = (percent_trained * self._percent_scale) >= (n_layers - layer_index) / n_layers - # TODO: either handle if layers_str != ":", or add an assert statement layers_str == ":" self.do_output_hidden_states[layer_index] = should_train + # Only enable layers that are set by the user + self.do_output_hidden_states = np.logical_and(self.do_output_hidden_states, self._final_do_output_hidden_states) + + # Ensure last layer is trained if self.train_last_layer: self.do_output_hidden_states[-1] = True - # TODO: move this to step() in parent class? - # TODO: how to ensure we always call parent step() in derived class? + self._step += 1 if self.verbose: log.info(f"Updated self.do_output_hidden_states to {self.do_output_hidden_states}.") From 7973459656b4c168ed4a04141365f64cf07bf47d Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 24 Nov 2024 13:28:37 +0000 Subject: [PATCH 41/88] add docstring for slice_str_to_array --- torchtune/modules/common_utils.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/torchtune/modules/common_utils.py b/torchtune/modules/common_utils.py index 0f3f6dcf90..3d1ece9d5d 100644 --- a/torchtune/modules/common_utils.py +++ b/torchtune/modules/common_utils.py @@ -58,7 +58,30 @@ def reparametrize_as_dtype_state_dict_post_hook( if offload_to_cpu: state_dict[k] = state_dict[k].cpu() + def slice_str_to_array(slice_str, length): + """ + Convert a string representing a Python slice into a boolean array. + The resulting array will have the same length as the specified `length` parameter. + Each element in the array corresponds to an index in the original sequence, + with `True` indicating that the index is included in the slice and `False` otherwise. + Args: + slice_str (str): A string representing a Python slice, e.g. "1:3", ":5", "2::3". + length (int): The length of the original sequence. + Returns: + list[bool]: A boolean array representing the slice. + Examples: + >>> slice_str_to_array("1:3", 5) + [False, True, True, False, False] + >>> slice_str_to_array(":", 5) + [True, True, True, True, True] + >>> slice_str_to_array("::2", 5) + [True, False, True, False, True] + >>> slice_str_to_array("1::2", 5) + [False, True, False, True, False] + >>> slice_str_to_array("2:5:2", 6) + [False, False, True, False, True, False] + """ # Parse the slice string parts = slice_str.split(':') start, end, step = None, None, None From baed8a9b79cad193c54d4efbca49864a54a60f77 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 24 Nov 2024 13:32:52 +0000 Subject: [PATCH 42/88] support commas and add assertion statements --- torchtune/modules/common_utils.py | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/torchtune/modules/common_utils.py b/torchtune/modules/common_utils.py index 3d1ece9d5d..95399ce729 100644 --- a/torchtune/modules/common_utils.py +++ b/torchtune/modules/common_utils.py @@ -59,17 +59,21 @@ def reparametrize_as_dtype_state_dict_post_hook( state_dict[k] = state_dict[k].cpu() -def slice_str_to_array(slice_str, length): +def slice_str_to_array(slice_str: str, length: int) -> list[bool]: """ - Convert a string representing a Python slice into a boolean array. + Convert a string representing a Python slice or index into a boolean array. + The resulting array will have the same length as the specified `length` parameter. Each element in the array corresponds to an index in the original sequence, with `True` indicating that the index is included in the slice and `False` otherwise. + Args: - slice_str (str): A string representing a Python slice, e.g. "1:3", ":5", "2::3". + slice_str (str): A string representing a Python slice or index, e.g. "1:3", ":5", "2::3", "0,4,5". length (int): The length of the original sequence. + Returns: list[bool]: A boolean array representing the slice. + Examples: >>> slice_str_to_array("1:3", 5) [False, True, True, False, False] @@ -81,9 +85,22 @@ def slice_str_to_array(slice_str, length): [False, True, False, True, False] >>> slice_str_to_array("2:5:2", 6) [False, False, True, False, True, False] + >>> slice_str_to_array("0,4,5", 7) + [True, False, False, False, True, True, False] """ - # Parse the slice string + + assert ',' not in slice_str or ':' not in slice_str, "Cannot mix commas and colons" + + if ',' in slice_str: + indices = [int(i) for i in slice_str.split(',')] + assert all(0 <= i < length for i in indices), "Index out of range" + result = [False] * length + for i in indices: + result[i] = True + return result + parts = slice_str.split(':') + assert len(parts) <= 3, "Invalid slice format" start, end, step = None, None, None if len(parts) == 1 and parts[0] != '': @@ -96,7 +113,10 @@ def slice_str_to_array(slice_str, length): end = int(parts[1]) if parts[1] != '' else None step = int(parts[2]) if parts[2] != '' else None - # Create a boolean array based on the slice + assert start is None or 0 <= start < length, "Start index out of range" + assert end is None or 0 <= end < length, "End index out of range" + assert step is None or step != 0, "Step cannot be zero" + result = [False] * length slice_indices = range(start if start is not None else 0, end if end is not None else length, @@ -108,6 +128,7 @@ def slice_str_to_array(slice_str, length): return result + def _low_ram_reparametrize_as_dtype_state_dict_post_hook( model: nn.Module, state_dict: Dict[str, Any], From 27f6b56d76b858b23f1e93d90692f8ba5c8ee611 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 24 Nov 2024 13:50:23 +0000 Subject: [PATCH 43/88] add test cases for slice_to_str_array --- tests/torchtune/modules/test_common_utils.py | 26 +++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/torchtune/modules/test_common_utils.py b/tests/torchtune/modules/test_common_utils.py index 41dc472f00..13ae33d65c 100644 --- a/tests/torchtune/modules/test_common_utils.py +++ b/tests/torchtune/modules/test_common_utils.py @@ -15,7 +15,7 @@ ) from torchtune.modules import delete_kv_caches, disable_kv_cache, local_kv_cache from torchtune.modules.model_fusion import DeepFusionModel - +from torchtune.modules.common_utils import slice_str_to_array @pytest.fixture def llama_vision_model(): @@ -191,3 +191,27 @@ def test_disable_kv_cache_raises_error_caches_not_setup(self, model, request): with pytest.raises(ValueError, match="Model caches must be setup"): with disable_kv_cache(model): pass + +class TestSliceStrToArray: + def test_single_index(self): + assert slice_str_to_array("0", 5) == [True, False, False, False, False] + + def test_slice_with_start_and_end(self): + assert slice_str_to_array("1:3", 5) == [False, True, True, False, False] + + def test_slice_with_start_and_step(self): + assert slice_str_to_array("1::2", 5) == [False, True, False, True, False] + + def test_slice_with_start_end_and_step(self): + assert slice_str_to_array("1:4:2", 5) == [False, True, False, True, False] + + def test_multiple_indices(self): + assert slice_str_to_array("0,2,4", 6) == [True, False, True, False, True, False] + + def test_out_of_range_index(self): + with pytest.raises(AssertionError): + slice_str_to_array("10", 5) + + def test_invalid_slice_format(self): + with pytest.raises(AssertionError): + slice_str_to_array("1:2:3:4", 5) From 63e7c5b21f19642b19fdd9ba437470e63d3f2fb5 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 24 Nov 2024 13:50:34 +0000 Subject: [PATCH 44/88] add copyright header --- tests/torchtune/modules/test_early_exit_loss.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/torchtune/modules/test_early_exit_loss.py b/tests/torchtune/modules/test_early_exit_loss.py index 968b77b097..91bfa1c508 100644 --- a/tests/torchtune/modules/test_early_exit_loss.py +++ b/tests/torchtune/modules/test_early_exit_loss.py @@ -1,3 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + import numpy as np import pytest import torch From 638056b63b5031103373ca8a01ffed412689d981 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 24 Nov 2024 13:50:45 +0000 Subject: [PATCH 45/88] support single index --- torchtune/modules/common_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchtune/modules/common_utils.py b/torchtune/modules/common_utils.py index 95399ce729..33aca8e98d 100644 --- a/torchtune/modules/common_utils.py +++ b/torchtune/modules/common_utils.py @@ -105,6 +105,8 @@ def slice_str_to_array(slice_str: str, length: int) -> list[bool]: if len(parts) == 1 and parts[0] != '': start = int(parts[0]) + end = start + 1 + step = 1 elif len(parts) == 2: start = int(parts[0]) if parts[0] != '' else None end = int(parts[1]) if parts[1] != '' else None From a20b07c7b96cc1841cb2e0ac4bb38eca078e1a11 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 24 Nov 2024 13:55:10 +0000 Subject: [PATCH 46/88] add new line at end of file --- torchtune/modules/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index b19f14b5ef..2d717b8dcd 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -953,4 +953,4 @@ def unembed(self, h): # shape: [b, seq_len, out_dim] output = F.linear(h, self.tok_embeddings.weight).float() - return output \ No newline at end of file + return output From 98897a825979d324b1f56b015a475ef7fc5d2533 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 24 Nov 2024 14:26:01 +0000 Subject: [PATCH 47/88] add layer dropout test cases --- tests/torchtune/modules/test_layer_dropout.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 tests/torchtune/modules/test_layer_dropout.py diff --git a/tests/torchtune/modules/test_layer_dropout.py b/tests/torchtune/modules/test_layer_dropout.py new file mode 100644 index 0000000000..3def1a3393 --- /dev/null +++ b/tests/torchtune/modules/test_layer_dropout.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Tuple +import pytest +import torch +from tests.test_utils import assert_expected +from torchtune.modules.layer_dropout import LayerDropout + + +class TestLayerDropout: + """Class for testing LayerDropout implementation.""" + + + @pytest.fixture(autouse=True) + def random(self): + torch.manual_seed(0) + + + @pytest.fixture + def input_shape(self) -> Tuple[int, int]: + bsz = 32 + seqlen = 1024 + dim = 4096 + return bsz, seqlen, dim + + + @pytest.fixture + def input(self, input_shape: Tuple[int]) -> torch.Tensor: + return torch.randn(input_shape) + + + @pytest.fixture + def layer_dropout(self, prob: float = 0.5, disable_on_eval: bool = True) -> LayerDropout: + return LayerDropout(prob=prob, disable_on_eval=disable_on_eval) + + + def test_forward_train_prob_1(self, layer_dropout: LayerDropout, input: torch.Tensor) -> None: + # With dropout probability = 1.0, we expect output to be the same as input + layer_dropout.prob = 1.0 + output = layer_dropout.forward(lambda x: x**2, input) + assert torch.allclose(output, input, atol=1e-7, rtol=1e-3) + + + def test_forward_train_prob_0(self, layer_dropout: LayerDropout, input: torch.Tensor) -> None: + # With dropout probability = 1.0, we expect the operation to be applied on all elements in the input + layer_dropout.prob = 0.0 + output = layer_dropout.forward(lambda x: x**2, input) + assert torch.allclose(output, input**2, atol=1e-7, rtol=1e-3) + + + def test_forward_eval(self, layer_dropout: LayerDropout, input: torch.Tensor) -> None: + layer_dropout.prob = 1.0 + layer_dropout.eval() + + layer_dropout.disable_on_eval = True + output = layer_dropout.forward(lambda x: x**2, input) + assert torch.allclose(output, input**2, atol=1e-7, rtol=1e-3) + + layer_dropout.disable_on_eval = False + with torch.no_grad(): + output = layer_dropout.forward(lambda x: x**2, input) + assert torch.allclose(output, input, atol=1e-7, rtol=1e-3) \ No newline at end of file From 2cc94cccce5aae80272889e0ed31c8680c35e438 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 24 Nov 2024 20:06:50 +0000 Subject: [PATCH 48/88] rename apply_layer_dropout to prepare_layer_dropout --- recipes/dev/early_exit_finetune_distributed.py | 4 ++-- torchtune/modules/__init__.py | 4 ++-- torchtune/modules/layer_dropout.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 976f4b9797..ab510a55c7 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -29,7 +29,7 @@ from torchtune.training.lr_schedulers import get_lr from torchtune.modules.early_exit_loss import early_exit_loss, setup_early_exit_loss_curriculum, EarlyExitCurriculum -from torchtune.modules.layer_dropout import apply_layer_dropout_modules +from torchtune.modules.layer_dropout import prepare_layer_dropout from torchtune.modules.common_utils import slice_str_to_array from tqdm import tqdm @@ -361,7 +361,7 @@ def setup(self, cfg: DictConfig) -> None: # Layer Dropout Setup cfg_layer_dropout = cfg.get("layer_dropout", None) if cfg_layer_dropout: - apply_layer_dropout_modules(self._model, prob_max=cfg_layer_dropout.get("prob", 0.0), prob_layer_scale=cfg_layer_dropout.get("layers_scale", "uniform"), layers_str=cfg_layer_dropout.get("layers", ":"), disable_on_eval=cfg_layer_dropout.get("disable_on_eval", True)) + prepare_layer_dropout(self._model, prob_max=cfg_layer_dropout.get("prob", 0.0), prob_layer_scale=cfg_layer_dropout.get("layers_scale", "uniform"), layers_str=cfg_layer_dropout.get("layers", ":"), disable_on_eval=cfg_layer_dropout.get("disable_on_eval", True)) def _setup_profiler( self, cfg_profiler: Optional[DictConfig] = None diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index d3a3cc75d4..c5917ab053 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -23,7 +23,7 @@ from .rms_norm import RMSNorm # noqa from .tanh_gate import TanhGate # noqa from .tied_linear import TiedLinear # noqa -from .layer_dropout import LayerDropout, apply_layer_dropout_modules # noqa +from .layer_dropout import LayerDropout, prepare_layer_dropout # noqa from .transformer import ( # noqa TransformerCrossAttentionLayer, TransformerDecoder, @@ -54,5 +54,5 @@ "delete_kv_caches", "disable_kv_cache", "LayerDropout", - "create_layer_dropout_modules", + "prepare_layer_dropout", ] diff --git a/torchtune/modules/layer_dropout.py b/torchtune/modules/layer_dropout.py index 599d534f0f..e12e381d23 100644 --- a/torchtune/modules/layer_dropout.py +++ b/torchtune/modules/layer_dropout.py @@ -101,8 +101,8 @@ def get_scale(scale_type: ScaleType, scale_period: int, val: int): ScaleType.SIGMOID: 1 / (1 + math.exp(-10 * (val / scale_period - 0.5))), }[scale_type] -# TODO: rename to prepare() just like quantizer()? -def apply_layer_dropout_modules(model, prob_max: float= 0.0, prob_layer_scale: ScaleType = ScaleType.EXP, layers_str: Optional[str] = None, disable_on_eval: bool = True): + +def prepare_layer_dropout(model, prob_max: float= 0.0, prob_layer_scale: ScaleType = ScaleType.EXP, layers_str: Optional[str] = None, disable_on_eval: bool = True): num_layers = len(model.layers) has_dropout = slice_str_to_array(layers_str, num_layers) if layers_str else [True] * num_layers for layer_id in range(len(model.layers)): From f4f8e020f0b8305690f90e08b41c19533f9ca389 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 24 Nov 2024 20:07:28 +0000 Subject: [PATCH 49/88] add test cases for get_scale --- tests/torchtune/modules/test_layer_dropout.py | 59 +++++++++++++++++-- 1 file changed, 54 insertions(+), 5 deletions(-) diff --git a/tests/torchtune/modules/test_layer_dropout.py b/tests/torchtune/modules/test_layer_dropout.py index 3def1a3393..0fead9f160 100644 --- a/tests/torchtune/modules/test_layer_dropout.py +++ b/tests/torchtune/modules/test_layer_dropout.py @@ -6,10 +6,11 @@ from typing import Tuple +import math import pytest import torch from tests.test_utils import assert_expected -from torchtune.modules.layer_dropout import LayerDropout +from torchtune.modules.layer_dropout import LayerDropout, get_scale, ScaleType class TestLayerDropout: @@ -43,14 +44,14 @@ def test_forward_train_prob_1(self, layer_dropout: LayerDropout, input: torch.Te # With dropout probability = 1.0, we expect output to be the same as input layer_dropout.prob = 1.0 output = layer_dropout.forward(lambda x: x**2, input) - assert torch.allclose(output, input, atol=1e-7, rtol=1e-3) + assert torch.allclose(output, input) def test_forward_train_prob_0(self, layer_dropout: LayerDropout, input: torch.Tensor) -> None: # With dropout probability = 1.0, we expect the operation to be applied on all elements in the input layer_dropout.prob = 0.0 output = layer_dropout.forward(lambda x: x**2, input) - assert torch.allclose(output, input**2, atol=1e-7, rtol=1e-3) + assert torch.allclose(output, input**2) def test_forward_eval(self, layer_dropout: LayerDropout, input: torch.Tensor) -> None: @@ -59,9 +60,57 @@ def test_forward_eval(self, layer_dropout: LayerDropout, input: torch.Tensor) -> layer_dropout.disable_on_eval = True output = layer_dropout.forward(lambda x: x**2, input) - assert torch.allclose(output, input**2, atol=1e-7, rtol=1e-3) + assert torch.allclose(output, input**2) layer_dropout.disable_on_eval = False with torch.no_grad(): output = layer_dropout.forward(lambda x: x**2, input) - assert torch.allclose(output, input, atol=1e-7, rtol=1e-3) \ No newline at end of file + assert torch.allclose(output, input) + + + def test_get_scale_uniform(self) -> None: + scale_type = ScaleType.UNIFORM + scale_period = 10 + + assert_expected(get_scale(scale_type, scale_period, 0), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period/2), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period*2), 1.0) + + + def test_get_scale_linear(self) -> None: + scale_type = ScaleType.LINEAR + scale_period = 10 + + assert_expected(get_scale(scale_type, scale_period, 0), 0.0) + assert_expected(get_scale(scale_type, scale_period, scale_period/2), 1/2) + assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period*2), 1.0) + + + def test_get_scale_exp(self) -> None: + scale_type = ScaleType.EXP + scale_period = 10 + + assert_expected(get_scale(scale_type, scale_period, 0), 0.0) + assert_expected(get_scale(scale_type, scale_period, scale_period/2), math.pow(2, 1/2) - 1) + assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period*2), 1.0) + + def test_get_scale_log(self) -> None: + scale_type = ScaleType.LOG + scale_period = 10 + + assert_expected(get_scale(scale_type, scale_period, 0), 0.0) + assert_expected(get_scale(scale_type, scale_period, scale_period/2), math.log(5 + 1, scale_period + 1)) + assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period*2), 1.0) + + + def test_get_scale_sin(self) -> None: + scale_type = ScaleType.SIN + scale_period = 10 + val = 5 + expected_scale = math.sin(0.5 * math.pi * 5 / 10) + actual_scale = get_scale(scale_type, scale_period, val) + assert_expected(actual_scale, expected_scale, atol=1e-7, rtol=1e-3) From fed955e3e0606d2e1d25948a7fbbabe652f602fc Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 24 Nov 2024 20:07:51 +0000 Subject: [PATCH 50/88] cleanup get_scale + re-write mathematically equivalent + ensure max scale is 1 --- torchtune/modules/layer_dropout.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torchtune/modules/layer_dropout.py b/torchtune/modules/layer_dropout.py index e12e381d23..7c15c28bbc 100644 --- a/torchtune/modules/layer_dropout.py +++ b/torchtune/modules/layer_dropout.py @@ -89,18 +89,20 @@ class ScaleType(str, Enum): def get_scale(scale_type: ScaleType, scale_period: int, val: int): if scale_period == 0: - return 1 + return 1.0 # all the equations below aim to make scale = 0 when val=0, and scale = 1 when val=scale_period - return { - ScaleType.UNIFORM: 1, - ScaleType.EXP: math.exp(val * math.log(2) / scale_period) - 1, + scale = { + ScaleType.UNIFORM: 1.0, + ScaleType.EXP: math.pow(2, val / scale_period) - 1, ScaleType.LINEAR: val / scale_period, - ScaleType.LOG: math.log(val + 1) / math.log(scale_period + 1), + ScaleType.LOG: math.log(val + 1, scale_period + 1), ScaleType.SIN: math.sin(0.5 * math.pi * val / scale_period), ScaleType.SIGMOID: 1 / (1 + math.exp(-10 * (val / scale_period - 0.5))), }[scale_type] + # after scale_period, scale should be 1 + return min(scale, 1.0) def prepare_layer_dropout(model, prob_max: float= 0.0, prob_layer_scale: ScaleType = ScaleType.EXP, layers_str: Optional[str] = None, disable_on_eval: bool = True): num_layers = len(model.layers) From ca7d8dafe8183d0b861f9450275416cd13830bdf Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 24 Nov 2024 20:28:14 +0000 Subject: [PATCH 51/88] test layer_dropout --- tests/torchtune/modules/test_layer_dropout.py | 66 ++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/tests/torchtune/modules/test_layer_dropout.py b/tests/torchtune/modules/test_layer_dropout.py index 0fead9f160..79ee98bed0 100644 --- a/tests/torchtune/modules/test_layer_dropout.py +++ b/tests/torchtune/modules/test_layer_dropout.py @@ -10,7 +10,7 @@ import pytest import torch from tests.test_utils import assert_expected -from torchtune.modules.layer_dropout import LayerDropout, get_scale, ScaleType +from torchtune.modules.layer_dropout import LayerDropout, get_scale, ScaleType, prepare_layer_dropout class TestLayerDropout: @@ -114,3 +114,67 @@ def test_get_scale_sin(self) -> None: expected_scale = math.sin(0.5 * math.pi * 5 / 10) actual_scale = get_scale(scale_type, scale_period, val) assert_expected(actual_scale, expected_scale, atol=1e-7, rtol=1e-3) + + @pytest.fixture(autouse=True) + def random(self): + torch.manual_seed(0) + + + def test_prepare_layer_dropout_uniform(self) -> None: + class MockModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 10) for _ in range(5)]) + model = MockModel() + prob_max = 0.5 + prob_layer_scale = ScaleType.UNIFORM + layers_str = "0:4" + prepare_layer_dropout(model, prob_max, prob_layer_scale, layers_str) + for i, layer in enumerate(model.layers): + assert hasattr(layer, "dropout") + if i in range(0, 4): + assert layer.dropout.prob == prob_max + else: + assert layer.dropout.prob == 0 + + + def test_prepare_layer_dropout_exp(self) -> None: + class MockModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 10) for _ in range(5)]) + model = MockModel() + prob_max = 0.5 + prob_layer_scale = ScaleType.EXP + layers_str = ":" + prepare_layer_dropout(model, prob_max, prob_layer_scale, layers_str) + for i, layer in enumerate(model.layers): + assert hasattr(layer, "dropout") + if i == 0: + assert layer.dropout.prob == 0 + elif i == len(model.layers) - 1: + assert layer.dropout.prob == prob_max + else: + assert layer.dropout.prob > 0 and layer.dropout.prob < prob_max + + + def test_prepare_layer_dropout_linear(self) -> None: + class MockModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 10) for _ in range(5)]) + model = MockModel() + prob_max = 0.5 + prob_layer_scale = ScaleType.LINEAR + layers_str = ":" + prepare_layer_dropout(model, prob_max, prob_layer_scale, layers_str) + for i, layer in enumerate(model.layers): + assert hasattr(layer, "dropout") + if i == 0: + assert layer.dropout.prob == 0 + elif i == len(model.layers) - 1: + assert layer.dropout.prob == prob_max + elif i == len(model.layers)/2: + assert layer.dropout.prob == prob_max/2 + else: + assert layer.dropout.prob >= 0.0 and layer.dropout.prob <= prob_max From 0146764d0091c206884f0d22f0bc4d728912b648 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 24 Nov 2024 21:11:40 +0000 Subject: [PATCH 52/88] start adding early exit loss and layer dropout to docstring --- recipes/dev/early_exit_finetune_distributed.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index ab510a55c7..7056dbb1d5 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -38,10 +38,15 @@ class EarlyExitFinetuneRecipeDistributed(FTRecipeInterface): """ - Full finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports - distributed training and can be run on a single node (1 to 8 GPUs). + Early exit and layer dropout full finetuning to make the model more robust to early exit and skipping + intermediate layers for dense transformer-based LLMs such as Llama2. This recipe supports distributed + training and can be run on a single node (1 to 8 GPUs). Features: + - Early Exit Loss. + + - Layer Dropout. + - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config From f599eca75ed2817296f8c3ff4eb4e7114669c7e2 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 24 Nov 2024 21:13:29 +0000 Subject: [PATCH 53/88] fix and update code and test cases to handle updating last layer separately --- .../torchtune/modules/test_early_exit_loss.py | 40 +++++++++++++------ torchtune/modules/early_exit_loss.py | 17 +++----- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/tests/torchtune/modules/test_early_exit_loss.py b/tests/torchtune/modules/test_early_exit_loss.py index 91bfa1c508..9cf82fc2f2 100644 --- a/tests/torchtune/modules/test_early_exit_loss.py +++ b/tests/torchtune/modules/test_early_exit_loss.py @@ -79,32 +79,46 @@ def test_setup_early_exit_loss_curriculum(): curriculum = setup_early_exit_loss_curriculum(EarlyExitCurriculumType.GRADUAL, [True, False, True], 100) assert isinstance(curriculum, GradualEarlyExitCurriculum) -def test_rotational_early_exit_curriculum(): - curriculum = RotationalEarlyExitCurriculum([True, False, True], 100) + +@pytest.mark.parametrize("train_last_layer", [ + True, + False, +]) +def test_rotational_early_exit_curriculum(train_last_layer): + curriculum = RotationalEarlyExitCurriculum([True, False, False], max_steps=100, train_last_layer=train_last_layer) + expected = np.array([True, False, train_last_layer]) + assert np.array_equal(curriculum.get(), expected) curriculum.step() - expected = np.array([False, True, True]) + expected = np.array([False, True, train_last_layer]) assert np.array_equal(curriculum.get(), expected) curriculum.step() - expected = np.array([True, True, False]) + expected = np.array([False, False, True]) assert np.array_equal(curriculum.get(), expected) curriculum.step() - expected = np.array([True, False, True]) + expected = np.array([True, False, train_last_layer]) assert np.array_equal(curriculum.get(), expected) -def test_gradual_early_exit_curriculum(): - curriculum = GradualEarlyExitCurriculum([False, False, False, False], max_steps=4, percent_scale=1) + +@pytest.mark.parametrize("train_last_layer", [ + True, + False, +]) +def test_gradual_early_exit_curriculum(train_last_layer): + curriculum = GradualEarlyExitCurriculum([True, True, True, True], max_steps=4, train_last_layer=train_last_layer, percent_scale=1) + expected = np.array([False, False, False, train_last_layer]) + assert np.array_equal(curriculum.get(), expected) curriculum.step() - assert curriculum.get() == [False, False, False, False] + assert np.array_equal(curriculum.get(), [False, False, False, train_last_layer]) curriculum.step() - assert curriculum.get() == [False, False, False, True] + assert np.array_equal(curriculum.get(), [False, False, False, True]) curriculum.step() - assert curriculum.get() == [False, False, True, True] + assert np.array_equal(curriculum.get(), [False, False, True, True]) curriculum.step() - assert curriculum.get() == [False, True, True, True] + assert np.array_equal(curriculum.get(), [False, True, True, True]) curriculum.step() - assert curriculum.get() == [True, True, True, True] + assert np.array_equal(curriculum.get(), [True, True, True, True]) curriculum.step() - assert curriculum.get() == [True, True, True, True] + assert np.array_equal(curriculum.get(), [True, True, True, True]) @pytest.fixture def hidden_states_dict(): diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index de14207275..9cfc6f5bac 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -49,7 +49,6 @@ def early_exit_loss(model, hidden_states_dict, labels, loss_fn, e_scale: float=1 s_unpadded = (labels != loss_fn.ignore_index).sum() losses_early = losses_early.float().sum(-1) / s_unpadded # Shape: [e] - # losses_scales = 0.1 * torch.Tensor(hidden_layer_ids).to(losses_early) / len(model.layers) losses_scales = layer_ids_to_loss_scales(torch.Tensor(hidden_layer_ids).to(losses_early), len(model.layers), loss_scale_type, e_scale) return torch.sum(losses_scales * losses_early) @@ -111,7 +110,11 @@ def step(self): pass def get(self): - return self.do_output_hidden_states + do_output_hidden_states = np.copy(self.do_output_hidden_states) + # Ensure last layer is trained + if self.train_last_layer: + do_output_hidden_states[-1] = True + return do_output_hidden_states class RotationalEarlyExitCurriculum(EarlyExitCurriculum): def __init__(self, do_output_hidden_states, max_steps, train_last_layer=True, verbose=False): @@ -120,11 +123,7 @@ def __init__(self, do_output_hidden_states, max_steps, train_last_layer=True, ve def step(self): # Rotate layer enablement one step forward - self.do_output_hidden_states = np.roll(self.do_output_hidden_states, -1) - - # Ensure last layer is trained - if self.train_last_layer: - self.do_output_hidden_states[-1] = True + self.do_output_hidden_states = np.roll(self.do_output_hidden_states, 1) if self.verbose: log.info(f"Updated self.output_hidden_states to {self.do_output_hidden_states}.") @@ -151,10 +150,6 @@ def step(self): # Only enable layers that are set by the user self.do_output_hidden_states = np.logical_and(self.do_output_hidden_states, self._final_do_output_hidden_states) - # Ensure last layer is trained - if self.train_last_layer: - self.do_output_hidden_states[-1] = True - self._step += 1 if self.verbose: log.info(f"Updated self.do_output_hidden_states to {self.do_output_hidden_states}.") From 2437092c08a3f8fdf6e4209ddbdd0b6133ea915b Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 24 Nov 2024 21:18:01 +0000 Subject: [PATCH 54/88] change match to if-else for CI --- .../dev/early_exit_finetune_distributed.py | 77 +++++++++--- tests/torchtune/modules/test_common_utils.py | 4 +- .../torchtune/modules/test_early_exit_loss.py | 87 ++++++++----- tests/torchtune/modules/test_layer_dropout.py | 85 +++++++------ torchtune/_recipe_registry.py | 5 +- torchtune/modules/__init__.py | 2 +- torchtune/modules/common_utils.py | 28 +++-- torchtune/modules/early_exit_loss.py | 117 ++++++++++++------ torchtune/modules/layer_dropout.py | 62 +++++++--- torchtune/modules/transformer.py | 1 + 10 files changed, 314 insertions(+), 154 deletions(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 7056dbb1d5..4a38bda2e5 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -23,22 +23,27 @@ from torchtune.config._utils import _get_component_from_path from torchtune.data import padded_collate_packed from torchtune.datasets import ConcatDataset +from torchtune.modules.common_utils import slice_str_to_array + +from torchtune.modules.early_exit_loss import ( + early_exit_loss, + EarlyExitCurriculum, + setup_early_exit_loss_curriculum, +) +from torchtune.modules.layer_dropout import prepare_layer_dropout from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import DummyProfiler, PROFILER_KEY from torchtune.training.activations import apply_selective_activation_checkpointing from torchtune.training.lr_schedulers import get_lr -from torchtune.modules.early_exit_loss import early_exit_loss, setup_early_exit_loss_curriculum, EarlyExitCurriculum -from torchtune.modules.layer_dropout import prepare_layer_dropout -from torchtune.modules.common_utils import slice_str_to_array - from tqdm import tqdm log = utils.get_logger("DEBUG") + class EarlyExitFinetuneRecipeDistributed(FTRecipeInterface): """ - Early exit and layer dropout full finetuning to make the model more robust to early exit and skipping + Early exit and layer dropout full finetuning to make the model more robust to early exit and skipping intermediate layers for dense transformer-based LLMs such as Llama2. This recipe supports distributed training and can be run on a single node (1 to 8 GPUs). @@ -213,7 +218,9 @@ def __init__(self, cfg: DictConfig) -> None: if cfg_early_exit_loss: self._do_early_exit_loss = True self._early_exit_loss_scale = cfg_early_exit_loss.get("scale", 1.0) - self._early_exit_loss_scale_type = cfg_early_exit_loss.get("scale_type", "one") + self._early_exit_loss_scale_type = cfg_early_exit_loss.get( + "scale_type", "one" + ) else: self._do_early_exit_loss = False self._early_exit_loss_scale = None @@ -361,12 +368,21 @@ def setup(self, cfg: DictConfig) -> None: ) # Setup early exit loss - self._do_output_hidden_states, self._early_exit_loss_curriculum = self._setup_early_exit_loss(cfg.get("early_exit_loss", None)) + ( + self._do_output_hidden_states, + self._early_exit_loss_curriculum, + ) = self._setup_early_exit_loss(cfg.get("early_exit_loss", None)) # Layer Dropout Setup cfg_layer_dropout = cfg.get("layer_dropout", None) if cfg_layer_dropout: - prepare_layer_dropout(self._model, prob_max=cfg_layer_dropout.get("prob", 0.0), prob_layer_scale=cfg_layer_dropout.get("layers_scale", "uniform"), layers_str=cfg_layer_dropout.get("layers", ":"), disable_on_eval=cfg_layer_dropout.get("disable_on_eval", True)) + prepare_layer_dropout( + self._model, + prob_max=cfg_layer_dropout.get("prob", 0.0), + prob_layer_scale=cfg_layer_dropout.get("layers_scale", "uniform"), + layers_str=cfg_layer_dropout.get("layers", ":"), + disable_on_eval=cfg_layer_dropout.get("disable_on_eval", True), + ) def _setup_profiler( self, cfg_profiler: Optional[DictConfig] = None @@ -656,7 +672,9 @@ def _setup_early_exit_loss( early_exit_loss_curriculum = None if cfg_early_exit_loss: - do_output_hidden_states = slice_str_to_array(cfg_early_exit_loss.get("layers", ":"), len(self._model.layers)) + do_output_hidden_states = slice_str_to_array( + cfg_early_exit_loss.get("layers", ":"), len(self._model.layers) + ) train_last_layer = cfg_early_exit_loss.get("include_last_layer", True) verbose = cfg_early_exit_loss.get("verbose", False) @@ -664,7 +682,13 @@ def _setup_early_exit_loss( do_output_hidden_states[len(self._model.layers) - 1] = True if cfg_early_exit_loss.curriculum: - early_exit_loss_curriculum = setup_early_exit_loss_curriculum(early_exit_curriculum=cfg_early_exit_loss.curriculum, do_output_hidden_states=do_output_hidden_states, max_steps=self.total_epochs*self._steps_per_epoch, train_last_layer=train_last_layer, verbose=verbose) + early_exit_loss_curriculum = setup_early_exit_loss_curriculum( + early_exit_curriculum=cfg_early_exit_loss.curriculum, + do_output_hidden_states=do_output_hidden_states, + max_steps=self.total_epochs * self._steps_per_epoch, + train_last_layer=train_last_layer, + verbose=verbose, + ) do_output_hidden_states = early_exit_loss_curriculum.get() else: early_exit_loss_curriculum = None @@ -784,7 +808,11 @@ def train(self) -> None: # Initialize output hidden states if self._do_output_hidden_states: - self._model.output_hidden_states = [i for i in range(len(self._do_output_hidden_states)) if self._do_output_hidden_states[i]] + self._model.output_hidden_states = [ + i + for i in range(len(self._do_output_hidden_states)) + if self._do_output_hidden_states[i] + ] self._profiler.start() # self.epochs_run should be non-zero when we're resuming from a checkpoint @@ -827,7 +855,10 @@ def train(self) -> None: outputs = self._model(**batch) if self._model.output_hidden_states: logits = outputs.pop(-1) - hidden_states = {i:h for i,h in zip(self._model.output_hidden_states, outputs)} + hidden_states = { + i: h + for i, h in zip(self._model.output_hidden_states, outputs) + } else: logits = outputs @@ -845,7 +876,17 @@ def train(self) -> None: # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients if self._model.output_hidden_states: - current_loss = early_exit_loss(self._model, hidden_states, labels, self._loss_fn, self._early_exit_loss_scale, self._early_exit_loss_scale_type) * current_num_tokens + current_loss = ( + early_exit_loss( + self._model, + hidden_states, + labels, + self._loss_fn, + self._early_exit_loss_scale, + self._early_exit_loss_scale_type, + ) + * current_num_tokens + ) else: current_loss = self._loss_fn(logits, labels) * current_num_tokens @@ -926,8 +967,14 @@ def train(self) -> None: # Update Early Exit Layers/Scales if self._early_exit_loss_curriculum: self._early_exit_loss_curriculum.step() - self._do_output_hidden_states = self._early_exit_loss_curriculum.get() - self._model.output_hidden_states = [i for i in range(len(self._do_output_hidden_states)) if self._do_output_hidden_states[i]] + self._do_output_hidden_states = ( + self._early_exit_loss_curriculum.get() + ) + self._model.output_hidden_states = [ + i + for i in range(len(self._do_output_hidden_states)) + if self._do_output_hidden_states[i] + ] # Stop tracking CUDA memory now that active steps are complete if ( diff --git a/tests/torchtune/modules/test_common_utils.py b/tests/torchtune/modules/test_common_utils.py index 13ae33d65c..0e2a6400e0 100644 --- a/tests/torchtune/modules/test_common_utils.py +++ b/tests/torchtune/modules/test_common_utils.py @@ -14,8 +14,9 @@ llama3_2_vision_encoder, ) from torchtune.modules import delete_kv_caches, disable_kv_cache, local_kv_cache -from torchtune.modules.model_fusion import DeepFusionModel from torchtune.modules.common_utils import slice_str_to_array +from torchtune.modules.model_fusion import DeepFusionModel + @pytest.fixture def llama_vision_model(): @@ -192,6 +193,7 @@ def test_disable_kv_cache_raises_error_caches_not_setup(self, model, request): with disable_kv_cache(model): pass + class TestSliceStrToArray: def test_single_index(self): assert slice_str_to_array("0", 5) == [True, False, False, False, False] diff --git a/tests/torchtune/modules/test_early_exit_loss.py b/tests/torchtune/modules/test_early_exit_loss.py index 9cf82fc2f2..feff2fe264 100644 --- a/tests/torchtune/modules/test_early_exit_loss.py +++ b/tests/torchtune/modules/test_early_exit_loss.py @@ -10,23 +10,25 @@ import pytest import torch import torch.nn as nn -from torchtune import utils from torchtune.modules import TransformerDecoder from torchtune.modules.early_exit_loss import ( early_exit_loss, + EarlyExitCurriculumType, + GradualEarlyExitCurriculum, layer_ids_to_loss_scales, LossScaleType, - EarlyExitCurriculumType, - setup_early_exit_loss_curriculum, RotationalEarlyExitCurriculum, - GradualEarlyExitCurriculum, + setup_early_exit_loss_curriculum, ) # Mock components for TransformerDecoder class MockLayer(nn.Module): - def forward(self, x, mask=None, encoder_input=None, encoder_mask=None, input_pos=None): + def forward( + self, x, mask=None, encoder_input=None, encoder_mask=None, input_pos=None + ): return x # Simply return the input for testing purposes + @pytest.fixture def mock_model(): # Create mock components @@ -45,47 +47,62 @@ def mock_model(): norm=norm, output=output, num_layers=12, - output_hidden_states=[0, 1, 2] # Example layers to output hidden states + output_hidden_states=[0, 1, 2], # Example layers to output hidden states ) return model + @pytest.fixture def hidden_states_dict(): return {i: torch.randn(4, 5, 512) for i in range(3)} # Adjusted embedding dim + @pytest.fixture def labels(): return torch.randint(0, 1000, (4, 5)) # Adjusted vocab size + @pytest.fixture def loss_fn(): return nn.CrossEntropyLoss(ignore_index=-1) + def test_early_exit_loss(mock_model, hidden_states_dict, labels, loss_fn): loss = early_exit_loss(mock_model, hidden_states_dict, labels, loss_fn) assert isinstance(loss, torch.Tensor) assert loss.item() >= 0 + def test_layer_ids_to_loss_scales(): layer_ids = torch.tensor([0, 1, 2]) n_layers = 12 scales = layer_ids_to_loss_scales(layer_ids, n_layers, LossScaleType.SUM_L, 1.0) assert torch.isclose(scales.sum(), torch.tensor(1.0)) + def test_setup_early_exit_loss_curriculum(): - curriculum = setup_early_exit_loss_curriculum(EarlyExitCurriculumType.ROTATIONAL, [True, False, True], 100) + curriculum = setup_early_exit_loss_curriculum( + EarlyExitCurriculumType.ROTATIONAL, [True, False, True], 100 + ) assert isinstance(curriculum, RotationalEarlyExitCurriculum) - curriculum = setup_early_exit_loss_curriculum(EarlyExitCurriculumType.GRADUAL, [True, False, True], 100) + curriculum = setup_early_exit_loss_curriculum( + EarlyExitCurriculumType.GRADUAL, [True, False, True], 100 + ) assert isinstance(curriculum, GradualEarlyExitCurriculum) -@pytest.mark.parametrize("train_last_layer", [ - True, - False, -]) +@pytest.mark.parametrize( + "train_last_layer", + [ + True, + False, + ], +) def test_rotational_early_exit_curriculum(train_last_layer): - curriculum = RotationalEarlyExitCurriculum([True, False, False], max_steps=100, train_last_layer=train_last_layer) + curriculum = RotationalEarlyExitCurriculum( + [True, False, False], max_steps=100, train_last_layer=train_last_layer + ) expected = np.array([True, False, train_last_layer]) assert np.array_equal(curriculum.get(), expected) curriculum.step() @@ -99,12 +116,20 @@ def test_rotational_early_exit_curriculum(train_last_layer): assert np.array_equal(curriculum.get(), expected) -@pytest.mark.parametrize("train_last_layer", [ - True, - False, -]) +@pytest.mark.parametrize( + "train_last_layer", + [ + True, + False, + ], +) def test_gradual_early_exit_curriculum(train_last_layer): - curriculum = GradualEarlyExitCurriculum([True, True, True, True], max_steps=4, train_last_layer=train_last_layer, percent_scale=1) + curriculum = GradualEarlyExitCurriculum( + [True, True, True, True], + max_steps=4, + train_last_layer=train_last_layer, + percent_scale=1, + ) expected = np.array([False, False, False, train_last_layer]) assert np.array_equal(curriculum.get(), expected) curriculum.step() @@ -120,22 +145,18 @@ def test_gradual_early_exit_curriculum(train_last_layer): curriculum.step() assert np.array_equal(curriculum.get(), [True, True, True, True]) -@pytest.fixture -def hidden_states_dict(): - return {i: torch.randn(4, 5, 512) for i in range(3)} # Adjusted embedding dim - -@pytest.fixture -def labels(): - return torch.randint(0, 1000, (4, 5)) # Adjusted vocab size - -@pytest.fixture -def loss_fn(): - return nn.CrossEntropyLoss(ignore_index=-1) def test_early_exit_loss_vs_manual(mock_model, hidden_states_dict, labels, loss_fn): # Convert to float32 for numeric equivalence # Calculate early exit loss using the function - calculated_loss = early_exit_loss(mock_model, hidden_states_dict, labels, loss_fn, e_scale=1, loss_scale_type="one") + calculated_loss = early_exit_loss( + mock_model, + hidden_states_dict, + labels, + loss_fn, + e_scale=1, + loss_scale_type="one", + ) # Manually calculate the loss for each hidden state total_loss = 0.0 num_hidden_states = len(hidden_states_dict) @@ -150,8 +171,10 @@ def test_early_exit_loss_vs_manual(mock_model, hidden_states_dict, labels, loss_ # Average the losses across all hidden states manual_loss = total_loss / num_hidden_states # Compare the two losses - assert torch.isclose(calculated_loss, manual_loss, atol=1e-6), \ - f"Calculated loss: {calculated_loss}, Manual loss: {manual_loss}" + assert torch.isclose( + calculated_loss, manual_loss, atol=1e-6 + ), f"Calculated loss: {calculated_loss}, Manual loss: {manual_loss}" + if __name__ == "__main__": pytest.main() diff --git a/tests/torchtune/modules/test_layer_dropout.py b/tests/torchtune/modules/test_layer_dropout.py index 79ee98bed0..8e28a01c45 100644 --- a/tests/torchtune/modules/test_layer_dropout.py +++ b/tests/torchtune/modules/test_layer_dropout.py @@ -5,23 +5,27 @@ # LICENSE file in the root directory of this source tree. -from typing import Tuple import math +from typing import Tuple + import pytest import torch from tests.test_utils import assert_expected -from torchtune.modules.layer_dropout import LayerDropout, get_scale, ScaleType, prepare_layer_dropout +from torchtune.modules.layer_dropout import ( + get_scale, + LayerDropout, + prepare_layer_dropout, + ScaleType, +) class TestLayerDropout: """Class for testing LayerDropout implementation.""" - @pytest.fixture(autouse=True) def random(self): torch.manual_seed(0) - @pytest.fixture def input_shape(self) -> Tuple[int, int]: bsz = 32 @@ -29,32 +33,35 @@ def input_shape(self) -> Tuple[int, int]: dim = 4096 return bsz, seqlen, dim - @pytest.fixture def input(self, input_shape: Tuple[int]) -> torch.Tensor: return torch.randn(input_shape) - @pytest.fixture - def layer_dropout(self, prob: float = 0.5, disable_on_eval: bool = True) -> LayerDropout: + def layer_dropout( + self, prob: float = 0.5, disable_on_eval: bool = True + ) -> LayerDropout: return LayerDropout(prob=prob, disable_on_eval=disable_on_eval) - - def test_forward_train_prob_1(self, layer_dropout: LayerDropout, input: torch.Tensor) -> None: + def test_forward_train_prob_1( + self, layer_dropout: LayerDropout, input: torch.Tensor + ) -> None: # With dropout probability = 1.0, we expect output to be the same as input layer_dropout.prob = 1.0 output = layer_dropout.forward(lambda x: x**2, input) assert torch.allclose(output, input) - - def test_forward_train_prob_0(self, layer_dropout: LayerDropout, input: torch.Tensor) -> None: + def test_forward_train_prob_0( + self, layer_dropout: LayerDropout, input: torch.Tensor + ) -> None: # With dropout probability = 1.0, we expect the operation to be applied on all elements in the input layer_dropout.prob = 0.0 output = layer_dropout.forward(lambda x: x**2, input) assert torch.allclose(output, input**2) - - def test_forward_eval(self, layer_dropout: LayerDropout, input: torch.Tensor) -> None: + def test_forward_eval( + self, layer_dropout: LayerDropout, input: torch.Tensor + ) -> None: layer_dropout.prob = 1.0 layer_dropout.eval() @@ -67,45 +74,47 @@ def test_forward_eval(self, layer_dropout: LayerDropout, input: torch.Tensor) -> output = layer_dropout.forward(lambda x: x**2, input) assert torch.allclose(output, input) - def test_get_scale_uniform(self) -> None: scale_type = ScaleType.UNIFORM scale_period = 10 assert_expected(get_scale(scale_type, scale_period, 0), 1.0) - assert_expected(get_scale(scale_type, scale_period, scale_period/2), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period / 2), 1.0) assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) - assert_expected(get_scale(scale_type, scale_period, scale_period*2), 1.0) - + assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0) def test_get_scale_linear(self) -> None: scale_type = ScaleType.LINEAR scale_period = 10 assert_expected(get_scale(scale_type, scale_period, 0), 0.0) - assert_expected(get_scale(scale_type, scale_period, scale_period/2), 1/2) + assert_expected(get_scale(scale_type, scale_period, scale_period / 2), 1 / 2) assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) - assert_expected(get_scale(scale_type, scale_period, scale_period*2), 1.0) - + assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0) def test_get_scale_exp(self) -> None: scale_type = ScaleType.EXP scale_period = 10 assert_expected(get_scale(scale_type, scale_period, 0), 0.0) - assert_expected(get_scale(scale_type, scale_period, scale_period/2), math.pow(2, 1/2) - 1) + assert_expected( + get_scale(scale_type, scale_period, scale_period / 2), + math.pow(2, 1 / 2) - 1, + ) assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) - assert_expected(get_scale(scale_type, scale_period, scale_period*2), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0) def test_get_scale_log(self) -> None: scale_type = ScaleType.LOG scale_period = 10 assert_expected(get_scale(scale_type, scale_period, 0), 0.0) - assert_expected(get_scale(scale_type, scale_period, scale_period/2), math.log(5 + 1, scale_period + 1)) + assert_expected( + get_scale(scale_type, scale_period, scale_period / 2), + math.log(5 + 1, scale_period + 1), + ) assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) - assert_expected(get_scale(scale_type, scale_period, scale_period*2), 1.0) - + assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0) def test_get_scale_sin(self) -> None: scale_type = ScaleType.SIN @@ -115,16 +124,14 @@ def test_get_scale_sin(self) -> None: actual_scale = get_scale(scale_type, scale_period, val) assert_expected(actual_scale, expected_scale, atol=1e-7, rtol=1e-3) - @pytest.fixture(autouse=True) - def random(self): - torch.manual_seed(0) - - def test_prepare_layer_dropout_uniform(self) -> None: class MockModel(torch.nn.Module): def __init__(self): super().__init__() - self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 10) for _ in range(5)]) + self.layers = torch.nn.ModuleList( + [torch.nn.Linear(10, 10) for _ in range(5)] + ) + model = MockModel() prob_max = 0.5 prob_layer_scale = ScaleType.UNIFORM @@ -137,12 +144,14 @@ def __init__(self): else: assert layer.dropout.prob == 0 - def test_prepare_layer_dropout_exp(self) -> None: class MockModel(torch.nn.Module): def __init__(self): super().__init__() - self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 10) for _ in range(5)]) + self.layers = torch.nn.ModuleList( + [torch.nn.Linear(10, 10) for _ in range(5)] + ) + model = MockModel() prob_max = 0.5 prob_layer_scale = ScaleType.EXP @@ -157,12 +166,14 @@ def __init__(self): else: assert layer.dropout.prob > 0 and layer.dropout.prob < prob_max - def test_prepare_layer_dropout_linear(self) -> None: class MockModel(torch.nn.Module): def __init__(self): super().__init__() - self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 10) for _ in range(5)]) + self.layers = torch.nn.ModuleList( + [torch.nn.Linear(10, 10) for _ in range(5)] + ) + model = MockModel() prob_max = 0.5 prob_layer_scale = ScaleType.LINEAR @@ -174,7 +185,7 @@ def __init__(self): assert layer.dropout.prob == 0 elif i == len(model.layers) - 1: assert layer.dropout.prob == prob_max - elif i == len(model.layers)/2: - assert layer.dropout.prob == prob_max/2 + elif i == len(model.layers) / 2: + assert layer.dropout.prob == prob_max / 2 else: assert layer.dropout.prob >= 0.0 and layer.dropout.prob <= prob_max diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 676d7e0151..5776269b8d 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -415,7 +415,10 @@ class Recipe: name="dev/early_exit_finetune_distributed", file_path="dev/early_exit_finetune_distributed.py", configs=[ - Config(name="llama2/7B_full_early_exit", file_path="dev/7B_full_early_exit.yaml"), + Config( + name="llama2/7B_full_early_exit", + file_path="dev/7B_full_early_exit.yaml", + ), ], supports_distributed=True, ), diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index c5917ab053..3554698d42 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -14,6 +14,7 @@ ) from .feed_forward import FeedForward # noqa from .kv_cache import KVCache # noqa +from .layer_dropout import LayerDropout, prepare_layer_dropout # noqa from .layer_norm import Fp32LayerNorm # noqa from .low_precision import FrozenNF4Linear # noqa from .position_embeddings import ( # noqa @@ -23,7 +24,6 @@ from .rms_norm import RMSNorm # noqa from .tanh_gate import TanhGate # noqa from .tied_linear import TiedLinear # noqa -from .layer_dropout import LayerDropout, prepare_layer_dropout # noqa from .transformer import ( # noqa TransformerCrossAttentionLayer, TransformerDecoder, diff --git a/torchtune/modules/common_utils.py b/torchtune/modules/common_utils.py index 7753395e85..16df7801cb 100644 --- a/torchtune/modules/common_utils.py +++ b/torchtune/modules/common_utils.py @@ -89,40 +89,42 @@ def slice_str_to_array(slice_str: str, length: int) -> list[bool]: [True, False, False, False, True, True, False] """ - assert ',' not in slice_str or ':' not in slice_str, "Cannot mix commas and colons" + assert "," not in slice_str or ":" not in slice_str, "Cannot mix commas and colons" - if ',' in slice_str: - indices = [int(i) for i in slice_str.split(',')] + if "," in slice_str: + indices = [int(i) for i in slice_str.split(",")] assert all(0 <= i < length for i in indices), "Index out of range" result = [False] * length for i in indices: result[i] = True return result - parts = slice_str.split(':') + parts = slice_str.split(":") assert len(parts) <= 3, "Invalid slice format" start, end, step = None, None, None - if len(parts) == 1 and parts[0] != '': + if len(parts) == 1 and parts[0] != "": start = int(parts[0]) end = start + 1 step = 1 elif len(parts) == 2: - start = int(parts[0]) if parts[0] != '' else None - end = int(parts[1]) if parts[1] != '' else None + start = int(parts[0]) if parts[0] != "" else None + end = int(parts[1]) if parts[1] != "" else None elif len(parts) == 3: - start = int(parts[0]) if parts[0] != '' else None - end = int(parts[1]) if parts[1] != '' else None - step = int(parts[2]) if parts[2] != '' else None + start = int(parts[0]) if parts[0] != "" else None + end = int(parts[1]) if parts[1] != "" else None + step = int(parts[2]) if parts[2] != "" else None assert start is None or 0 <= start < length, "Start index out of range" assert end is None or 0 <= end < length, "End index out of range" assert step is None or step != 0, "Step cannot be zero" result = [False] * length - slice_indices = range(start if start is not None else 0, - end if end is not None else length, - step if step is not None else 1) + slice_indices = range( + start if start is not None else 0, + end if end is not None else length, + step if step is not None else 1, + ) for i in slice_indices: if 0 <= i < length: diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index 9cfc6f5bac..9a4d191fc3 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -5,15 +5,16 @@ # LICENSE file in the root directory of this source tree. import copy +from enum import Enum + import numpy as np import torch -from enum import Enum -from typing import List from torchtune import utils log = utils.get_logger("DEBUG") + class LossScaleType(str, Enum): ONE = "one" L = "l" @@ -22,9 +23,17 @@ class LossScaleType(str, Enum): SQRT_L = "sqrt_l" INV_SQRT_L = "inv_sqrt_l" + # TODO: create docstring using other functions as template # TODO: add assert on type of loss_fn -def early_exit_loss(model, hidden_states_dict, labels, loss_fn, e_scale: float=1.0, loss_scale_type=LossScaleType.SUM_L): +def early_exit_loss( + model, + hidden_states_dict, + labels, + loss_fn, + e_scale: float = 1.0, + loss_scale_type=LossScaleType.SUM_L, +): batch_loss_fn = copy.deepcopy(loss_fn) batch_loss_fn.reduction = "none" @@ -49,26 +58,33 @@ def early_exit_loss(model, hidden_states_dict, labels, loss_fn, e_scale: float=1 s_unpadded = (labels != loss_fn.ignore_index).sum() losses_early = losses_early.float().sum(-1) / s_unpadded # Shape: [e] - losses_scales = layer_ids_to_loss_scales(torch.Tensor(hidden_layer_ids).to(losses_early), len(model.layers), loss_scale_type, e_scale) + losses_scales = layer_ids_to_loss_scales( + torch.Tensor(hidden_layer_ids).to(losses_early), + len(model.layers), + loss_scale_type, + e_scale, + ) return torch.sum(losses_scales * losses_early) -def layer_ids_to_loss_scales(layer_ids, n_layers, loss_scale_type: LossScaleType, e_scale: float): - match loss_scale_type: - case LossScaleType.ONE: - loss_scales = torch.ones(len(layer_ids)) - case LossScaleType.L: - loss_scales = torch.Tensor(layer_ids+1) - case LossScaleType.SUM_L: - loss_scales = torch.cumsum(layer_ids+1, dim=0) - case LossScaleType.SQRT_L: - loss_scales = torch.sqrt(layer_ids+1) - case LossScaleType.INV_L: - loss_scales = 1.0 / (layer_ids+1) - case LossScaleType.INV_SQRT_L: - loss_scales = torch.reciprocal(torch.sqrt(layer_ids+1)) - case _: - raise ValueError(f"Unsupported loss_scale type {loss_scale_type}") + +def layer_ids_to_loss_scales( + layer_ids, n_layers, loss_scale_type: LossScaleType, e_scale: float +): + if loss_scale_type == LossScaleType.ONE: + loss_scales = torch.ones(len(layer_ids)) + elif loss_scale_type == LossScaleType.L: + loss_scales = torch.Tensor(layer_ids + 1) + elif loss_scale_type == LossScaleType.SUM_L: + loss_scales = torch.cumsum(layer_ids + 1, dim=0) + elif loss_scale_type == LossScaleType.SQRT_L: + loss_scales = torch.sqrt(layer_ids + 1) + elif loss_scale_type == LossScaleType.INV_L: + loss_scales = 1.0 / (layer_ids + 1) + elif loss_scale_type == LossScaleType.INV_SQRT_L: + loss_scales = torch.reciprocal(torch.sqrt(layer_ids + 1)) + else: + raise ValueError(f"Unsupported loss_scale type {loss_scale_type}") loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) # normalize loss scales to ensure that their sum is 1.0 @@ -77,29 +93,31 @@ def layer_ids_to_loss_scales(layer_ids, n_layers, loss_scale_type: LossScaleType return loss_scales + class EarlyExitCurriculumType(str, Enum): NONE = "none" ROTATIONAL = "rot" GRADUAL = "gradual" -def setup_early_exit_loss_curriculum(early_exit_curriculum: EarlyExitCurriculumType, *args, **kwargs): - match early_exit_curriculum: - case EarlyExitCurriculumType.NONE: - return None - - case EarlyExitCurriculumType.ROTATIONAL: - return RotationalEarlyExitCurriculum(*args, **kwargs) - case EarlyExitCurriculumType.GRADUAL: - return GradualEarlyExitCurriculum(*args, **kwargs) +def setup_early_exit_loss_curriculum( + early_exit_curriculum: EarlyExitCurriculumType, *args, **kwargs +): + if early_exit_curriculum == EarlyExitCurriculumType.NONE: + return None + elif early_exit_curriculum == EarlyExitCurriculumType.ROTATIONAL: + return RotationalEarlyExitCurriculum(*args, **kwargs) + elif early_exit_curriculum == EarlyExitCurriculumType.GRADUAL: + return GradualEarlyExitCurriculum(*args, **kwargs) + else: + raise ValueError(f"Unsupported early loss curriculum {early_exit_curriculum}.") - case _: - raise ValueError(f"Unsupported early loss curriculum {early_exit_curriculum}.") - # TODO: create a base curriculum class that can be used for other aspects, e.g., dropout, datasets, etc. -class EarlyExitCurriculum(): - def __init__(self, do_output_hidden_states, max_steps, train_last_layer=True, verbose=False): +class EarlyExitCurriculum: + def __init__( + self, do_output_hidden_states, max_steps, train_last_layer=True, verbose=False + ): self._init_do_output_hidden_states = do_output_hidden_states self.do_output_hidden_states = do_output_hidden_states self.train_last_layer = train_last_layer @@ -116,8 +134,11 @@ def get(self): do_output_hidden_states[-1] = True return do_output_hidden_states + class RotationalEarlyExitCurriculum(EarlyExitCurriculum): - def __init__(self, do_output_hidden_states, max_steps, train_last_layer=True, verbose=False): + def __init__( + self, do_output_hidden_states, max_steps, train_last_layer=True, verbose=False + ): super().__init__(do_output_hidden_states, max_steps, train_last_layer, verbose) self._initial_do_output_hidden_states = np.copy(do_output_hidden_states) @@ -126,10 +147,20 @@ def step(self): self.do_output_hidden_states = np.roll(self.do_output_hidden_states, 1) if self.verbose: - log.info(f"Updated self.output_hidden_states to {self.do_output_hidden_states}.") + log.info( + f"Updated self.output_hidden_states to {self.do_output_hidden_states}." + ) + class GradualEarlyExitCurriculum(EarlyExitCurriculum): - def __init__(self, do_output_hidden_states, max_steps, train_last_layer=True, percent_scale=2, verbose=False): + def __init__( + self, + do_output_hidden_states, + max_steps, + train_last_layer=True, + percent_scale=2, + verbose=False, + ): super().__init__(do_output_hidden_states, max_steps, train_last_layer, verbose) self._final_do_output_hidden_states = np.copy(do_output_hidden_states) self._step = 0 @@ -144,12 +175,18 @@ def step(self): n_layers = len(self.do_output_hidden_states) # Enable each layer based on proportion of completed training steps for layer_index in range(len(self.do_output_hidden_states)): - should_train = (percent_trained * self._percent_scale) >= (n_layers - layer_index) / n_layers + should_train = (percent_trained * self._percent_scale) >= ( + n_layers - layer_index + ) / n_layers self.do_output_hidden_states[layer_index] = should_train # Only enable layers that are set by the user - self.do_output_hidden_states = np.logical_and(self.do_output_hidden_states, self._final_do_output_hidden_states) + self.do_output_hidden_states = np.logical_and( + self.do_output_hidden_states, self._final_do_output_hidden_states + ) self._step += 1 if self.verbose: - log.info(f"Updated self.do_output_hidden_states to {self.do_output_hidden_states}.") + log.info( + f"Updated self.do_output_hidden_states to {self.do_output_hidden_states}." + ) diff --git a/torchtune/modules/layer_dropout.py b/torchtune/modules/layer_dropout.py index 7c15c28bbc..cb9fbb4fd9 100644 --- a/torchtune/modules/layer_dropout.py +++ b/torchtune/modules/layer_dropout.py @@ -4,13 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math from enum import Enum from typing import Any, Callable, Optional -import math + import torch from torchtune.modules.common_utils import slice_str_to_array + class LayerDropout(torch.nn.Module): def __init__(self, prob=0.0, dim=0, disable_on_eval=True, seed=None): super().__init__() @@ -30,7 +32,11 @@ def forward(self, function: Callable, input: torch.Tensor, *args, **kwargs): self.inferred = 1.0 return function(input, *args, **kwargs) - skip = torch.bernoulli(torch.Tensor((n) * [self.prob]), generator=self.generator).to(input.device).to(input.dtype) + skip = ( + torch.bernoulli(torch.Tensor((n) * [self.prob]), generator=self.generator) + .to(input.device) + .to(input.dtype) + ) self.inferred = 1 - torch.mean(skip) ind_selected = (skip == 0).nonzero().squeeze() @@ -39,11 +45,14 @@ def forward(self, function: Callable, input: torch.Tensor, *args, **kwargs): out_selected = function(x_selected, *args, **kwargs) out = input.clone() - assert self.dim == 0, "Currently only supporting dropping elements along the 0th dimension" + assert ( + self.dim == 0 + ), "Currently only supporting dropping elements along the 0th dimension" if ind_selected.numel() > 0: out[ind_selected] = out_selected return out + class ModuleLayerDropoutWrapper(torch.nn.Module): def __init__(self, module: torch.nn.Module, dropout: LayerDropout): super().__init__() @@ -78,6 +87,7 @@ def load_state_dict(self, state_dict, *args, **kwargs): self.module.load_state_dict(state_dict, *args, **kwargs) return + class ScaleType(str, Enum): UNIFORM = "uniform" EXP = "exp" @@ -87,6 +97,7 @@ class ScaleType(str, Enum): SIGMOID = "sigmoid" STEP = "step" + def get_scale(scale_type: ScaleType, scale_period: int, val: int): if scale_period == 0: return 1.0 @@ -104,16 +115,39 @@ def get_scale(scale_type: ScaleType, scale_period: int, val: int): # after scale_period, scale should be 1 return min(scale, 1.0) -def prepare_layer_dropout(model, prob_max: float= 0.0, prob_layer_scale: ScaleType = ScaleType.EXP, layers_str: Optional[str] = None, disable_on_eval: bool = True): + +def prepare_layer_dropout( + model, + prob_max: float = 0.0, + prob_layer_scale: ScaleType = ScaleType.EXP, + layers_str: Optional[str] = None, + disable_on_eval: bool = True, +): num_layers = len(model.layers) - has_dropout = slice_str_to_array(layers_str, num_layers) if layers_str else [True] * num_layers + has_dropout = ( + slice_str_to_array(layers_str, num_layers) + if layers_str + else [True] * num_layers + ) for layer_id in range(len(model.layers)): - prob = prob_max * get_scale( - scale_type = prob_layer_scale, - scale_period = num_layers - 1, - val = layer_id, - ) if has_dropout[layer_id] else 0.0 - assert prob >= 0.0 and prob <= prob_max, f"prob={prob} should be between 0 and {prob_max}" - # We would like each layer to have a different seed, so that we don't have the same samples skipped across layers. Hence, we use the layer_id as a seed for each layer's dropout. - layer_dropout = LayerDropout(prob, disable_on_eval=disable_on_eval, seed=layer_id) - model.layers[layer_id] = ModuleLayerDropoutWrapper(model.layers[layer_id], layer_dropout) + prob = ( + prob_max + * get_scale( + scale_type=prob_layer_scale, + scale_period=num_layers - 1, + val=layer_id, + ) + if has_dropout[layer_id] + else 0.0 + ) + assert ( + prob >= 0.0 and prob <= prob_max + ), f"prob={prob} should be between 0 and {prob_max}" + # We would like each layer to have a different seed, so that we don't have the same samples skipped across layers. + # Hence, we use the layer_id as a seed for each layer's dropout. + layer_dropout = LayerDropout( + prob, disable_on_eval=disable_on_eval, seed=layer_id + ) + model.layers[layer_id] = ModuleLayerDropoutWrapper( + model.layers[layer_id], layer_dropout + ) diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 1ea1cb4a38..66ac92002f 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -11,6 +11,7 @@ from torchtune.modules import MultiHeadAttention from torchtune.modules.attention_utils import _MaskType + class TransformerSelfAttentionLayer(nn.Module): """ Transformer layer derived from the Llama2 model. Normalization is applied before the attention **and** FF layer. From ad090af9f2db002bb4d6f8746c80230bbe95e9ed Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Mon, 25 Nov 2024 00:21:09 +0000 Subject: [PATCH 55/88] add assertion on type of loss fn for early exit loss --- recipes/dev/early_exit_finetune_distributed.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 4a38bda2e5..9f59c40dac 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -672,6 +672,11 @@ def _setup_early_exit_loss( early_exit_loss_curriculum = None if cfg_early_exit_loss: + assert ( + hasattr(self._loss_fn, "reduction") + and self._loss_fn.reduction == "mean" + ), "Currently early exit loss is only implemented for loss functions that apply a mean reduction." + do_output_hidden_states = slice_str_to_array( cfg_early_exit_loss.get("layers", ":"), len(self._model.layers) ) From cec8cd4cef408e83bed4d38204e3b7b7050dd3e0 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Mon, 25 Nov 2024 03:40:32 +0000 Subject: [PATCH 56/88] add docstring and slightly change attribute of layer_dropout and early_exit --- torchtune/modules/early_exit_loss.py | 201 ++++++++++++++++++++------- torchtune/modules/layer_dropout.py | 165 ++++++++++++++++++++-- 2 files changed, 308 insertions(+), 58 deletions(-) diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index 9a4d191fc3..6d10c54711 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -6,11 +6,13 @@ import copy from enum import Enum +from typing import Dict, List, Optional import numpy as np import torch from torchtune import utils +from torchtune.modules.transformer import TransformerDecoder log = utils.get_logger("DEBUG") @@ -24,16 +26,32 @@ class LossScaleType(str, Enum): INV_SQRT_L = "inv_sqrt_l" -# TODO: create docstring using other functions as template -# TODO: add assert on type of loss_fn def early_exit_loss( - model, - hidden_states_dict, - labels, - loss_fn, + model: TransformerDecoder, + hidden_states_dict: Dict[int, torch.Tensor], + labels: torch.Tensor, + loss_fn: torch.nn.Module, e_scale: float = 1.0, - loss_scale_type=LossScaleType.SUM_L, -): + loss_scale_type: LossScaleType = LossScaleType.SUM_L, +) -> torch.Tensor: + """ + Compute the early exit loss for a given model and outputs of intermediate layers. + This function takes in a model, a dictionary of hidden states, labels, a loss function, + and optional parameters for scaling the loss. It computes the early exit loss by + iterating over the hidden states, computing the logits and losses at each layer, + and then scaling and summing these losses. + Args: + model (TransformerDecoder): The model to compute the early exit loss for. + hidden_states_dict (Dict[int, torch.Tensor]): A dictionary of hidden states, + where each key is a layer index and each value is a tensor of shape [b, s, d]. + labels (torch.Tensor): The labels for the input data. + loss_fn (torch.nn.Module): The loss function to use (should be the same as the standard loss function for last layer). + e_scale (float, optional): A scaling factor for the early exit losses. Defaults to 1.0. + loss_scale_type (LossScaleType, optional): The type of loss scaling to use to determine + scale of each layer's loss. Defaults to LossScaleType.SUM_L. + Returns: + torch.Tensor: The computed early exit loss. + """ batch_loss_fn = copy.deepcopy(loss_fn) batch_loss_fn.reduction = "none" @@ -69,8 +87,34 @@ def early_exit_loss( def layer_ids_to_loss_scales( - layer_ids, n_layers, loss_scale_type: LossScaleType, e_scale: float -): + layer_ids: List, + n_layers: int, + loss_scale_type: LossScaleType, + e_scale: float, +) -> torch.Tensor: + """ + Compute the loss scales for a given set of layer IDs and loss scale type. + This function takes in a list of layer IDs, the total number of layers, + a loss scale type, and an early exit scaling factor. It computes the loss + scales based on the specified loss scale type and then normalizes them to + ensure that their sum is 1.0. + Args: + layer_ids (List): A tensor of layer IDs. + n_layers (int): The total number of layers. + loss_scale_type (LossScaleType): The type of loss scaling to use. + e_scale (float): An early exit scaling factor. + Returns: + torch.Tensor: The computed loss scales. + Raises: + ValueError: If the provided loss scale type is not supported. + AssertionError: If the sum of the loss scales is not close to 1.0. + Example: + >>> layer_ids = [0, 1, 2] + >>> n_layers = 3 + >>> loss_scale_type = LossScaleType.SUM_L + >>> e_scale = 1.0 + >>> loss_scales = layer_ids_to_loss_scales(layer_ids, n_layers, loss_scale_type, e_scale) + """ if loss_scale_type == LossScaleType.ONE: loss_scales = torch.ones(len(layer_ids)) elif loss_scale_type == LossScaleType.L: @@ -100,35 +144,45 @@ class EarlyExitCurriculumType(str, Enum): GRADUAL = "gradual" -def setup_early_exit_loss_curriculum( - early_exit_curriculum: EarlyExitCurriculumType, *args, **kwargs -): - if early_exit_curriculum == EarlyExitCurriculumType.NONE: - return None - elif early_exit_curriculum == EarlyExitCurriculumType.ROTATIONAL: - return RotationalEarlyExitCurriculum(*args, **kwargs) - elif early_exit_curriculum == EarlyExitCurriculumType.GRADUAL: - return GradualEarlyExitCurriculum(*args, **kwargs) - else: - raise ValueError(f"Unsupported early loss curriculum {early_exit_curriculum}.") - - # TODO: create a base curriculum class that can be used for other aspects, e.g., dropout, datasets, etc. class EarlyExitCurriculum: + """ + A curriculum for early exit loss training, which controls which layers to use their hidden states + during training. + Args: + do_output_hidden_states (List[bool]): A list indicating whether each layer's hidden state + should be output to calculate their losses. + max_steps (int): The maximum number of steps in the curriculum. + train_last_layer (bool, optional): Whether to always calculate loss for the last layer. Defaults to True. + verbose (bool, optional): Whether to print verbose logs. Defaults to False. + """ + def __init__( - self, do_output_hidden_states, max_steps, train_last_layer=True, verbose=False + self, + do_output_hidden_states: List[bool], + max_steps: int, + train_last_layer: bool = True, + verbose: bool = False, ): self._init_do_output_hidden_states = do_output_hidden_states - self.do_output_hidden_states = do_output_hidden_states + self._do_output_hidden_states = do_output_hidden_states self.train_last_layer = train_last_layer self.verbose = verbose self.max_steps = max_steps - def step(self): + def step(self) -> None: + """ + Perform a step in the curriculum. Should be called at the end of each iteration during training. + """ pass - def get(self): - do_output_hidden_states = np.copy(self.do_output_hidden_states) + def get(self) -> np.ndarray: + """ + Get the current output hidden states. + Returns: + np.ndarray: A list indicating whether we should calculate loss for each layer. + """ + do_output_hidden_states = np.copy(self._do_output_hidden_states) # Ensure last layer is trained if self.train_last_layer: do_output_hidden_states[-1] = True @@ -136,57 +190,104 @@ def get(self): class RotationalEarlyExitCurriculum(EarlyExitCurriculum): - def __init__( - self, do_output_hidden_states, max_steps, train_last_layer=True, verbose=False - ): - super().__init__(do_output_hidden_states, max_steps, train_last_layer, verbose) - self._initial_do_output_hidden_states = np.copy(do_output_hidden_states) + """ + A rotational early exit curriculum, which rotates the layer enablement one step forward + at each step. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._initial_do_output_hidden_states = np.copy(self._do_output_hidden_states) def step(self): + """ + Rotate the layer enablement one step forward. + This method updates the `do_output_hidden_states` attribute by rotating it one position to the right. + """ # Rotate layer enablement one step forward - self.do_output_hidden_states = np.roll(self.do_output_hidden_states, 1) + self._do_output_hidden_states = np.roll(self._do_output_hidden_states, 1) if self.verbose: log.info( - f"Updated self.output_hidden_states to {self.do_output_hidden_states}." + f"Updated self._do_output_hidden_states to {self._do_output_hidden_states}." ) class GradualEarlyExitCurriculum(EarlyExitCurriculum): + """ + A gradual early exit curriculum, which gradually enables more layers (starting from the last layer) as training progresses. + Args: + *args: Positional arguments passed to the parent EarlyExitCurriculum class. + percent_scale (float, optional): A scaling factor to determine at which percentage + of steps, all the layers will be enabled. At `steps = max_steps / percent_scale`, all the layers will be enabled. + **kwargs: Keyword arguments passed to the parent EarlyExitCurriculum class. + """ + def __init__( self, - do_output_hidden_states, - max_steps, - train_last_layer=True, - percent_scale=2, - verbose=False, + *args, + percent_scale: float = 2, + **kwargs, ): - super().__init__(do_output_hidden_states, max_steps, train_last_layer, verbose) - self._final_do_output_hidden_states = np.copy(do_output_hidden_states) + super().__init__(*args, **kwargs) + self._final_do_output_hidden_states = np.copy(self._do_output_hidden_states) self._step = 0 self._percent_scale = percent_scale # Initialize all layers to False - for i in range(len(self.do_output_hidden_states)): - self.do_output_hidden_states[i] = False + for i in range(len(self._do_output_hidden_states)): + self._do_output_hidden_states[i] = False def step(self): + """ + Perform a step in the curriculum. + This method updates the `_do_output_hidden_states` attribute based on the current + step and the percentage of completed training steps. + """ percent_trained = self._step / self.max_steps - n_layers = len(self.do_output_hidden_states) + n_layers = len(self._do_output_hidden_states) # Enable each layer based on proportion of completed training steps - for layer_index in range(len(self.do_output_hidden_states)): + for layer_index in range(len(self._do_output_hidden_states)): should_train = (percent_trained * self._percent_scale) >= ( n_layers - layer_index ) / n_layers - self.do_output_hidden_states[layer_index] = should_train + self._do_output_hidden_states[layer_index] = should_train # Only enable layers that are set by the user - self.do_output_hidden_states = np.logical_and( - self.do_output_hidden_states, self._final_do_output_hidden_states + self._do_output_hidden_states = np.logical_and( + self._do_output_hidden_states, self._final_do_output_hidden_states ) self._step += 1 if self.verbose: log.info( - f"Updated self.do_output_hidden_states to {self.do_output_hidden_states}." + f"Updated self._do_output_hidden_states to {self._do_output_hidden_states}." ) + + +def setup_early_exit_loss_curriculum( + early_exit_curriculum: EarlyExitCurriculumType, *args, **kwargs +) -> Optional[EarlyExitCurriculum]: + """ + Set up an early exit loss curriculum based on the provided type. + This function takes in an early exit curriculum type and optional arguments. + It returns an instance of the corresponding early exit curriculum class, + or None if the curriculum type is NONE. + Args: + early_exit_curriculum (EarlyExitCurriculumType): The type of early exit curriculum to set up. + *args: Optional positional arguments for the early exit curriculum constructor. + **kwargs: Optional keyword arguments for the early exit curriculum constructor. + Returns: + Optional[EarlyExitCurriculum]: + An instance of the corresponding early exit curriculum class, or None. + Raises: + ValueError: If the provided early exit curriculum type is not supported. + """ + if early_exit_curriculum == EarlyExitCurriculumType.NONE: + return None + elif early_exit_curriculum == EarlyExitCurriculumType.ROTATIONAL: + return RotationalEarlyExitCurriculum(*args, **kwargs) + elif early_exit_curriculum == EarlyExitCurriculumType.GRADUAL: + return GradualEarlyExitCurriculum(*args, **kwargs) + else: + raise ValueError(f"Unsupported early loss curriculum {early_exit_curriculum}.") diff --git a/torchtune/modules/layer_dropout.py b/torchtune/modules/layer_dropout.py index cb9fbb4fd9..09f01098ec 100644 --- a/torchtune/modules/layer_dropout.py +++ b/torchtune/modules/layer_dropout.py @@ -6,7 +6,7 @@ import math from enum import Enum -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch @@ -14,7 +14,34 @@ class LayerDropout(torch.nn.Module): - def __init__(self, prob=0.0, dim=0, disable_on_eval=True, seed=None): + """ + A module that applies layer dropout to the input tensor of an underlying module. + It drops a portion of an input tensor, applies the underlying module on the + remaining parts of the tensor, and then concatenates with the dropped portion of the tensor. + When applied during training, it can have a regularization effect, and can potentially speedup training. + Args: + prob (float): The probability of dropping an input. Defaults to 0.0. + dim (Optional[int]): The dimension of input tensor along which to drop layers. Defaults to 0 (i.e., batch size). + disable_on_eval (Optional[bool]): Whether to disable layer dropout during evaluation. Defaults to True. + seed (Optional[int]): The seed for the random number generator. Defaults to None. + Examples: + >>> import torch + >>> # Apply layer dropout to a lambda function + >>> layer_dropout = LayerDropout(prob=0.5) + >>> output = layer_dropout(lambda x: x**2, torch.randn(1)) + >>> # Apply layer dropout to a torch.nn.Linear module + >>> linear = torch.nn.Linear(5, 3) + >>> layer_dropout = LayerDropout(prob=0.5) + >>> output = layer_dropout(linear, torch.randn(1, 5)) + """ + + def __init__( + self, + prob: float = 0.0, + dim: Optional[int] = 0, + disable_on_eval: Optional[bool] = True, + seed: Optional[int] = None, + ): super().__init__() self.prob: float = prob self.dim = dim @@ -25,7 +52,23 @@ def __init__(self, prob=0.0, dim=0, disable_on_eval=True, seed=None): if seed is not None: self.generator.manual_seed(seed) - def forward(self, function: Callable, input: torch.Tensor, *args, **kwargs): + def forward( + self, + function: Union[Callable, torch.nn.Module], + input: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Apply layer dropout to the input tensor. + Args: + function (Union[Callable, torch.nn.Module]): The function or module to apply to the input tensor. + input (torch.Tensor): The input tensor. + *args: Additional positional arguments passed to the function. + **kwargs: Additional keyword arguments passed to the function. + Returns: + torch.Tensor: The output tensor after applying layer dropout. + """ n = input.shape[self.dim] if self.prob == 0 or (self.disable_on_eval and self.training is False): @@ -54,6 +97,39 @@ def forward(self, function: Callable, input: torch.Tensor, *args, **kwargs): class ModuleLayerDropoutWrapper(torch.nn.Module): + """ + A wrapper module that adds layer dropout functionality to a given module. + This class wraps a given module and applies layer dropout to it. It also + provides getter and setter methods for the wrapped module's attributes. + Args: + module (torch.nn.Module): The module to wrap. + dropout (LayerDropout): The layer dropout object. + Examples: + >>> import torch + >>> from torch import nn + >>> # Define a simple model + >>> class MyModel(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.fc1 = nn.Linear(5, 3) + ... self.fc2 = nn.Linear(3, 2) + ... + ... def forward(self, x): + ... return self.fc2(self.fc1(x)) + >>> model = MyModel() + >>> fc1 = model.fc1 + >>> fc2 = model.fc2 + >>> # Apply layer dropout to the model + >>> layer_dropout = LayerDropout(prob=0.5) + >>> model = ModuleLayerDropoutWrapper(model, layer_dropout) + >>> # Accessing attributes of the wrapped model + >>> assert model.dropout.prob == 0.5 + >>> assert model.fc1 == fc1 + >>> assert model.fc2 == fc2 + >>> # Pass an input to the wrapped model as if you are passing it to the original model + >>> output = model(torch.randn(1, 5)) + """ + def __init__(self, module: torch.nn.Module, dropout: LayerDropout): super().__init__() self.module = module @@ -81,9 +157,11 @@ def __getitem__(self, key: int) -> Any: return self.module.__getitem__(key) def state_dict(self, *args, **kwargs): + """Return the state dictionary of the wrapped module.""" return self.module.state_dict(*args, **kwargs) def load_state_dict(self, state_dict, *args, **kwargs): + """Load the state dictionary into the wrapped module.""" self.module.load_state_dict(state_dict, *args, **kwargs) return @@ -98,7 +176,29 @@ class ScaleType(str, Enum): STEP = "step" -def get_scale(scale_type: ScaleType, scale_period: int, val: int): +def get_scale( + scale_type: ScaleType, + scale_period: int, + val: int, +) -> float: + """ + Compute a scaling factor based on the provided scale type, period, and value. + The scaling factor is designed to be 0 when the value is 0 and 1 when the value + reaches or is larger than the scale period. + Args: + scale_type (ScaleType): The type of scaling to use. + scale_period (int): The period over which the scaling factor increases from 0 to 1. + val (int): The current value used to compute the scaling factor. + Returns: + float: The computed scaling factor. + Examples: + >>> get_scale(ScaleType.LINEAR, 10, 5) + 0.5 + >>> get_scale(ScaleType.LINEAR, 10, 0) + 0.0 + >>> get_scale(ScaleType.LOG, 10, 10) + 1.0 + """ if scale_period == 0: return 1.0 @@ -117,12 +217,61 @@ def get_scale(scale_type: ScaleType, scale_period: int, val: int): def prepare_layer_dropout( - model, + model: torch.nn.Module, prob_max: float = 0.0, - prob_layer_scale: ScaleType = ScaleType.EXP, + prob_layer_scale: Optional[ScaleType] = ScaleType.UNIFORM, layers_str: Optional[str] = None, - disable_on_eval: bool = True, -): + disable_on_eval: Optional[bool] = True, +) -> None: + """ + Prepare a model for layer dropout by wrapping each layer with a ModuleLayerDropoutWrapper. + This function takes in a model, the maximum probability of dropping a layer, + the scaling type for the layer dropout probability, a string specifying which + layers to apply dropout to, and a boolean indicating whether to disable dropout + during evaluation. It then wraps each layer of the model inplace with a + ModuleLayerDropoutWrapper, which applies layer dropout to the input tensor. + Args: + model (torch.nn.Module): The model to prepare for layer dropout. + prob_max (float): The maximum probability of dropping a layer. Defaults to 0.0. + prob_layer_scale (Optional[ScaleType]): The scaling type for the dropout probability + across layers. Defaults to ScaleType.UNIFORM. + layers_str (Optional[str]): A string specifying which layers to apply dropout to. + Defaults to None which means apply to all layers. + disable_on_eval (Optional[bool]): Whether to disable dropout during evaluation. Defaults to True. + Returns: + None + Example: + >>> import torch + >>> from torch import nn + >>> # Define a simple model + >>> class MyModel(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.layers = nn.ModuleList([ + ... nn.Linear(5, 3), + ... nn.Linear(3, 2), + ... nn.Linear(2, 1), + ... nn.Linear(1, 2), + ... nn.Linear(2, 3), + ... ]) + ... + ... def forward(self, x): + ... for layer in self.layers: + ... x = layer(x) + ... return x + >>> model = MyModel() + >>> # Apply layer dropout uniformly to all layers + >>> prepare_layer_dropout(model, prob_max=0.2, prob_layer_scale=ScaleType.UNIFORM) + >>> # Apply layer dropout every other layer, as described in LayerDrop paper + (Fan et al., https://arxiv.org/abs/1909.11556v1) + >>> prepare_layer_dropout(model, prob_max=0.2, prob_layer_scale=ScaleType.UNIFORM, layers_str="::2") + >>> # Apply layer dropout that increases linearly across layers, as described in Progressive Layer + Dropout paper (Zhang et al., https://arxiv.org/abs/2010.13369) + >>> prepare_layer_dropout(model, prob_max=0.2, prob_layer_scale=ScaleType.LINEAR) + >>> # Apply layer dropout that increases exponentially across layers, as described in + LayerSkip paper (Elhoushi et al., https://arxiv.org/abs/2404.16710) + >>> prepare_layer_dropout(model, prob_max=0.2, prob_layer_scale=ScaleType.EXP) + """ num_layers = len(model.layers) has_dropout = ( slice_str_to_array(layers_str, num_layers) From b69f2f30c875871476a3099fe87c120d79b29f4c Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Mon, 25 Nov 2024 04:37:52 +0000 Subject: [PATCH 57/88] refactor layer_dropout and add test cases on wrapper --- tests/torchtune/modules/test_layer_dropout.py | 109 ++++++++++++++---- 1 file changed, 89 insertions(+), 20 deletions(-) diff --git a/tests/torchtune/modules/test_layer_dropout.py b/tests/torchtune/modules/test_layer_dropout.py index 8e28a01c45..73f42998d4 100644 --- a/tests/torchtune/modules/test_layer_dropout.py +++ b/tests/torchtune/modules/test_layer_dropout.py @@ -14,6 +14,7 @@ from torchtune.modules.layer_dropout import ( get_scale, LayerDropout, + ModuleLayerDropoutWrapper, prepare_layer_dropout, ScaleType, ) @@ -22,15 +23,11 @@ class TestLayerDropout: """Class for testing LayerDropout implementation.""" - @pytest.fixture(autouse=True) - def random(self): - torch.manual_seed(0) - @pytest.fixture def input_shape(self) -> Tuple[int, int]: - bsz = 32 - seqlen = 1024 - dim = 4096 + bsz = 8 + seqlen = 256 + dim = 32 return bsz, seqlen, dim @pytest.fixture @@ -45,7 +42,7 @@ def layer_dropout( def test_forward_train_prob_1( self, layer_dropout: LayerDropout, input: torch.Tensor - ) -> None: + ): # With dropout probability = 1.0, we expect output to be the same as input layer_dropout.prob = 1.0 output = layer_dropout.forward(lambda x: x**2, input) @@ -53,15 +50,13 @@ def test_forward_train_prob_1( def test_forward_train_prob_0( self, layer_dropout: LayerDropout, input: torch.Tensor - ) -> None: + ): # With dropout probability = 1.0, we expect the operation to be applied on all elements in the input layer_dropout.prob = 0.0 output = layer_dropout.forward(lambda x: x**2, input) assert torch.allclose(output, input**2) - def test_forward_eval( - self, layer_dropout: LayerDropout, input: torch.Tensor - ) -> None: + def test_forward_eval(self, layer_dropout: LayerDropout, input: torch.Tensor): layer_dropout.prob = 1.0 layer_dropout.eval() @@ -74,7 +69,79 @@ def test_forward_eval( output = layer_dropout.forward(lambda x: x**2, input) assert torch.allclose(output, input) - def test_get_scale_uniform(self) -> None: + +class TestLayerDropoutWrapper: + @pytest.fixture + def input_shape(self) -> Tuple[int, int]: + bsz = 4 + dim = 8 + return (bsz, dim) + + @pytest.fixture + def input(self, input_shape: Tuple[int]) -> torch.Tensor: + return torch.randn(input_shape) + + @pytest.fixture + def model(self, input_shape) -> torch.nn.Module: + _, dim = input_shape + return torch.nn.Sequential( + torch.nn.Linear(dim, 32), torch.nn.ReLU(), torch.nn.Linear(32, dim) + ) + + @pytest.fixture + def linear(self, input_shape) -> torch.nn.Module: + _, dim = input_shape + return torch.nn.Linear(dim, dim) + + def test_linear(self, linear: torch.nn.Module, input: torch.Tensor): + wrapper = ModuleLayerDropoutWrapper(linear, LayerDropout(prob=0.5)) + assert wrapper.module == linear + + # Test output + wrapper.dropout.prob = 1 + assert torch.allclose(wrapper(input), input) + wrapper.dropout.prob = 0 + assert torch.allclose(wrapper(input), linear(input)) + + # Test getters + assert wrapper.in_features == linear.in_features + assert wrapper.out_features == linear.out_features + assert torch.equal(wrapper.weight, linear.weight) + + # Test setters + wrapper.weight.data = wrapper.weight.data * 2 + assert torch.equal(wrapper.weight, linear.weight) + + # Test state_dict + for k in wrapper.state_dict().keys(): + assert torch.equal(wrapper.state_dict()[k], linear.state_dict()[k]) + + def test_model(self, model: torch.nn.Module, input: torch.Tensor): + wrapper = ModuleLayerDropoutWrapper(model, LayerDropout(prob=0.5)) + assert wrapper.module == model + + # Test output + wrapper.dropout.prob = 1 + assert torch.allclose(wrapper(input), input) + wrapper.dropout.prob = 0 + assert torch.allclose(wrapper(input), model(input)) + + # Test getters + assert wrapper[0].in_features == model[0].in_features + assert wrapper[0].out_features == model[0].out_features + assert torch.equal(wrapper[0].weight, model[0].weight) + + # Test setters + wrapper[2].weight.data = wrapper[2].weight.data * 2 + assert torch.equal(wrapper[2].weight, model[2].weight) + + # Test state_dict + for k in wrapper.state_dict().keys(): + assert torch.equal(wrapper.state_dict()[k], model.state_dict()[k]) + + +class TestScales: + def test_get_scale_uniform(self): scale_type = ScaleType.UNIFORM scale_period = 10 @@ -83,7 +150,7 @@ def test_get_scale_uniform(self) -> None: assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0) - def test_get_scale_linear(self) -> None: + def test_get_scale_linear(self): scale_type = ScaleType.LINEAR scale_period = 10 @@ -92,7 +159,7 @@ def test_get_scale_linear(self) -> None: assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0) - def test_get_scale_exp(self) -> None: + def test_get_scale_exp(self): scale_type = ScaleType.EXP scale_period = 10 @@ -104,7 +171,7 @@ def test_get_scale_exp(self) -> None: assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0) - def test_get_scale_log(self) -> None: + def test_get_scale_log(self): scale_type = ScaleType.LOG scale_period = 10 @@ -116,7 +183,7 @@ def test_get_scale_log(self) -> None: assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0) - def test_get_scale_sin(self) -> None: + def test_get_scale_sin(self): scale_type = ScaleType.SIN scale_period = 10 val = 5 @@ -124,7 +191,9 @@ def test_get_scale_sin(self) -> None: actual_scale = get_scale(scale_type, scale_period, val) assert_expected(actual_scale, expected_scale, atol=1e-7, rtol=1e-3) - def test_prepare_layer_dropout_uniform(self) -> None: + +class TestLayerDopoutModel: + def test_prepare_layer_dropout_uniform(self): class MockModel(torch.nn.Module): def __init__(self): super().__init__() @@ -144,7 +213,7 @@ def __init__(self): else: assert layer.dropout.prob == 0 - def test_prepare_layer_dropout_exp(self) -> None: + def test_prepare_layer_dropout_exp(self): class MockModel(torch.nn.Module): def __init__(self): super().__init__() @@ -166,7 +235,7 @@ def __init__(self): else: assert layer.dropout.prob > 0 and layer.dropout.prob < prob_max - def test_prepare_layer_dropout_linear(self) -> None: + def test_prepare_layer_dropout_linear(self): class MockModel(torch.nn.Module): def __init__(self): super().__init__() From a21cbd3b67bb62939436bb8a9e751496687adf7f Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Mon, 25 Nov 2024 19:37:56 +0000 Subject: [PATCH 58/88] add TODO comment --- recipes/dev/early_exit_finetune_distributed.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 9f59c40dac..fcd162fe96 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -40,7 +40,8 @@ log = utils.get_logger("DEBUG") - +# TODO: add explanation of EE and LD and cite papers +# TODO: add to .yaml file full test commands and different examples, citing commands to implement different papers. class EarlyExitFinetuneRecipeDistributed(FTRecipeInterface): """ Early exit and layer dropout full finetuning to make the model more robust to early exit and skipping From eb37cb6e59f2ae32eb0c7f8d9b79c2dbfc33549d Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Mon, 25 Nov 2024 19:39:31 +0000 Subject: [PATCH 59/88] fix error in checking if early exit loss is enabled --- recipes/dev/early_exit_finetune_distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index fcd162fe96..613dafeedd 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -813,7 +813,7 @@ def train(self) -> None: num_tokens = 0 # Initialize output hidden states - if self._do_output_hidden_states: + if self._do_output_hidden_states is not None: self._model.output_hidden_states = [ i for i in range(len(self._do_output_hidden_states)) From 2e3f502238be99c79463e4d3caf91ddd956c12c8 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Tue, 26 Nov 2024 08:24:00 +0000 Subject: [PATCH 60/88] change recipe defaults of dataset and layer_drop probability --- recipes/dev/7B_full_early_exit.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml index 69e66d897c..20db08efcf 100644 --- a/recipes/dev/7B_full_early_exit.yaml +++ b/recipes/dev/7B_full_early_exit.yaml @@ -83,7 +83,7 @@ dtype: bf16 metric_logger: _component_: torchtune.training.metric_logging.DiskLogger log_dir: ${output_dir} -output_dir: /tmp/alpaca-llama2-finetune +output_dir: /tmp/topv2-llama2-finetune log_every_n_steps: 1 log_peak_memory_stats: True @@ -121,7 +121,7 @@ early_exit_loss: # Layer Dropout layer_dropout: - prob: 0.5 + prob: 0.2 layers: ":" layers_scale: "exp" disable_on_eval: True From 66a41b210604f426727a70d278513a54611a1670 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Tue, 26 Nov 2024 08:24:28 +0000 Subject: [PATCH 61/88] add detailed docstring to training script --- .../dev/early_exit_finetune_distributed.py | 40 +++++++++++++++++-- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 613dafeedd..9eb55732fc 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -49,9 +49,43 @@ class EarlyExitFinetuneRecipeDistributed(FTRecipeInterface): training and can be run on a single node (1 to 8 GPUs). Features: - - Early Exit Loss. - - - Layer Dropout. + - Early Exit Loss. This makes the model more robust to exiting early by applying the outputs of intermediate + layers on the model's language model head (a.k.a. unembedding operation) to obtain outputs of earlier + layers, then obtain the losses at such earlier layers. Then the loss of the model during training + would be a weighted average of the losses at different layers. The different arguments you can + configure are: + - ``early_exit_loss.layers`` is a string, whose format mimics indexing in numpy arrays (e.g., `:` + depicts all layers, `0:10:3` depicts layers 0, 3, 6, 9, and `1,5,11` depicts layers 1,5,11), to + represent which layers to apply early exit loss at, + - ``early_exit_loss.scale_type`` and ``early_exit_loss.scale`` determine how we calculate the + weights of losses at different layers when calculating total loss, and + - ``early_exit_loss.curriculum`` depicts how the early exit loss layers change across training + iterations. + See ``torchtune/modules/early_exit_loss.py` for more details of each argument. + To reproduce results of different papers that use early exit loss: + - LayerSkip (https://arxiv.org/abs/2404.16710) results on finetuning on TOPv2: set + ``early_exit_loss.scale=1.0, early_exit_loss.curriculum=gradual early_exit_loss.scale_type=l + early_exit_loss.layers="::"`, + - LITE (https://arxiv.org/abs/2310.18581) results on finetuning Llama2 7B on Alpaca you can set + ``early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_type=one``. + + - Layer Dropout. (a.k.a. Stochastic Depth) This drops samples stochastically for each layer during training. + "Dropping" a sample at a layer in this context means a sample will pass through the layer without modification. + The different arguments you can configure are: + - ``layer_dropout.prob``: is the (maximum) probability of a sample being dropped at each layer. + - ``layer_dropout.layers``: is a string, whose format mimics indexing in numpy arrays + (same as ``early_exit_loss.layers``), that determines which layers will have layer dropout applied. + - ``layer_dropout.layers_scale``: determines how probability changes across layers from + probability 0 at first layer, to probability ``layer_dropout.prob`` at last layer. + You can choose from ``one`` (all layers have ``layer_dropout.prob``), ``linear``, + ``exp``, ``log``, ``sqrt``. + - ``disable_on_eval``: if True, will only apply layer dropout during training. If False, will + apply to both training and evaluation. + To reproduce results of different papers that use layer dropout: + - LayerDrop(https://arxiv.org/abs/1909.11556) that applies dropout on every other layer, set + ``layer_dropout.prob=0.2 layer_dropout.layers=::2``. + - Progressive Layer Dropping (https://arxiv.org/abs/2010.13369) that increases dropout linearly + across layers, set ``layer_dropout.prob=0.5 layer_dropout.layers="::" layer_dropout.layers_scale=linear`` - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is From 345a0a344654e439e6493c19eb9bafc0109eb488 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Tue, 26 Nov 2024 08:25:10 +0000 Subject: [PATCH 62/88] ensure we set last layer early exit enable correctly --- recipes/dev/early_exit_finetune_distributed.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 9eb55732fc..0ed7957701 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -718,9 +718,6 @@ def _setup_early_exit_loss( train_last_layer = cfg_early_exit_loss.get("include_last_layer", True) verbose = cfg_early_exit_loss.get("verbose", False) - if train_last_layer: - do_output_hidden_states[len(self._model.layers) - 1] = True - if cfg_early_exit_loss.curriculum: early_exit_loss_curriculum = setup_early_exit_loss_curriculum( early_exit_curriculum=cfg_early_exit_loss.curriculum, @@ -732,6 +729,8 @@ def _setup_early_exit_loss( do_output_hidden_states = early_exit_loss_curriculum.get() else: early_exit_loss_curriculum = None + if train_last_layer: + do_output_hidden_states[len(self._model.layers) - 1] = True return do_output_hidden_states, early_exit_loss_curriculum From 20c618c0289dd34c18d4ffde8a00fee9862aed12 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Tue, 26 Nov 2024 08:25:31 +0000 Subject: [PATCH 63/88] ensure uniform early exit loss works --- torchtune/modules/early_exit_loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index 6d10c54711..b1057db53f 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -87,7 +87,7 @@ def early_exit_loss( def layer_ids_to_loss_scales( - layer_ids: List, + layer_ids: torch.Tensor, n_layers: int, loss_scale_type: LossScaleType, e_scale: float, @@ -99,7 +99,7 @@ def layer_ids_to_loss_scales( scales based on the specified loss scale type and then normalizes them to ensure that their sum is 1.0. Args: - layer_ids (List): A tensor of layer IDs. + layer_ids (torch.Tensor): A tensor of layer IDs. n_layers (int): The total number of layers. loss_scale_type (LossScaleType): The type of loss scaling to use. e_scale (float): An early exit scaling factor. @@ -116,7 +116,7 @@ def layer_ids_to_loss_scales( >>> loss_scales = layer_ids_to_loss_scales(layer_ids, n_layers, loss_scale_type, e_scale) """ if loss_scale_type == LossScaleType.ONE: - loss_scales = torch.ones(len(layer_ids)) + loss_scales = torch.ones(len(layer_ids), device=layer_ids.device) elif loss_scale_type == LossScaleType.L: loss_scales = torch.Tensor(layer_ids + 1) elif loss_scale_type == LossScaleType.SUM_L: From f0e8d7f0cb46f0d37c0bbedff511bc88b059aae0 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Tue, 26 Nov 2024 09:14:50 +0000 Subject: [PATCH 64/88] add documentation to .yaml file and update doc in .py --- recipes/dev/7B_full_early_exit.yaml | 27 ++++++++++++++----- .../dev/early_exit_finetune_distributed.py | 16 ++++++----- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml index 20db08efcf..0829e0c2b7 100644 --- a/recipes/dev/7B_full_early_exit.yaml +++ b/recipes/dev/7B_full_early_exit.yaml @@ -1,21 +1,34 @@ -# Config for multi-device full finetuning in full_finetune_distributed.py -# using a Llama2 7B model +# Config for multi-device full finetuning with early exit loss and/or layer dropout +# in dev/early_exit_finetune_distributed.py using a Llama2 7B model on a small TOPv2 +# instruction set. # # This config assumes that you've run the following command before launching # this run: # tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token # # To launch on 4 devices, run the following command from root: -# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama2/7B_full +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml # # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training # you can run: -# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama2/7B_full checkpointer.checkpoint_dir= +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml checkpointer.checkpoint_dir= +# +# To reproduce experiments of various papers that use early exit loss and/or layer dropout: +# - LayerSkip (https://arxiv.org/abs/2404.16710) on TOPv2: +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss.scale=1.0, early_exit_loss.curriculum=gradual early_exit_loss.scale_type=l layer_dropout.prob=0.2 layer_dropout.scale=exp +# +# - LITE (https://arxiv.org/abs/2310.18581): +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml layer_dropout=null early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_type=one +# +# - LayerDrop (https://arxiv.org/abs/1909.11556): +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss=null layer_dropout.prob=0.2 layer_dropout.layers=::2 +# +# - Progressive Layer Dropping (https://arxiv.org/abs/2010.13369) (The paper also implements a curriculum for layer drop probability which is not yet implemented.): +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss=null layer_dropout.prob=0.5 layer_dropout.scale=exp +# +# This config works best for distributed training, hence when the model is being fine-tuned on 2+ GPUs. # -# This config works best when the model is being fine-tuned on 2+ GPUs. -# Single device full finetuning requires more memory optimizations. It's -# best to use 7B_full_single_device.yaml for those cases # Tokenizer diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 0ed7957701..10468005ba 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -62,11 +62,10 @@ class EarlyExitFinetuneRecipeDistributed(FTRecipeInterface): - ``early_exit_loss.curriculum`` depicts how the early exit loss layers change across training iterations. See ``torchtune/modules/early_exit_loss.py` for more details of each argument. - To reproduce results of different papers that use early exit loss: - - LayerSkip (https://arxiv.org/abs/2404.16710) results on finetuning on TOPv2: set - ``early_exit_loss.scale=1.0, early_exit_loss.curriculum=gradual early_exit_loss.scale_type=l - early_exit_loss.layers="::"`, - - LITE (https://arxiv.org/abs/2310.18581) results on finetuning Llama2 7B on Alpaca you can set + To reproduce experiments of different papers that use early exit loss: + - LayerSkip (https://arxiv.org/abs/2404.16710) for finetuning on TOPv2: set + ``early_exit_loss.scale=1.0, early_exit_loss.curriculum=gradual early_exit_loss.scale_type=l``, + - LITE (https://arxiv.org/abs/2310.18581) for finetuning Llama2 7B on Alpaca you can set ``early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_type=one``. - Layer Dropout. (a.k.a. Stochastic Depth) This drops samples stochastically for each layer during training. @@ -82,10 +81,13 @@ class EarlyExitFinetuneRecipeDistributed(FTRecipeInterface): - ``disable_on_eval``: if True, will only apply layer dropout during training. If False, will apply to both training and evaluation. To reproduce results of different papers that use layer dropout: - - LayerDrop(https://arxiv.org/abs/1909.11556) that applies dropout on every other layer, set + - LayerDrop (https://arxiv.org/abs/1909.11556) that applies dropout on every other layer, set ``layer_dropout.prob=0.2 layer_dropout.layers=::2``. - Progressive Layer Dropping (https://arxiv.org/abs/2010.13369) that increases dropout linearly - across layers, set ``layer_dropout.prob=0.5 layer_dropout.layers="::" layer_dropout.layers_scale=linear`` + across layers, set ``layer_dropout.prob=0.5 layer_dropout.layers_scale=linear``. + The paper also implements a curriculum for layer drop probability which is not yet implemented. + - LayerSkip (https://arxiv.org/abs/2404.16710) for finetuning on TOPv2: (in addition to early exit loss + arguments above) set ``layer_dropout.prob=0.2 layer_dropout.scale=exp``. - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is From b03cb574168534f631f1ec57655252496f5a3ae5 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 27 Nov 2024 04:47:48 +0000 Subject: [PATCH 65/88] remove commented lines --- recipes/dev/7B_full_early_exit.yaml | 3 --- recipes/dev/early_exit_finetune_distributed.py | 3 +-- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml index 0829e0c2b7..2e44727f62 100644 --- a/recipes/dev/7B_full_early_exit.yaml +++ b/recipes/dev/7B_full_early_exit.yaml @@ -38,9 +38,6 @@ tokenizer: max_seq_len: null # Dataset -# dataset: -# _component_: torchtune.datasets.alpaca_dataset -# packed: False # True increases speed dataset: _component_: torchtune.datasets.instruct_dataset source: WillHeld/top_v2 diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 10468005ba..110981bf2e 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -40,8 +40,7 @@ log = utils.get_logger("DEBUG") -# TODO: add explanation of EE and LD and cite papers -# TODO: add to .yaml file full test commands and different examples, citing commands to implement different papers. + class EarlyExitFinetuneRecipeDistributed(FTRecipeInterface): """ Early exit and layer dropout full finetuning to make the model more robust to early exit and skipping From 199b8dd584e6b96a330b0fcb236d84bbfb5075b1 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 27 Nov 2024 04:49:46 +0000 Subject: [PATCH 66/88] remove check on PyTorch version since we assume latest stable PyTorch --- recipes/dev/early_exit_finetune_distributed.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 110981bf2e..04125fc4d5 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -171,15 +171,6 @@ def __init__(self, cfg: DictConfig) -> None: "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." ) - if ( - cfg.get("fsdp_cpu_offload", False) - and cfg.optimizer.get("fused", False) - and not utils.torch_version_ge("2.4.0") - ): - raise RuntimeError( - "Using fused optimizer on CPU is only supported in PyTorch nightly." - ) - # logging attributes self._output_dir = cfg.output_dir self._log_every_n_steps = cfg.get("log_every_n_steps", 1) From 6a2d79ba71b8813bcb4f0cdcefae12a0636c6ec7 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 27 Nov 2024 06:26:49 +0000 Subject: [PATCH 67/88] load curriculum step when resuming --- recipes/dev/early_exit_finetune_distributed.py | 1 + torchtune/modules/early_exit_loss.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 04125fc4d5..656affbbcf 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -716,6 +716,7 @@ def _setup_early_exit_loss( do_output_hidden_states=do_output_hidden_states, max_steps=self.total_epochs * self._steps_per_epoch, train_last_layer=train_last_layer, + last_step=self.global_step, verbose=verbose, ) do_output_hidden_states = early_exit_loss_curriculum.get() diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index b1057db53f..edb27b9fdd 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -154,6 +154,7 @@ class EarlyExitCurriculum: should be output to calculate their losses. max_steps (int): The maximum number of steps in the curriculum. train_last_layer (bool, optional): Whether to always calculate loss for the last layer. Defaults to True. + last_step (Optional[int]): The last step the curriculum stopped at in a previous run. This is used when resuming training. verbose (bool, optional): Whether to print verbose logs. Defaults to False. """ @@ -162,6 +163,7 @@ def __init__( do_output_hidden_states: List[bool], max_steps: int, train_last_layer: bool = True, + last_step: Optional[int] = None, verbose: bool = False, ): self._init_do_output_hidden_states = do_output_hidden_states @@ -169,6 +171,7 @@ def __init__( self.train_last_layer = train_last_layer self.verbose = verbose self.max_steps = max_steps + self._step = 0 if last_step is None else last_step def step(self) -> None: """ @@ -207,6 +210,7 @@ def step(self): # Rotate layer enablement one step forward self._do_output_hidden_states = np.roll(self._do_output_hidden_states, 1) + self._step += 1 if self.verbose: log.info( f"Updated self._do_output_hidden_states to {self._do_output_hidden_states}." From e5534eaf48ac8b4f42c7629ed76727d2d95e7b86 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 27 Nov 2024 06:31:40 +0000 Subject: [PATCH 68/88] repeat arguments in derived classes --- torchtune/modules/early_exit_loss.py | 52 +++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index edb27b9fdd..7e0cde61ab 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -154,7 +154,8 @@ class EarlyExitCurriculum: should be output to calculate their losses. max_steps (int): The maximum number of steps in the curriculum. train_last_layer (bool, optional): Whether to always calculate loss for the last layer. Defaults to True. - last_step (Optional[int]): The last step the curriculum stopped at in a previous run. This is used when resuming training. + last_step (Optional[int]): The last step the curriculum stopped at in a previous run. + This is used when resuming training. verbose (bool, optional): Whether to print verbose logs. Defaults to False. """ @@ -196,10 +197,31 @@ class RotationalEarlyExitCurriculum(EarlyExitCurriculum): """ A rotational early exit curriculum, which rotates the layer enablement one step forward at each step. + Args: + do_output_hidden_states (List[bool]): A list indicating whether each layer's hidden state + should be output to calculate their losses. + max_steps (int): The maximum number of steps in the curriculum. + train_last_layer (bool, optional): Whether to always calculate loss for the last layer. Defaults to True. + last_step (Optional[int]): The last step the curriculum stopped at in a previous run. + This is used when resuming training. + verbose (bool, optional): Whether to print verbose logs. Defaults to False. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, + do_output_hidden_states: List[bool], + max_steps: int, + train_last_layer: bool = True, + last_step: Optional[int] = None, + verbose: bool = False, + ): + super().__init__( + do_output_hidden_states=do_output_hidden_states, + max_steps=max_steps, + train_last_layer=train_last_layer, + last_step=last_step, + verbose=verbose, + ) self._initial_do_output_hidden_states = np.copy(self._do_output_hidden_states) def step(self): @@ -221,19 +243,33 @@ class GradualEarlyExitCurriculum(EarlyExitCurriculum): """ A gradual early exit curriculum, which gradually enables more layers (starting from the last layer) as training progresses. Args: - *args: Positional arguments passed to the parent EarlyExitCurriculum class. + do_output_hidden_states (List[bool]): A list indicating whether each layer's hidden state + should be output to calculate their losses. + max_steps (int): The maximum number of steps in the curriculum. + train_last_layer (bool, optional): Whether to always calculate loss for the last layer. Defaults to True. + last_step (Optional[int]): The last step the curriculum stopped at in a previous run. + This is used when resuming training. percent_scale (float, optional): A scaling factor to determine at which percentage of steps, all the layers will be enabled. At `steps = max_steps / percent_scale`, all the layers will be enabled. - **kwargs: Keyword arguments passed to the parent EarlyExitCurriculum class. + verbose (bool, optional): Whether to print verbose logs. Defaults to False. """ def __init__( self, - *args, + do_output_hidden_states: List[bool], + max_steps: int, + train_last_layer: bool = True, + last_step: Optional[int] = None, percent_scale: float = 2, - **kwargs, + verbose: bool = False, ): - super().__init__(*args, **kwargs) + super().__init__( + do_output_hidden_states=do_output_hidden_states, + max_steps=max_steps, + train_last_layer=train_last_layer, + last_step=last_step, + verbose=verbose, + ) self._final_do_output_hidden_states = np.copy(self._do_output_hidden_states) self._step = 0 self._percent_scale = percent_scale From d270d1fc1029e2f49fc3cbe1d956f86901b8adb9 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 27 Nov 2024 06:39:25 +0000 Subject: [PATCH 69/88] rename percent_scale to fraction_scale and change its implementation --- .../torchtune/modules/test_early_exit_loss.py | 2 +- torchtune/modules/early_exit_loss.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/torchtune/modules/test_early_exit_loss.py b/tests/torchtune/modules/test_early_exit_loss.py index feff2fe264..1421baf165 100644 --- a/tests/torchtune/modules/test_early_exit_loss.py +++ b/tests/torchtune/modules/test_early_exit_loss.py @@ -128,7 +128,7 @@ def test_gradual_early_exit_curriculum(train_last_layer): [True, True, True, True], max_steps=4, train_last_layer=train_last_layer, - percent_scale=1, + fraction_scale=1, ) expected = np.array([False, False, False, train_last_layer]) assert np.array_equal(curriculum.get(), expected) diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index 7e0cde61ab..747b76a7f7 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -249,8 +249,8 @@ class GradualEarlyExitCurriculum(EarlyExitCurriculum): train_last_layer (bool, optional): Whether to always calculate loss for the last layer. Defaults to True. last_step (Optional[int]): The last step the curriculum stopped at in a previous run. This is used when resuming training. - percent_scale (float, optional): A scaling factor to determine at which percentage - of steps, all the layers will be enabled. At `steps = max_steps / percent_scale`, all the layers will be enabled. + fraction_scale (float, optional): A scaling factor to determine at which fraction + of steps, all the layers will be enabled. At `steps = max_steps * fraction_scale`, all the layers will be enabled. verbose (bool, optional): Whether to print verbose logs. Defaults to False. """ @@ -260,7 +260,7 @@ def __init__( max_steps: int, train_last_layer: bool = True, last_step: Optional[int] = None, - percent_scale: float = 2, + fraction_scale: float = 0.5, verbose: bool = False, ): super().__init__( @@ -272,7 +272,7 @@ def __init__( ) self._final_do_output_hidden_states = np.copy(self._do_output_hidden_states) self._step = 0 - self._percent_scale = percent_scale + self._fraction_scale = fraction_scale # Initialize all layers to False for i in range(len(self._do_output_hidden_states)): @@ -282,15 +282,16 @@ def step(self): """ Perform a step in the curriculum. This method updates the `_do_output_hidden_states` attribute based on the current - step and the percentage of completed training steps. + step and the fraction of completed training steps. """ - percent_trained = self._step / self.max_steps + fraction_trained = self._step / self.max_steps n_layers = len(self._do_output_hidden_states) # Enable each layer based on proportion of completed training steps for layer_index in range(len(self._do_output_hidden_states)): - should_train = (percent_trained * self._percent_scale) >= ( - n_layers - layer_index - ) / n_layers + should_train = ( + fraction_trained + >= self._fraction_scale * (n_layers - layer_index) / n_layers + ) self._do_output_hidden_states[layer_index] = should_train # Only enable layers that are set by the user From e51419c83d9a326154a64db31c0d2576b2ed8bda Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 1 Dec 2024 22:02:12 +0000 Subject: [PATCH 70/88] fixes to docstrings and config examples --- recipes/dev/7B_full_early_exit.yaml | 8 ++++---- recipes/dev/early_exit_finetune_distributed.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml index 2e44727f62..8433309a48 100644 --- a/recipes/dev/7B_full_early_exit.yaml +++ b/recipes/dev/7B_full_early_exit.yaml @@ -16,13 +16,13 @@ # # To reproduce experiments of various papers that use early exit loss and/or layer dropout: # - LayerSkip (https://arxiv.org/abs/2404.16710) on TOPv2: -# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss.scale=1.0, early_exit_loss.curriculum=gradual early_exit_loss.scale_type=l layer_dropout.prob=0.2 layer_dropout.scale=exp +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss.scale=1.0 early_exit_loss.curriculum=gradual early_exit_loss.scale_type=l layer_dropout.prob=0.2 layer_dropout.scale=exp # # - LITE (https://arxiv.org/abs/2310.18581): -# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml layer_dropout=null early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_type=one +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml layer_dropout=null early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_type=one early_exit_loss.curriculum=null epochs=5 # # - LayerDrop (https://arxiv.org/abs/1909.11556): -# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss=null layer_dropout.prob=0.2 layer_dropout.layers=::2 +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss=null layer_dropout.prob=0.2 layer_dropout.layers=1::2 # # - Progressive Layer Dropping (https://arxiv.org/abs/2010.13369) (The paper also implements a curriculum for layer drop probability which is not yet implemented.): # tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss=null layer_dropout.prob=0.5 layer_dropout.scale=exp @@ -66,7 +66,7 @@ checkpointer: resume_from_checkpoint: False # Fine-tuning arguments -batch_size: 2 +batch_size: 8 epochs: 1 optimizer: _component_: torch.optim.AdamW diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 656affbbcf..91d445e22f 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -81,7 +81,7 @@ class EarlyExitFinetuneRecipeDistributed(FTRecipeInterface): apply to both training and evaluation. To reproduce results of different papers that use layer dropout: - LayerDrop (https://arxiv.org/abs/1909.11556) that applies dropout on every other layer, set - ``layer_dropout.prob=0.2 layer_dropout.layers=::2``. + ``layer_dropout.prob=0.2 layer_dropout.layers=1::2``. - Progressive Layer Dropping (https://arxiv.org/abs/2010.13369) that increases dropout linearly across layers, set ``layer_dropout.prob=0.5 layer_dropout.layers_scale=linear``. The paper also implements a curriculum for layer drop probability which is not yet implemented. From 40b798774a32ffd4762b951fa4f49e0aa16b27eb Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 1 Dec 2024 22:02:39 +0000 Subject: [PATCH 71/88] check if cfg_early_exit_loss has curriculum --- recipes/dev/early_exit_finetune_distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 91d445e22f..f75ae13d80 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -710,7 +710,7 @@ def _setup_early_exit_loss( train_last_layer = cfg_early_exit_loss.get("include_last_layer", True) verbose = cfg_early_exit_loss.get("verbose", False) - if cfg_early_exit_loss.curriculum: + if cfg_early_exit_loss.get("curriculum", None): early_exit_loss_curriculum = setup_early_exit_loss_curriculum( early_exit_curriculum=cfg_early_exit_loss.curriculum, do_output_hidden_states=do_output_hidden_states, From 0c18595c21ccd638e719d05a42914fadb49ddc5d Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 1 Dec 2024 22:06:01 +0000 Subject: [PATCH 72/88] add comment to explain when has no effect --- tests/torchtune/modules/test_early_exit_loss.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/torchtune/modules/test_early_exit_loss.py b/tests/torchtune/modules/test_early_exit_loss.py index 1421baf165..3339e53757 100644 --- a/tests/torchtune/modules/test_early_exit_loss.py +++ b/tests/torchtune/modules/test_early_exit_loss.py @@ -109,6 +109,7 @@ def test_rotational_early_exit_curriculum(train_last_layer): expected = np.array([False, True, train_last_layer]) assert np.array_equal(curriculum.get(), expected) curriculum.step() + # Since the last element is already True on this rotation, the value of `train_last_layer` has no effect. expected = np.array([False, False, True]) assert np.array_equal(curriculum.get(), expected) curriculum.step() @@ -135,6 +136,7 @@ def test_gradual_early_exit_curriculum(train_last_layer): curriculum.step() assert np.array_equal(curriculum.get(), [False, False, False, train_last_layer]) curriculum.step() + # Since the last element is already True on this update, the value of `train_last_layer` has no effect. assert np.array_equal(curriculum.get(), [False, False, False, True]) curriculum.step() assert np.array_equal(curriculum.get(), [False, False, True, True]) From 3e68696cd02c7cbbe2e2db6e2251aadc5cfe285f Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 1 Dec 2024 22:12:40 +0000 Subject: [PATCH 73/88] organize early exit loss tests into classes --- .../torchtune/modules/test_early_exit_loss.py | 290 +++++++++--------- 1 file changed, 141 insertions(+), 149 deletions(-) diff --git a/tests/torchtune/modules/test_early_exit_loss.py b/tests/torchtune/modules/test_early_exit_loss.py index 3339e53757..cb63ddf2eb 100644 --- a/tests/torchtune/modules/test_early_exit_loss.py +++ b/tests/torchtune/modules/test_early_exit_loss.py @@ -29,154 +29,146 @@ def forward( return x # Simply return the input for testing purposes -@pytest.fixture -def mock_model(): - # Create mock components - tok_embeddings = nn.Embedding(1000, 512) # Example vocab size and embedding dim - layers = nn.ModuleList([MockLayer() for _ in range(12)]) # 12 mock layers - norm = nn.LayerNorm(512) # Example layer normalization - output = nn.Linear(512, 1000) # Example output layer - - # Create an instance of TransformerDecoder - model = TransformerDecoder( - tok_embeddings=tok_embeddings, - layers=layers, - max_seq_len=512, - num_heads=8, - head_dim=64, - norm=norm, - output=output, - num_layers=12, - output_hidden_states=[0, 1, 2], # Example layers to output hidden states - ) - return model - - -@pytest.fixture -def hidden_states_dict(): - return {i: torch.randn(4, 5, 512) for i in range(3)} # Adjusted embedding dim - - -@pytest.fixture -def labels(): - return torch.randint(0, 1000, (4, 5)) # Adjusted vocab size - - -@pytest.fixture -def loss_fn(): - return nn.CrossEntropyLoss(ignore_index=-1) - - -def test_early_exit_loss(mock_model, hidden_states_dict, labels, loss_fn): - loss = early_exit_loss(mock_model, hidden_states_dict, labels, loss_fn) - assert isinstance(loss, torch.Tensor) - assert loss.item() >= 0 - - -def test_layer_ids_to_loss_scales(): - layer_ids = torch.tensor([0, 1, 2]) - n_layers = 12 - scales = layer_ids_to_loss_scales(layer_ids, n_layers, LossScaleType.SUM_L, 1.0) - assert torch.isclose(scales.sum(), torch.tensor(1.0)) - - -def test_setup_early_exit_loss_curriculum(): - curriculum = setup_early_exit_loss_curriculum( - EarlyExitCurriculumType.ROTATIONAL, [True, False, True], 100 - ) - assert isinstance(curriculum, RotationalEarlyExitCurriculum) - - curriculum = setup_early_exit_loss_curriculum( - EarlyExitCurriculumType.GRADUAL, [True, False, True], 100 - ) - assert isinstance(curriculum, GradualEarlyExitCurriculum) - - -@pytest.mark.parametrize( - "train_last_layer", - [ - True, - False, - ], -) -def test_rotational_early_exit_curriculum(train_last_layer): - curriculum = RotationalEarlyExitCurriculum( - [True, False, False], max_steps=100, train_last_layer=train_last_layer - ) - expected = np.array([True, False, train_last_layer]) - assert np.array_equal(curriculum.get(), expected) - curriculum.step() - expected = np.array([False, True, train_last_layer]) - assert np.array_equal(curriculum.get(), expected) - curriculum.step() - # Since the last element is already True on this rotation, the value of `train_last_layer` has no effect. - expected = np.array([False, False, True]) - assert np.array_equal(curriculum.get(), expected) - curriculum.step() - expected = np.array([True, False, train_last_layer]) - assert np.array_equal(curriculum.get(), expected) - - -@pytest.mark.parametrize( - "train_last_layer", - [ - True, - False, - ], -) -def test_gradual_early_exit_curriculum(train_last_layer): - curriculum = GradualEarlyExitCurriculum( - [True, True, True, True], - max_steps=4, - train_last_layer=train_last_layer, - fraction_scale=1, +class TestEarlyExitLoss: + @pytest.fixture + def mock_model(self): + # Create mock components + tok_embeddings = nn.Embedding(1000, 512) # Example vocab size and embedding dim + layers = nn.ModuleList([MockLayer() for _ in range(12)]) # 12 mock layers + norm = nn.LayerNorm(512) # Example layer normalization + output = nn.Linear(512, 1000) # Example output layer + + # Create an instance of TransformerDecoder + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layers, + max_seq_len=512, + num_heads=8, + head_dim=64, + norm=norm, + output=output, + num_layers=12, + output_hidden_states=[0, 1, 2], # Example layers to output hidden states + ) + return model + + @pytest.fixture + def hidden_states_dict(self): + return {i: torch.randn(4, 5, 512) for i in range(3)} # Adjusted embedding dim + + @pytest.fixture + def labels(self): + return torch.randint(0, 1000, (4, 5)) # Adjusted vocab size + + @pytest.fixture + def loss_fn(self): + return nn.CrossEntropyLoss(ignore_index=-1) + + def test_early_exit_loss(self, mock_model, hidden_states_dict, labels, loss_fn): + loss = early_exit_loss(mock_model, hidden_states_dict, labels, loss_fn) + assert isinstance(loss, torch.Tensor) + assert loss.item() >= 0 + + def test_layer_ids_to_loss_scales(self): + layer_ids = torch.tensor([0, 1, 2]) + n_layers = 12 + scales = layer_ids_to_loss_scales(layer_ids, n_layers, LossScaleType.SUM_L, 1.0) + assert torch.isclose(scales.sum(), torch.tensor(1.0)) + + def test_early_exit_loss_vs_manual( + self, mock_model, hidden_states_dict, labels, loss_fn + ): + # Convert to float32 for numeric equivalence + # Calculate early exit loss using the function + calculated_loss = early_exit_loss( + mock_model, + hidden_states_dict, + labels, + loss_fn, + e_scale=1, + loss_scale_type="one", + ) + # Manually calculate the loss for each hidden state + total_loss = 0.0 + num_hidden_states = len(hidden_states_dict) + for i, hidden_state in hidden_states_dict.items(): + # Compute logits for the current hidden state + logits = mock_model.unembed(hidden_state) + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + # Compute the loss for the current hidden state + loss = loss_fn(logits, labels) + total_loss += loss + # Average the losses across all hidden states + manual_loss = total_loss / num_hidden_states + # Compare the two losses + assert torch.isclose( + calculated_loss, manual_loss, atol=1e-6 + ), f"Calculated loss: {calculated_loss}, Manual loss: {manual_loss}" + + +class TestEarlyExitLossCurriculum: + def test_setup_early_exit_loss_curriculum(self): + curriculum = setup_early_exit_loss_curriculum( + EarlyExitCurriculumType.ROTATIONAL, [True, False, True], 100 + ) + assert isinstance(curriculum, RotationalEarlyExitCurriculum) + + curriculum = setup_early_exit_loss_curriculum( + EarlyExitCurriculumType.GRADUAL, [True, False, True], 100 + ) + assert isinstance(curriculum, GradualEarlyExitCurriculum) + + @pytest.mark.parametrize( + "train_last_layer", + [ + True, + False, + ], ) - expected = np.array([False, False, False, train_last_layer]) - assert np.array_equal(curriculum.get(), expected) - curriculum.step() - assert np.array_equal(curriculum.get(), [False, False, False, train_last_layer]) - curriculum.step() - # Since the last element is already True on this update, the value of `train_last_layer` has no effect. - assert np.array_equal(curriculum.get(), [False, False, False, True]) - curriculum.step() - assert np.array_equal(curriculum.get(), [False, False, True, True]) - curriculum.step() - assert np.array_equal(curriculum.get(), [False, True, True, True]) - curriculum.step() - assert np.array_equal(curriculum.get(), [True, True, True, True]) - curriculum.step() - assert np.array_equal(curriculum.get(), [True, True, True, True]) - - -def test_early_exit_loss_vs_manual(mock_model, hidden_states_dict, labels, loss_fn): - # Convert to float32 for numeric equivalence - # Calculate early exit loss using the function - calculated_loss = early_exit_loss( - mock_model, - hidden_states_dict, - labels, - loss_fn, - e_scale=1, - loss_scale_type="one", + def test_rotational_early_exit_curriculum(self, train_last_layer): + curriculum = RotationalEarlyExitCurriculum( + [True, False, False], max_steps=100, train_last_layer=train_last_layer + ) + expected = np.array([True, False, train_last_layer]) + assert np.array_equal(curriculum.get(), expected) + curriculum.step() + expected = np.array([False, True, train_last_layer]) + assert np.array_equal(curriculum.get(), expected) + curriculum.step() + # Since the last element is already True on this rotation, the value of `train_last_layer` has no effect. + expected = np.array([False, False, True]) + assert np.array_equal(curriculum.get(), expected) + curriculum.step() + expected = np.array([True, False, train_last_layer]) + assert np.array_equal(curriculum.get(), expected) + + @pytest.mark.parametrize( + "train_last_layer", + [ + True, + False, + ], ) - # Manually calculate the loss for each hidden state - total_loss = 0.0 - num_hidden_states = len(hidden_states_dict) - for i, hidden_state in hidden_states_dict.items(): - # Compute logits for the current hidden state - logits = mock_model.unembed(hidden_state) - labels = labels.reshape(-1) - logits = logits.reshape(-1, logits.size(-1)) - # Compute the loss for the current hidden state - loss = loss_fn(logits, labels) - total_loss += loss - # Average the losses across all hidden states - manual_loss = total_loss / num_hidden_states - # Compare the two losses - assert torch.isclose( - calculated_loss, manual_loss, atol=1e-6 - ), f"Calculated loss: {calculated_loss}, Manual loss: {manual_loss}" - - -if __name__ == "__main__": - pytest.main() + def test_gradual_early_exit_curriculum(self, train_last_layer): + curriculum = GradualEarlyExitCurriculum( + [True, True, True, True], + max_steps=4, + train_last_layer=train_last_layer, + fraction_scale=1, + ) + expected = np.array([False, False, False, train_last_layer]) + assert np.array_equal(curriculum.get(), expected) + curriculum.step() + assert np.array_equal(curriculum.get(), [False, False, False, train_last_layer]) + curriculum.step() + # Since the last element is already True on this update, the value of `train_last_layer` has no effect. + assert np.array_equal(curriculum.get(), [False, False, False, True]) + curriculum.step() + assert np.array_equal(curriculum.get(), [False, False, True, True]) + curriculum.step() + assert np.array_equal(curriculum.get(), [False, True, True, True]) + curriculum.step() + assert np.array_equal(curriculum.get(), [True, True, True, True]) + curriculum.step() + assert np.array_equal(curriculum.get(), [True, True, True, True]) From 418951bae8b60dc79c66e1beb5fd56eaa6b9ce56 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 1 Dec 2024 22:13:50 +0000 Subject: [PATCH 74/88] fix typo --- tests/torchtune/modules/test_layer_dropout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/torchtune/modules/test_layer_dropout.py b/tests/torchtune/modules/test_layer_dropout.py index 73f42998d4..24494df91f 100644 --- a/tests/torchtune/modules/test_layer_dropout.py +++ b/tests/torchtune/modules/test_layer_dropout.py @@ -192,7 +192,7 @@ def test_get_scale_sin(self): assert_expected(actual_scale, expected_scale, atol=1e-7, rtol=1e-3) -class TestLayerDopoutModel: +class TestLayerDropoutModel: def test_prepare_layer_dropout_uniform(self): class MockModel(torch.nn.Module): def __init__(self): From e5a53f9e598455f91a3c43864c3c665de7efdcf3 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 1 Dec 2024 22:30:00 +0000 Subject: [PATCH 75/88] test all loss scale types --- tests/torchtune/modules/test_early_exit_loss.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/torchtune/modules/test_early_exit_loss.py b/tests/torchtune/modules/test_early_exit_loss.py index cb63ddf2eb..9bef74eba3 100644 --- a/tests/torchtune/modules/test_early_exit_loss.py +++ b/tests/torchtune/modules/test_early_exit_loss.py @@ -6,6 +6,8 @@ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +import random + import numpy as np import pytest import torch @@ -69,10 +71,16 @@ def test_early_exit_loss(self, mock_model, hidden_states_dict, labels, loss_fn): assert isinstance(loss, torch.Tensor) assert loss.item() >= 0 - def test_layer_ids_to_loss_scales(self): - layer_ids = torch.tensor([0, 1, 2]) + @pytest.mark.parametrize( + "scale_type", + [e.value for e in LossScaleType], + ) + def test_layer_ids_to_loss_scales(self, scale_type): n_layers = 12 - scales = layer_ids_to_loss_scales(layer_ids, n_layers, LossScaleType.SUM_L, 1.0) + n_subset_layers = 5 + layer_ids = torch.tensor(random.sample(range(0, n_layers), n_subset_layers)) + + scales = layer_ids_to_loss_scales(layer_ids, n_layers, scale_type, 1.0) assert torch.isclose(scales.sum(), torch.tensor(1.0)) def test_early_exit_loss_vs_manual( From 3567a2431dc601efd29220a1ed2826e440fd43be Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 1 Dec 2024 22:47:46 +0000 Subject: [PATCH 76/88] use variable number of subset layers --- .../torchtune/modules/test_early_exit_loss.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/torchtune/modules/test_early_exit_loss.py b/tests/torchtune/modules/test_early_exit_loss.py index 9bef74eba3..4f9d271f39 100644 --- a/tests/torchtune/modules/test_early_exit_loss.py +++ b/tests/torchtune/modules/test_early_exit_loss.py @@ -33,7 +33,11 @@ def forward( class TestEarlyExitLoss: @pytest.fixture - def mock_model(self): + def num_layers(self): + return 12 + + @pytest.fixture + def mock_model(self, num_layers): # Create mock components tok_embeddings = nn.Embedding(1000, 512) # Example vocab size and embedding dim layers = nn.ModuleList([MockLayer() for _ in range(12)]) # 12 mock layers @@ -49,7 +53,7 @@ def mock_model(self): head_dim=64, norm=norm, output=output, - num_layers=12, + num_layers=num_layers, output_hidden_states=[0, 1, 2], # Example layers to output hidden states ) return model @@ -75,13 +79,13 @@ def test_early_exit_loss(self, mock_model, hidden_states_dict, labels, loss_fn): "scale_type", [e.value for e in LossScaleType], ) - def test_layer_ids_to_loss_scales(self, scale_type): - n_layers = 12 - n_subset_layers = 5 - layer_ids = torch.tensor(random.sample(range(0, n_layers), n_subset_layers)) - - scales = layer_ids_to_loss_scales(layer_ids, n_layers, scale_type, 1.0) - assert torch.isclose(scales.sum(), torch.tensor(1.0)) + def test_layer_ids_to_loss_scales(self, scale_type, num_layers): + for n_subset_layers in range(1, num_layers + 1): + layer_ids = torch.tensor( + random.sample(range(0, num_layers), n_subset_layers) + ) + scales = layer_ids_to_loss_scales(layer_ids, num_layers, scale_type, 1.0) + assert torch.isclose(scales.sum(), torch.tensor(1.0)) def test_early_exit_loss_vs_manual( self, mock_model, hidden_states_dict, labels, loss_fn From ae2108db094916598ffb1f137f8117b1adcf4cdf Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Sun, 1 Dec 2024 23:58:47 +0000 Subject: [PATCH 77/88] ensure get_scale returns values between 0 and 1 --- tests/torchtune/modules/test_layer_dropout.py | 12 ++++++++---- torchtune/modules/layer_dropout.py | 6 ++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/torchtune/modules/test_layer_dropout.py b/tests/torchtune/modules/test_layer_dropout.py index 24494df91f..398267630d 100644 --- a/tests/torchtune/modules/test_layer_dropout.py +++ b/tests/torchtune/modules/test_layer_dropout.py @@ -186,10 +186,14 @@ def test_get_scale_log(self): def test_get_scale_sin(self): scale_type = ScaleType.SIN scale_period = 10 - val = 5 - expected_scale = math.sin(0.5 * math.pi * 5 / 10) - actual_scale = get_scale(scale_type, scale_period, val) - assert_expected(actual_scale, expected_scale, atol=1e-7, rtol=1e-3) + + assert_expected(get_scale(scale_type, scale_period, 0), 0.0) + assert_expected( + get_scale(scale_type, scale_period, scale_period / 2), + math.sin(0.5 * math.pi * 0.5), + ) + assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0) class TestLayerDropoutModel: diff --git a/torchtune/modules/layer_dropout.py b/torchtune/modules/layer_dropout.py index 09f01098ec..eb58ed8e27 100644 --- a/torchtune/modules/layer_dropout.py +++ b/torchtune/modules/layer_dropout.py @@ -201,6 +201,8 @@ def get_scale( """ if scale_period == 0: return 1.0 + if val >= scale_period: + return 1.0 # all the equations below aim to make scale = 0 when val=0, and scale = 1 when val=scale_period scale = { @@ -212,8 +214,8 @@ def get_scale( ScaleType.SIGMOID: 1 / (1 + math.exp(-10 * (val / scale_period - 0.5))), }[scale_type] - # after scale_period, scale should be 1 - return min(scale, 1.0) + # ensure returned scale is between 0.0 and 1.0 (inclusive) + return max(min(scale, 1.0), 0.0) def prepare_layer_dropout( From 71707de109ce5f2ef6197cc78d5cf5be8dc4f740 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Mon, 2 Dec 2024 00:02:59 +0000 Subject: [PATCH 78/88] add test cases for sigmoid --- tests/torchtune/modules/test_layer_dropout.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/torchtune/modules/test_layer_dropout.py b/tests/torchtune/modules/test_layer_dropout.py index 398267630d..1aca0d3e9a 100644 --- a/tests/torchtune/modules/test_layer_dropout.py +++ b/tests/torchtune/modules/test_layer_dropout.py @@ -195,6 +195,21 @@ def test_get_scale_sin(self): assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0) + def test_get_scale_sigmoid(self): + scale_type = ScaleType.SIGMOID + scale_period = 10 + + # sigmoid(0) is close to 0 but not 0, hence adding relatively large rotl and atol + assert_expected( + get_scale(scale_type, scale_period, 0), 0.0, rtol=1e-2, atol=1e-2 + ) + assert_expected( + get_scale(scale_type, scale_period, scale_period / 2), + 0.5, + ) + assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0) + class TestLayerDropoutModel: def test_prepare_layer_dropout_uniform(self): From 78aff5af2cc79f74d977510dbfa703af97aab033 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Mon, 2 Dec 2024 00:11:50 +0000 Subject: [PATCH 79/88] make prepare_layer_dropout apply on a list of layers rather than a model --- .../dev/early_exit_finetune_distributed.py | 2 +- tests/torchtune/modules/test_layer_dropout.py | 6 ++--- torchtune/modules/layer_dropout.py | 26 +++++++++---------- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index f75ae13d80..85e082654a 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -404,7 +404,7 @@ def setup(self, cfg: DictConfig) -> None: cfg_layer_dropout = cfg.get("layer_dropout", None) if cfg_layer_dropout: prepare_layer_dropout( - self._model, + self._model.layers, prob_max=cfg_layer_dropout.get("prob", 0.0), prob_layer_scale=cfg_layer_dropout.get("layers_scale", "uniform"), layers_str=cfg_layer_dropout.get("layers", ":"), diff --git a/tests/torchtune/modules/test_layer_dropout.py b/tests/torchtune/modules/test_layer_dropout.py index 1aca0d3e9a..9bbd779d79 100644 --- a/tests/torchtune/modules/test_layer_dropout.py +++ b/tests/torchtune/modules/test_layer_dropout.py @@ -224,7 +224,7 @@ def __init__(self): prob_max = 0.5 prob_layer_scale = ScaleType.UNIFORM layers_str = "0:4" - prepare_layer_dropout(model, prob_max, prob_layer_scale, layers_str) + prepare_layer_dropout(model.layers, prob_max, prob_layer_scale, layers_str) for i, layer in enumerate(model.layers): assert hasattr(layer, "dropout") if i in range(0, 4): @@ -244,7 +244,7 @@ def __init__(self): prob_max = 0.5 prob_layer_scale = ScaleType.EXP layers_str = ":" - prepare_layer_dropout(model, prob_max, prob_layer_scale, layers_str) + prepare_layer_dropout(model.layers, prob_max, prob_layer_scale, layers_str) for i, layer in enumerate(model.layers): assert hasattr(layer, "dropout") if i == 0: @@ -266,7 +266,7 @@ def __init__(self): prob_max = 0.5 prob_layer_scale = ScaleType.LINEAR layers_str = ":" - prepare_layer_dropout(model, prob_max, prob_layer_scale, layers_str) + prepare_layer_dropout(model.layers, prob_max, prob_layer_scale, layers_str) for i, layer in enumerate(model.layers): assert hasattr(layer, "dropout") if i == 0: diff --git a/torchtune/modules/layer_dropout.py b/torchtune/modules/layer_dropout.py index eb58ed8e27..eae766bdd6 100644 --- a/torchtune/modules/layer_dropout.py +++ b/torchtune/modules/layer_dropout.py @@ -6,7 +6,7 @@ import math from enum import Enum -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Iterable, Optional, Union import torch @@ -219,21 +219,21 @@ def get_scale( def prepare_layer_dropout( - model: torch.nn.Module, + layers: Union[torch.nn.ModuleList, Iterable[torch.nn.Module]], prob_max: float = 0.0, prob_layer_scale: Optional[ScaleType] = ScaleType.UNIFORM, layers_str: Optional[str] = None, disable_on_eval: Optional[bool] = True, ) -> None: """ - Prepare a model for layer dropout by wrapping each layer with a ModuleLayerDropoutWrapper. - This function takes in a model, the maximum probability of dropping a layer, + Prepare a model's layers for layer dropout by wrapping each layer with a ModuleLayerDropoutWrapper. + This function takes in a list of layers, the maximum probability of dropping a layer, the scaling type for the layer dropout probability, a string specifying which layers to apply dropout to, and a boolean indicating whether to disable dropout during evaluation. It then wraps each layer of the model inplace with a ModuleLayerDropoutWrapper, which applies layer dropout to the input tensor. Args: - model (torch.nn.Module): The model to prepare for layer dropout. + layers (Union[torch.nn.ModuleList, Iterable[torch.nn.Module]]): The list of layers to prepare for layer dropout. prob_max (float): The maximum probability of dropping a layer. Defaults to 0.0. prob_layer_scale (Optional[ScaleType]): The scaling type for the dropout probability across layers. Defaults to ScaleType.UNIFORM. @@ -263,24 +263,24 @@ def prepare_layer_dropout( ... return x >>> model = MyModel() >>> # Apply layer dropout uniformly to all layers - >>> prepare_layer_dropout(model, prob_max=0.2, prob_layer_scale=ScaleType.UNIFORM) + >>> prepare_layer_dropout(model.layers, prob_max=0.2, prob_layer_scale=ScaleType.UNIFORM) >>> # Apply layer dropout every other layer, as described in LayerDrop paper (Fan et al., https://arxiv.org/abs/1909.11556v1) - >>> prepare_layer_dropout(model, prob_max=0.2, prob_layer_scale=ScaleType.UNIFORM, layers_str="::2") + >>> prepare_layer_dropout(model.layers, prob_max=0.2, prob_layer_scale=ScaleType.UNIFORM, layers_str="::2") >>> # Apply layer dropout that increases linearly across layers, as described in Progressive Layer Dropout paper (Zhang et al., https://arxiv.org/abs/2010.13369) - >>> prepare_layer_dropout(model, prob_max=0.2, prob_layer_scale=ScaleType.LINEAR) + >>> prepare_layer_dropout(model.layers, prob_max=0.2, prob_layer_scale=ScaleType.LINEAR) >>> # Apply layer dropout that increases exponentially across layers, as described in LayerSkip paper (Elhoushi et al., https://arxiv.org/abs/2404.16710) - >>> prepare_layer_dropout(model, prob_max=0.2, prob_layer_scale=ScaleType.EXP) + >>> prepare_layer_dropout(model.layers, prob_max=0.2, prob_layer_scale=ScaleType.EXP) """ - num_layers = len(model.layers) + num_layers = len(layers) has_dropout = ( slice_str_to_array(layers_str, num_layers) if layers_str else [True] * num_layers ) - for layer_id in range(len(model.layers)): + for layer_id in range(len(layers)): prob = ( prob_max * get_scale( @@ -299,6 +299,4 @@ def prepare_layer_dropout( layer_dropout = LayerDropout( prob, disable_on_eval=disable_on_eval, seed=layer_id ) - model.layers[layer_id] = ModuleLayerDropoutWrapper( - model.layers[layer_id], layer_dropout - ) + layers[layer_id] = ModuleLayerDropoutWrapper(layers[layer_id], layer_dropout) From 0fb373b8d8da344d52ca0abb39c8a57626126858 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 4 Dec 2024 00:02:23 -0500 Subject: [PATCH 80/88] Only add `optional` in docstring when argument is optional Co-authored-by: ebsmothers --- torchtune/modules/early_exit_loss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index 747b76a7f7..6517a77c9f 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -246,12 +246,12 @@ class GradualEarlyExitCurriculum(EarlyExitCurriculum): do_output_hidden_states (List[bool]): A list indicating whether each layer's hidden state should be output to calculate their losses. max_steps (int): The maximum number of steps in the curriculum. - train_last_layer (bool, optional): Whether to always calculate loss for the last layer. Defaults to True. + train_last_layer (bool): Whether to always calculate loss for the last layer. Defaults to True. last_step (Optional[int]): The last step the curriculum stopped at in a previous run. This is used when resuming training. - fraction_scale (float, optional): A scaling factor to determine at which fraction - of steps, all the layers will be enabled. At `steps = max_steps * fraction_scale`, all the layers will be enabled. - verbose (bool, optional): Whether to print verbose logs. Defaults to False. + fraction_scale (float): A scaling factor to determine at which fraction + of steps, all the layers will be enabled. At `steps = max_steps * fraction_scale`, all the layers will be enabled. Defaults to 0.5. + verbose (bool): Whether to print verbose logs. Defaults to False. """ def __init__( From b66e23b9e56dae6e10acf260fe6c22eec52fed0d Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 4 Dec 2024 05:31:04 +0000 Subject: [PATCH 81/88] add Dropout class and prepare_layer_dropout APIs to docs --- docs/source/api_ref_modules.rst | 2 ++ torchtune/modules/layer_dropout.py | 11 ++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index 36ea3637c8..5eb8fff358 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -23,6 +23,8 @@ Modeling Components and Building Blocks TransformerCrossAttentionLayer TransformerDecoder VisionTransformer + LayerDropout + prepare_layer_dropout Losses ------ diff --git a/torchtune/modules/layer_dropout.py b/torchtune/modules/layer_dropout.py index eae766bdd6..f46e9bf10d 100644 --- a/torchtune/modules/layer_dropout.py +++ b/torchtune/modules/layer_dropout.py @@ -17,7 +17,7 @@ class LayerDropout(torch.nn.Module): """ A module that applies layer dropout to the input tensor of an underlying module. It drops a portion of an input tensor, applies the underlying module on the - remaining parts of the tensor, and then concatenates with the dropped portion of the tensor. + remaining parts of the tensor, and then concatenates with the dropped portion of the tensor. When applied during training, it can have a regularization effect, and can potentially speedup training. Args: prob (float): The probability of dropping an input. Defaults to 0.0. @@ -232,13 +232,14 @@ def prepare_layer_dropout( layers to apply dropout to, and a boolean indicating whether to disable dropout during evaluation. It then wraps each layer of the model inplace with a ModuleLayerDropoutWrapper, which applies layer dropout to the input tensor. + Args: layers (Union[torch.nn.ModuleList, Iterable[torch.nn.Module]]): The list of layers to prepare for layer dropout. prob_max (float): The maximum probability of dropping a layer. Defaults to 0.0. - prob_layer_scale (Optional[ScaleType]): The scaling type for the dropout probability - across layers. Defaults to ScaleType.UNIFORM. - layers_str (Optional[str]): A string specifying which layers to apply dropout to. - Defaults to None which means apply to all layers. + prob_layer_scale (Optional[ScaleType]): The scaling type for the dropout probability across layers. Defaults to + ScaleType.UNIFORM. + layers_str (Optional[str]): A string specifying which layers to apply dropout to. Defaults to None which means + apply to all layers. disable_on_eval (Optional[bool]): Whether to disable dropout during evaluation. Defaults to True. Returns: None From cd8be646634cf3369332dbdd5c8bd5ca3a192bab Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 4 Dec 2024 05:34:20 +0000 Subject: [PATCH 82/88] add empty line between function description and Args --- torchtune/modules/early_exit_loss.py | 9 ++++++++- torchtune/modules/layer_dropout.py | 4 ++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index 6517a77c9f..6d6aaf0344 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -40,6 +40,7 @@ def early_exit_loss( and optional parameters for scaling the loss. It computes the early exit loss by iterating over the hidden states, computing the logits and losses at each layer, and then scaling and summing these losses. + Args: model (TransformerDecoder): The model to compute the early exit loss for. hidden_states_dict (Dict[int, torch.Tensor]): A dictionary of hidden states, @@ -98,6 +99,7 @@ def layer_ids_to_loss_scales( a loss scale type, and an early exit scaling factor. It computes the loss scales based on the specified loss scale type and then normalizes them to ensure that their sum is 1.0. + Args: layer_ids (torch.Tensor): A tensor of layer IDs. n_layers (int): The total number of layers. @@ -149,6 +151,7 @@ class EarlyExitCurriculum: """ A curriculum for early exit loss training, which controls which layers to use their hidden states during training. + Args: do_output_hidden_states (List[bool]): A list indicating whether each layer's hidden state should be output to calculate their losses. @@ -197,6 +200,7 @@ class RotationalEarlyExitCurriculum(EarlyExitCurriculum): """ A rotational early exit curriculum, which rotates the layer enablement one step forward at each step. + Args: do_output_hidden_states (List[bool]): A list indicating whether each layer's hidden state should be output to calculate their losses. @@ -242,6 +246,7 @@ def step(self): class GradualEarlyExitCurriculum(EarlyExitCurriculum): """ A gradual early exit curriculum, which gradually enables more layers (starting from the last layer) as training progresses. + Args: do_output_hidden_states (List[bool]): A list indicating whether each layer's hidden state should be output to calculate their losses. @@ -250,7 +255,8 @@ class GradualEarlyExitCurriculum(EarlyExitCurriculum): last_step (Optional[int]): The last step the curriculum stopped at in a previous run. This is used when resuming training. fraction_scale (float): A scaling factor to determine at which fraction - of steps, all the layers will be enabled. At `steps = max_steps * fraction_scale`, all the layers will be enabled. Defaults to 0.5. + of steps, all the layers will be enabled. At `steps = max_steps * fraction_scale`, all the layers will be + enabled. Defaults to 0.5. verbose (bool): Whether to print verbose logs. Defaults to False. """ @@ -314,6 +320,7 @@ def setup_early_exit_loss_curriculum( This function takes in an early exit curriculum type and optional arguments. It returns an instance of the corresponding early exit curriculum class, or None if the curriculum type is NONE. + Args: early_exit_curriculum (EarlyExitCurriculumType): The type of early exit curriculum to set up. *args: Optional positional arguments for the early exit curriculum constructor. diff --git a/torchtune/modules/layer_dropout.py b/torchtune/modules/layer_dropout.py index f46e9bf10d..75e28b4ae1 100644 --- a/torchtune/modules/layer_dropout.py +++ b/torchtune/modules/layer_dropout.py @@ -19,6 +19,7 @@ class LayerDropout(torch.nn.Module): It drops a portion of an input tensor, applies the underlying module on the remaining parts of the tensor, and then concatenates with the dropped portion of the tensor. When applied during training, it can have a regularization effect, and can potentially speedup training. + Args: prob (float): The probability of dropping an input. Defaults to 0.0. dim (Optional[int]): The dimension of input tensor along which to drop layers. Defaults to 0 (i.e., batch size). @@ -61,6 +62,7 @@ def forward( ) -> torch.Tensor: """ Apply layer dropout to the input tensor. + Args: function (Union[Callable, torch.nn.Module]): The function or module to apply to the input tensor. input (torch.Tensor): The input tensor. @@ -101,6 +103,7 @@ class ModuleLayerDropoutWrapper(torch.nn.Module): A wrapper module that adds layer dropout functionality to a given module. This class wraps a given module and applies layer dropout to it. It also provides getter and setter methods for the wrapped module's attributes. + Args: module (torch.nn.Module): The module to wrap. dropout (LayerDropout): The layer dropout object. @@ -185,6 +188,7 @@ def get_scale( Compute a scaling factor based on the provided scale type, period, and value. The scaling factor is designed to be 0 when the value is 0 and 1 when the value reaches or is larger than the scale period. + Args: scale_type (ScaleType): The type of scaling to use. scale_period (int): The period over which the scaling factor increases from 0 to 1. From 2675b4c3fbe5ca4a9f0ea0ace4e9a3f485b87cd4 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 4 Dec 2024 05:38:06 +0000 Subject: [PATCH 83/88] remove assert statement as we added the check in testing --- torchtune/modules/early_exit_loss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index 6d6aaf0344..9d4c3b58da 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -135,7 +135,6 @@ def layer_ids_to_loss_scales( loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) # normalize loss scales to ensure that their sum is 1.0 loss_scales = loss_scales / torch.sum(loss_scales) - assert torch.isclose(torch.sum(loss_scales), torch.Tensor([1.0]).to(loss_scales)) return loss_scales From 00d8efa7ca63f6c10db0fc7ace1143f8f6ca71c1 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Thu, 5 Dec 2024 04:54:52 +0000 Subject: [PATCH 84/88] change loss scale from enum to function --- recipes/dev/7B_full_early_exit.yaml | 6 +- .../dev/early_exit_finetune_distributed.py | 6 +- .../torchtune/modules/test_early_exit_loss.py | 25 +++- torchtune/modules/early_exit_loss.py | 120 ++++++++---------- 4 files changed, 79 insertions(+), 78 deletions(-) diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml index 8433309a48..a9e8ab13d7 100644 --- a/recipes/dev/7B_full_early_exit.yaml +++ b/recipes/dev/7B_full_early_exit.yaml @@ -16,10 +16,10 @@ # # To reproduce experiments of various papers that use early exit loss and/or layer dropout: # - LayerSkip (https://arxiv.org/abs/2404.16710) on TOPv2: -# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss.scale=1.0 early_exit_loss.curriculum=gradual early_exit_loss.scale_type=l layer_dropout.prob=0.2 layer_dropout.scale=exp +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss.scale=1.0 early_exit_loss.curriculum=gradual early_exit_loss.scale_type=torchtune.modules.early_exit_loss.linear_l_loss_scale layer_dropout.prob=0.2 layer_dropout.scale=exp # # - LITE (https://arxiv.org/abs/2310.18581): -# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml layer_dropout=null early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_type=one early_exit_loss.curriculum=null epochs=5 +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml layer_dropout=null early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_type=torchtune.modules.early_exit_loss.uniform_loss_scale early_exit_loss.curriculum=null epochs=5 # # - LayerDrop (https://arxiv.org/abs/1909.11556): # tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss=null layer_dropout.prob=0.2 layer_dropout.layers=1::2 @@ -126,7 +126,7 @@ profiler: early_exit_loss: layers: ":" curriculum: "gradual" - scale_type: "sum_l" + scale_type: torchtune.modules.early_exit_loss.sum_l_loss_scale scale: 1.0 # Layer Dropout diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 85e082654a..d421bf5e7b 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -245,8 +245,10 @@ def __init__(self, cfg: DictConfig) -> None: if cfg_early_exit_loss: self._do_early_exit_loss = True self._early_exit_loss_scale = cfg_early_exit_loss.get("scale", 1.0) - self._early_exit_loss_scale_type = cfg_early_exit_loss.get( - "scale_type", "one" + self._early_exit_loss_scale_type = _get_component_from_path( + cfg_early_exit_loss.get( + "scale_type", "torchtune.modules.early_exit_loss.sum_l_loss_scale" + ) ) else: self._do_early_exit_loss = False diff --git a/tests/torchtune/modules/test_early_exit_loss.py b/tests/torchtune/modules/test_early_exit_loss.py index 4f9d271f39..7cb9fd3384 100644 --- a/tests/torchtune/modules/test_early_exit_loss.py +++ b/tests/torchtune/modules/test_early_exit_loss.py @@ -17,10 +17,14 @@ early_exit_loss, EarlyExitCurriculumType, GradualEarlyExitCurriculum, - layer_ids_to_loss_scales, - LossScaleType, + inv_l_loss_scale, + inv_sqrt_l_loss_scale, + linear_l_loss_scale, RotationalEarlyExitCurriculum, setup_early_exit_loss_curriculum, + sqrt_l_loss_scale, + sum_l_loss_scale, + uniform_loss_scale, ) # Mock components for TransformerDecoder @@ -76,15 +80,22 @@ def test_early_exit_loss(self, mock_model, hidden_states_dict, labels, loss_fn): assert loss.item() >= 0 @pytest.mark.parametrize( - "scale_type", - [e.value for e in LossScaleType], + "scale_fn", + [ + uniform_loss_scale, + linear_l_loss_scale, + sum_l_loss_scale, + sqrt_l_loss_scale, + inv_l_loss_scale, + inv_sqrt_l_loss_scale, + ], ) - def test_layer_ids_to_loss_scales(self, scale_type, num_layers): + def test_layer_ids_to_loss_scales(self, scale_fn, num_layers): for n_subset_layers in range(1, num_layers + 1): layer_ids = torch.tensor( random.sample(range(0, num_layers), n_subset_layers) ) - scales = layer_ids_to_loss_scales(layer_ids, num_layers, scale_type, 1.0) + scales = scale_fn(layer_ids, num_layers, 1.0) assert torch.isclose(scales.sum(), torch.tensor(1.0)) def test_early_exit_loss_vs_manual( @@ -98,7 +109,7 @@ def test_early_exit_loss_vs_manual( labels, loss_fn, e_scale=1, - loss_scale_type="one", + loss_scale_fn=uniform_loss_scale, ) # Manually calculate the loss for each hidden state total_loss = 0.0 diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index 9d4c3b58da..0e87e83b07 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -6,7 +6,7 @@ import copy from enum import Enum -from typing import Dict, List, Optional +from typing import Callable, Dict, List, Optional import numpy as np import torch @@ -17,13 +17,52 @@ log = utils.get_logger("DEBUG") -class LossScaleType(str, Enum): - ONE = "one" - L = "l" - SUM_L = "sum_l" - INV_L = "inv_l" - SQRT_L = "sqrt_l" - INV_SQRT_L = "inv_sqrt_l" +def uniform_loss_scale( + layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0 +) -> torch.Tensor: + loss_scales = torch.ones(len(layer_ids), device=layer_ids.device) + loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) + return loss_scales / torch.sum(loss_scales) + + +def linear_l_loss_scale( + layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0 +) -> torch.Tensor: + loss_scales = torch.Tensor(layer_ids + 1) + loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) + return loss_scales / torch.sum(loss_scales) + + +def sum_l_loss_scale( + layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0 +) -> torch.Tensor: + loss_scales = torch.cumsum(layer_ids + 1, dim=0) + loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) + return loss_scales / torch.sum(loss_scales) + + +def sqrt_l_loss_scale( + layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0 +) -> torch.Tensor: + loss_scales = torch.sqrt(layer_ids + 1) + loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) + return loss_scales / torch.sum(loss_scales) + + +def inv_l_loss_scale( + layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0 +) -> torch.Tensor: + loss_scales = 1.0 / (layer_ids + 1) + loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) + return loss_scales / torch.sum(loss_scales) + + +def inv_sqrt_l_loss_scale( + layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0 +) -> torch.Tensor: + loss_scales = torch.reciprocal(torch.sqrt(layer_ids + 1)) + loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) + return loss_scales / torch.sum(loss_scales) def early_exit_loss( @@ -32,7 +71,9 @@ def early_exit_loss( labels: torch.Tensor, loss_fn: torch.nn.Module, e_scale: float = 1.0, - loss_scale_type: LossScaleType = LossScaleType.SUM_L, + loss_scale_fn: Callable[ + [torch.Tensor, int, float], torch.Tensor + ] = uniform_loss_scale, ) -> torch.Tensor: """ Compute the early exit loss for a given model and outputs of intermediate layers. @@ -47,9 +88,9 @@ def early_exit_loss( where each key is a layer index and each value is a tensor of shape [b, s, d]. labels (torch.Tensor): The labels for the input data. loss_fn (torch.nn.Module): The loss function to use (should be the same as the standard loss function for last layer). - e_scale (float, optional): A scaling factor for the early exit losses. Defaults to 1.0. - loss_scale_type (LossScaleType, optional): The type of loss scaling to use to determine - scale of each layer's loss. Defaults to LossScaleType.SUM_L. + e_scale (float): A scaling factor for the early exit losses. Defaults to 1.0. + loss_scale_fn (Callable[[torch.Tensor, int, float], torch.Tensor]): A function to determine scale of each + layer's loss. Defaults to uniform_loss_scale. Returns: torch.Tensor: The computed early exit loss. """ @@ -77,68 +118,15 @@ def early_exit_loss( s_unpadded = (labels != loss_fn.ignore_index).sum() losses_early = losses_early.float().sum(-1) / s_unpadded # Shape: [e] - losses_scales = layer_ids_to_loss_scales( + losses_scales = loss_scale_fn( torch.Tensor(hidden_layer_ids).to(losses_early), len(model.layers), - loss_scale_type, e_scale, ) return torch.sum(losses_scales * losses_early) -def layer_ids_to_loss_scales( - layer_ids: torch.Tensor, - n_layers: int, - loss_scale_type: LossScaleType, - e_scale: float, -) -> torch.Tensor: - """ - Compute the loss scales for a given set of layer IDs and loss scale type. - This function takes in a list of layer IDs, the total number of layers, - a loss scale type, and an early exit scaling factor. It computes the loss - scales based on the specified loss scale type and then normalizes them to - ensure that their sum is 1.0. - - Args: - layer_ids (torch.Tensor): A tensor of layer IDs. - n_layers (int): The total number of layers. - loss_scale_type (LossScaleType): The type of loss scaling to use. - e_scale (float): An early exit scaling factor. - Returns: - torch.Tensor: The computed loss scales. - Raises: - ValueError: If the provided loss scale type is not supported. - AssertionError: If the sum of the loss scales is not close to 1.0. - Example: - >>> layer_ids = [0, 1, 2] - >>> n_layers = 3 - >>> loss_scale_type = LossScaleType.SUM_L - >>> e_scale = 1.0 - >>> loss_scales = layer_ids_to_loss_scales(layer_ids, n_layers, loss_scale_type, e_scale) - """ - if loss_scale_type == LossScaleType.ONE: - loss_scales = torch.ones(len(layer_ids), device=layer_ids.device) - elif loss_scale_type == LossScaleType.L: - loss_scales = torch.Tensor(layer_ids + 1) - elif loss_scale_type == LossScaleType.SUM_L: - loss_scales = torch.cumsum(layer_ids + 1, dim=0) - elif loss_scale_type == LossScaleType.SQRT_L: - loss_scales = torch.sqrt(layer_ids + 1) - elif loss_scale_type == LossScaleType.INV_L: - loss_scales = 1.0 / (layer_ids + 1) - elif loss_scale_type == LossScaleType.INV_SQRT_L: - loss_scales = torch.reciprocal(torch.sqrt(layer_ids + 1)) - else: - raise ValueError(f"Unsupported loss_scale type {loss_scale_type}") - - loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) - # normalize loss scales to ensure that their sum is 1.0 - loss_scales = loss_scales / torch.sum(loss_scales) - - return loss_scales - - class EarlyExitCurriculumType(str, Enum): NONE = "none" ROTATIONAL = "rot" From 78b8996545149433a69c1ba5a991c9a596c60708 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Thu, 5 Dec 2024 05:18:34 +0000 Subject: [PATCH 85/88] change curriculum from enum to function --- recipes/dev/7B_full_early_exit.yaml | 4 +-- .../dev/early_exit_finetune_distributed.py | 15 ++++---- .../torchtune/modules/test_early_exit_loss.py | 13 ------- torchtune/modules/early_exit_loss.py | 36 ------------------- 4 files changed, 8 insertions(+), 60 deletions(-) diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml index a9e8ab13d7..dbcdba9190 100644 --- a/recipes/dev/7B_full_early_exit.yaml +++ b/recipes/dev/7B_full_early_exit.yaml @@ -16,7 +16,7 @@ # # To reproduce experiments of various papers that use early exit loss and/or layer dropout: # - LayerSkip (https://arxiv.org/abs/2404.16710) on TOPv2: -# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss.scale=1.0 early_exit_loss.curriculum=gradual early_exit_loss.scale_type=torchtune.modules.early_exit_loss.linear_l_loss_scale layer_dropout.prob=0.2 layer_dropout.scale=exp +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss.scale=1.0 early_exit_loss.curriculum=torchtune.modules.early_exit_loss.GradualEarlyExitCurriculum early_exit_loss.scale_type=torchtune.modules.early_exit_loss.linear_l_loss_scale layer_dropout.prob=0.2 layer_dropout.scale=exp # # - LITE (https://arxiv.org/abs/2310.18581): # tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml layer_dropout=null early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_type=torchtune.modules.early_exit_loss.uniform_loss_scale early_exit_loss.curriculum=null epochs=5 @@ -125,7 +125,7 @@ profiler: # Early Exit Loss early_exit_loss: layers: ":" - curriculum: "gradual" + curriculum: torchtune.modules.early_exit_loss.GradualEarlyExitCurriculum scale_type: torchtune.modules.early_exit_loss.sum_l_loss_scale scale: 1.0 diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index d421bf5e7b..e41e98d5c4 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -25,11 +25,7 @@ from torchtune.datasets import ConcatDataset from torchtune.modules.common_utils import slice_str_to_array -from torchtune.modules.early_exit_loss import ( - early_exit_loss, - EarlyExitCurriculum, - setup_early_exit_loss_curriculum, -) +from torchtune.modules.early_exit_loss import early_exit_loss, EarlyExitCurriculum from torchtune.modules.layer_dropout import prepare_layer_dropout from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import DummyProfiler, PROFILER_KEY @@ -712,9 +708,11 @@ def _setup_early_exit_loss( train_last_layer = cfg_early_exit_loss.get("include_last_layer", True) verbose = cfg_early_exit_loss.get("verbose", False) - if cfg_early_exit_loss.get("curriculum", None): - early_exit_loss_curriculum = setup_early_exit_loss_curriculum( - early_exit_curriculum=cfg_early_exit_loss.curriculum, + early_exit_loss_curriculum = cfg_early_exit_loss.get("curriculum", None) + if early_exit_loss_curriculum: + early_exit_loss_curriculum = _get_component_from_path( + early_exit_loss_curriculum + )( do_output_hidden_states=do_output_hidden_states, max_steps=self.total_epochs * self._steps_per_epoch, train_last_layer=train_last_layer, @@ -723,7 +721,6 @@ def _setup_early_exit_loss( ) do_output_hidden_states = early_exit_loss_curriculum.get() else: - early_exit_loss_curriculum = None if train_last_layer: do_output_hidden_states[len(self._model.layers) - 1] = True diff --git a/tests/torchtune/modules/test_early_exit_loss.py b/tests/torchtune/modules/test_early_exit_loss.py index 7cb9fd3384..27d421c14d 100644 --- a/tests/torchtune/modules/test_early_exit_loss.py +++ b/tests/torchtune/modules/test_early_exit_loss.py @@ -15,13 +15,11 @@ from torchtune.modules import TransformerDecoder from torchtune.modules.early_exit_loss import ( early_exit_loss, - EarlyExitCurriculumType, GradualEarlyExitCurriculum, inv_l_loss_scale, inv_sqrt_l_loss_scale, linear_l_loss_scale, RotationalEarlyExitCurriculum, - setup_early_exit_loss_curriculum, sqrt_l_loss_scale, sum_l_loss_scale, uniform_loss_scale, @@ -131,17 +129,6 @@ def test_early_exit_loss_vs_manual( class TestEarlyExitLossCurriculum: - def test_setup_early_exit_loss_curriculum(self): - curriculum = setup_early_exit_loss_curriculum( - EarlyExitCurriculumType.ROTATIONAL, [True, False, True], 100 - ) - assert isinstance(curriculum, RotationalEarlyExitCurriculum) - - curriculum = setup_early_exit_loss_curriculum( - EarlyExitCurriculumType.GRADUAL, [True, False, True], 100 - ) - assert isinstance(curriculum, GradualEarlyExitCurriculum) - @pytest.mark.parametrize( "train_last_layer", [ diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py index 0e87e83b07..976f6bee07 100644 --- a/torchtune/modules/early_exit_loss.py +++ b/torchtune/modules/early_exit_loss.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import copy -from enum import Enum from typing import Callable, Dict, List, Optional import numpy as np @@ -127,12 +126,6 @@ def early_exit_loss( return torch.sum(losses_scales * losses_early) -class EarlyExitCurriculumType(str, Enum): - NONE = "none" - ROTATIONAL = "rot" - GRADUAL = "gradual" - - # TODO: create a base curriculum class that can be used for other aspects, e.g., dropout, datasets, etc. class EarlyExitCurriculum: """ @@ -297,32 +290,3 @@ def step(self): log.info( f"Updated self._do_output_hidden_states to {self._do_output_hidden_states}." ) - - -def setup_early_exit_loss_curriculum( - early_exit_curriculum: EarlyExitCurriculumType, *args, **kwargs -) -> Optional[EarlyExitCurriculum]: - """ - Set up an early exit loss curriculum based on the provided type. - This function takes in an early exit curriculum type and optional arguments. - It returns an instance of the corresponding early exit curriculum class, - or None if the curriculum type is NONE. - - Args: - early_exit_curriculum (EarlyExitCurriculumType): The type of early exit curriculum to set up. - *args: Optional positional arguments for the early exit curriculum constructor. - **kwargs: Optional keyword arguments for the early exit curriculum constructor. - Returns: - Optional[EarlyExitCurriculum]: - An instance of the corresponding early exit curriculum class, or None. - Raises: - ValueError: If the provided early exit curriculum type is not supported. - """ - if early_exit_curriculum == EarlyExitCurriculumType.NONE: - return None - elif early_exit_curriculum == EarlyExitCurriculumType.ROTATIONAL: - return RotationalEarlyExitCurriculum(*args, **kwargs) - elif early_exit_curriculum == EarlyExitCurriculumType.GRADUAL: - return GradualEarlyExitCurriculum(*args, **kwargs) - else: - raise ValueError(f"Unsupported early loss curriculum {early_exit_curriculum}.") From ed33ba9e8082626a41019a4956b4b2c3932f4e33 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Fri, 6 Dec 2024 20:54:27 +0000 Subject: [PATCH 86/88] rename scale_type to scale_fn --- recipes/dev/7B_full_early_exit.yaml | 6 +++--- recipes/dev/early_exit_finetune_distributed.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml index dbcdba9190..e33535e8a3 100644 --- a/recipes/dev/7B_full_early_exit.yaml +++ b/recipes/dev/7B_full_early_exit.yaml @@ -16,10 +16,10 @@ # # To reproduce experiments of various papers that use early exit loss and/or layer dropout: # - LayerSkip (https://arxiv.org/abs/2404.16710) on TOPv2: -# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss.scale=1.0 early_exit_loss.curriculum=torchtune.modules.early_exit_loss.GradualEarlyExitCurriculum early_exit_loss.scale_type=torchtune.modules.early_exit_loss.linear_l_loss_scale layer_dropout.prob=0.2 layer_dropout.scale=exp +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss.scale=1.0 early_exit_loss.curriculum=torchtune.modules.early_exit_loss.GradualEarlyExitCurriculum early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.linear_l_loss_scale layer_dropout.prob=0.2 layer_dropout.scale=exp # # - LITE (https://arxiv.org/abs/2310.18581): -# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml layer_dropout=null early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_type=torchtune.modules.early_exit_loss.uniform_loss_scale early_exit_loss.curriculum=null epochs=5 +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml layer_dropout=null early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.uniform_loss_scale early_exit_loss.curriculum=null epochs=5 # # - LayerDrop (https://arxiv.org/abs/1909.11556): # tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss=null layer_dropout.prob=0.2 layer_dropout.layers=1::2 @@ -126,7 +126,7 @@ profiler: early_exit_loss: layers: ":" curriculum: torchtune.modules.early_exit_loss.GradualEarlyExitCurriculum - scale_type: torchtune.modules.early_exit_loss.sum_l_loss_scale + scale_fn: torchtune.modules.early_exit_loss.sum_l_loss_scale scale: 1.0 # Layer Dropout diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index e41e98d5c4..e7b7649526 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -52,16 +52,16 @@ class EarlyExitFinetuneRecipeDistributed(FTRecipeInterface): - ``early_exit_loss.layers`` is a string, whose format mimics indexing in numpy arrays (e.g., `:` depicts all layers, `0:10:3` depicts layers 0, 3, 6, 9, and `1,5,11` depicts layers 1,5,11), to represent which layers to apply early exit loss at, - - ``early_exit_loss.scale_type`` and ``early_exit_loss.scale`` determine how we calculate the + - ``early_exit_loss.scale_fn`` and ``early_exit_loss.scale`` determine how we calculate the weights of losses at different layers when calculating total loss, and - ``early_exit_loss.curriculum`` depicts how the early exit loss layers change across training iterations. See ``torchtune/modules/early_exit_loss.py` for more details of each argument. To reproduce experiments of different papers that use early exit loss: - LayerSkip (https://arxiv.org/abs/2404.16710) for finetuning on TOPv2: set - ``early_exit_loss.scale=1.0, early_exit_loss.curriculum=gradual early_exit_loss.scale_type=l``, + ``early_exit_loss.scale=1.0, early_exit_loss.curriculum=gradual early_exit_loss.scale_fn=l``, - LITE (https://arxiv.org/abs/2310.18581) for finetuning Llama2 7B on Alpaca you can set - ``early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_type=one``. + ``early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_fn=one``. - Layer Dropout. (a.k.a. Stochastic Depth) This drops samples stochastically for each layer during training. "Dropping" a sample at a layer in this context means a sample will pass through the layer without modification. @@ -243,7 +243,7 @@ def __init__(self, cfg: DictConfig) -> None: self._early_exit_loss_scale = cfg_early_exit_loss.get("scale", 1.0) self._early_exit_loss_scale_type = _get_component_from_path( cfg_early_exit_loss.get( - "scale_type", "torchtune.modules.early_exit_loss.sum_l_loss_scale" + "scale_fn", "torchtune.modules.early_exit_loss.sum_l_loss_scale" ) ) else: From c7f02de9303806fa3c97de50cc755757ae451aa7 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Fri, 6 Dec 2024 21:01:11 +0000 Subject: [PATCH 87/88] change default --- recipes/dev/7B_full_early_exit.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml index e33535e8a3..7d02a34f0e 100644 --- a/recipes/dev/7B_full_early_exit.yaml +++ b/recipes/dev/7B_full_early_exit.yaml @@ -124,8 +124,8 @@ profiler: # Early Exit Loss early_exit_loss: - layers: ":" - curriculum: torchtune.modules.early_exit_loss.GradualEarlyExitCurriculum + layers: "0::4" + curriculum: torchtune.modules.early_exit_loss.RotationalEarlyExitCurriculum scale_fn: torchtune.modules.early_exit_loss.sum_l_loss_scale scale: 1.0 From 69f840ca057a8d826bb331e4ec9454ae96b27cdb Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Fri, 6 Dec 2024 21:02:24 +0000 Subject: [PATCH 88/88] update docstring --- recipes/dev/early_exit_finetune_distributed.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index e7b7649526..aed914a463 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -59,9 +59,12 @@ class EarlyExitFinetuneRecipeDistributed(FTRecipeInterface): See ``torchtune/modules/early_exit_loss.py` for more details of each argument. To reproduce experiments of different papers that use early exit loss: - LayerSkip (https://arxiv.org/abs/2404.16710) for finetuning on TOPv2: set - ``early_exit_loss.scale=1.0, early_exit_loss.curriculum=gradual early_exit_loss.scale_fn=l``, + ``early_exit_loss.scale=1.0, + early_exit_loss.curriculum=torchtune.modules.early_exit_loss.GradualEarlyExitCurriculum + early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.linear_l_loss_scale``, - LITE (https://arxiv.org/abs/2310.18581) for finetuning Llama2 7B on Alpaca you can set - ``early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_fn=one``. + ``early_exit_loss.layers=8,12,16,20,24,28 + early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.uniform_loss_scale``. - Layer Dropout. (a.k.a. Stochastic Depth) This drops samples stochastically for each layer during training. "Dropping" a sample at a layer in this context means a sample will pass through the layer without modification.