diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst
index 36ea3637c8..5eb8fff358 100644
--- a/docs/source/api_ref_modules.rst
+++ b/docs/source/api_ref_modules.rst
@@ -23,6 +23,8 @@ Modeling Components and Building Blocks
     TransformerCrossAttentionLayer
     TransformerDecoder
     VisionTransformer
+    LayerDropout
+    prepare_layer_dropout
 
 Losses
 ------
diff --git a/recipes/dev/7B_full_early_exit.yaml b/recipes/dev/7B_full_early_exit.yaml
new file mode 100644
index 0000000000..7d02a34f0e
--- /dev/null
+++ b/recipes/dev/7B_full_early_exit.yaml
@@ -0,0 +1,137 @@
+# Config for multi-device full finetuning with early exit loss and/or layer dropout
+# in dev/early_exit_finetune_distributed.py using a Llama2 7B model on a small TOPv2
+# instruction set.
+#
+# This config assumes that you've run the following command before launching
+# this run:
+#   tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token <HF_TOKEN>
+#
+# To launch on 4 devices, run the following command from root:
+#   tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training
+# you can run:
+#   tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
+#
+# To reproduce experiments of various papers that use early exit loss and/or layer dropout:
+# - LayerSkip (https://arxiv.org/abs/2404.16710) on TOPv2:
+#   tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss.scale=1.0 early_exit_loss.curriculum=torchtune.modules.early_exit_loss.GradualEarlyExitCurriculum early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.linear_l_loss_scale layer_dropout.prob=0.2 layer_dropout.scale=exp
+#
+# - LITE (https://arxiv.org/abs/2310.18581):
+#   tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml layer_dropout=null early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.uniform_loss_scale early_exit_loss.curriculum=null epochs=5
+#
+# - LayerDrop (https://arxiv.org/abs/1909.11556):
+#   tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss=null layer_dropout.prob=0.2 layer_dropout.layers=1::2
+#
+# - Progressive Layer Dropping (https://arxiv.org/abs/2010.13369) (The paper also implements a curriculum for layer drop probability which is not yet implemented.):
+#   tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss=null layer_dropout.prob=0.5 layer_dropout.scale=exp
+#
+# This config works best for distributed training, hence when the model is being fine-tuned on 2+ GPUs.
+#
+
+
+# Tokenizer
+tokenizer:
+  _component_: torchtune.models.llama2.llama2_tokenizer
+  path: /tmp/Llama-2-7b-hf/tokenizer.model
+  max_seq_len: null
+
+# Dataset
+dataset:
+  _component_: torchtune.datasets.instruct_dataset
+  source: WillHeld/top_v2
+  split: train
+  column_map:
+    input: utterance
+    output: semantic_parse
+
+seed: null
+shuffle: True
+
+# Model Arguments
+model:
+  _component_: torchtune.models.llama2.llama2_7b
+
+checkpointer:
+  _component_: torchtune.training.FullModelHFCheckpointer
+  checkpoint_dir: /tmp/Llama-2-7b-hf
+  checkpoint_files: [
+    pytorch_model-00001-of-00002.bin,
+    pytorch_model-00002-of-00002.bin
+  ]
+  recipe_checkpoint: null
+  output_dir: /tmp/Llama-2-7b-hf
+  model_type: LLAMA2
+resume_from_checkpoint: False
+
+# Fine-tuning arguments
+batch_size: 8
+epochs: 1
+optimizer:
+  _component_: torch.optim.AdamW
+  fused: True
+  lr: 2e-5
+loss:
+  _component_: torch.nn.CrossEntropyLoss
+max_steps_per_epoch: null
+gradient_accumulation_steps: 1  # Use to increase virtual batch size
+compile: False  # pytorch compile, set to true for better perf/memory
+optimizer_in_bwd: False  # True saves memory. Requires gradient_accumulation_steps=1
+
+# Training env
+device: cuda
+
+# Memory management
+enable_activation_checkpointing: True  # True reduces memory
+enable_activation_offloading: False  # True reduces memory
+
+# Reduced precision
+dtype: bf16
+
+# Logging
+metric_logger:
+  _component_: torchtune.training.metric_logging.DiskLogger
+  log_dir: ${output_dir}
+output_dir: /tmp/topv2-llama2-finetune
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+# Profiler (disabled)
+profiler:
+  _component_: torchtune.training.setup_torch_profiler
+  enabled: False
+
+  #Output directory of trace artifacts
+  output_dir: ${output_dir}/profiling_outputs
+
+  #`torch.profiler.ProfilerActivity` types to trace
+  cpu: True
+  cuda: True
+
+  #trace options passed to `torch.profiler.profile`
+  profile_memory: False
+  with_stack: False
+  record_shapes: True
+  with_flops: False
+
+  # `torch.profiler.schedule` options:
+  # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+  wait_steps: 5
+  warmup_steps: 3
+  active_steps: 2
+  num_cycles: 1
+
+# Early Exit Loss
+early_exit_loss:
+  layers: "0::4"
+  curriculum: torchtune.modules.early_exit_loss.RotationalEarlyExitCurriculum
+  scale_fn: torchtune.modules.early_exit_loss.sum_l_loss_scale
+  scale: 1.0
+
+# Layer Dropout
+layer_dropout:
+  prob: 0.2
+  layers: ":"
+  layers_scale: "exp"
+  disable_on_eval: True
diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py
new file mode 100644
index 0000000000..aed914a463
--- /dev/null
+++ b/recipes/dev/early_exit_finetune_distributed.py
@@ -0,0 +1,1070 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import sys
+import time
+
+from functools import partial
+from typing import Any, Dict, List, Optional, Tuple, Union
+from warnings import warn
+
+import torch
+from omegaconf import DictConfig, ListConfig
+
+from torch import nn
+from torch.distributed import destroy_process_group, init_process_group
+
+from torch.optim import Optimizer
+from torch.utils.data import DataLoader, DistributedSampler
+from torchtune import config, modules, training, utils
+from torchtune.config._utils import _get_component_from_path
+from torchtune.data import padded_collate_packed
+from torchtune.datasets import ConcatDataset
+from torchtune.modules.common_utils import slice_str_to_array
+
+from torchtune.modules.early_exit_loss import early_exit_loss, EarlyExitCurriculum
+from torchtune.modules.layer_dropout import prepare_layer_dropout
+from torchtune.recipe_interfaces import FTRecipeInterface
+from torchtune.training import DummyProfiler, PROFILER_KEY
+from torchtune.training.activations import apply_selective_activation_checkpointing
+from torchtune.training.lr_schedulers import get_lr
+
+from tqdm import tqdm
+
+log = utils.get_logger("DEBUG")
+
+
+class EarlyExitFinetuneRecipeDistributed(FTRecipeInterface):
+    """
+    Early exit and layer dropout full finetuning to make the model more robust to early exit and skipping
+    intermediate layers for dense transformer-based LLMs such as Llama2. This recipe supports distributed
+    training and can be run on a single node (1 to 8 GPUs).
+
+    Features:
+        - Early Exit Loss. This makes the model more robust to exiting early by applying the outputs of intermediate
+            layers on the model's language model head (a.k.a. unembedding operation) to obtain outputs of earlier
+            layers, then obtain the losses at such earlier layers. Then the loss of the model during training
+            would be a weighted average of the losses at different layers. The different arguments you can
+            configure are:
+                - ``early_exit_loss.layers`` is a string, whose format mimics indexing in numpy arrays (e.g., `:`
+                depicts all layers, `0:10:3` depicts layers 0, 3, 6, 9, and `1,5,11` depicts layers 1,5,11), to
+                represent which layers to apply early exit loss at,
+                - ``early_exit_loss.scale_fn`` and ``early_exit_loss.scale`` determine how we calculate the
+                weights of losses at different  layers when calculating total loss, and
+                - ``early_exit_loss.curriculum`` depicts how the early exit loss layers change across training
+                iterations.
+            See ``torchtune/modules/early_exit_loss.py` for more details of each argument.
+            To reproduce experiments of different papers that use early exit loss:
+                - LayerSkip (https://arxiv.org/abs/2404.16710) for finetuning on TOPv2: set
+                ``early_exit_loss.scale=1.0,
+                    early_exit_loss.curriculum=torchtune.modules.early_exit_loss.GradualEarlyExitCurriculum
+                    early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.linear_l_loss_scale``,
+                - LITE (https://arxiv.org/abs/2310.18581) for finetuning Llama2 7B on Alpaca you can set
+                ``early_exit_loss.layers=8,12,16,20,24,28
+                    early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.uniform_loss_scale``.
+
+        - Layer Dropout. (a.k.a. Stochastic Depth) This drops samples stochastically for each layer during training.
+            "Dropping" a sample at a layer in this context means a sample will pass through the layer without modification.
+            The different arguments you can configure are:
+                - ``layer_dropout.prob``: is the (maximum) probability of a sample being dropped at each layer.
+                - ``layer_dropout.layers``: is a string, whose format mimics indexing in numpy arrays
+                    (same as ``early_exit_loss.layers``), that determines which layers will have layer dropout applied.
+                - ``layer_dropout.layers_scale``: determines how probability changes across layers from
+                    probability 0 at first layer, to probability ``layer_dropout.prob`` at last layer.
+                    You can choose from ``one`` (all layers have ``layer_dropout.prob``), ``linear``,
+                    ``exp``, ``log``, ``sqrt``.
+                - ``disable_on_eval``: if True, will only apply layer dropout during training. If False, will
+                    apply to both training and evaluation.
+            To reproduce results of different papers that use layer dropout:
+                - LayerDrop (https://arxiv.org/abs/1909.11556) that applies dropout on every other layer, set
+                    ``layer_dropout.prob=0.2 layer_dropout.layers=1::2``.
+                - Progressive Layer Dropping (https://arxiv.org/abs/2010.13369) that increases dropout linearly
+                    across layers, set ``layer_dropout.prob=0.5 layer_dropout.layers_scale=linear``.
+                    The paper also implements a curriculum for layer drop probability which is not yet implemented.
+                - LayerSkip (https://arxiv.org/abs/2404.16710) for finetuning on TOPv2: (in addition to early exit loss
+                    arguments above) set ``layer_dropout.prob=0.2 layer_dropout.scale=exp``.
+
+        - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
+            is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
+            done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
+            ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
+            DDP is currently not supported. Training on CPU is not supported.
+
+        - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing``
+            flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
+            activations in memory and instead recompute them during the backward pass. This is especially
+            helpful for larger batch sizes when you're memory constrained. But these savings in memory
+            come at the cost of training performance. In most cases training can slow-down quite a bit as
+            a result of this activation recomputation.
+
+        - Activation Offloading. This can be controlled using the ``enable_activation_offloading``
+            flag. Activation offloading is a technique similar to activations checkpointing that helps
+            reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations
+            checkpointing drops the activation in the forward to recompute it later in the backward,
+            activations offloading will drop the activation in the forward to the CPU and bring it
+            back during the backward pass. As always, there is a tradeoff--these savings in memory can
+            come at the cost of training performance and CPU resources. To recover some runtime cost,
+            we've added an option to enable offloading on a different stream to permit overlapping with
+            the computation. This option is currently only available on PyTorch 2.5 or later and will
+            be enabled by default if an acceptable torch version is found. Activation offloading can be
+            used in conjunction with activation checkpointing.
+
+        - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
+            flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
+            most cases this should halve the memory footprint of full precision (fp32) training, without
+            loss in model quality (will depend on the model, training data and other settings). For
+            GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16
+            precision are currently not supported.
+
+        - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is
+            controlled using the ``gradient_accumulation_steps`` flag.
+
+                Total Batch Size = batch_size * number of GPUs * gradient accumulation steps.
+
+            For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a
+            total batch size of 64.
+
+            Gradient accumulation is especially useful when you are memory constrained. In this case,
+            accumulating gradients might give you better training speed than enabling activation
+            checkpointing.
+
+        - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of
+            training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are
+            only saved at the end of a given epoch and used in case of resuming training.
+
+            Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is
+            currently not supported.
+
+            For more details on the checkpointer, please take a look at
+            our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html).
+
+        - Logging. Terminal, Disk, WandB and TensorBoard are all supported.
+
+        - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default,
+            ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set
+            ``clip_grad_norm='inf'``.
+
+    For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
+    has example commands for how to kick-off training.
+
+    Args:
+        cfg (DictConfig): OmegaConf object parsed from yaml file
+
+    Raises:
+        ValueError: If ``dtype`` is set to fp16.
+        RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
+        RuntimeError: If ``left_pad_sequence`` is set as the data collator.
+        RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
+        RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False.
+    """
+
+    def __init__(self, cfg: DictConfig) -> None:
+        self._device = utils.get_device(device=cfg.device)
+        self._dtype = training.get_dtype(cfg.dtype, device=self._device)
+
+        if self._dtype == torch.float16:
+            raise ValueError(
+                "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
+            )
+
+        # logging attributes
+        self._output_dir = cfg.output_dir
+        self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
+        self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
+
+        if self._log_peak_memory_stats and self._device.type != "cuda":
+            log.info(
+                "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
+            )
+            self._log_peak_memory_stats = False
+
+        # _is_rank_zero is used primarily for logging. In the future, the logger
+        # should directly take care of this
+        _, rank = training.get_world_size_and_rank()
+        self._is_rank_zero = rank == 0
+
+        # Training cfg
+        self._resume_from_checkpoint = cfg.resume_from_checkpoint
+        self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
+        self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
+        self._clip_grad_norm = cfg.get("clip_grad_norm", None)
+
+        # Optimizer in backward is not compatible with gradient accumulation or gradient clipping
+        if self._optimizer_in_bwd:
+            if self._clip_grad_norm is not None:
+                raise RuntimeError(
+                    "Gradient clipping is not supported with optimizer in bwd."
+                    "Please set clip_grad_norm=None, or optimizer_in_bwd=False."
+                )
+            if self._gradient_accumulation_steps > 1:
+                raise RuntimeError(
+                    "Gradient accumulation is not supported with optimizer in bwd."
+                    "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
+                )
+
+        # activation checkpointing/offloading
+        self._enable_activation_checkpointing = cfg.get(
+            "enable_activation_checkpointing", False
+        )
+        self._enable_activation_offloading = cfg.get(
+            "enable_activation_offloading", False
+        )
+        if self._enable_activation_offloading:
+            if self._device.type != "cuda":
+                raise RuntimeError(
+                    "enable_activation_offloading should only be True when training on CUDA"
+                )
+            if not self._enable_activation_checkpointing:
+                raise RuntimeError(
+                    "enable_activation_offloading should only be True when enable_activation_checkpointing is True"
+                )
+        elif (
+            self._enable_activation_checkpointing
+            and cfg.checkpointer.model_type != "LLAMA3_VISION"
+        ):
+            utils.log_rank_zero(
+                log,
+                "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. "
+                "Enabling activation offloading should reduce memory further.",
+            )
+
+        # These are public properties which are updated by the checkpoint loader
+        # when ``resume_from_checkpoint`` is `True` or validated in tests
+        self.seed = training.set_seed(seed=cfg.seed)
+        self.epochs_run = 0
+        self.total_epochs = cfg.epochs
+        self.max_steps_per_epoch = cfg.max_steps_per_epoch
+        self.global_step = 0
+
+        # Early Exit Properties
+        cfg_early_exit_loss = cfg.get("early_exit_loss", None)
+        if cfg_early_exit_loss:
+            self._do_early_exit_loss = True
+            self._early_exit_loss_scale = cfg_early_exit_loss.get("scale", 1.0)
+            self._early_exit_loss_scale_type = _get_component_from_path(
+                cfg_early_exit_loss.get(
+                    "scale_fn", "torchtune.modules.early_exit_loss.sum_l_loss_scale"
+                )
+            )
+        else:
+            self._do_early_exit_loss = False
+            self._early_exit_loss_scale = None
+            self._early_exit_loss_scale_type = None
+
+    def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
+        """
+        Extract the checkpoint state from file and validate. If resume_from_checkpoint
+        is True, this also includes the recipe state.
+        """
+        self._checkpointer = config.instantiate(
+            cfg_checkpointer,
+            resume_from_checkpoint=self._resume_from_checkpoint,
+        )
+        checkpoint_dict = self._checkpointer.load_checkpoint()
+
+        if self._resume_from_checkpoint:
+            self._update_recipe_state(checkpoint_dict)
+        return checkpoint_dict
+
+    def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
+        """
+        Updates the recipe state from checkpoint.
+        """
+        try:
+            self.epochs_run = ckpt_dict[training.EPOCHS_KEY]
+
+            # on mismatch, warn the user and prevent the override
+            if self.seed != ckpt_dict[training.SEED_KEY]:
+                warn(
+                    message=(
+                        "Config value for seed does not match the checkpoint value, "
+                        f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}"
+                    )
+                )
+                self.seed = ckpt_dict[training.SEED_KEY]
+            if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]:
+                warn(
+                    message=(
+                        "Config value for max_steps_per_epoch does not match the checkpoint value, "
+                        f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}"
+                    )
+                )
+                self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY]
+
+            # on mismatch, warn the user but allow the override
+            if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]:
+                warn(
+                    message=(
+                        "Config value for total_epochs does not match the checkpoint value, "
+                        f"using the config value: {self.total_epochs}"
+                    )
+                )
+
+        except KeyError as e:
+            raise KeyError(
+                "Checkpoint does not contain the required keys needed for updating recipe state. "
+                "Are you sure you passed in the right recipe checkpoint?"
+            ) from e
+
+    def setup(self, cfg: DictConfig) -> None:
+        """
+        Setup the recipe. This includes training state (if resume_from_checkpoint is True),
+        model, tokenizer, loss, optimizer, sampler, and dataloader.
+        """
+        if self._is_rank_zero:
+            self._metric_logger = config.instantiate(cfg.metric_logger)
+
+            # log config with parameter override
+            self._metric_logger.log_config(cfg)
+
+        checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
+
+        self._compile = cfg.get("compile", False)
+        self._model = self._setup_model(
+            cfg_model=cfg.model,
+            enable_activation_checkpointing=self._enable_activation_checkpointing,
+            enable_activation_offloading=self._enable_activation_offloading,
+            custom_sharded_layers=cfg.get("custom_sharded_layers", None),
+            fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
+            reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
+            model_state_dict=checkpoint_dict[training.MODEL_KEY],
+            ac_mode=cfg.get("ac_mode", None),
+            ac_option=cfg.get("ac_option", None),
+        )
+        self._tokenizer = config.instantiate(cfg.tokenizer)
+
+        self._optimizer = self._setup_optimizer(
+            cfg_optimizer=cfg.optimizer,
+            optimizer_in_bwd=self._optimizer_in_bwd,
+            opt_state_dict=(
+                checkpoint_dict[training.OPT_KEY]
+                if self._resume_from_checkpoint
+                else None
+            ),
+        )
+
+        # initialize loss
+        self._loss_fn = config.instantiate(cfg.loss)
+
+        if self._compile:
+            training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)
+
+        if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
+            # set num_output_chunks for model
+            self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
+
+        if self._is_rank_zero:
+            log.info("Loss is initialized.")
+
+        # sampler and dataloader depend on the tokenizer and loss_fn and should be
+        # setup after both of these are initialized
+        collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
+        self._sampler, self._dataloader = self._setup_data(
+            cfg_dataset=cfg.dataset,
+            shuffle=cfg.shuffle,
+            batch_size=cfg.batch_size,
+            collate_fn=collate_name,
+        )
+
+        # Finally update the recipe state which can only be correctly set after all of the
+        # other components have been initialized and updated.
+        #
+        # Number of training steps in each epoch depends on the number of batches produced
+        # by the dataloader, the max_steps_per_epoch param set by the user and the
+        # gradient_accumulation_steps param. This value is used for logging and tracking
+        # training state. The computation should happen after the dataloader has been setup
+        self._steps_per_epoch = (
+            len(self._dataloader) // self._gradient_accumulation_steps
+        )
+        if (
+            self.max_steps_per_epoch is not None
+            and self.max_steps_per_epoch < self._steps_per_epoch
+        ):
+            self._steps_per_epoch = self.max_steps_per_epoch
+        self.global_step = self.epochs_run * self._steps_per_epoch
+
+        # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
+        # if cfg is missing profiler key or if `cfg.profiler.enabled = False`
+        self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
+
+        # Used to ignore labels for loss computation
+        self.ignore_labels_cache = torch.full(
+            (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device
+        )
+
+        # Setup early exit loss
+        (
+            self._do_output_hidden_states,
+            self._early_exit_loss_curriculum,
+        ) = self._setup_early_exit_loss(cfg.get("early_exit_loss", None))
+
+        # Layer Dropout Setup
+        cfg_layer_dropout = cfg.get("layer_dropout", None)
+        if cfg_layer_dropout:
+            prepare_layer_dropout(
+                self._model.layers,
+                prob_max=cfg_layer_dropout.get("prob", 0.0),
+                prob_layer_scale=cfg_layer_dropout.get("layers_scale", "uniform"),
+                layers_str=cfg_layer_dropout.get("layers", ":"),
+                disable_on_eval=cfg_layer_dropout.get("disable_on_eval", True),
+            )
+
+    def _setup_profiler(
+        self, cfg_profiler: Optional[DictConfig] = None
+    ) -> Union[torch.profiler.profile, DummyProfiler]:
+        """
+        Parses the `profiler` section of top-level `cfg` and sets up profiler
+
+        Args:
+            cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to
+                `recipe.main`). Default None.
+
+        Returns:
+            profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods
+            for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such
+            that the instrumented training loop does not need to be changed profiling is disabled.
+
+        The profiler config can be provided in configs under the `profiler` key with the following layout:
+
+        .. code-block:: yaml
+            profiler:
+                enabled: bool
+
+                #Output directory of trace artifacts
+                output_dir: str
+
+            #`torch.profiler.ProfilerActivity` types to trace
+            cpu: bool
+            cuda: bool
+
+                #Trace options
+                profile_memory: bool
+                with_stack: bool
+                record_shapes: bool
+                with_flops: bool
+
+            # `torch.profiler.schedule` options:
+            # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+            wait_steps: int
+            warmup_steps: int
+            active_steps: int
+            num_cycles: int
+        """
+        # Missing profiler section in config, assume disabled
+        if cfg_profiler is None:
+            cfg_profiler = DictConfig({"enabled": False})
+
+        # Check that component is included and set correctly
+        if cfg_profiler.get("_component_", None) is None:
+            cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler"
+        else:
+            assert (
+                cfg_profiler.get("_component_")
+                == "torchtune.training.setup_torch_profiler"
+            ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`"
+
+        profiler, profiler_cfg = config.instantiate(cfg_profiler)
+
+        if self._is_rank_zero:
+            log.info(f" Profiler config after instantiation: {profiler_cfg}")
+
+            self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
+            if profiler_cfg["enabled"]:
+                self.profiler_wait_steps = profiler_cfg["wait_steps"]
+                self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
+                self.profiler_active_steps = profiler_cfg["active_steps"]
+
+        return profiler
+
+    def _setup_model(
+        self,
+        cfg_model: DictConfig,
+        enable_activation_checkpointing: bool,
+        enable_activation_offloading: bool,
+        fsdp_cpu_offload: bool,
+        reshard_after_forward: bool,
+        model_state_dict: Dict[str, Any],
+        custom_sharded_layers: Optional[List[str]] = None,
+        ac_mode: Optional[str] = None,
+        ac_option: Optional[int] = None,
+    ) -> nn.Module:
+        """
+        Model initialization has some important considerations:
+           a. To minimize GPU peak memory, we initialize the model on meta device with
+              the right dtype
+           b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
+              full state dicts are loaded with ``torch.load(mmap=True)``
+        """
+
+        if self._is_rank_zero:
+            log.info(
+                "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..."
+            )
+            init_start = time.perf_counter()
+
+        with training.set_default_dtype(self._dtype), torch.device("meta"):
+            model = config.instantiate(cfg_model)
+
+        if self._compile:
+            training.compile_model(model, verbose=self._is_rank_zero)
+
+        # We currently have two versions of activation checkpointing in this recipe
+        # for testing and BC purposes. ``enable_activation_checkpointing`` controls
+        # the older version of AC and this behavior is unchanged
+        # ac_mode and ac_option together control selective AC. This is only enabled
+        # when these are set AND ``enable_activation_checkpointing`` is set to False
+        # We'll clean this up as soon as testing of AC is complete
+        if (not enable_activation_checkpointing) and (ac_mode is not None):
+            apply_selective_activation_checkpointing(
+                model,
+                ac_mode,
+                ac_option,
+            )
+
+        # original activation checkpointing (full) - flip the condition above
+        if enable_activation_checkpointing and ac_mode is None:
+            training.set_activation_checkpointing(
+                model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
+            )
+
+        # For FSDP sharding
+        fsdp_shard_conditions = [
+            partial(
+                training.get_shard_conditions,
+                names_to_match=custom_sharded_layers,
+            )
+        ]
+        training.shard_model(
+            model=model,
+            shard_conditions=fsdp_shard_conditions,
+            cpu_offload=fsdp_cpu_offload,
+            reshard_after_forward=reshard_after_forward,
+        )
+
+        with training.set_default_dtype(self._dtype), self._device:
+            for m in model.modules():
+                # RoPE is not covered in state dict
+                if hasattr(m, "rope_init"):
+                    m.rope_init()
+
+        # This method will convert the full model state dict into a sharded state
+        # dict and load into the model
+        training.load_from_full_model_state_dict(
+            model,
+            model_state_dict,
+            self._device,
+            self._is_rank_zero,
+            strict=True,
+            cpu_offload=fsdp_cpu_offload,
+        )
+
+        # activation offloading
+        self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
+            model, enable_activation_offloading
+        )
+
+        # Ensure no params and buffers are on meta device
+        training.validate_no_params_on_meta_device(model)
+
+        if self._is_rank_zero:
+            log.info(
+                f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
+            )
+            memory_stats = training.get_memory_stats(device=self._device)
+            training.log_memory_stats(memory_stats)
+
+        # synchronize before training begins
+        torch.distributed.barrier()
+
+        return model
+
+    def _setup_optimizer(
+        self,
+        cfg_optimizer: DictConfig,
+        optimizer_in_bwd: bool = False,
+        opt_state_dict: Optional[Dict[str, Any]] = None,
+    ) -> Optional[Optimizer]:
+        if optimizer_in_bwd:
+            # Maintain a dict of optims for every parameter.
+            optim_dict = {
+                param: config.instantiate(cfg_optimizer, [param])
+                for param in self._model.parameters()
+            }
+
+            # Register optimizer step hooks on the model to run optimizer in backward.
+            training.register_optim_in_bwd_hooks(
+                model=self._model, optim_dict=optim_dict
+            )
+            # Create a wrapper for checkpoint save/load of optimizer states when running in backward.
+            self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper(
+                model=self._model, optim_dict=optim_dict
+            )
+            # Load optimizer states for each param. If optimizer states are being restored in an optimizer in
+            # backward run, these need to have been saved with the same setting. Cannot restore from runs that
+            # did not use optimizer in backward.
+            if opt_state_dict is not None:
+                for param in opt_state_dict.keys():
+                    try:
+                        training.load_from_full_optimizer_state_dict(
+                            self._optim_ckpt_wrapper.state_dict()[param],
+                            opt_state_dict[param],
+                            self._device,
+                        )
+                    except BaseException as e:
+                        raise RuntimeError(
+                            "Failed loading in-backward optimizer checkpoints."
+                            "Please make sure run being restored from was using in-backward optimizer."
+                        ) from e
+            if self._is_rank_zero:
+                log.info("In-backward optimizers are set up.")
+            return None
+        else:
+            optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
+            if opt_state_dict:
+                training.load_from_full_optimizer_state_dict(
+                    optimizer,
+                    opt_state_dict,
+                    self._device,
+                )
+
+            if self._is_rank_zero:
+                log.info("Optimizer is initialized.")
+            return optimizer
+
+    def _setup_data(
+        self,
+        cfg_dataset: DictConfig,
+        shuffle: bool,
+        batch_size: int,
+        collate_fn: str,
+    ) -> Tuple[DistributedSampler, DataLoader]:
+        """
+        All data related setup happens here. Currently this recipe only supports the
+        DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
+        iterable datasets and streaming datasets are not supported.
+        """
+        world_size, rank = training.get_world_size_and_rank()
+
+        if isinstance(cfg_dataset, ListConfig):
+            datasets = [
+                config.instantiate(single_cfg_dataset, self._tokenizer)
+                for single_cfg_dataset in cfg_dataset
+            ]
+            ds = ConcatDataset(datasets=datasets)
+            packed = False
+        else:
+            ds = config.instantiate(cfg_dataset, self._tokenizer)
+            packed = cfg_dataset.get("packed", False)
+
+        # Instantiate collate_fn
+        if "left_pad_sequence" in collate_fn:
+            raise RuntimeError("left_pad_sequence collator is only for inference.")
+        collate_fn = _get_component_from_path(collate_fn)
+
+        sampler = DistributedSampler(
+            ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
+        )
+        dataloader = DataLoader(
+            dataset=ds,
+            batch_size=batch_size,
+            sampler=sampler,
+            # dropping last avoids shape issues with compile + flex attention
+            drop_last=True,
+            collate_fn=(
+                partial(
+                    collate_fn,
+                    padding_idx=self._tokenizer.pad_id,
+                    ignore_idx=self._loss_fn.ignore_index,
+                )
+                if not packed
+                else padded_collate_packed
+            ),
+        )
+
+        if self._is_rank_zero:
+            log.info("Dataset and Sampler are initialized.")
+
+        return sampler, dataloader
+
+    def _setup_early_exit_loss(
+        self,
+        cfg_early_exit_loss: DictConfig,
+    ) -> Tuple[List[bool], EarlyExitCurriculum]:
+        """
+        All early exit loss related setup happens here.
+        """
+        do_output_hidden_states = None
+        early_exit_loss_curriculum = None
+
+        if cfg_early_exit_loss:
+            assert (
+                hasattr(self._loss_fn, "reduction")
+                and self._loss_fn.reduction == "mean"
+            ), "Currently early exit loss is only implemented for loss functions that apply a mean reduction."
+
+            do_output_hidden_states = slice_str_to_array(
+                cfg_early_exit_loss.get("layers", ":"), len(self._model.layers)
+            )
+            train_last_layer = cfg_early_exit_loss.get("include_last_layer", True)
+            verbose = cfg_early_exit_loss.get("verbose", False)
+
+            early_exit_loss_curriculum = cfg_early_exit_loss.get("curriculum", None)
+            if early_exit_loss_curriculum:
+                early_exit_loss_curriculum = _get_component_from_path(
+                    early_exit_loss_curriculum
+                )(
+                    do_output_hidden_states=do_output_hidden_states,
+                    max_steps=self.total_epochs * self._steps_per_epoch,
+                    train_last_layer=train_last_layer,
+                    last_step=self.global_step,
+                    verbose=verbose,
+                )
+                do_output_hidden_states = early_exit_loss_curriculum.get()
+            else:
+                if train_last_layer:
+                    do_output_hidden_states[len(self._model.layers) - 1] = True
+
+        return do_output_hidden_states, early_exit_loss_curriculum
+
+    def save_checkpoint(
+        self,
+        epoch: int,
+    ) -> None:
+        """
+        Checkpoint the state of the recipe. The constructed checkpoint state dict
+        contains the following information:
+        - Model weights with key training.MODEL_KEY
+        - Relevant recipe state if training is not complete
+
+        Checkpointer will save the model weights and recipe state in
+        different checkpoint files. To correctly resume training from an intermediate checkpoint,
+        the model weights and recipe state must be provided.
+        """
+        # final dict passed onto the checkpointer
+        checkpoint_dict = {}
+
+        intermediate_checkpoint = epoch + 1 < self.total_epochs
+
+        if self._is_rank_zero:
+            log.info(
+                "Saving checkpoint. This may take some time. Retrieving full model state dict..."
+            )
+            start = time.perf_counter()
+
+        # To prevent GPU memory from spiking during checkpoint save,
+        # we consolidate the full model and optim state dicts on CPU for rank 0
+        cpu_state_dict = training.gather_cpu_state_dict(
+            self._model.state_dict(),
+            self._is_rank_zero,
+            device=self._device,
+        )
+
+        if self._is_rank_zero:
+            log.info(
+                f"Getting full model state dict took {time.perf_counter() - start:.2f} secs"
+            )
+
+        if intermediate_checkpoint:
+            start = time.perf_counter()
+            if self._is_rank_zero:
+                log.info("Getting optimizer state dict...")
+            if not self._optimizer_in_bwd:
+                opt_state_dict = training.get_full_optimizer_state_dict(
+                    self._optimizer,
+                    self._is_rank_zero,
+                    device=self._device,
+                )
+            else:
+                opt_state_dict = {}
+                for param, opt in self._optim_ckpt_wrapper.optim_map.items():
+                    opt_state_dict[param] = training.get_full_optimizer_state_dict(
+                        opt, self._is_rank_zero, device=self._device
+                    )
+            if self._is_rank_zero:
+                log.info(
+                    f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs"
+                )
+        else:
+            opt_state_dict = None
+
+        # Now that we have the model and opt state dict, create the actual checkpoint dict
+        # to be sent to the checkpointer and ultimately written to file
+
+        if self._is_rank_zero:
+            start = time.perf_counter()
+            checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict})
+
+            # if training is in-progress, checkpoint the optimizer state and recipe state
+            # as well.
+            if intermediate_checkpoint:
+                checkpoint_dict.update(
+                    {
+                        training.OPT_KEY: opt_state_dict,
+                        training.SEED_KEY: self.seed,
+                        training.EPOCHS_KEY: self.epochs_run,
+                        training.TOTAL_EPOCHS_KEY: self.total_epochs,
+                        training.MAX_STEPS_KEY: self.max_steps_per_epoch,
+                    }
+                )
+
+            self._checkpointer.save_checkpoint(
+                checkpoint_dict,
+                epoch=epoch,
+                intermediate_checkpoint=intermediate_checkpoint,
+            )
+            log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs")
+
+        torch.distributed.barrier()
+
+    def train(self) -> None:
+        """
+        The core training loop.
+        """
+        # clean up before training begins
+        training.cleanup_before_training()
+
+        world_size, rank = training.get_world_size_and_rank()
+
+        # zero out the gradients before starting training
+        if not self._optimizer_in_bwd:
+            self._optimizer.zero_grad()
+        else:
+            for opt in self._optim_ckpt_wrapper.optim_map.values():
+                opt.zero_grad()
+
+        # Initialize tokens count and running loss (for grad accumulation)
+        t0 = time.perf_counter()
+        running_loss = 0
+        num_tokens = 0
+
+        # Initialize output hidden states
+        if self._do_output_hidden_states is not None:
+            self._model.output_hidden_states = [
+                i
+                for i in range(len(self._do_output_hidden_states))
+                if self._do_output_hidden_states[i]
+            ]
+
+        self._profiler.start()
+        # self.epochs_run should be non-zero when we're resuming from a checkpoint
+        for curr_epoch in range(self.epochs_run, self.total_epochs):
+            # Update the sampler to ensure data is correctly shuffled across epochs
+            # in case shuffle is True
+            self._sampler.set_epoch(curr_epoch)
+
+            pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0))
+            for idx, batch in enumerate(self._dataloader):
+                if (
+                    self.max_steps_per_epoch is not None
+                    and (idx // self._gradient_accumulation_steps)
+                    == self.max_steps_per_epoch
+                ):
+                    break
+
+                # Start tracking CUDA memory for active steps for just the first epoch
+                if (
+                    self._is_rank_zero
+                    and curr_epoch == 0
+                    and self.profiler_profile_memory
+                    and idx == self.profiler_wait_steps + self.profiler_warmup_steps
+                ):
+                    torch.cuda.memory._record_memory_history()
+
+                utils.batch_to_device(batch, self._device)
+
+                # Calculate the number of unmasked tokens in the current batch
+                # and increment the total number of tokens seen in the step
+                current_num_tokens = (
+                    batch["labels"] != self._loss_fn.ignore_index
+                ).sum()
+                num_tokens += current_num_tokens
+
+                # Shape [b, s], needed for the loss not the model
+                labels = batch.pop("labels")
+
+                with self.activations_handling_ctx:
+                    outputs = self._model(**batch)
+                    if self._model.output_hidden_states:
+                        logits = outputs.pop(-1)
+                        hidden_states = {
+                            i: h
+                            for i, h in zip(self._model.output_hidden_states, outputs)
+                        }
+                    else:
+                        logits = outputs
+
+                # Shift labels to compute loss
+                # equivalent to doing labels[..., 1:] and logits[..., :-1, :]
+                # But this way we dont need to slice the logits. We just add an ignore index to labels.
+                labels = torch.hstack(
+                    (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
+                )
+                if not isinstance(logits, list):
+                    labels = labels.reshape(-1)
+                    logits = logits.reshape(-1, logits.size(-1))
+
+                # Compute loss
+                # Loss is normalized by default so we multiply by the number of tokens
+                # This way we can normalize by the total number of tokens if we're accumulating gradients
+                if self._model.output_hidden_states:
+                    current_loss = (
+                        early_exit_loss(
+                            self._model,
+                            hidden_states,
+                            labels,
+                            self._loss_fn,
+                            self._early_exit_loss_scale,
+                            self._early_exit_loss_scale_type,
+                        )
+                        * current_num_tokens
+                    )
+                else:
+                    current_loss = self._loss_fn(logits, labels) * current_num_tokens
+
+                # free logits otherwise it peaks backward memory
+                del logits
+
+                running_loss += current_loss
+
+                # For optimizer in backward, we need to normalize before calling backward
+                # This case and gradient accumulation are mutually exclusive
+                if self._optimizer_in_bwd:
+                    torch.distributed.all_reduce(num_tokens)
+                    torch.distributed.all_reduce(running_loss)
+                    current_loss = current_loss / num_tokens
+
+                current_loss.backward()
+
+                # Step with optimizer
+                if (idx + 1) % self._gradient_accumulation_steps == 0:
+                    if not self._optimizer_in_bwd:
+                        # Get total number of tokens across all ranks to normalize gradients
+                        torch.distributed.all_reduce(num_tokens)
+                        # This will ensure that the logged loss matches what we're optimizing
+                        torch.distributed.all_reduce(running_loss)
+                        # Manually scale the gradients from unnormalized loss by total # of tokens
+                        training.scale_grads(self._model, 1 / num_tokens)
+                        if self._clip_grad_norm is not None:
+                            grad_norm = torch.nn.utils.clip_grad_norm_(
+                                self._model.parameters(),
+                                max_norm=float(self._clip_grad_norm),
+                            )
+                        self._optimizer.step()
+                        self._optimizer.zero_grad(set_to_none=True)
+
+                    # Update the number of steps when the weights are updated
+                    self.global_step += 1
+
+                    loss_to_log = running_loss.item() / num_tokens
+                    pbar.update(1)
+                    pbar.set_description(
+                        f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
+                    )
+
+                    # Log per-step metrics
+                    if (
+                        self.global_step % self._log_every_n_steps == 0
+                        and self._is_rank_zero
+                    ):
+                        time_per_step = time.perf_counter() - t0
+                        log_dict = {
+                            "loss": loss_to_log,
+                            "lr": get_lr(
+                                (
+                                    self._optimizer
+                                    if not self._optimizer_in_bwd
+                                    else self._optim_ckpt_wrapper
+                                ),
+                            ),
+                            "tokens_per_second_per_gpu": num_tokens
+                            / (time_per_step * world_size),
+                        }
+                        if self._log_peak_memory_stats:
+                            log_dict.update(
+                                training.get_memory_stats(device=self._device)
+                            )
+                        if self._clip_grad_norm is not None:
+                            log_dict.update({"grad_norm": grad_norm})
+                        self._metric_logger.log_dict(
+                            log_dict,
+                            step=self.global_step,
+                        )
+
+                    # Reset running stats for the next step
+                    running_loss = 0
+                    num_tokens = 0
+                    t0 = time.perf_counter()
+
+                    # Update Early Exit Layers/Scales
+                    if self._early_exit_loss_curriculum:
+                        self._early_exit_loss_curriculum.step()
+                        self._do_output_hidden_states = (
+                            self._early_exit_loss_curriculum.get()
+                        )
+                        self._model.output_hidden_states = [
+                            i
+                            for i in range(len(self._do_output_hidden_states))
+                            if self._do_output_hidden_states[i]
+                        ]
+
+                    # Stop tracking CUDA memory now that active steps are complete
+                    if (
+                        self._is_rank_zero
+                        and curr_epoch == 0
+                        and self.profiler_profile_memory
+                        and idx
+                        == self.profiler_wait_steps
+                        + self.profiler_warmup_steps
+                        + self.profiler_active_steps
+                    ):
+                        torch.cuda.memory._record_memory_history(enabled=None)
+
+                    # Step profiler
+                    # Note that this is called within gradient accumulation block, hence
+                    # will include multiple forward / backward passes if gradient accumulation > 1
+                    self._profiler.step()
+
+            self.epochs_run += 1
+            self.save_checkpoint(epoch=curr_epoch)
+
+        self._profiler.stop()
+
+    def cleanup(self) -> None:
+        if self._is_rank_zero:
+            self._metric_logger.close()
+        destroy_process_group()
+
+
+@config.parse
+def recipe_main(cfg: DictConfig) -> None:
+    """
+    Entry point for the recipe.
+
+    Configurable parameters are read in the following order:
+        - Parameters specified in config (see available configs through ``tune ls``)
+        - Overwritten by arguments from the command-line
+    """
+    if not training.is_distributed():
+        raise RuntimeError(
+            "Distributed finetune recipe should be run via a distributed launcher."
+            "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
+        )
+    init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
+    if cfg.get("fsdp_cpu_offload", False):
+        # Utilize all available CPU cores for intra-op parallelism. This provides ~2x
+        # speed up when benchmarking fused AdamW on CPU
+        training.set_torch_num_threads()
+
+    config.log_config(recipe_name="EarlyExitFinetuneRecipeDistributed", cfg=cfg)
+
+    recipe = EarlyExitFinetuneRecipeDistributed(cfg=cfg)
+    recipe.setup(cfg=cfg)
+    recipe.train()
+    recipe.cleanup()
+
+
+if __name__ == "__main__":
+    sys.exit(recipe_main())
diff --git a/tests/torchtune/modules/test_common_utils.py b/tests/torchtune/modules/test_common_utils.py
index 41dc472f00..0e2a6400e0 100644
--- a/tests/torchtune/modules/test_common_utils.py
+++ b/tests/torchtune/modules/test_common_utils.py
@@ -14,6 +14,7 @@
     llama3_2_vision_encoder,
 )
 from torchtune.modules import delete_kv_caches, disable_kv_cache, local_kv_cache
+from torchtune.modules.common_utils import slice_str_to_array
 from torchtune.modules.model_fusion import DeepFusionModel
 
 
@@ -191,3 +192,28 @@ def test_disable_kv_cache_raises_error_caches_not_setup(self, model, request):
         with pytest.raises(ValueError, match="Model caches must be setup"):
             with disable_kv_cache(model):
                 pass
+
+
+class TestSliceStrToArray:
+    def test_single_index(self):
+        assert slice_str_to_array("0", 5) == [True, False, False, False, False]
+
+    def test_slice_with_start_and_end(self):
+        assert slice_str_to_array("1:3", 5) == [False, True, True, False, False]
+
+    def test_slice_with_start_and_step(self):
+        assert slice_str_to_array("1::2", 5) == [False, True, False, True, False]
+
+    def test_slice_with_start_end_and_step(self):
+        assert slice_str_to_array("1:4:2", 5) == [False, True, False, True, False]
+
+    def test_multiple_indices(self):
+        assert slice_str_to_array("0,2,4", 6) == [True, False, True, False, True, False]
+
+    def test_out_of_range_index(self):
+        with pytest.raises(AssertionError):
+            slice_str_to_array("10", 5)
+
+    def test_invalid_slice_format(self):
+        with pytest.raises(AssertionError):
+            slice_str_to_array("1:2:3:4", 5)
diff --git a/tests/torchtune/modules/test_early_exit_loss.py b/tests/torchtune/modules/test_early_exit_loss.py
new file mode 100644
index 0000000000..27d421c14d
--- /dev/null
+++ b/tests/torchtune/modules/test_early_exit_loss.py
@@ -0,0 +1,184 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
+
+import random
+
+import numpy as np
+import pytest
+import torch
+import torch.nn as nn
+from torchtune.modules import TransformerDecoder
+from torchtune.modules.early_exit_loss import (
+    early_exit_loss,
+    GradualEarlyExitCurriculum,
+    inv_l_loss_scale,
+    inv_sqrt_l_loss_scale,
+    linear_l_loss_scale,
+    RotationalEarlyExitCurriculum,
+    sqrt_l_loss_scale,
+    sum_l_loss_scale,
+    uniform_loss_scale,
+)
+
+# Mock components for TransformerDecoder
+class MockLayer(nn.Module):
+    def forward(
+        self, x, mask=None, encoder_input=None, encoder_mask=None, input_pos=None
+    ):
+        return x  # Simply return the input for testing purposes
+
+
+class TestEarlyExitLoss:
+    @pytest.fixture
+    def num_layers(self):
+        return 12
+
+    @pytest.fixture
+    def mock_model(self, num_layers):
+        # Create mock components
+        tok_embeddings = nn.Embedding(1000, 512)  # Example vocab size and embedding dim
+        layers = nn.ModuleList([MockLayer() for _ in range(12)])  # 12 mock layers
+        norm = nn.LayerNorm(512)  # Example layer normalization
+        output = nn.Linear(512, 1000)  # Example output layer
+
+        # Create an instance of TransformerDecoder
+        model = TransformerDecoder(
+            tok_embeddings=tok_embeddings,
+            layers=layers,
+            max_seq_len=512,
+            num_heads=8,
+            head_dim=64,
+            norm=norm,
+            output=output,
+            num_layers=num_layers,
+            output_hidden_states=[0, 1, 2],  # Example layers to output hidden states
+        )
+        return model
+
+    @pytest.fixture
+    def hidden_states_dict(self):
+        return {i: torch.randn(4, 5, 512) for i in range(3)}  # Adjusted embedding dim
+
+    @pytest.fixture
+    def labels(self):
+        return torch.randint(0, 1000, (4, 5))  # Adjusted vocab size
+
+    @pytest.fixture
+    def loss_fn(self):
+        return nn.CrossEntropyLoss(ignore_index=-1)
+
+    def test_early_exit_loss(self, mock_model, hidden_states_dict, labels, loss_fn):
+        loss = early_exit_loss(mock_model, hidden_states_dict, labels, loss_fn)
+        assert isinstance(loss, torch.Tensor)
+        assert loss.item() >= 0
+
+    @pytest.mark.parametrize(
+        "scale_fn",
+        [
+            uniform_loss_scale,
+            linear_l_loss_scale,
+            sum_l_loss_scale,
+            sqrt_l_loss_scale,
+            inv_l_loss_scale,
+            inv_sqrt_l_loss_scale,
+        ],
+    )
+    def test_layer_ids_to_loss_scales(self, scale_fn, num_layers):
+        for n_subset_layers in range(1, num_layers + 1):
+            layer_ids = torch.tensor(
+                random.sample(range(0, num_layers), n_subset_layers)
+            )
+            scales = scale_fn(layer_ids, num_layers, 1.0)
+            assert torch.isclose(scales.sum(), torch.tensor(1.0))
+
+    def test_early_exit_loss_vs_manual(
+        self, mock_model, hidden_states_dict, labels, loss_fn
+    ):
+        # Convert to float32 for numeric equivalence
+        # Calculate early exit loss using the function
+        calculated_loss = early_exit_loss(
+            mock_model,
+            hidden_states_dict,
+            labels,
+            loss_fn,
+            e_scale=1,
+            loss_scale_fn=uniform_loss_scale,
+        )
+        # Manually calculate the loss for each hidden state
+        total_loss = 0.0
+        num_hidden_states = len(hidden_states_dict)
+        for i, hidden_state in hidden_states_dict.items():
+            # Compute logits for the current hidden state
+            logits = mock_model.unembed(hidden_state)
+            labels = labels.reshape(-1)
+            logits = logits.reshape(-1, logits.size(-1))
+            # Compute the loss for the current hidden state
+            loss = loss_fn(logits, labels)
+            total_loss += loss
+        # Average the losses across all hidden states
+        manual_loss = total_loss / num_hidden_states
+        # Compare the two losses
+        assert torch.isclose(
+            calculated_loss, manual_loss, atol=1e-6
+        ), f"Calculated loss: {calculated_loss}, Manual loss: {manual_loss}"
+
+
+class TestEarlyExitLossCurriculum:
+    @pytest.mark.parametrize(
+        "train_last_layer",
+        [
+            True,
+            False,
+        ],
+    )
+    def test_rotational_early_exit_curriculum(self, train_last_layer):
+        curriculum = RotationalEarlyExitCurriculum(
+            [True, False, False], max_steps=100, train_last_layer=train_last_layer
+        )
+        expected = np.array([True, False, train_last_layer])
+        assert np.array_equal(curriculum.get(), expected)
+        curriculum.step()
+        expected = np.array([False, True, train_last_layer])
+        assert np.array_equal(curriculum.get(), expected)
+        curriculum.step()
+        # Since the last element is already True on this rotation, the value of `train_last_layer` has no effect.
+        expected = np.array([False, False, True])
+        assert np.array_equal(curriculum.get(), expected)
+        curriculum.step()
+        expected = np.array([True, False, train_last_layer])
+        assert np.array_equal(curriculum.get(), expected)
+
+    @pytest.mark.parametrize(
+        "train_last_layer",
+        [
+            True,
+            False,
+        ],
+    )
+    def test_gradual_early_exit_curriculum(self, train_last_layer):
+        curriculum = GradualEarlyExitCurriculum(
+            [True, True, True, True],
+            max_steps=4,
+            train_last_layer=train_last_layer,
+            fraction_scale=1,
+        )
+        expected = np.array([False, False, False, train_last_layer])
+        assert np.array_equal(curriculum.get(), expected)
+        curriculum.step()
+        assert np.array_equal(curriculum.get(), [False, False, False, train_last_layer])
+        curriculum.step()
+        # Since the last element is already True on this update, the value of `train_last_layer` has no effect.
+        assert np.array_equal(curriculum.get(), [False, False, False, True])
+        curriculum.step()
+        assert np.array_equal(curriculum.get(), [False, False, True, True])
+        curriculum.step()
+        assert np.array_equal(curriculum.get(), [False, True, True, True])
+        curriculum.step()
+        assert np.array_equal(curriculum.get(), [True, True, True, True])
+        curriculum.step()
+        assert np.array_equal(curriculum.get(), [True, True, True, True])
diff --git a/tests/torchtune/modules/test_layer_dropout.py b/tests/torchtune/modules/test_layer_dropout.py
new file mode 100644
index 0000000000..9bbd779d79
--- /dev/null
+++ b/tests/torchtune/modules/test_layer_dropout.py
@@ -0,0 +1,279 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import math
+from typing import Tuple
+
+import pytest
+import torch
+from tests.test_utils import assert_expected
+from torchtune.modules.layer_dropout import (
+    get_scale,
+    LayerDropout,
+    ModuleLayerDropoutWrapper,
+    prepare_layer_dropout,
+    ScaleType,
+)
+
+
+class TestLayerDropout:
+    """Class for testing LayerDropout implementation."""
+
+    @pytest.fixture
+    def input_shape(self) -> Tuple[int, int]:
+        bsz = 8
+        seqlen = 256
+        dim = 32
+        return bsz, seqlen, dim
+
+    @pytest.fixture
+    def input(self, input_shape: Tuple[int]) -> torch.Tensor:
+        return torch.randn(input_shape)
+
+    @pytest.fixture
+    def layer_dropout(
+        self, prob: float = 0.5, disable_on_eval: bool = True
+    ) -> LayerDropout:
+        return LayerDropout(prob=prob, disable_on_eval=disable_on_eval)
+
+    def test_forward_train_prob_1(
+        self, layer_dropout: LayerDropout, input: torch.Tensor
+    ):
+        # With dropout probability = 1.0, we expect output to be the same as input
+        layer_dropout.prob = 1.0
+        output = layer_dropout.forward(lambda x: x**2, input)
+        assert torch.allclose(output, input)
+
+    def test_forward_train_prob_0(
+        self, layer_dropout: LayerDropout, input: torch.Tensor
+    ):
+        # With dropout probability = 1.0, we expect the operation to be applied on all elements in the input
+        layer_dropout.prob = 0.0
+        output = layer_dropout.forward(lambda x: x**2, input)
+        assert torch.allclose(output, input**2)
+
+    def test_forward_eval(self, layer_dropout: LayerDropout, input: torch.Tensor):
+        layer_dropout.prob = 1.0
+        layer_dropout.eval()
+
+        layer_dropout.disable_on_eval = True
+        output = layer_dropout.forward(lambda x: x**2, input)
+        assert torch.allclose(output, input**2)
+
+        layer_dropout.disable_on_eval = False
+        with torch.no_grad():
+            output = layer_dropout.forward(lambda x: x**2, input)
+        assert torch.allclose(output, input)
+
+
+class TestLayerDropoutWrapper:
+    @pytest.fixture
+    def input_shape(self) -> Tuple[int, int]:
+        bsz = 4
+        dim = 8
+        return (bsz, dim)
+
+    @pytest.fixture
+    def input(self, input_shape: Tuple[int]) -> torch.Tensor:
+        return torch.randn(input_shape)
+
+    @pytest.fixture
+    def model(self, input_shape) -> torch.nn.Module:
+        _, dim = input_shape
+        return torch.nn.Sequential(
+            torch.nn.Linear(dim, 32), torch.nn.ReLU(), torch.nn.Linear(32, dim)
+        )
+
+    @pytest.fixture
+    def linear(self, input_shape) -> torch.nn.Module:
+        _, dim = input_shape
+        return torch.nn.Linear(dim, dim)
+
+    def test_linear(self, linear: torch.nn.Module, input: torch.Tensor):
+        wrapper = ModuleLayerDropoutWrapper(linear, LayerDropout(prob=0.5))
+        assert wrapper.module == linear
+
+        # Test output
+        wrapper.dropout.prob = 1
+        assert torch.allclose(wrapper(input), input)
+        wrapper.dropout.prob = 0
+        assert torch.allclose(wrapper(input), linear(input))
+
+        # Test getters
+        assert wrapper.in_features == linear.in_features
+        assert wrapper.out_features == linear.out_features
+        assert torch.equal(wrapper.weight, linear.weight)
+
+        # Test setters
+        wrapper.weight.data = wrapper.weight.data * 2
+        assert torch.equal(wrapper.weight, linear.weight)
+
+        # Test state_dict
+        for k in wrapper.state_dict().keys():
+            assert torch.equal(wrapper.state_dict()[k], linear.state_dict()[k])
+
+    def test_model(self, model: torch.nn.Module, input: torch.Tensor):
+        wrapper = ModuleLayerDropoutWrapper(model, LayerDropout(prob=0.5))
+        assert wrapper.module == model
+
+        # Test output
+        wrapper.dropout.prob = 1
+        assert torch.allclose(wrapper(input), input)
+        wrapper.dropout.prob = 0
+        assert torch.allclose(wrapper(input), model(input))
+
+        # Test getters
+        assert wrapper[0].in_features == model[0].in_features
+        assert wrapper[0].out_features == model[0].out_features
+        assert torch.equal(wrapper[0].weight, model[0].weight)
+
+        # Test setters
+        wrapper[2].weight.data = wrapper[2].weight.data * 2
+        assert torch.equal(wrapper[2].weight, model[2].weight)
+
+        # Test state_dict
+        for k in wrapper.state_dict().keys():
+            assert torch.equal(wrapper.state_dict()[k], model.state_dict()[k])
+
+
+class TestScales:
+    def test_get_scale_uniform(self):
+        scale_type = ScaleType.UNIFORM
+        scale_period = 10
+
+        assert_expected(get_scale(scale_type, scale_period, 0), 1.0)
+        assert_expected(get_scale(scale_type, scale_period, scale_period / 2), 1.0)
+        assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0)
+        assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0)
+
+    def test_get_scale_linear(self):
+        scale_type = ScaleType.LINEAR
+        scale_period = 10
+
+        assert_expected(get_scale(scale_type, scale_period, 0), 0.0)
+        assert_expected(get_scale(scale_type, scale_period, scale_period / 2), 1 / 2)
+        assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0)
+        assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0)
+
+    def test_get_scale_exp(self):
+        scale_type = ScaleType.EXP
+        scale_period = 10
+
+        assert_expected(get_scale(scale_type, scale_period, 0), 0.0)
+        assert_expected(
+            get_scale(scale_type, scale_period, scale_period / 2),
+            math.pow(2, 1 / 2) - 1,
+        )
+        assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0)
+        assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0)
+
+    def test_get_scale_log(self):
+        scale_type = ScaleType.LOG
+        scale_period = 10
+
+        assert_expected(get_scale(scale_type, scale_period, 0), 0.0)
+        assert_expected(
+            get_scale(scale_type, scale_period, scale_period / 2),
+            math.log(5 + 1, scale_period + 1),
+        )
+        assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0)
+        assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0)
+
+    def test_get_scale_sin(self):
+        scale_type = ScaleType.SIN
+        scale_period = 10
+
+        assert_expected(get_scale(scale_type, scale_period, 0), 0.0)
+        assert_expected(
+            get_scale(scale_type, scale_period, scale_period / 2),
+            math.sin(0.5 * math.pi * 0.5),
+        )
+        assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0)
+        assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0)
+
+    def test_get_scale_sigmoid(self):
+        scale_type = ScaleType.SIGMOID
+        scale_period = 10
+
+        # sigmoid(0) is close to 0 but not 0, hence adding relatively large rotl and atol
+        assert_expected(
+            get_scale(scale_type, scale_period, 0), 0.0, rtol=1e-2, atol=1e-2
+        )
+        assert_expected(
+            get_scale(scale_type, scale_period, scale_period / 2),
+            0.5,
+        )
+        assert_expected(get_scale(scale_type, scale_period, scale_period), 1.0)
+        assert_expected(get_scale(scale_type, scale_period, scale_period * 2), 1.0)
+
+
+class TestLayerDropoutModel:
+    def test_prepare_layer_dropout_uniform(self):
+        class MockModel(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.layers = torch.nn.ModuleList(
+                    [torch.nn.Linear(10, 10) for _ in range(5)]
+                )
+
+        model = MockModel()
+        prob_max = 0.5
+        prob_layer_scale = ScaleType.UNIFORM
+        layers_str = "0:4"
+        prepare_layer_dropout(model.layers, prob_max, prob_layer_scale, layers_str)
+        for i, layer in enumerate(model.layers):
+            assert hasattr(layer, "dropout")
+            if i in range(0, 4):
+                assert layer.dropout.prob == prob_max
+            else:
+                assert layer.dropout.prob == 0
+
+    def test_prepare_layer_dropout_exp(self):
+        class MockModel(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.layers = torch.nn.ModuleList(
+                    [torch.nn.Linear(10, 10) for _ in range(5)]
+                )
+
+        model = MockModel()
+        prob_max = 0.5
+        prob_layer_scale = ScaleType.EXP
+        layers_str = ":"
+        prepare_layer_dropout(model.layers, prob_max, prob_layer_scale, layers_str)
+        for i, layer in enumerate(model.layers):
+            assert hasattr(layer, "dropout")
+            if i == 0:
+                assert layer.dropout.prob == 0
+            elif i == len(model.layers) - 1:
+                assert layer.dropout.prob == prob_max
+            else:
+                assert layer.dropout.prob > 0 and layer.dropout.prob < prob_max
+
+    def test_prepare_layer_dropout_linear(self):
+        class MockModel(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.layers = torch.nn.ModuleList(
+                    [torch.nn.Linear(10, 10) for _ in range(5)]
+                )
+
+        model = MockModel()
+        prob_max = 0.5
+        prob_layer_scale = ScaleType.LINEAR
+        layers_str = ":"
+        prepare_layer_dropout(model.layers, prob_max, prob_layer_scale, layers_str)
+        for i, layer in enumerate(model.layers):
+            assert hasattr(layer, "dropout")
+            if i == 0:
+                assert layer.dropout.prob == 0
+            elif i == len(model.layers) - 1:
+                assert layer.dropout.prob == prob_max
+            elif i == len(model.layers) / 2:
+                assert layer.dropout.prob == prob_max / 2
+            else:
+                assert layer.dropout.prob >= 0.0 and layer.dropout.prob <= prob_max
diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py
index 5bbf860482..5776269b8d 100644
--- a/torchtune/_recipe_registry.py
+++ b/torchtune/_recipe_registry.py
@@ -411,6 +411,17 @@ class Recipe:
         ],
         supports_distributed=False,
     ),
+    Recipe(
+        name="dev/early_exit_finetune_distributed",
+        file_path="dev/early_exit_finetune_distributed.py",
+        configs=[
+            Config(
+                name="llama2/7B_full_early_exit",
+                file_path="dev/7B_full_early_exit.yaml",
+            ),
+        ],
+        supports_distributed=True,
+    ),
     Recipe(
         name="eleuther_eval",
         file_path="eleuther_eval.py",
diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py
index 29c014c33e..3554698d42 100644
--- a/torchtune/modules/__init__.py
+++ b/torchtune/modules/__init__.py
@@ -14,6 +14,7 @@
 )
 from .feed_forward import FeedForward  # noqa
 from .kv_cache import KVCache  # noqa
+from .layer_dropout import LayerDropout, prepare_layer_dropout  # noqa
 from .layer_norm import Fp32LayerNorm  # noqa
 from .low_precision import FrozenNF4Linear  # noqa
 from .position_embeddings import (  # noqa
@@ -30,6 +31,7 @@
 )
 from .vision_transformer import VisionTransformer
 
+
 __all__ = [
     "MultiHeadAttention",
     "TanhGate",
@@ -51,4 +53,6 @@
     "local_kv_cache",
     "delete_kv_caches",
     "disable_kv_cache",
+    "LayerDropout",
+    "prepare_layer_dropout",
 ]
diff --git a/torchtune/modules/common_utils.py b/torchtune/modules/common_utils.py
index 38b5776f88..16df7801cb 100644
--- a/torchtune/modules/common_utils.py
+++ b/torchtune/modules/common_utils.py
@@ -59,6 +59,80 @@ def reparametrize_as_dtype_state_dict_post_hook(
                 state_dict[k] = state_dict[k].cpu()
 
 
+def slice_str_to_array(slice_str: str, length: int) -> list[bool]:
+    """
+    Convert a string representing a Python slice or index into a boolean array.
+
+    The resulting array will have the same length as the specified `length` parameter.
+    Each element in the array corresponds to an index in the original sequence,
+    with `True` indicating that the index is included in the slice and `False` otherwise.
+
+    Args:
+        slice_str (str): A string representing a Python slice or index, e.g. "1:3", ":5", "2::3", "0,4,5".
+        length (int): The length of the original sequence.
+
+    Returns:
+        list[bool]: A boolean array representing the slice.
+
+    Examples:
+        >>> slice_str_to_array("1:3", 5)
+        [False, True, True, False, False]
+        >>> slice_str_to_array(":", 5)
+        [True, True, True, True, True]
+        >>> slice_str_to_array("::2", 5)
+        [True, False, True, False, True]
+        >>> slice_str_to_array("1::2", 5)
+        [False, True, False, True, False]
+        >>> slice_str_to_array("2:5:2", 6)
+        [False, False, True, False, True, False]
+        >>> slice_str_to_array("0,4,5", 7)
+        [True, False, False, False, True, True, False]
+    """
+
+    assert "," not in slice_str or ":" not in slice_str, "Cannot mix commas and colons"
+
+    if "," in slice_str:
+        indices = [int(i) for i in slice_str.split(",")]
+        assert all(0 <= i < length for i in indices), "Index out of range"
+        result = [False] * length
+        for i in indices:
+            result[i] = True
+        return result
+
+    parts = slice_str.split(":")
+    assert len(parts) <= 3, "Invalid slice format"
+    start, end, step = None, None, None
+
+    if len(parts) == 1 and parts[0] != "":
+        start = int(parts[0])
+        end = start + 1
+        step = 1
+    elif len(parts) == 2:
+        start = int(parts[0]) if parts[0] != "" else None
+        end = int(parts[1]) if parts[1] != "" else None
+    elif len(parts) == 3:
+        start = int(parts[0]) if parts[0] != "" else None
+        end = int(parts[1]) if parts[1] != "" else None
+        step = int(parts[2]) if parts[2] != "" else None
+
+    assert start is None or 0 <= start < length, "Start index out of range"
+    assert end is None or 0 <= end < length, "End index out of range"
+    assert step is None or step != 0, "Step cannot be zero"
+
+    result = [False] * length
+    slice_indices = range(
+        start if start is not None else 0,
+        end if end is not None else length,
+        step if step is not None else 1,
+    )
+
+    for i in slice_indices:
+        if 0 <= i < length:
+            result[i] = True
+
+    return result
+
+
 def _low_ram_reparametrize_as_dtype_state_dict_post_hook(
     model: nn.Module,
     state_dict: Dict[str, Any],
diff --git a/torchtune/modules/early_exit_loss.py b/torchtune/modules/early_exit_loss.py
new file mode 100644
index 0000000000..976f6bee07
--- /dev/null
+++ b/torchtune/modules/early_exit_loss.py
@@ -0,0 +1,292 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import copy
+from typing import Callable, Dict, List, Optional
+
+import numpy as np
+import torch
+
+from torchtune import utils
+from torchtune.modules.transformer import TransformerDecoder
+
+log = utils.get_logger("DEBUG")
+
+
+def uniform_loss_scale(
+    layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0
+) -> torch.Tensor:
+    loss_scales = torch.ones(len(layer_ids), device=layer_ids.device)
+    loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0)
+    return loss_scales / torch.sum(loss_scales)
+
+
+def linear_l_loss_scale(
+    layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0
+) -> torch.Tensor:
+    loss_scales = torch.Tensor(layer_ids + 1)
+    loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0)
+    return loss_scales / torch.sum(loss_scales)
+
+
+def sum_l_loss_scale(
+    layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0
+) -> torch.Tensor:
+    loss_scales = torch.cumsum(layer_ids + 1, dim=0)
+    loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0)
+    return loss_scales / torch.sum(loss_scales)
+
+
+def sqrt_l_loss_scale(
+    layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0
+) -> torch.Tensor:
+    loss_scales = torch.sqrt(layer_ids + 1)
+    loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0)
+    return loss_scales / torch.sum(loss_scales)
+
+
+def inv_l_loss_scale(
+    layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0
+) -> torch.Tensor:
+    loss_scales = 1.0 / (layer_ids + 1)
+    loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0)
+    return loss_scales / torch.sum(loss_scales)
+
+
+def inv_sqrt_l_loss_scale(
+    layer_ids: torch.Tensor, n_layers: int, e_scale: float = 1.0
+) -> torch.Tensor:
+    loss_scales = torch.reciprocal(torch.sqrt(layer_ids + 1))
+    loss_scales = loss_scales * torch.where(layer_ids < n_layers - 1, e_scale, 1.0)
+    return loss_scales / torch.sum(loss_scales)
+
+
+def early_exit_loss(
+    model: TransformerDecoder,
+    hidden_states_dict: Dict[int, torch.Tensor],
+    labels: torch.Tensor,
+    loss_fn: torch.nn.Module,
+    e_scale: float = 1.0,
+    loss_scale_fn: Callable[
+        [torch.Tensor, int, float], torch.Tensor
+    ] = uniform_loss_scale,
+) -> torch.Tensor:
+    """
+    Compute the early exit loss for a given model and outputs of intermediate layers.
+    This function takes in a model, a dictionary of hidden states, labels, a loss function,
+    and optional parameters for scaling the loss. It computes the early exit loss by
+    iterating over the hidden states, computing the logits and losses at each layer,
+    and then scaling and summing these losses.
+
+    Args:
+        model (TransformerDecoder): The model to compute the early exit loss for.
+        hidden_states_dict (Dict[int, torch.Tensor]): A dictionary of hidden states,
+            where each key is a layer index and each value is a tensor of shape [b, s, d].
+        labels (torch.Tensor): The labels for the input data.
+        loss_fn (torch.nn.Module): The loss function to use (should be the same as the standard loss function for last layer).
+        e_scale (float): A scaling factor for the early exit losses. Defaults to 1.0.
+        loss_scale_fn (Callable[[torch.Tensor, int, float], torch.Tensor]): A function to determine scale of each
+            layer's loss. Defaults to uniform_loss_scale.
+    Returns:
+        torch.Tensor: The computed early exit loss.
+    """
+    batch_loss_fn = copy.deepcopy(loss_fn)
+    batch_loss_fn.reduction = "none"
+
+    e = len(hidden_states_dict)
+    # List of e tensors with shape [b, s, d]
+    hidden_states = tuple(hidden_states_dict.values())
+    hidden_layer_ids = tuple(hidden_states_dict.keys())
+    # Shape: [e, b, s, d]
+    hidden_states_stacked = torch.stack(hidden_states)
+    # Shape: [e, b, s, out_dim]
+    logits_early = model.unembed(hidden_states_stacked)
+    # Shape: [e*b*s, out_dim]
+    logits_early = logits_early.reshape(-1, logits_early.size(-1))
+    logits_early = logits_early.contiguous()
+    # Shape: [e*b*s]
+    labels_repeated = labels.repeat(e, 1).reshape(-1)
+    # Compute early losses: Shape: [e*b*s]
+    losses_early = batch_loss_fn(logits_early, labels_repeated)
+    # Shape: [e, b*s]
+    losses_early = losses_early.view(e, -1)
+    # Shape: [e]
+    s_unpadded = (labels != loss_fn.ignore_index).sum()
+    losses_early = losses_early.float().sum(-1) / s_unpadded
+    # Shape: [e]
+    losses_scales = loss_scale_fn(
+        torch.Tensor(hidden_layer_ids).to(losses_early),
+        len(model.layers),
+        e_scale,
+    )
+
+    return torch.sum(losses_scales * losses_early)
+
+
+# TODO: create a base curriculum class that can be used for other aspects, e.g., dropout, datasets, etc.
+class EarlyExitCurriculum:
+    """
+    A curriculum for early exit loss training, which controls which layers to use their hidden states
+    during training.
+
+    Args:
+        do_output_hidden_states (List[bool]): A list indicating whether each layer's hidden state
+            should be output to calculate their losses.
+        max_steps (int): The maximum number of steps in the curriculum.
+        train_last_layer (bool, optional): Whether to always calculate loss for the last layer. Defaults to True.
+        last_step (Optional[int]): The last step the curriculum stopped at in a previous run.
+            This is used when resuming training.
+        verbose (bool, optional): Whether to print verbose logs. Defaults to False.
+    """
+
+    def __init__(
+        self,
+        do_output_hidden_states: List[bool],
+        max_steps: int,
+        train_last_layer: bool = True,
+        last_step: Optional[int] = None,
+        verbose: bool = False,
+    ):
+        self._init_do_output_hidden_states = do_output_hidden_states
+        self._do_output_hidden_states = do_output_hidden_states
+        self.train_last_layer = train_last_layer
+        self.verbose = verbose
+        self.max_steps = max_steps
+        self._step = 0 if last_step is None else last_step
+
+    def step(self) -> None:
+        """
+        Perform a step in the curriculum. Should be called at the end of each iteration during training.
+        """
+        pass
+
+    def get(self) -> np.ndarray:
+        """
+        Get the current output hidden states.
+        Returns:
+            np.ndarray: A list indicating whether we should calculate loss for each layer.
+        """
+        do_output_hidden_states = np.copy(self._do_output_hidden_states)
+        # Ensure last layer is trained
+        if self.train_last_layer:
+            do_output_hidden_states[-1] = True
+        return do_output_hidden_states
+
+
+class RotationalEarlyExitCurriculum(EarlyExitCurriculum):
+    """
+    A rotational early exit curriculum, which rotates the layer enablement one step forward
+    at each step.
+
+    Args:
+        do_output_hidden_states (List[bool]): A list indicating whether each layer's hidden state
+            should be output to calculate their losses.
+        max_steps (int): The maximum number of steps in the curriculum.
+        train_last_layer (bool, optional): Whether to always calculate loss for the last layer. Defaults to True.
+        last_step (Optional[int]): The last step the curriculum stopped at in a previous run.
+            This is used when resuming training.
+        verbose (bool, optional): Whether to print verbose logs. Defaults to False.
+    """
+
+    def __init__(
+        self,
+        do_output_hidden_states: List[bool],
+        max_steps: int,
+        train_last_layer: bool = True,
+        last_step: Optional[int] = None,
+        verbose: bool = False,
+    ):
+        super().__init__(
+            do_output_hidden_states=do_output_hidden_states,
+            max_steps=max_steps,
+            train_last_layer=train_last_layer,
+            last_step=last_step,
+            verbose=verbose,
+        )
+        self._initial_do_output_hidden_states = np.copy(self._do_output_hidden_states)
+
+    def step(self):
+        """
+        Rotate the layer enablement one step forward.
+        This method updates the `do_output_hidden_states` attribute by rotating it one position to the right.
+        """
+        # Rotate layer enablement one step forward
+        self._do_output_hidden_states = np.roll(self._do_output_hidden_states, 1)
+
+        self._step += 1
+        if self.verbose:
+            log.info(
+                f"Updated self._do_output_hidden_states to {self._do_output_hidden_states}."
+            )
+
+
+class GradualEarlyExitCurriculum(EarlyExitCurriculum):
+    """
+    A gradual early exit curriculum, which gradually enables more layers (starting from the last layer) as training progresses.
+
+    Args:
+        do_output_hidden_states (List[bool]): A list indicating whether each layer's hidden state
+            should be output to calculate their losses.
+        max_steps (int): The maximum number of steps in the curriculum.
+        train_last_layer (bool): Whether to always calculate loss for the last layer. Defaults to True.
+        last_step (Optional[int]): The last step the curriculum stopped at in a previous run.
+            This is used when resuming training.
+        fraction_scale (float): A scaling factor to determine at which fraction
+            of steps, all the layers will be enabled. At `steps = max_steps * fraction_scale`, all the layers will be
+            enabled. Defaults to 0.5.
+        verbose (bool): Whether to print verbose logs. Defaults to False.
+    """
+
+    def __init__(
+        self,
+        do_output_hidden_states: List[bool],
+        max_steps: int,
+        train_last_layer: bool = True,
+        last_step: Optional[int] = None,
+        fraction_scale: float = 0.5,
+        verbose: bool = False,
+    ):
+        super().__init__(
+            do_output_hidden_states=do_output_hidden_states,
+            max_steps=max_steps,
+            train_last_layer=train_last_layer,
+            last_step=last_step,
+            verbose=verbose,
+        )
+        self._final_do_output_hidden_states = np.copy(self._do_output_hidden_states)
+        self._step = 0
+        self._fraction_scale = fraction_scale
+
+        # Initialize all layers to False
+        for i in range(len(self._do_output_hidden_states)):
+            self._do_output_hidden_states[i] = False
+
+    def step(self):
+        """
+        Perform a step in the curriculum.
+        This method updates the `_do_output_hidden_states` attribute based on the current
+            step and the fraction of completed training steps.
+        """
+        fraction_trained = self._step / self.max_steps
+        n_layers = len(self._do_output_hidden_states)
+        # Enable each layer based on proportion of completed training steps
+        for layer_index in range(len(self._do_output_hidden_states)):
+            should_train = (
+                fraction_trained
+                >= self._fraction_scale * (n_layers - layer_index) / n_layers
+            )
+            self._do_output_hidden_states[layer_index] = should_train
+
+        # Only enable layers that are set by the user
+        self._do_output_hidden_states = np.logical_and(
+            self._do_output_hidden_states, self._final_do_output_hidden_states
+        )
+
+        self._step += 1
+        if self.verbose:
+            log.info(
+                f"Updated self._do_output_hidden_states to {self._do_output_hidden_states}."
+            )
diff --git a/torchtune/modules/layer_dropout.py b/torchtune/modules/layer_dropout.py
new file mode 100644
index 0000000000..75e28b4ae1
--- /dev/null
+++ b/torchtune/modules/layer_dropout.py
@@ -0,0 +1,307 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from enum import Enum
+from typing import Any, Callable, Iterable, Optional, Union
+
+import torch
+
+from torchtune.modules.common_utils import slice_str_to_array
+
+
+class LayerDropout(torch.nn.Module):
+    """
+    A module that applies layer dropout to the input tensor of an underlying module.
+    It drops a portion of an input tensor, applies the underlying module on the
+    remaining parts of the tensor, and then concatenates with the dropped portion of the tensor.
+    When applied during training, it can have a regularization effect, and can potentially speedup training.
+
+    Args:
+        prob (float): The probability of dropping an input. Defaults to 0.0.
+        dim (Optional[int]): The dimension of input tensor along which to drop layers. Defaults to 0 (i.e., batch size).
+        disable_on_eval (Optional[bool]): Whether to disable layer dropout during evaluation. Defaults to True.
+        seed (Optional[int]): The seed for the random number generator. Defaults to None.
+    Examples:
+        >>> import torch
+        >>> # Apply layer dropout to a lambda function
+        >>> layer_dropout = LayerDropout(prob=0.5)
+        >>> output = layer_dropout(lambda x: x**2, torch.randn(1))
+        >>> # Apply layer dropout to a torch.nn.Linear module
+        >>> linear = torch.nn.Linear(5, 3)
+        >>> layer_dropout = LayerDropout(prob=0.5)
+        >>> output = layer_dropout(linear, torch.randn(1, 5))
+    """
+
+    def __init__(
+        self,
+        prob: float = 0.0,
+        dim: Optional[int] = 0,
+        disable_on_eval: Optional[bool] = True,
+        seed: Optional[int] = None,
+    ):
+        super().__init__()
+        self.prob: float = prob
+        self.dim = dim
+        self.disable_on_eval: bool = disable_on_eval
+        self.generator = torch.Generator(device="cpu")
+        self.inferred: float = None
+
+        if seed is not None:
+            self.generator.manual_seed(seed)
+
+    def forward(
+        self,
+        function: Union[Callable, torch.nn.Module],
+        input: torch.Tensor,
+        *args,
+        **kwargs,
+    ) -> torch.Tensor:
+        """
+        Apply layer dropout to the input tensor.
+
+        Args:
+            function (Union[Callable, torch.nn.Module]): The function or module to apply to the input tensor.
+            input (torch.Tensor): The input tensor.
+            *args: Additional positional arguments passed to the function.
+            **kwargs: Additional keyword arguments passed to the function.
+        Returns:
+            torch.Tensor: The output tensor after applying layer dropout.
+        """
+        n = input.shape[self.dim]
+
+        if self.prob == 0 or (self.disable_on_eval and self.training is False):
+            self.inferred = 1.0
+            return function(input, *args, **kwargs)
+
+        skip = (
+            torch.bernoulli(torch.Tensor((n) * [self.prob]), generator=self.generator)
+            .to(input.device)
+            .to(input.dtype)
+        )
+        self.inferred = 1 - torch.mean(skip)
+        ind_selected = (skip == 0).nonzero().squeeze()
+
+        if ind_selected.numel() > 0:
+            x_selected = torch.index_select(input, self.dim, ind_selected)
+            out_selected = function(x_selected, *args, **kwargs)
+
+        out = input.clone()
+        assert (
+            self.dim == 0
+        ), "Currently only supporting dropping elements along the 0th dimension"
+        if ind_selected.numel() > 0:
+            out[ind_selected] = out_selected
+        return out
+
+
+class ModuleLayerDropoutWrapper(torch.nn.Module):
+    """
+    A wrapper module that adds layer dropout functionality to a given module.
+    This class wraps a given module and applies layer dropout to it. It also
+    provides getter and setter methods for the wrapped module's attributes.
+
+    Args:
+        module (torch.nn.Module): The module to wrap.
+        dropout (LayerDropout): The layer dropout object.
+    Examples:
+        >>> import torch
+        >>> from torch import nn
+        >>> # Define a simple model
+        >>> class MyModel(nn.Module):
+        ...     def __init__(self):
+        ...         super().__init__()
+        ...         self.fc1 = nn.Linear(5, 3)
+        ...         self.fc2 = nn.Linear(3, 2)
+        ...
+        ...     def forward(self, x):
+        ...         return self.fc2(self.fc1(x))
+        >>> model = MyModel()
+        >>> fc1 = model.fc1
+        >>> fc2 = model.fc2
+        >>> # Apply layer dropout to the model
+        >>> layer_dropout = LayerDropout(prob=0.5)
+        >>> model = ModuleLayerDropoutWrapper(model, layer_dropout)
+        >>> # Accessing attributes of the wrapped model
+        >>> assert model.dropout.prob == 0.5
+        >>> assert model.fc1 == fc1
+        >>> assert model.fc2 == fc2
+        >>> # Pass an input to the wrapped model as if you are passing it to the original model
+        >>> output = model(torch.randn(1, 5))
+    """
+
+    def __init__(self, module: torch.nn.Module, dropout: LayerDropout):
+        super().__init__()
+        self.module = module
+        self.dropout = dropout
+
+    def forward(self, input: torch.Tensor, *args, **kwargs):
+        return self.dropout(self.module, input, *args, **kwargs)
+
+    def __getattr__(self, name: str) -> Any:
+        """Forward missing attributes to wrapped module."""
+        try:
+            return super().__getattr__(name)  # defer to nn.Module's logic
+        except AttributeError:
+            return getattr(self.module, name)  # fallback to wrapped module
+
+    def __setattr__(self, name: str, value: Any) -> Any:
+        """Forward missing attributes to wrapped module."""
+        try:
+            return super().__setattr__(name, value)  # defer to nn.Module's logic
+        except AttributeError:
+            return setattr(self.module, name, value)  # fallback to wrapped module
+
+    def __getitem__(self, key: int) -> Any:
+        """Forward indexing calls in case the module is a nn.Sequential."""
+        return self.module.__getitem__(key)
+
+    def state_dict(self, *args, **kwargs):
+        """Return the state dictionary of the wrapped module."""
+        return self.module.state_dict(*args, **kwargs)
+
+    def load_state_dict(self, state_dict, *args, **kwargs):
+        """Load the state dictionary into the wrapped module."""
+        self.module.load_state_dict(state_dict, *args, **kwargs)
+        return
+
+
+class ScaleType(str, Enum):
+    UNIFORM = "uniform"
+    EXP = "exp"
+    LINEAR = "linear"
+    LOG = "log"
+    SIN = "sin"
+    SIGMOID = "sigmoid"
+    STEP = "step"
+
+
+def get_scale(
+    scale_type: ScaleType,
+    scale_period: int,
+    val: int,
+) -> float:
+    """
+    Compute a scaling factor based on the provided scale type, period, and value.
+    The scaling factor is designed to be 0 when the value is 0 and 1 when the value
+    reaches or is larger than the scale period.
+
+    Args:
+        scale_type (ScaleType): The type of scaling to use.
+        scale_period (int): The period over which the scaling factor increases from 0 to 1.
+        val (int): The current value used to compute the scaling factor.
+    Returns:
+        float: The computed scaling factor.
+    Examples:
+        >>> get_scale(ScaleType.LINEAR, 10, 5)
+        0.5
+        >>> get_scale(ScaleType.LINEAR, 10, 0)
+        0.0
+        >>> get_scale(ScaleType.LOG, 10, 10)
+        1.0
+    """
+    if scale_period == 0:
+        return 1.0
+    if val >= scale_period:
+        return 1.0
+
+    # all the equations below aim to make scale = 0 when val=0, and scale = 1 when val=scale_period
+    scale = {
+        ScaleType.UNIFORM: 1.0,
+        ScaleType.EXP: math.pow(2, val / scale_period) - 1,
+        ScaleType.LINEAR: val / scale_period,
+        ScaleType.LOG: math.log(val + 1, scale_period + 1),
+        ScaleType.SIN: math.sin(0.5 * math.pi * val / scale_period),
+        ScaleType.SIGMOID: 1 / (1 + math.exp(-10 * (val / scale_period - 0.5))),
+    }[scale_type]
+
+    # ensure returned scale is between 0.0 and 1.0 (inclusive)
+    return max(min(scale, 1.0), 0.0)
+
+
+def prepare_layer_dropout(
+    layers: Union[torch.nn.ModuleList, Iterable[torch.nn.Module]],
+    prob_max: float = 0.0,
+    prob_layer_scale: Optional[ScaleType] = ScaleType.UNIFORM,
+    layers_str: Optional[str] = None,
+    disable_on_eval: Optional[bool] = True,
+) -> None:
+    """
+    Prepare a model's layers for layer dropout by wrapping each layer with a ModuleLayerDropoutWrapper.
+    This function takes in a list of layers, the maximum probability of dropping a layer,
+    the scaling type for the layer dropout probability, a string specifying which
+    layers to apply dropout to, and a boolean indicating whether to disable dropout
+    during evaluation. It then wraps each layer of the model inplace with a
+    ModuleLayerDropoutWrapper, which applies layer dropout to the input tensor.
+
+    Args:
+        layers (Union[torch.nn.ModuleList, Iterable[torch.nn.Module]]): The list of layers to prepare for layer dropout.
+        prob_max (float): The maximum probability of dropping a layer. Defaults to 0.0.
+        prob_layer_scale (Optional[ScaleType]): The scaling type for the dropout probability across layers. Defaults to
+            ScaleType.UNIFORM.
+        layers_str (Optional[str]): A string specifying which layers to apply dropout to. Defaults to None which means
+            apply to all layers.
+        disable_on_eval (Optional[bool]): Whether to disable dropout during evaluation. Defaults to True.
+    Returns:
+        None
+    Example:
+        >>> import torch
+        >>> from torch import nn
+        >>> # Define a simple model
+        >>> class MyModel(nn.Module):
+        ...     def __init__(self):
+        ...         super().__init__()
+        ...         self.layers = nn.ModuleList([
+        ...             nn.Linear(5, 3),
+        ...             nn.Linear(3, 2),
+        ...             nn.Linear(2, 1),
+        ...             nn.Linear(1, 2),
+        ...             nn.Linear(2, 3),
+        ...         ])
+        ...
+        ...     def forward(self, x):
+        ...         for layer in self.layers:
+        ...             x = layer(x)
+        ...         return x
+        >>> model = MyModel()
+        >>> # Apply layer dropout uniformly to all layers
+        >>> prepare_layer_dropout(model.layers, prob_max=0.2, prob_layer_scale=ScaleType.UNIFORM)
+        >>> # Apply layer dropout every other layer, as described in LayerDrop paper
+            (Fan et al., https://arxiv.org/abs/1909.11556v1)
+        >>> prepare_layer_dropout(model.layers, prob_max=0.2, prob_layer_scale=ScaleType.UNIFORM, layers_str="::2")
+        >>> # Apply layer dropout that increases linearly across layers, as described in Progressive Layer
+            Dropout paper (Zhang et al., https://arxiv.org/abs/2010.13369)
+        >>> prepare_layer_dropout(model.layers, prob_max=0.2, prob_layer_scale=ScaleType.LINEAR)
+        >>> # Apply layer dropout that increases exponentially across layers, as described in
+            LayerSkip paper (Elhoushi et al., https://arxiv.org/abs/2404.16710)
+        >>> prepare_layer_dropout(model.layers, prob_max=0.2, prob_layer_scale=ScaleType.EXP)
+    """
+    num_layers = len(layers)
+    has_dropout = (
+        slice_str_to_array(layers_str, num_layers)
+        if layers_str
+        else [True] * num_layers
+    )
+    for layer_id in range(len(layers)):
+        prob = (
+            prob_max
+            * get_scale(
+                scale_type=prob_layer_scale,
+                scale_period=num_layers - 1,
+                val=layer_id,
+            )
+            if has_dropout[layer_id]
+            else 0.0
+        )
+        assert (
+            prob >= 0.0 and prob <= prob_max
+        ), f"prob={prob} should be between 0 and {prob_max}"
+        # We would like each layer to have a different seed, so that we don't have the same samples skipped across layers.
+        # Hence, we use the layer_id as a seed for each layer's dropout.
+        layer_dropout = LayerDropout(
+            prob, disable_on_eval=disable_on_eval, seed=layer_id
+        )
+        layers[layer_id] = ModuleLayerDropoutWrapper(layers[layer_id], layer_dropout)
diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py
index 85d6b22869..66ac92002f 100644
--- a/torchtune/modules/transformer.py
+++ b/torchtune/modules/transformer.py
@@ -640,6 +640,15 @@ def forward(
                 input_pos=input_pos,
             )
 
+        # shape: [b, seq_len, out_dim]
+        output = self.unembed(h)
+
+        # Output list if hidden states are requested, otherwise just the output
+        # TODO: always output a list to have a consistent output type
+        output = output if not hidden else [*hidden, output]
+        return output
+
+    def unembed(self, h):
         # shape: [b, s, d]
         h = self.norm(h)
 
@@ -649,7 +658,4 @@ def forward(
             # shape: [b, seq_len, out_dim]
             output = self.output(h).float()
 
-        # Output list if hidden states are requested, otherwise just the output
-        # TODO: always output a list to have a consistent output type
-        output = output if not hidden else [*hidden, output]
         return output