diff --git a/examples/speculative_decoding/distill_trainer.py b/examples/speculative_decoding/distill_trainer.py new file mode 100644 index 000000000..d39f9d57e --- /dev/null +++ b/examples/speculative_decoding/distill_trainer.py @@ -0,0 +1,402 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +from abc import abstractmethod +from contextlib import nullcontext + +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh +from tqdm import tqdm +from transformers import AutoModelForCausalLM +from transformers.optimization import get_linear_schedule_with_warmup +from transformers.utils import ModelOutput + +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG + +try: + import wandb +except ImportError: + wandb = None + + +mto.enable_huggingface_checkpointing() + +# Hyperparameters for profiling +LOG_INTERVAL = 100 +SAVE_INTERVAL = 20000 + +# Shape and dtype description of the distillation signal +DistillMetadata = dict[str, tuple[torch.Size, torch.dtype]] + + +class BaseDistillTrainer: + """ + Base distill trainer with basic training loop and overlapped teacher and student steps. + Initalized and called on every rank. + Args: + rank: rank of the current process + args: arguments + teacher_step: teacher step function. + student_step: student step function. + """ + + def __init__(self, rank, args, tokenizer, dataloader): + self.rank = rank + self.args = args + self.tokenizer = tokenizer + self.dataloader = dataloader + + # Prepare models + if rank in args.student_ranks: + self.model = self._prepare_student_model() + self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr) + self.scheduler = get_linear_schedule_with_warmup( + self.optimizer, num_warmup_steps=0, num_training_steps=117380 + ) + else: + self.model = self._prepare_teacher_model() + self._print_model_placement(self.model) + + def _print_model_placement(self, module): + for name, param in module.named_parameters(): + print(f"(Rank {self.rank}) {name} ---> {param.device} ") + + def _reset_all_mem_stats(self): + torch.cuda.reset_max_memory_allocated(self.current_rank_device) + + def _print_mem_stats(self): + max_mem = torch.cuda.max_memory_allocated(self.current_rank_device) + print(f"GPU {self.current_rank_device}: Max memory allocated: {max_mem / 1024**3:.2f} GB") + + @property + def current_rank_device(self): + """Return device of the current rank.""" + + @property + def distill_metadata(self): + """Return a DistillMetadata that describe the distillation message received by student.""" + + @abstractmethod + def _prepare_teacher_model(self): + """Return coverted teacher model with correct parallelization.""" + + @abstractmethod + def _prepare_student_model(self): + """Return coverted student model with correct parallelization.""" + + @abstractmethod + def teacher_step(self, *args, **kwargs) -> list[dict[str, torch.Tensor]]: + """Run one student step and return distillation messages for each student rank.""" + + @abstractmethod + def student_step(self, *args, **kwargs) -> ModelOutput: + """Run forward of student step, return a modeloutput object.""" + + def save_pretrained(self, save_path): + """Save the model and tokenizer.""" + if self.rank == self.args.student_ranks[0]: + if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): + self.model.module.save_pretrained(save_path) + else: + self.model.save_pretrained(save_path) + self.tokenizer.save_pretrained(save_path) + print(f"Pretrained model saved to {save_path}") + + def _check_valid_message(self, message: dict[str, torch.Tensor]): + """Check if message in the format of distill_metadata.""" + if set(message.keys()) != set(self.distill_metadata.keys()): + raise ValueError( + f"Message keys: {set(message.keys())} \n" + f"do not match expected keys {set(self.distill_metadata.keys())}" + ) + if len(message) != len(self.distill_metadata): + raise ValueError( + f"Message length: {len(message)} \n" + f"does not match expected {len(self.distill_metadata)}" + ) + for k, v in message.items(): + if v.shape != self.distill_metadata[k][0] or v.dtype != self.distill_metadata[k][1]: + raise ValueError( + f"Invalid message. {k} has shape {v.shape} and dtype {v.dtype}, \n" + f"expected {self.distill_metadata[k]}" + ) + + def _init_student_recv_buffer(self): + self.student_recv_buffer = { + k: torch.empty(v[0], device=self.current_rank_device, dtype=v[1]) + for k, v in self.distill_metadata.items() + } + + def _recv_from_teacher(self): + reqs = [ + dist.irecv(buffer, src=self.args.teacher_ranks[0]) + for buffer in self.student_recv_buffer.values() + ] + for req in reqs: + req.wait() + + def _clone_recv_buffer(self): + """Return a copy of received tensors for student step input.""" + return {k: v.clone().detach() for k, v in self.student_recv_buffer.items()} + + def _send_to_student(self, teacher_outputs): + if self.rank != self.args.teacher_ranks[0]: + return + # TODO: use broadcast + assert len(teacher_outputs) == len(self.args.student_ranks), ( + f"Number of teacher outputs {len(teacher_outputs)} does not \ + match number of student ranks {len(self.args.student_ranks)}" + ) + for s in self.args.student_ranks: + self._check_valid_message(teacher_outputs[s]) + reqs = [dist.isend(buffer, dst=s) for buffer in teacher_outputs[s].values()] + for req in reqs: + req.wait() + + def _get_logging_context(self): + print( + f"Rank {self.rank} is logging: {wandb is not None and self.rank == self.args.student_ranks[0]}" + ) + if wandb is not None and self.rank == self.args.student_ranks[0]: + return wandb.init( + entity=os.environ["WANDB_ENTITY"], + project=os.environ["WANDB_PROJECT"], + config={ + "epochs": self.args.epoch, + "lr": self.args.lr, + "batch_size": self.args.batch_size, + }, + ) + return nullcontext() + + def train(self): + """Main training entrance of the composed model.""" + self._reset_all_mem_stats() + + if self.rank in self.args.student_ranks: + with self._get_logging_context() as run: + self._init_student_recv_buffer() + + # Student training loop + for epoch in range(self.args.epoch): + pbar = ( + tqdm(self.dataloader) + if self.rank == self.args.student_ranks[0] + else self.dataloader + ) + for i, batch in enumerate(pbar): + global_step = epoch * len(self.dataloader) + i + inputs = {k: v.to(self.model.device) for k, v in batch.items()} + + # Receive distill messages from teacher + self._recv_from_teacher() + + # Run forward of student step + output = self.student_step(inputs, **self._clone_recv_buffer()) + loss = output.loss + + # Run backward step + loss.backward() + self.optimizer.step() + self.scheduler.step() + + # Log and save only on student rank 0 + if self.rank != self.args.student_ranks[0]: + continue + + train_metrics = { + "loss": round(loss.item(), 3), + "lr": self.optimizer.param_groups[0]["lr"], + # Attach all float metrics + **{k: round(v, 3) for k, v in output.items() if isinstance(v, float)}, + } + + pbar.set_description(f"Epoch {epoch} Loss {train_metrics['loss']}") + if global_step % LOG_INTERVAL == 0: + run.log(train_metrics, step=global_step) + if global_step > 0 and global_step % SAVE_INTERVAL == 0: + self.save_pretrained( + f"{self.args.out_path}/epoch_{epoch}_step_{global_step}" + ) + + else: + # Inference Loop + for epoch in range(self.args.epoch): + for i, batch in enumerate(self.dataloader): + inputs = {k: v.to(self.model.device) for k, v in batch.items()} + with torch.inference_mode(): + self._send_to_student(self.teacher_step(self.model, inputs)) + + self._print_mem_stats() + # Makesure all processes finished before destroy. + dist.barrier() + # clean up processess + dist.destroy_process_group() + + +class EagleTPTrainer(BaseDistillTrainer): + """A subclass of BaseDistillTrainer for online eagle training, with base model TP and student DDP.""" + + def __init__(self, rank, args, tokenizer, dataloader): + # Load eagle config + args.eagle_config = EAGLE3_DEFAULT_CFG["config"] + if args.eagle_config_path: + with open(args.eagle_config_path) as f: + custom_config = json.load(f) + args.eagle_config["eagle_architecture_config"].update(custom_config) + + super().__init__(rank, args, tokenizer, dataloader) + + @property + def current_rank_device(self): + if self.rank in self.args.student_ranks: + return self.args.student_devices[self.rank] + else: + return self.args.teacher_devices[self.rank - len(self.args.student_ranks)] + + def _prepare_teacher_model(self): + # Load model with TP among teacher ranks. + model = AutoModelForCausalLM.from_pretrained( + self.args.model_path, + torch_dtype="auto", + tp_plan="auto", + device_mesh=DeviceMesh.from_group(self.args.teacher_pgroup, "cuda"), + ) + # load eagle config and convert. + self.args.eagle_config["eagle_architecture_config"].update( + { + "hidden_size": model.config.hidden_size, + "vocab_size": model.config.vocab_size, + "draft_vocab_size": model.config.vocab_size, + } + ) + mtsp.convert(model, [("eagle", self.args.eagle_config)]) + model.eval() + return model + + def _prepare_student_model(self): + # Load to CPU first to avoid OOM + model = AutoModelForCausalLM.from_pretrained( + self.args.model_path, torch_dtype="auto", device_map="cpu" + ) + # Hidden size and vocab size must match base model + self.args.eagle_config["eagle_architecture_config"].update( + { + "hidden_size": model.config.hidden_size, + "vocab_size": model.config.vocab_size, + "draft_vocab_size": model.config.vocab_size, + } + ) + mtsp.convert( + model, + [("eagle", self.args.eagle_config)], + ) + + # TODO:copy needed modules and del the rest + model.model._modules.pop("layers") + model.to(self.current_rank_device) + + model.train() + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.current_rank_device], + process_group=self.args.student_pgroup, + find_unused_parameters=True, + ) + return model + + @property + def distill_metadata(self) -> DistillMetadata: + """Description of the distillation signal received by student.""" + return { + "base_model_hidden_states": ( + torch.Size( + [ + int(self.args.batch_size / len(self.args.student_ranks)), + self.args.training_seq_len, + self.args.eagle_config["eagle_architecture_config"]["hidden_size"], + ] + ), + torch.bfloat16, + ), + "aux_hidden_states": ( + torch.Size( + [ + int(self.args.batch_size / len(self.args.student_ranks)), + self.args.training_seq_len, + self.args.eagle_config["eagle_architecture_config"]["hidden_size"] * 3, + ] + ), + torch.bfloat16, + ), + "base_model_logits": ( + torch.Size( + [ + int(self.args.batch_size / len(self.args.student_ranks)), + self.args.training_seq_len, + self.args.eagle_config["eagle_architecture_config"]["draft_vocab_size"], + ] + ), + torch.bfloat16, + ), + } + + def teacher_step(self, model, inputs): + # Collect base model outputs. + base_model_hidden_states, base_model_logits, _, _ = model._base_model_forward( + **inputs, + freeze_base_model=True, + past_key_values=None, + ) + + # Aux_hidden_states could be on multiple devices. Gather before cat. + aux_hidden_states = torch.cat( + [t.to(base_model_logits.device) for t in model.pop_aux_hidden_states()], dim=-1 + ) + + # Chunk the tensors for each student rank. + base_model_hidden_states = base_model_hidden_states.chunk(len(self.args.student_ranks)) + base_model_logits = base_model_logits.chunk(len(self.args.student_ranks)) + aux_hidden_states = aux_hidden_states.chunk(len(self.args.student_ranks)) + + return [ + { + "base_model_hidden_states": base_model_hidden_states[i], + "aux_hidden_states": aux_hidden_states[i], + "base_model_logits": base_model_logits[i], + } + for i in range(len(self.args.student_ranks)) + ] + + def student_step( + self, + inputs, + **distill_msgs, + ) -> ModelOutput: + self.optimizer.zero_grad() + + # Chunk input_ids and attention_mask for each student rank. + inputs = {k: v.chunk(len(self.args.student_ranks))[self.rank] for k, v in inputs.items()} + + # Second stage forward with provided base model outputs. + output = self.model(**inputs, base_model_outputs=distill_msgs) + + return output diff --git a/examples/speculative_decoding/train.py b/examples/speculative_decoding/train.py new file mode 100644 index 000000000..20ce5e7c4 --- /dev/null +++ b/examples/speculative_decoding/train.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from distill_trainer import EagleTPTrainer +from eagle_utils import DataCollatorWithPadding, make_eagle_supervised_data_module +from transformers import AutoTokenizer + +# Hyperparameters for profiling +torch.manual_seed(0) + + +def _setup_distributed(rank, args, backend="nccl"): + """Initialize distributed environment""" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = args.master_port + os.environ["LOCAL_RANK"] = str(rank) + # Initialize process group + dist.init_process_group(backend, rank=rank, world_size=args.world_size) + if rank in args.student_ranks: + torch.cuda.set_device(args.student_devices[rank]) + else: + torch.cuda.set_device(args.teacher_devices[rank - len(args.student_ranks)]) + print( + f"Starting process rank={rank}, device={torch.cuda.current_device()}, world_size={args.world_size}" + ) + args.teacher_pgroup = dist.new_group(ranks=args.teacher_ranks) + args.student_pgroup = dist.new_group(ranks=args.student_ranks) + + +def train(rank, args): + _setup_distributed(rank, args) + + tokenizer = AutoTokenizer.from_pretrained( + args.model_path, model_max_length=args.training_seq_len + ) + data_module = make_eagle_supervised_data_module(tokenizer, args, use_offline_training=False) + + train_dataloader = torch.utils.data.DataLoader( + data_module["train_dataset"], + batch_size=args.batch_size, + shuffle=True, + num_workers=0, + collate_fn=DataCollatorWithPadding(max_length=args.training_seq_len), + drop_last=True, + ) + + trainer = EagleTPTrainer(rank, args, tokenizer, train_dataloader) + trainer.train() + trainer.save_pretrained(args.out_path) + + +def main(): + parser = argparse.ArgumentParser(description="Multi-GPU distributed two-stage forward example") + parser.add_argument("--model_path", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + parser.add_argument("--student_devices", type=list, default=[0, 1, 2, 3]) + parser.add_argument("--teacher_devices", type=list, default=[4, 5]) + parser.add_argument( + "--data_path", type=str, default="data/magpie_llama3.2_1b_generated/data.cleaned.jsonl" + ) + parser.add_argument("--training_seq_len", type=str, default=1024) + parser.add_argument("--eagle_config_path", type=str, default="eagle_config.json") + parser.add_argument( + "--lazy_preprocess", type=bool, default=True, help="Whether to use lazy preprocessing." + ) + parser.add_argument("--out_path", type=str, default="ckpts/fast-trained") + parser.add_argument("--lr", type=float, default=1e-5) + parser.add_argument("--epoch", type=int, default=1) + parser.add_argument( + "--batch_size", type=int, default=4, help="Total batch size across all parallel ranks." + ) + parser.add_argument("--master_port", type=str, default="12357") + + args = parser.parse_args() + # TODO: add sanity check for args + + def set_ranks(args): + # TODO(hg): This is for TP-DDP setting only. Add "no-parallel", "MP", "FSDP". + args.world_size = len(args.teacher_devices) + len(args.student_devices) + args.student_ranks = list(range(len(args.student_devices))) + args.teacher_ranks = list( + range(len(args.student_devices), len(args.student_devices) + len(args.teacher_devices)) + ) + + set_ranks(args) + # Launch multiple processes + mp.spawn( + train, + args=(args,), + nprocs=args.world_size, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index ad0b32074..7afbd1e1c 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -586,10 +586,10 @@ def _base_model_forward( self, input_ids, attention_mask, - position_ids, - past_key_values, - freeze_base_model, - labels, + position_ids=None, + past_key_values=None, + freeze_base_model=True, + labels=None, **kwargs, ): # TODO: This function still use eagle_module. Ideally we should remove it, @@ -726,7 +726,7 @@ def forward( # ====Run eagle forward==== eagle_loss = None - train_accs = [] + train_accs = {} if self.training: # In EAGLE-3, we have an additional FC layer to concentrate hidden states from multiple base model layers b, seq_length, h = base_model_hidden_states.shape @@ -770,7 +770,7 @@ def forward( loss_mask[:, 1:], ) eagle_loss = classification_loss - train_accs.append(acc) + train_accs["train_acc_step0"] = acc # ====Perform training-time-testing with 3 extra eagle forward passes==== for ttt_step in range(self.num_ttt_steps): @@ -811,7 +811,7 @@ def forward( ), ) eagle_loss += classification_loss - train_accs.append(acc) + train_accs[f"train_acc_step{ttt_step + 1}"] = acc # Finally, we merge base model loss and eagle loss, raise error if both are None if base_model_loss is not None and eagle_loss is not None: loss = base_model_loss + eagle_loss @@ -830,7 +830,7 @@ def forward( logits=base_model_logits, past_key_values=past_key_values, hidden_states=base_model_hidden_states, - train_acc=train_accs, + **train_accs, ) def _eagle_loss(