diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index 36ea3637c8..5eb8fff358 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -23,6 +23,8 @@ Modeling Components and Building Blocks TransformerCrossAttentionLayer TransformerDecoder VisionTransformer + LayerDropout + prepare_layer_dropout Losses ------ diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml new file mode 100644 index 0000000000..7d02a34f0e --- /dev/null +++ b/recipes/dev/7B_full_early_exit.yaml @@ -0,0 +1,137 @@ +# Config for multi-device full finetuning with early exit loss and/or layer dropout +# in dev/early_exit_finetune_distributed.py using a Llama2 7B model on a small TOPv2 +# instruction set. +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token <HF_TOKEN> +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> +# +# To reproduce experiments of various papers that use early exit loss and/or layer dropout: +# - LayerSkip (https://arxiv.org/abs/2404.16710) on TOPv2: +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss.scale=1.0 early_exit_loss.curriculum=torchtune.modules.early_exit_loss.GradualEarlyExitCurriculum early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.linear_l_loss_scale layer_dropout.prob=0.2 layer_dropout.scale=exp +# +# - LITE (https://arxiv.org/abs/2310.18581): +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml layer_dropout=null early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.uniform_loss_scale early_exit_loss.curriculum=null epochs=5 +# +# - LayerDrop (https://arxiv.org/abs/1909.11556): +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss=null layer_dropout.prob=0.2 layer_dropout.layers=1::2 +# +# - Progressive Layer Dropping (https://arxiv.org/abs/2010.13369) (The paper also implements a curriculum for layer drop probability which is not yet implemented.): +# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss=null layer_dropout.prob=0.5 layer_dropout.scale=exp +# +# This config works best for distributed training, hence when the model is being fine-tuned on 2+ GPUs. +# + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-7b-hf/tokenizer.model + max_seq_len: null + +# Dataset +dataset: + _component_: torchtune.datasets.instruct_dataset + source: WillHeld/top_v2 + split: train + column_map: + input: utterance + output: semantic_parse + +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama2.llama2_7b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-2-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] + recipe_checkpoint: null + output_dir: /tmp/Llama-2-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 8 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/topv2-llama2-finetune +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 + +# Early Exit Loss +early_exit_loss: + layers: "0::4" + curriculum: torchtune.modules.early_exit_loss.RotationalEarlyExitCurriculum + scale_fn: torchtune.modules.early_exit_loss.sum_l_loss_scale + scale: 1.0 + +# Layer Dropout +layer_dropout: + prob: 0.2 + layers: ":" + layers_scale: "exp" + disable_on_eval: True diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py new file mode 100644 index 0000000000..aed914a463 --- /dev/null +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -0,0 +1,1070 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time + +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.distributed import destroy_process_group, init_process_group + +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, training, utils +from torchtune.config._utils import _get_component_from_path +from torchtune.data import padded_collate_packed +from torchtune.datasets import ConcatDataset +from torchtune.modules.common_utils import slice_str_to_array + +from torchtune.modules.early_exit_loss import early_exit_loss, EarlyExitCurriculum +from torchtune.modules.layer_dropout import prepare_layer_dropout +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.training import DummyProfiler, PROFILER_KEY +from torchtune.training.activations import apply_selective_activation_checkpointing +from torchtune.training.lr_schedulers import get_lr + +from tqdm import tqdm + +log = utils.get_logger("DEBUG") + + +class EarlyExitFinetuneRecipeDistributed(FTRecipeInterface): + """ + Early exit and layer dropout full finetuning to make the model more robust to early exit and skipping + intermediate layers for dense transformer-based LLMs such as Llama2. This recipe supports distributed + training and can be run on a single node (1 to 8 GPUs). + + Features: + - Early Exit Loss. This makes the model more robust to exiting early by applying the outputs of intermediate + layers on the model's language model head (a.k.a. unembedding operation) to obtain outputs of earlier + layers, then obtain the losses at such earlier layers. Then the loss of the model during training + would be a weighted average of the losses at different layers. The different arguments you can + configure are: + - ``early_exit_loss.layers`` is a string, whose format mimics indexing in numpy arrays (e.g., `:` + depicts all layers, `0:10:3` depicts layers 0, 3, 6, 9, and `1,5,11` depicts layers 1,5,11), to + represent which layers to apply early exit loss at, + - ``early_exit_loss.scale_fn`` and ``early_exit_loss.scale`` determine how we calculate the + weights of losses at different layers when calculating total loss, and + - ``early_exit_loss.curriculum`` depicts how the early exit loss layers change across training + iterations. + See ``torchtune/modules/early_exit_loss.py` for more details of each argument. + To reproduce experiments of different papers that use early exit loss: + - LayerSkip (https://arxiv.org/abs/2404.16710) for finetuning on TOPv2: set + ``early_exit_loss.scale=1.0, + early_exit_loss.curriculum=torchtune.modules.early_exit_loss.GradualEarlyExitCurriculum + early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.linear_l_loss_scale``, + - LITE (https://arxiv.org/abs/2310.18581) for finetuning Llama2 7B on Alpaca you can set + ``early_exit_loss.layers=8,12,16,20,24,28 + early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.uniform_loss_scale``. + + - Layer Dropout. (a.k.a. Stochastic Depth) This drops samples stochastically for each layer during training. + "Dropping" a sample at a layer in this context means a sample will pass through the layer without modification. + The different arguments you can configure are: + - ``layer_dropout.prob``: is the (maximum) probability of a sample being dropped at each layer. + - ``layer_dropout.layers``: is a string, whose format mimics indexing in numpy arrays + (same as ``early_exit_loss.layers``), that determines which layers will have layer dropout applied. + - ``layer_dropout.layers_scale``: determines how probability changes across layers from + probability 0 at first layer, to probability ``layer_dropout.prob`` at last layer. + You can choose from ``one`` (all layers have ``layer_dropout.prob``), ``linear``, + ``exp``, ``log``, ``sqrt``. + - ``disable_on_eval``: if True, will only apply layer dropout during training. If False, will + apply to both training and evaluation. + To reproduce results of different papers that use layer dropout: + - LayerDrop (https://arxiv.org/abs/1909.11556) that applies dropout on every other layer, set + ``layer_dropout.prob=0.2 layer_dropout.layers=1::2``. + - Progressive Layer Dropping (https://arxiv.org/abs/2010.13369) that increases dropout linearly + across layers, set ``layer_dropout.prob=0.5 layer_dropout.layers_scale=linear``. + The paper also implements a curriculum for layer drop probability which is not yet implemented. + - LayerSkip (https://arxiv.org/abs/2404.16710) for finetuning on TOPv2: (in addition to early exit loss + arguments above) set ``layer_dropout.prob=0.2 layer_dropout.scale=exp``. + + - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states + is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is + done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config + ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). + DDP is currently not supported. Training on CPU is not supported. + + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * number of GPUs * gradient accumulation steps. + + For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a + total batch size of 64. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. + + Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``left_pad_sequence`` is set as the data collator. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + self._dtype = training.get_dtype(cfg.dtype, device=self._device) + + if self._dtype == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + if self._log_peak_memory_stats and self._device.type != "cuda": + log.info( + "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." + ) + self._log_peak_memory_stats = False + + # _is_rank_zero is used primarily for logging. In the future, the logger + # should directly take care of this + _, rank = training.get_world_size_and_rank() + self._is_rank_zero = rank == 0 + + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False) + self._clip_grad_norm = cfg.get("clip_grad_norm", None) + + # Optimizer in backward is not compatible with gradient accumulation or gradient clipping + if self._optimizer_in_bwd: + if self._clip_grad_norm is not None: + raise RuntimeError( + "Gradient clipping is not supported with optimizer in bwd." + "Please set clip_grad_norm=None, or optimizer_in_bwd=False." + ) + if self._gradient_accumulation_steps > 1: + raise RuntimeError( + "Gradient accumulation is not supported with optimizer in bwd." + "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." + ) + + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif ( + self._enable_activation_checkpointing + and cfg.checkpointer.model_type != "LLAMA3_VISION" + ): + utils.log_rank_zero( + log, + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further.", + ) + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = training.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + + # Early Exit Properties + cfg_early_exit_loss = cfg.get("early_exit_loss", None) + if cfg_early_exit_loss: + self._do_early_exit_loss = True + self._early_exit_loss_scale = cfg_early_exit_loss.get("scale", 1.0) + self._early_exit_loss_scale_type = _get_component_from_path( + cfg_early_exit_loss.get( + "scale_fn", "torchtune.modules.early_exit_loss.sum_l_loss_scale" + ) + ) + else: + self._do_early_exit_loss = False + self._early_exit_loss_scale = None + self._early_exit_loss_scale_type = None + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. If resume_from_checkpoint + is True, this also includes the recipe state. + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[training.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[training.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[training.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe. This includes training state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, sampler, and dataloader. + """ + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + + self._compile = cfg.get("compile", False) + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, + custom_sharded_layers=cfg.get("custom_sharded_layers", None), + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), + reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), + model_state_dict=checkpoint_dict[training.MODEL_KEY], + ac_mode=cfg.get("ac_mode", None), + ac_option=cfg.get("ac_option", None), + ) + self._tokenizer = config.instantiate(cfg.tokenizer) + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + optimizer_in_bwd=self._optimizer_in_bwd, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + # initialize loss + self._loss_fn = config.instantiate(cfg.loss) + + if self._compile: + training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) + + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + + if self._is_rank_zero: + log.info("Loss is initialized.") + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after both of these are initialized + collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + collate_fn=collate_name, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + # + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader, the max_steps_per_epoch param set by the user and the + # gradient_accumulation_steps param. This value is used for logging and tracking + # training state. The computation should happen after the dataloader has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + + # Setup early exit loss + ( + self._do_output_hidden_states, + self._early_exit_loss_curriculum, + ) = self._setup_early_exit_loss(cfg.get("early_exit_loss", None)) + + # Layer Dropout Setup + cfg_layer_dropout = cfg.get("layer_dropout", None) + if cfg_layer_dropout: + prepare_layer_dropout( + self._model.layers, + prob_max=cfg_layer_dropout.get("prob", 0.0), + prob_layer_scale=cfg_layer_dropout.get("layers_scale", "uniform"), + layers_str=cfg_layer_dropout.get("layers", ":"), + disable_on_eval=cfg_layer_dropout.get("disable_on_eval", True), + ) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + if self._is_rank_zero: + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + enable_activation_offloading: bool, + fsdp_cpu_offload: bool, + reshard_after_forward: bool, + model_state_dict: Dict[str, Any], + custom_sharded_layers: Optional[List[str]] = None, + ac_mode: Optional[str] = None, + ac_option: Optional[int] = None, + ) -> nn.Module: + """ + Model initialization has some important considerations: + a. To minimize GPU peak memory, we initialize the model on meta device with + the right dtype + b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since + full state dicts are loaded with ``torch.load(mmap=True)`` + """ + + if self._is_rank_zero: + log.info( + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." + ) + init_start = time.perf_counter() + + with training.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + + # We currently have two versions of activation checkpointing in this recipe + # for testing and BC purposes. ``enable_activation_checkpointing`` controls + # the older version of AC and this behavior is unchanged + # ac_mode and ac_option together control selective AC. This is only enabled + # when these are set AND ``enable_activation_checkpointing`` is set to False + # We'll clean this up as soon as testing of AC is complete + if (not enable_activation_checkpointing) and (ac_mode is not None): + apply_selective_activation_checkpointing( + model, + ac_mode, + ac_option, + ) + + # original activation checkpointing (full) - flip the condition above + if enable_activation_checkpointing and ac_mode is None: + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) + + # For FSDP sharding + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, + ) + ] + training.shard_model( + model=model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, + ) + + with training.set_default_dtype(self._dtype), self._device: + for m in model.modules(): + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + # This method will convert the full model state dict into a sharded state + # dict and load into the model + training.load_from_full_model_state_dict( + model, + model_state_dict, + self._device, + self._is_rank_zero, + strict=True, + cpu_offload=fsdp_cpu_offload, + ) + + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + + # Ensure no params and buffers are on meta device + training.validate_no_params_on_meta_device(model) + + if self._is_rank_zero: + log.info( + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" + ) + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + + # synchronize before training begins + torch.distributed.barrier() + + return model + + def _setup_optimizer( + self, + cfg_optimizer: DictConfig, + optimizer_in_bwd: bool = False, + opt_state_dict: Optional[Dict[str, Any]] = None, + ) -> Optional[Optimizer]: + if optimizer_in_bwd: + # Maintain a dict of optims for every parameter. + optim_dict = { + param: config.instantiate(cfg_optimizer, [param]) + for param in self._model.parameters() + } + + # Register optimizer step hooks on the model to run optimizer in backward. + training.register_optim_in_bwd_hooks( + model=self._model, optim_dict=optim_dict + ) + # Create a wrapper for checkpoint save/load of optimizer states when running in backward. + self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( + model=self._model, optim_dict=optim_dict + ) + # Load optimizer states for each param. If optimizer states are being restored in an optimizer in + # backward run, these need to have been saved with the same setting. Cannot restore from runs that + # did not use optimizer in backward. + if opt_state_dict is not None: + for param in opt_state_dict.keys(): + try: + training.load_from_full_optimizer_state_dict( + self._optim_ckpt_wrapper.state_dict()[param], + opt_state_dict[param], + self._device, + ) + except BaseException as e: + raise RuntimeError( + "Failed loading in-backward optimizer checkpoints." + "Please make sure run being restored from was using in-backward optimizer." + ) from e + if self._is_rank_zero: + log.info("In-backward optimizers are set up.") + return None + else: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + training.load_from_full_optimizer_state_dict( + optimizer, + opt_state_dict, + self._device, + ) + + if self._is_rank_zero: + log.info("Optimizer is initialized.") + return optimizer + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + collate_fn: str, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports the + DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, + iterable datasets and streaming datasets are not supported. + """ + world_size, rank = training.get_world_size_and_rank() + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + packed = False + else: + ds = config.instantiate(cfg_dataset, self._tokenizer) + packed = cfg_dataset.get("packed", False) + + # Instantiate collate_fn + if "left_pad_sequence" in collate_fn: + raise RuntimeError("left_pad_sequence collator is only for inference.") + collate_fn = _get_component_from_path(collate_fn) + + sampler = DistributedSampler( + ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 + ) + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + # dropping last avoids shape issues with compile + flex attention + drop_last=True, + collate_fn=( + partial( + collate_fn, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else padded_collate_packed + ), + ) + + if self._is_rank_zero: + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def _setup_early_exit_loss( + self, + cfg_early_exit_loss: DictConfig, + ) -> Tuple[List[bool], EarlyExitCurriculum]: + """ + All early exit loss related setup happens here. + """ + do_output_hidden_states = None + early_exit_loss_curriculum = None + + if cfg_early_exit_loss: + assert ( + hasattr(self._loss_fn, "reduction") + and self._loss_fn.reduction == "mean" + ), "Currently early exit loss is only implemented for loss functions that apply a mean reduction." + + do_output_hidden_states = slice_str_to_array( + cfg_early_exit_loss.get("layers", ":"), len(self._model.layers) + ) + train_last_layer = cfg_early_exit_loss.get("include_last_layer", True) + verbose = cfg_early_exit_loss.get("verbose", False) + + early_exit_loss_curriculum = cfg_early_exit_loss.get("curriculum", None) + if early_exit_loss_curriculum: + early_exit_loss_curriculum = _get_component_from_path( + early_exit_loss_curriculum + )( + do_output_hidden_states=do_output_hidden_states, + max_steps=self.total_epochs * self._steps_per_epoch, + train_last_layer=train_last_layer, + last_step=self.global_step, + verbose=verbose, + ) + do_output_hidden_states = early_exit_loss_curriculum.get() + else: + if train_last_layer: + do_output_hidden_states[len(self._model.layers) - 1] = True + + return do_output_hidden_states, early_exit_loss_curriculum + + def save_checkpoint( + self, + epoch: int, + ) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Model weights with key training.MODEL_KEY + - Relevant recipe state if training is not complete + + Checkpointer will save the model weights and recipe state in + different checkpoint files. To correctly resume training from an intermediate checkpoint, + the model weights and recipe state must be provided. + """ + # final dict passed onto the checkpointer + checkpoint_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + + if self._is_rank_zero: + log.info( + "Saving checkpoint. This may take some time. Retrieving full model state dict..." + ) + start = time.perf_counter() + + # To prevent GPU memory from spiking during checkpoint save, + # we consolidate the full model and optim state dicts on CPU for rank 0 + cpu_state_dict = training.gather_cpu_state_dict( + self._model.state_dict(), + self._is_rank_zero, + device=self._device, + ) + + if self._is_rank_zero: + log.info( + f"Getting full model state dict took {time.perf_counter() - start:.2f} secs" + ) + + if intermediate_checkpoint: + start = time.perf_counter() + if self._is_rank_zero: + log.info("Getting optimizer state dict...") + if not self._optimizer_in_bwd: + opt_state_dict = training.get_full_optimizer_state_dict( + self._optimizer, + self._is_rank_zero, + device=self._device, + ) + else: + opt_state_dict = {} + for param, opt in self._optim_ckpt_wrapper.optim_map.items(): + opt_state_dict[param] = training.get_full_optimizer_state_dict( + opt, self._is_rank_zero, device=self._device + ) + if self._is_rank_zero: + log.info( + f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs" + ) + else: + opt_state_dict = None + + # Now that we have the model and opt state dict, create the actual checkpoint dict + # to be sent to the checkpointer and ultimately written to file + + if self._is_rank_zero: + start = time.perf_counter() + checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict}) + + # if training is in-progress, checkpoint the optimizer state and recipe state + # as well. + if intermediate_checkpoint: + checkpoint_dict.update( + { + training.OPT_KEY: opt_state_dict, + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self.epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + self._checkpointer.save_checkpoint( + checkpoint_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + ) + log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs") + + torch.distributed.barrier() + + def train(self) -> None: + """ + The core training loop. + """ + # clean up before training begins + training.cleanup_before_training() + + world_size, rank = training.get_world_size_and_rank() + + # zero out the gradients before starting training + if not self._optimizer_in_bwd: + self._optimizer.zero_grad() + else: + for opt in self._optim_ckpt_wrapper.optim_map.values(): + opt.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + # Initialize output hidden states + if self._do_output_hidden_states is not None: + self._model.output_hidden_states = [ + i + for i in range(len(self._do_output_hidden_states)) + if self._do_output_hidden_states[i] + ] + + self._profiler.start() + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # Start tracking CUDA memory for active steps for just the first epoch + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + ): + torch.cuda.memory._record_memory_history() + + utils.batch_to_device(batch, self._device) + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens + + # Shape [b, s], needed for the loss not the model + labels = batch.pop("labels") + + with self.activations_handling_ctx: + outputs = self._model(**batch) + if self._model.output_hidden_states: + logits = outputs.pop(-1) + hidden_states = { + i: h + for i, h in zip(self._model.output_hidden_states, outputs) + } + else: + logits = outputs + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + + # Compute loss + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients + if self._model.output_hidden_states: + current_loss = ( + early_exit_loss( + self._model, + hidden_states, + labels, + self._loss_fn, + self._early_exit_loss_scale, + self._early_exit_loss_scale_type, + ) + * current_num_tokens + ) + else: + current_loss = self._loss_fn(logits, labels) * current_num_tokens + + # free logits otherwise it peaks backward memory + del logits + + running_loss += current_loss + + # For optimizer in backward, we need to normalize before calling backward + # This case and gradient accumulation are mutually exclusive + if self._optimizer_in_bwd: + torch.distributed.all_reduce(num_tokens) + torch.distributed.all_reduce(running_loss) + current_loss = current_loss / num_tokens + + current_loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + if not self._optimizer_in_bwd: + # Get total number of tokens across all ranks to normalize gradients + torch.distributed.all_reduce(num_tokens) + # This will ensure that the logged loss matches what we're optimizing + torch.distributed.all_reduce(running_loss) + # Manually scale the gradients from unnormalized loss by total # of tokens + training.scale_grads(self._model, 1 / num_tokens) + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + # Update the number of steps when the weights are updated + self.global_step += 1 + + loss_to_log = running_loss.item() / num_tokens + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": get_lr( + ( + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper + ), + ), + "tokens_per_second_per_gpu": num_tokens + / (time_per_step * world_size), + } + if self._log_peak_memory_stats: + log_dict.update( + training.get_memory_stats(device=self._device) + ) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + # Update Early Exit Layers/Scales + if self._early_exit_loss_curriculum: + self._early_exit_loss_curriculum.step() + self._do_output_hidden_states = ( + self._early_exit_loss_curriculum.get() + ) + self._model.output_hidden_states = [ + i + for i in range(len(self._do_output_hidden_states)) + if self._do_output_hidden_states[i] + ] + + # Stop tracking CUDA memory now that active steps are complete + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + ): + torch.cuda.memory._record_memory_history(enabled=None) + + # Step profiler + # Note that this is called within gradient accumulation block, hence + # will include multiple forward / backward passes if gradient accumulation > 1 + self._profiler.step() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + self._profiler.stop() + + def cleanup(self) -> None: + if self._is_rank_zero: + self._metric_logger.close() + destroy_process_group() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + if not training.is_distributed(): + raise RuntimeError( + "Distributed finetune recipe should be run via a distributed launcher." + "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" + ) + init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + if cfg.get("fsdp_cpu_offload", False): + # Utilize all available CPU cores for intra-op parallelism. This provides ~2x + # speed up when benchmarking fused AdamW on CPU + training.set_torch_num_threads() + + config.log_config(recipe_name="EarlyExitFinetuneRecipeDistributed", cfg=cfg) + + recipe = EarlyExitFinetuneRecipeDistributed(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/tests/torchtune/modules/test_common_utils.py b/tests/torchtune/modules/test_common_utils.py index 41dc472f00..0e2a6400e0 100644 --- a/tests/torchtune/modules/test_common_utils.py +++ b/tests/torchtune/modules/test_common_utils.py @@ -14,6 +14,7 @@ llama3_2_vision_encoder, ) from torchtune.modules import delete_kv_caches, disable_kv_cache, local_kv_cache +from torchtune.modules.common_utils import slice_str_to_array from torchtune.modules.model_fusion import DeepFusionModel @@ -191,3 +192,28 @@ def test_disable_kv_cache_raises_error_caches_not_setup(self, model, request): with pytest.raises(ValueError, match="Model caches must be setup"): with disable_kv_cache(model): pass + + +class TestSliceStrToArray: + def test_single_index(self): + assert slice_str_to_array("0", 5) == [True, False, False, False, False] + + def test_slice_with_start_and_end(self): + assert slice_str_to_array("1:3", 5) == [False, True, True, False, False] + + def test_slice_with_start_and_step(self): + assert slice_str_to_array("1::2", 5) == [False, True, False, True, False] + + def test_slice_with_start_end_and_step(self): + assert slice_str_to_array("1:4:2", 5) == [False, True, False, True, False] + + def test_multiple_indices(self): + assert slice_str_to_array("0,2,4", 6) == [True, False, True, False, True, False] + + def test_out_of_range_index(self): + with pytest.raises(AssertionError): + slice_str_to_array("10", 5) + + def test_invalid_slice_format(self): + with pytest.raises(AssertionError): + slice_str_to_array("1:2:3:4", 5) diff --git a/tests/torchtune/modules/test_early_exit_loss.py b/tests/torchtune/modules/test_early_exit_loss.py new file mode 100644 index 0000000000..27d421c14d --- /dev/null +++ b/tests/torchtune/modules/test_early_exit_loss.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import random + +import numpy as np +import pytest +import torch +import torch.nn as nn +from torchtune.modules import TransformerDecoder +from torchtune.modules.early_exit_loss import ( + early_exit_loss, + GradualEarlyExitCurriculum, + inv_l_loss_scale, + inv_sqrt_l_loss_scale, + linear_l_loss_scale, + RotationalEarlyExitCurriculum, + sqrt_l_loss_scale, + sum_l_loss_scale, + uniform_loss_scale, +) + +# Mock components for TransformerDecoder +class MockLayer(nn.Module): + def forward( + self, x, mask=None, encoder_input=None, encoder_mask=None, input_pos=None + ): + return x # Simply return the input for testing purposes + + +class TestEarlyExitLoss: + @pytest.fixture + def num_layers(self): + return 12 + + @pytest.fixture + def mock_model(self, num_layers): + # Create mock components + tok_embeddings = nn.Embedding(1000, 512) # Example vocab size and embedding dim + layers = nn.ModuleList([MockLayer() for _ in range(12)]) # 12 mock layers + norm = nn.LayerNorm(512) # Example layer normalization + output = nn.Linear(512, 1000) # Example output layer + + # Create an instance of TransformerDecoder + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layers, + max_seq_len=512, + num_heads=8, + head_dim=64, + norm=norm, + output=output, + num_layers=num_layers, + output_hidden_states=[0, 1, 2], # Example layers to output hidden states + ) + return model + + @pytest.fixture + def hidden_states_dict(self): + return {i: torch.randn(4, 5, 512) for i in range(3)} # Adjusted embedding dim + + @pytest.fixture + def labels(self): + return torch.randint(0, 1000, (4, 5)) # Adjusted vocab size + + @pytest.fixture + def loss_fn(self): + return nn.CrossEntropyLoss(ignore_index=-1) + + def test_early_exit_loss(self, mock_model, hidden_states_dict, labels, loss_fn): + loss = early_exit_loss(mock_model, hidden_states_dict, labels, loss_fn) + assert isinstance(loss, torch.Tensor) + assert loss.item() >= 0 + + @pytest.mark.parametrize( + "scale_fn", + [ + uniform_loss_scale, + linear_l_loss_scale, + sum_l_loss_scale, + sqrt_l_loss_scale, + inv_l_loss_scale, + inv_sqrt_l_loss_scale, + ], + ) + def test_layer_ids_to_loss_scales(self, scale_fn, num_layers): + for n_subset_layers in range(1, num_layers + 1): + layer_ids = torch.tensor( + random.sample(range(0, num_layers), n_subset_layers) + ) + scales = scale_fn(layer_ids, num_layers, 1.0) + assert torch.isclose(scales.sum(), torch.tensor(1.0)) + + def test_early_exit_loss_vs_manual( + self, mock_model, hidden_states_dict, labels, loss_fn + ): + # Convert to float32 for numeric equivalence + # Calculate early exit loss using the function + calculated_loss = early_exit_loss( + mock_model, + hidden_states_dict, + labels, + loss_fn, + e_scale=1, + loss_scale_fn=uniform_loss_scale, + ) + # Manually calculate the loss for each hidden state + total_loss = 0.0 + num_hidden_states = len(hidden_states_dict) + for i, hidden_state in hidden_states_dict.items(): + # Compute logits for the current hidden state + logits = mock_model.unembed(hidden_state) + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + # Compute the loss for the current hidden state + loss = loss_fn(logits, labels) + total_loss += loss + # Average the losses across all hidden states + manual_loss = total_loss / num_hidden_states + # Compare the two losses + assert torch.isclose( + calculated_loss, manual_loss, atol=1e-6 + ), f"Calculated loss: {calculated_loss}, Manual loss: {manual_loss}" + + +class TestEarlyExitLossCurriculum: + @pytest.mark.parametrize( + "train_last_layer", + [ + True, + False, + ], + ) + def test_rotational_early_exit_curriculum(self, train_last_layer): + curriculum = RotationalEarlyExitCurriculum( + [True, False, False], max_steps=100, train_last_layer=train_last_layer + ) + expected = np.array([True, False, train_last_layer]) + assert np.array_equal(curriculum.get(), expected) + curriculum.step() + expected = np.array([False, True, train_last_layer]) + assert np.array_equal(curriculum.get(), expected) + curriculum.step() + # Since the last element is already True on this rotation, the value of `train_last_layer` has no effect. + expected = np.array([False, False, True]) + assert np.array_equal(curriculum.get(), expected) + curriculum.step() + expected = np.array([True, False, train_last_layer]) + assert np.array_equal(curriculum.get(), expected) + + @pytest.mark.parametrize( + "train_last_layer", + [ + True, + False, + ], + ) + def test_gradual_early_exit_curriculum(self, train_last_layer): + curriculum = GradualEarlyExitCurriculum( + [True, True, True, True], + max_steps=4, + train_last_layer=train_last_layer, + fraction_scale=1, + ) + expected = np.array([False, False, False, train_last_layer]) + assert np.array_equal(curriculum.get(), expected) + curriculum.step() + assert np.array_equal(curriculum.get(), [False, False, False, train_last_layer]) + curriculum.step() + # Since the last element is already True on this update, the value of `train_last_layer` has no effect. + assert np.array_equal(curriculum.get(), [False, False, False, True]) + curriculum.step() + assert np.array_equal(curriculum.get(), [False, False, True, True]) + curriculum.step() + assert np.array_equal(curriculum.get(), [False, True, True, True]) + curriculum.step() + assert np.array_equal(curriculum.get(), [True, True, True, True]) + curriculum.step() + assert np.array_equal(curriculum.get(), [True, True, True, True]) diff --git a/tests/torchtune/modules/test_layer_dropout.py b/tests/torchtune/modules/test_layer_dropout.py new file mode 100644 index 0000000000..9bbd779d79 --- /dev/null +++ b/tests/torchtune/modules/test_layer_dropout.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from typing import Tuple + +import pytest +import torch +from tests.test_utils import assert_expected +from torchtune.modules.layer_dropout import ( + get_scale, + LayerDropout, + ModuleLayerDropoutWrapper, + prepare_layer_dropout, + ScaleType, +) + + +class TestLayerDropout: + """Class for testing LayerDropout implementation.""" + + @pytest.fixture + def input_shape(self) -> Tuple[int, int]: + bsz = 8 + seqlen = 256 + dim = 32 + return bsz, seqlen, dim + + @pytest.fixture + def input(self, input_shape: Tuple[int]) -> torch.Tensor: + return torch.randn(input_shape) + + @pytest.fixture + def layer_dropout( + self, prob: float = 0.5, disable_on_eval: bool = True + ) -> LayerDropout: + return LayerDropout(prob=prob, disable_on_eval=disable_on_eval) + + def test_forward_train_prob_1( + self, layer_dropout: LayerDropout, input: torch.Tensor + ): + # With dropout probability = 1.0, we expect output to be the same as input + layer_dropout.prob = 1.0 + output = layer_dropout.forward(lambda x: x**2, input) + assert torch.allclose(output, input) + + def test_forward_train_prob_0( + self, layer_dropout: LayerDropout, input: torch.Tensor + ): + # With dropout probability = 1.0, we expect the operation to be applied on all elements in the input + layer_dropout.prob = 0.0 + output = layer_dropout.forward(lambda x: x**2, input) + assert torch.allclose(output, input**2) + + def test_forward_eval(self, layer_dropout: LayerDropout, input: torch.Tensor): + layer_dropout.prob = 1.0 + layer_dropout.eval() + + layer_dropout.disable_on_eval = True + output = layer_dropout.forward(lambda x: x**2, input) + assert torch.allclose(output, input**2) + + layer_dropout.disable_on_eval = False + with torch.no_grad(): + output = layer_dropout.forward(lambda x: x**2, input) + assert torch.allclose(output, input) + + +class TestLayerDropoutWrapper: + @pytest.fixture + def input_shape(self) -> Tuple[int, int]: + bsz = 4 + dim = 8 + return (bsz, dim) + + @pytest.fixture + def input(self, input_shape: Tuple[int]) -> torch.Tensor: + return torch.randn(input_shape) + + @pytest.fixture + def model(self, input_shape) -> torch.nn.Module: + _, dim = input_shape + return torch.nn.Sequential( + torch.nn.Linear(dim, 32), torch.nn.ReLU(), torch.nn.Linear(32, dim) + ) + + @pytest.fixture + def linear(self, input_shape) -> torch.nn.Module: + _, dim = input_shape + return torch.nn.Linear(dim, dim) + + def test_linear(self, linear: torch.nn.Module, input: torch.Tensor): + wrapper = ModuleLayerDropoutWrapper(linear, LayerDropout(prob=0.5)) + assert wrapper.module == linear + + # Test output + wrapper.dropout.prob = 1 + assert torch.allclose(wrapper(input), input) + wrapper.dropout.prob = 0 + assert torch.allclose(wrapper(input), linear(input)) + + # Test getters + assert wrapper.in_features == linear.in_features + assert wrapper.out_features == linear.out_features + assert torch.equal(wrapper.weight, linear.weight) + + # Test setters + wrapper.weight.data = wrapper.weight.data * 2 + assert torch.equal(wrapper.weight, linear.weight) + + # Test state_dict + for k in wrapper.state_dict().keys(): + assert torch.equal(wrapper.state_dict()[k], linear.state_dict()[k]) + + def test_model(self, model: torch.nn.Module, input: torch.Tensor): + wrapper = ModuleLayerDropoutWrapper(model, LayerDropout(prob=0.5)) + assert wrapper.module == model + + # Test output + wrapper.dropout.prob = 1 + assert torch.allclose(wrapper(input), input) + wrapper.dropout.prob = 0 + assert torch.allclose(wrapper(input), model(input)) + + # Test getters + assert wrapper[0].in_features == model[0].in_features + assert wrapper[0].out_features == model[0].out_features + assert torch.equal(wrapper[0].weight, model[0].weight) + + # Test setters + wrapper[2].weight.data = wrapper[2].weight.data * 2 + assert torch.equal(wrapper[2].weight, model[2].weight) + + # Test state_dict + for k in wrapper.state_dict().keys(): + assert torch.equal(wrapper.state_dict()[k], model.state_dict()[k]) + + +class TestScales: + def test_get_scale_uniform(self): + scale_type = ScaleType.UNIFORM + scale_period = 10 + + assert_expected(get_scale(scale_type, scale_period, 0), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period / 2), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0) + + def test_get_scale_linear(self): + scale_type = ScaleType.LINEAR + scale_period = 10 + + assert_expected(get_scale(scale_type, scale_period, 0), 0.0) + assert_expected(get_scale(scale_type, scale_period, scale_period / 2), 1 / 2) + assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0) + + def test_get_scale_exp(self): + scale_type = ScaleType.EXP + scale_period = 10 + + assert_expected(get_scale(scale_type, scale_period, 0), 0.0) + assert_expected( + get_scale(scale_type, scale_period, scale_period / 2), + math.pow(2, 1 / 2) - 1, + ) + assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0) + + def test_get_scale_log(self): + scale_type = ScaleType.LOG + scale_period = 10 + + assert_expected(get_scale(scale_type, scale_period, 0), 0.0) + assert_expected( + get_scale(scale_type, scale_period, scale_period / 2), + math.log(5 + 1, scale_period + 1), + ) + assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0) + + def test_get_scale_sin(self): + scale_type = ScaleType.SIN + scale_period = 10 + + assert_expected(get_scale(scale_type, scale_period, 0), 0.0) + assert_expected( + get_scale(scale_type, scale_period, scale_period / 2), + math.sin(0.5 * math.pi * 0.5), + ) + assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0) + + def test_get_scale_sigmoid(self): + scale_type = ScaleType.SIGMOID + scale_period = 10 + + # sigmoid(0) is close to 0 but not 0, hence adding relatively large rotl and atol + assert_expected( + get_scale(scale_type, scale_period, 0), 0.0, rtol=1e-2, atol=1e-2 + ) + assert_expected( + get_scale(scale_type, scale_period, scale_period / 2), + 0.5, + ) + assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0) + assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0) + + +class TestLayerDropoutModel: + def test_prepare_layer_dropout_uniform(self): + class MockModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList( + [torch.nn.Linear(10, 10) for _ in range(5)] + ) + + model = MockModel() + prob_max = 0.5 + prob_layer_scale = ScaleType.UNIFORM + layers_str = "0:4" + prepare_layer_dropout(model.layers, prob_max, prob_layer_scale, layers_str) + for i, layer in enumerate(model.layers): + assert hasattr(layer, "dropout") + if i in range(0, 4): + assert layer.dropout.prob == prob_max + else: + assert layer.dropout.prob == 0 + + def test_prepare_layer_dropout_exp(self): + class MockModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList( + [torch.nn.Linear(10, 10) for _ in range(5)] + ) + + model = MockModel() + prob_max = 0.5 + prob_layer_scale = ScaleType.EXP + layers_str = ":" + prepare_layer_dropout(model.layers, prob_max, prob_layer_scale, layers_str) + for i, layer in enumerate(model.layers): + assert hasattr(layer, "dropout") + if i == 0: + assert layer.dropout.prob == 0 + elif i == len(model.layers) - 1: + assert layer.dropout.prob == prob_max + else: + assert layer.dropout.prob > 0 and layer.dropout.prob < prob_max + + def test_prepare_layer_dropout_linear(self): + class MockModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList( + [torch.nn.Linear(10, 10) for _ in range(5)] + ) + + model = MockModel() + prob_max = 0.5 + prob_layer_scale = ScaleType.LINEAR + layers_str = ":" + prepare_layer_dropout(model.layers, prob_max, prob_layer_scale, layers_str) + for i, layer in enumerate(model.layers): + assert hasattr(layer, "dropout") + if i == 0: + assert layer.dropout.prob == 0 + elif i == len(model.layers) - 1: + assert layer.dropout.prob == prob_max + elif i == len(model.layers) / 2: + assert layer.dropout.prob == prob_max / 2 + else: + assert layer.dropout.prob >= 0.0 and layer.dropout.prob <= prob_max diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 5bbf860482..5776269b8d 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -411,6 +411,17 @@ class Recipe: ], supports_distributed=False, ), + Recipe( + name="dev/early_exit_finetune_distributed", + file_path="dev/early_exit_finetune_distributed.py", + configs=[ + Config( + name="llama2/7B_full_early_exit", + file_path="dev/7B_full_early_exit.yaml", + ), + ], + supports_distributed=True, + ), Recipe( name="eleuther_eval", file_path="eleuther_eval.py", diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 29c014c33e..3554698d42 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -14,6 +14,7 @@ ) from .feed_forward import FeedForward # noqa from .kv_cache import KVCache # noqa +from .layer_dropout import LayerDropout, prepare_layer_dropout # noqa from .layer_norm import Fp32LayerNorm # noqa from .low_precision import FrozenNF4Linear # noqa from .position_embeddings import ( # noqa @@ -30,6 +31,7 @@ ) from .vision_transformer import VisionTransformer + __all__ = [ "MultiHeadAttention", "TanhGate", @@ -51,4 +53,6 @@ "local_kv_cache", "delete_kv_caches", "disable_kv_cache", + "LayerDropout", + "prepare_layer_dropout", ] diff --git a/torchtune/modules/common_utils.py b/torchtune/modules/common_utils.py index 38b5776f88..16df7801cb 100644 --- a/torchtune/modules/common_utils.py +++ b/torchtune/modules/common_utils.py @@ -59,6 +59,80 @@ def reparametrize_as_dtype_state_dict_post_hook( state_dict[k] = state_dict[k].cpu() +def slice_str_to_array(slice_str: str, length: int) -> list[bool]: + """ + Convert a string representing a Python slice or index into a boolean array. + + The resulting array will have the same length as the specified `length` parameter. + Each element in the array corresponds to an index in the original sequence, + with `True` indicating that the index is included in the slice and `False` otherwise. + + Args: + slice_str (str): A string representing a Python slice or index, e.g. "1:3", ":5", "2::3", "0,4,5". + length (int): The length of the original sequence. + + Returns: + list[bool]: A boolean array representing the slice. + + Examples: + >>> slice_str_to_array("1:3", 5) + [False, True, True, False, False] + >>> slice_str_to_array(":", 5) + [True, True, True, True, True] + >>> slice_str_to_array("::2", 5) + [True, False, True, False, True] + >>> slice_str_to_array("1::2", 5) + [False, True, False, True, False] + >>> slice_str_to_array("2:5:2", 6) + [False, False, True, False, True, False] + >>> slice_str_to_array("0,4,5", 7) + [True, False, False, False, True, True, False] + """ + + assert "," not in slice_str or ":" not in slice_str, "Cannot mix commas and colons" + + if "," in slice_str: + indices = [int(i) for i in slice_str.split(",")] + assert all(0 <= i < length for i in indices), "Index out of range" + result = [False] * length + for i in indices: + result[i] = True + return result + + parts = slice_str.split(":") + assert len(parts) <= 3, "Invalid slice format" + start, end, step = None, None, None + + if len(parts) == 1 and parts[0] != "": + start = int(parts[0]) + end = start + 1 + step = 1 + elif len(parts) == 2: + start = int(parts[0]) if parts[0] != "" else None + end = int(parts[1]) if parts[1] != "" else None + elif len(parts) == 3: + start = int(parts[0]) if parts[0] != "" else None + end = int(parts[1]) if parts[1] != "" else None + step = int(parts[2]) if parts[2] != "" else None + + assert start is None or 0 <= start < length, "Start index out of range" + assert end is None or 0 <= end < length, "End index out of range" + assert step is None or step != 0, "Step cannot be zero" + + result = [False] * length + slice_indices = range( + start if start is not None else 0, + end if end is not None else length, + step if step is not None else 1, + ) + + for i in slice_indices: + if 0 <= i < length: + result[i] = True + + return result + + def _low_ram_reparametrize_as_dtype_state_dict_post_hook( model: nn.Module, state_dict: Dict[str, Any], diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py new file mode 100644 index 0000000000..976f6bee07 --- /dev/null +++ b/torchtune/modules/early_exit_loss.py @@ -0,0 +1,292 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +from typing import Callable, Dict, List, Optional + +import numpy as np +import torch + +from torchtune import utils +from torchtune.modules.transformer import TransformerDecoder + +log = utils.get_logger("DEBUG") + + +def uniform_loss_scale( + layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0 +) -> torch.Tensor: + loss_scales = torch.ones(len(layer_ids), device=layer_ids.device) + loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) + return loss_scales / torch.sum(loss_scales) + + +def linear_l_loss_scale( + layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0 +) -> torch.Tensor: + loss_scales = torch.Tensor(layer_ids + 1) + loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) + return loss_scales / torch.sum(loss_scales) + + +def sum_l_loss_scale( + layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0 +) -> torch.Tensor: + loss_scales = torch.cumsum(layer_ids + 1, dim=0) + loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) + return loss_scales / torch.sum(loss_scales) + + +def sqrt_l_loss_scale( + layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0 +) -> torch.Tensor: + loss_scales = torch.sqrt(layer_ids + 1) + loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) + return loss_scales / torch.sum(loss_scales) + + +def inv_l_loss_scale( + layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0 +) -> torch.Tensor: + loss_scales = 1.0 / (layer_ids + 1) + loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) + return loss_scales / torch.sum(loss_scales) + + +def inv_sqrt_l_loss_scale( + layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0 +) -> torch.Tensor: + loss_scales = torch.reciprocal(torch.sqrt(layer_ids + 1)) + loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0) + return loss_scales / torch.sum(loss_scales) + + +def early_exit_loss( + model: TransformerDecoder, + hidden_states_dict: Dict[int, torch.Tensor], + labels: torch.Tensor, + loss_fn: torch.nn.Module, + e_scale: float = 1.0, + loss_scale_fn: Callable[ + [torch.Tensor, int, float], torch.Tensor + ] = uniform_loss_scale, +) -> torch.Tensor: + """ + Compute the early exit loss for a given model and outputs of intermediate layers. + This function takes in a model, a dictionary of hidden states, labels, a loss function, + and optional parameters for scaling the loss. It computes the early exit loss by + iterating over the hidden states, computing the logits and losses at each layer, + and then scaling and summing these losses. + + Args: + model (TransformerDecoder): The model to compute the early exit loss for. + hidden_states_dict (Dict[int, torch.Tensor]): A dictionary of hidden states, + where each key is a layer index and each value is a tensor of shape [b, s, d]. + labels (torch.Tensor): The labels for the input data. + loss_fn (torch.nn.Module): The loss function to use (should be the same as the standard loss function for last layer). + e_scale (float): A scaling factor for the early exit losses. Defaults to 1.0. + loss_scale_fn (Callable[[torch.Tensor, int, float], torch.Tensor]): A function to determine scale of each + layer's loss. Defaults to uniform_loss_scale. + Returns: + torch.Tensor: The computed early exit loss. + """ + batch_loss_fn = copy.deepcopy(loss_fn) + batch_loss_fn.reduction = "none" + + e = len(hidden_states_dict) + # List of e tensors with shape [b, s, d] + hidden_states = tuple(hidden_states_dict.values()) + hidden_layer_ids = tuple(hidden_states_dict.keys()) + # Shape: [e, b, s, d] + hidden_states_stacked = torch.stack(hidden_states) + # Shape: [e, b, s, out_dim] + logits_early = model.unembed(hidden_states_stacked) + # Shape: [e*b*s, out_dim] + logits_early = logits_early.reshape(-1, logits_early.size(-1)) + logits_early = logits_early.contiguous() + # Shape: [e*b*s] + labels_repeated = labels.repeat(e, 1).reshape(-1) + # Compute early losses: Shape: [e*b*s] + losses_early = batch_loss_fn(logits_early, labels_repeated) + # Shape: [e, b*s] + losses_early = losses_early.view(e, -1) + # Shape: [e] + s_unpadded = (labels != loss_fn.ignore_index).sum() + losses_early = losses_early.float().sum(-1) / s_unpadded + # Shape: [e] + losses_scales = loss_scale_fn( + torch.Tensor(hidden_layer_ids).to(losses_early), + len(model.layers), + e_scale, + ) + + return torch.sum(losses_scales * losses_early) + + +# TODO: create a base curriculum class that can be used for other aspects, e.g., dropout, datasets, etc. +class EarlyExitCurriculum: + """ + A curriculum for early exit loss training, which controls which layers to use their hidden states + during training. + + Args: + do_output_hidden_states (List[bool]): A list indicating whether each layer's hidden state + should be output to calculate their losses. + max_steps (int): The maximum number of steps in the curriculum. + train_last_layer (bool, optional): Whether to always calculate loss for the last layer. Defaults to True. + last_step (Optional[int]): The last step the curriculum stopped at in a previous run. + This is used when resuming training. + verbose (bool, optional): Whether to print verbose logs. Defaults to False. + """ + + def __init__( + self, + do_output_hidden_states: List[bool], + max_steps: int, + train_last_layer: bool = True, + last_step: Optional[int] = None, + verbose: bool = False, + ): + self._init_do_output_hidden_states = do_output_hidden_states + self._do_output_hidden_states = do_output_hidden_states + self.train_last_layer = train_last_layer + self.verbose = verbose + self.max_steps = max_steps + self._step = 0 if last_step is None else last_step + + def step(self) -> None: + """ + Perform a step in the curriculum. Should be called at the end of each iteration during training. + """ + pass + + def get(self) -> np.ndarray: + """ + Get the current output hidden states. + Returns: + np.ndarray: A list indicating whether we should calculate loss for each layer. + """ + do_output_hidden_states = np.copy(self._do_output_hidden_states) + # Ensure last layer is trained + if self.train_last_layer: + do_output_hidden_states[-1] = True + return do_output_hidden_states + + +class RotationalEarlyExitCurriculum(EarlyExitCurriculum): + """ + A rotational early exit curriculum, which rotates the layer enablement one step forward + at each step. + + Args: + do_output_hidden_states (List[bool]): A list indicating whether each layer's hidden state + should be output to calculate their losses. + max_steps (int): The maximum number of steps in the curriculum. + train_last_layer (bool, optional): Whether to always calculate loss for the last layer. Defaults to True. + last_step (Optional[int]): The last step the curriculum stopped at in a previous run. + This is used when resuming training. + verbose (bool, optional): Whether to print verbose logs. Defaults to False. + """ + + def __init__( + self, + do_output_hidden_states: List[bool], + max_steps: int, + train_last_layer: bool = True, + last_step: Optional[int] = None, + verbose: bool = False, + ): + super().__init__( + do_output_hidden_states=do_output_hidden_states, + max_steps=max_steps, + train_last_layer=train_last_layer, + last_step=last_step, + verbose=verbose, + ) + self._initial_do_output_hidden_states = np.copy(self._do_output_hidden_states) + + def step(self): + """ + Rotate the layer enablement one step forward. + This method updates the `do_output_hidden_states` attribute by rotating it one position to the right. + """ + # Rotate layer enablement one step forward + self._do_output_hidden_states = np.roll(self._do_output_hidden_states, 1) + + self._step += 1 + if self.verbose: + log.info( + f"Updated self._do_output_hidden_states to {self._do_output_hidden_states}." + ) + + +class GradualEarlyExitCurriculum(EarlyExitCurriculum): + """ + A gradual early exit curriculum, which gradually enables more layers (starting from the last layer) as training progresses. + + Args: + do_output_hidden_states (List[bool]): A list indicating whether each layer's hidden state + should be output to calculate their losses. + max_steps (int): The maximum number of steps in the curriculum. + train_last_layer (bool): Whether to always calculate loss for the last layer. Defaults to True. + last_step (Optional[int]): The last step the curriculum stopped at in a previous run. + This is used when resuming training. + fraction_scale (float): A scaling factor to determine at which fraction + of steps, all the layers will be enabled. At `steps = max_steps * fraction_scale`, all the layers will be + enabled. Defaults to 0.5. + verbose (bool): Whether to print verbose logs. Defaults to False. + """ + + def __init__( + self, + do_output_hidden_states: List[bool], + max_steps: int, + train_last_layer: bool = True, + last_step: Optional[int] = None, + fraction_scale: float = 0.5, + verbose: bool = False, + ): + super().__init__( + do_output_hidden_states=do_output_hidden_states, + max_steps=max_steps, + train_last_layer=train_last_layer, + last_step=last_step, + verbose=verbose, + ) + self._final_do_output_hidden_states = np.copy(self._do_output_hidden_states) + self._step = 0 + self._fraction_scale = fraction_scale + + # Initialize all layers to False + for i in range(len(self._do_output_hidden_states)): + self._do_output_hidden_states[i] = False + + def step(self): + """ + Perform a step in the curriculum. + This method updates the `_do_output_hidden_states` attribute based on the current + step and the fraction of completed training steps. + """ + fraction_trained = self._step / self.max_steps + n_layers = len(self._do_output_hidden_states) + # Enable each layer based on proportion of completed training steps + for layer_index in range(len(self._do_output_hidden_states)): + should_train = ( + fraction_trained + >= self._fraction_scale * (n_layers - layer_index) / n_layers + ) + self._do_output_hidden_states[layer_index] = should_train + + # Only enable layers that are set by the user + self._do_output_hidden_states = np.logical_and( + self._do_output_hidden_states, self._final_do_output_hidden_states + ) + + self._step += 1 + if self.verbose: + log.info( + f"Updated self._do_output_hidden_states to {self._do_output_hidden_states}." + ) diff --git a/torchtune/modules/layer_dropout.py b/torchtune/modules/layer_dropout.py new file mode 100644 index 0000000000..75e28b4ae1 --- /dev/null +++ b/torchtune/modules/layer_dropout.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from enum import Enum +from typing import Any, Callable, Iterable, Optional, Union + +import torch + +from torchtune.modules.common_utils import slice_str_to_array + + +class LayerDropout(torch.nn.Module): + """ + A module that applies layer dropout to the input tensor of an underlying module. + It drops a portion of an input tensor, applies the underlying module on the + remaining parts of the tensor, and then concatenates with the dropped portion of the tensor. + When applied during training, it can have a regularization effect, and can potentially speedup training. + + Args: + prob (float): The probability of dropping an input. Defaults to 0.0. + dim (Optional[int]): The dimension of input tensor along which to drop layers. Defaults to 0 (i.e., batch size). + disable_on_eval (Optional[bool]): Whether to disable layer dropout during evaluation. Defaults to True. + seed (Optional[int]): The seed for the random number generator. Defaults to None. + Examples: + >>> import torch + >>> # Apply layer dropout to a lambda function + >>> layer_dropout = LayerDropout(prob=0.5) + >>> output = layer_dropout(lambda x: x**2, torch.randn(1)) + >>> # Apply layer dropout to a torch.nn.Linear module + >>> linear = torch.nn.Linear(5, 3) + >>> layer_dropout = LayerDropout(prob=0.5) + >>> output = layer_dropout(linear, torch.randn(1, 5)) + """ + + def __init__( + self, + prob: float = 0.0, + dim: Optional[int] = 0, + disable_on_eval: Optional[bool] = True, + seed: Optional[int] = None, + ): + super().__init__() + self.prob: float = prob + self.dim = dim + self.disable_on_eval: bool = disable_on_eval + self.generator = torch.Generator(device="cpu") + self.inferred: float = None + + if seed is not None: + self.generator.manual_seed(seed) + + def forward( + self, + function: Union[Callable, torch.nn.Module], + input: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Apply layer dropout to the input tensor. + + Args: + function (Union[Callable, torch.nn.Module]): The function or module to apply to the input tensor. + input (torch.Tensor): The input tensor. + *args: Additional positional arguments passed to the function. + **kwargs: Additional keyword arguments passed to the function. + Returns: + torch.Tensor: The output tensor after applying layer dropout. + """ + n = input.shape[self.dim] + + if self.prob == 0 or (self.disable_on_eval and self.training is False): + self.inferred = 1.0 + return function(input, *args, **kwargs) + + skip = ( + torch.bernoulli(torch.Tensor((n) * [self.prob]), generator=self.generator) + .to(input.device) + .to(input.dtype) + ) + self.inferred = 1 - torch.mean(skip) + ind_selected = (skip == 0).nonzero().squeeze() + + if ind_selected.numel() > 0: + x_selected = torch.index_select(input, self.dim, ind_selected) + out_selected = function(x_selected, *args, **kwargs) + + out = input.clone() + assert ( + self.dim == 0 + ), "Currently only supporting dropping elements along the 0th dimension" + if ind_selected.numel() > 0: + out[ind_selected] = out_selected + return out + + +class ModuleLayerDropoutWrapper(torch.nn.Module): + """ + A wrapper module that adds layer dropout functionality to a given module. + This class wraps a given module and applies layer dropout to it. It also + provides getter and setter methods for the wrapped module's attributes. + + Args: + module (torch.nn.Module): The module to wrap. + dropout (LayerDropout): The layer dropout object. + Examples: + >>> import torch + >>> from torch import nn + >>> # Define a simple model + >>> class MyModel(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.fc1 = nn.Linear(5, 3) + ... self.fc2 = nn.Linear(3, 2) + ... + ... def forward(self, x): + ... return self.fc2(self.fc1(x)) + >>> model = MyModel() + >>> fc1 = model.fc1 + >>> fc2 = model.fc2 + >>> # Apply layer dropout to the model + >>> layer_dropout = LayerDropout(prob=0.5) + >>> model = ModuleLayerDropoutWrapper(model, layer_dropout) + >>> # Accessing attributes of the wrapped model + >>> assert model.dropout.prob == 0.5 + >>> assert model.fc1 == fc1 + >>> assert model.fc2 == fc2 + >>> # Pass an input to the wrapped model as if you are passing it to the original model + >>> output = model(torch.randn(1, 5)) + """ + + def __init__(self, module: torch.nn.Module, dropout: LayerDropout): + super().__init__() + self.module = module + self.dropout = dropout + + def forward(self, input: torch.Tensor, *args, **kwargs): + return self.dropout(self.module, input, *args, **kwargs) + + def __getattr__(self, name: str) -> Any: + """Forward missing attributes to wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.module, name) # fallback to wrapped module + + def __setattr__(self, name: str, value: Any) -> Any: + """Forward missing attributes to wrapped module.""" + try: + return super().__setattr__(name, value) # defer to nn.Module's logic + except AttributeError: + return setattr(self.module, name, value) # fallback to wrapped module + + def __getitem__(self, key: int) -> Any: + """Forward indexing calls in case the module is a nn.Sequential.""" + return self.module.__getitem__(key) + + def state_dict(self, *args, **kwargs): + """Return the state dictionary of the wrapped module.""" + return self.module.state_dict(*args, **kwargs) + + def load_state_dict(self, state_dict, *args, **kwargs): + """Load the state dictionary into the wrapped module.""" + self.module.load_state_dict(state_dict, *args, **kwargs) + return + + +class ScaleType(str, Enum): + UNIFORM = "uniform" + EXP = "exp" + LINEAR = "linear" + LOG = "log" + SIN = "sin" + SIGMOID = "sigmoid" + STEP = "step" + + +def get_scale( + scale_type: ScaleType, + scale_period: int, + val: int, +) -> float: + """ + Compute a scaling factor based on the provided scale type, period, and value. + The scaling factor is designed to be 0 when the value is 0 and 1 when the value + reaches or is larger than the scale period. + + Args: + scale_type (ScaleType): The type of scaling to use. + scale_period (int): The period over which the scaling factor increases from 0 to 1. + val (int): The current value used to compute the scaling factor. + Returns: + float: The computed scaling factor. + Examples: + >>> get_scale(ScaleType.LINEAR, 10, 5) + 0.5 + >>> get_scale(ScaleType.LINEAR, 10, 0) + 0.0 + >>> get_scale(ScaleType.LOG, 10, 10) + 1.0 + """ + if scale_period == 0: + return 1.0 + if val >= scale_period: + return 1.0 + + # all the equations below aim to make scale = 0 when val=0, and scale = 1 when val=scale_period + scale = { + ScaleType.UNIFORM: 1.0, + ScaleType.EXP: math.pow(2, val / scale_period) - 1, + ScaleType.LINEAR: val / scale_period, + ScaleType.LOG: math.log(val + 1, scale_period + 1), + ScaleType.SIN: math.sin(0.5 * math.pi * val / scale_period), + ScaleType.SIGMOID: 1 / (1 + math.exp(-10 * (val / scale_period - 0.5))), + }[scale_type] + + # ensure returned scale is between 0.0 and 1.0 (inclusive) + return max(min(scale, 1.0), 0.0) + + +def prepare_layer_dropout( + layers: Union[torch.nn.ModuleList, Iterable[torch.nn.Module]], + prob_max: float = 0.0, + prob_layer_scale: Optional[ScaleType] = ScaleType.UNIFORM, + layers_str: Optional[str] = None, + disable_on_eval: Optional[bool] = True, +) -> None: + """ + Prepare a model's layers for layer dropout by wrapping each layer with a ModuleLayerDropoutWrapper. + This function takes in a list of layers, the maximum probability of dropping a layer, + the scaling type for the layer dropout probability, a string specifying which + layers to apply dropout to, and a boolean indicating whether to disable dropout + during evaluation. It then wraps each layer of the model inplace with a + ModuleLayerDropoutWrapper, which applies layer dropout to the input tensor. + + Args: + layers (Union[torch.nn.ModuleList, Iterable[torch.nn.Module]]): The list of layers to prepare for layer dropout. + prob_max (float): The maximum probability of dropping a layer. Defaults to 0.0. + prob_layer_scale (Optional[ScaleType]): The scaling type for the dropout probability across layers. Defaults to + ScaleType.UNIFORM. + layers_str (Optional[str]): A string specifying which layers to apply dropout to. Defaults to None which means + apply to all layers. + disable_on_eval (Optional[bool]): Whether to disable dropout during evaluation. Defaults to True. + Returns: + None + Example: + >>> import torch + >>> from torch import nn + >>> # Define a simple model + >>> class MyModel(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.layers = nn.ModuleList([ + ... nn.Linear(5, 3), + ... nn.Linear(3, 2), + ... nn.Linear(2, 1), + ... nn.Linear(1, 2), + ... nn.Linear(2, 3), + ... ]) + ... + ... def forward(self, x): + ... for layer in self.layers: + ... x = layer(x) + ... return x + >>> model = MyModel() + >>> # Apply layer dropout uniformly to all layers + >>> prepare_layer_dropout(model.layers, prob_max=0.2, prob_layer_scale=ScaleType.UNIFORM) + >>> # Apply layer dropout every other layer, as described in LayerDrop paper + (Fan et al., https://arxiv.org/abs/1909.11556v1) + >>> prepare_layer_dropout(model.layers, prob_max=0.2, prob_layer_scale=ScaleType.UNIFORM, layers_str="::2") + >>> # Apply layer dropout that increases linearly across layers, as described in Progressive Layer + Dropout paper (Zhang et al., https://arxiv.org/abs/2010.13369) + >>> prepare_layer_dropout(model.layers, prob_max=0.2, prob_layer_scale=ScaleType.LINEAR) + >>> # Apply layer dropout that increases exponentially across layers, as described in + LayerSkip paper (Elhoushi et al., https://arxiv.org/abs/2404.16710) + >>> prepare_layer_dropout(model.layers, prob_max=0.2, prob_layer_scale=ScaleType.EXP) + """ + num_layers = len(layers) + has_dropout = ( + slice_str_to_array(layers_str, num_layers) + if layers_str + else [True] * num_layers + ) + for layer_id in range(len(layers)): + prob = ( + prob_max + * get_scale( + scale_type=prob_layer_scale, + scale_period=num_layers - 1, + val=layer_id, + ) + if has_dropout[layer_id] + else 0.0 + ) + assert ( + prob >= 0.0 and prob <= prob_max + ), f"prob={prob} should be between 0 and {prob_max}" + # We would like each layer to have a different seed, so that we don't have the same samples skipped across layers. + # Hence, we use the layer_id as a seed for each layer's dropout. + layer_dropout = LayerDropout( + prob, disable_on_eval=disable_on_eval, seed=layer_id + ) + layers[layer_id] = ModuleLayerDropoutWrapper(layers[layer_id], layer_dropout) diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 85d6b22869..66ac92002f 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -640,6 +640,15 @@ def forward( input_pos=input_pos, ) + # shape: [b, seq_len, out_dim] + output = self.unembed(h) + + # Output list if hidden states are requested, otherwise just the output + # TODO: always output a list to have a consistent output type + output = output if not hidden else [*hidden, output] + return output + + def unembed(self, h): # shape: [b, s, d] h = self.norm(h) @@ -649,7 +658,4 @@ def forward( # shape: [b, seq_len, out_dim] output = self.output(h).float() - # Output list if hidden states are requested, otherwise just the output - # TODO: always output a list to have a consistent output type - output = output if not hidden else [*hidden, output] return output