diff --git a/examples/moe_dualpipe/README.md b/examples/moe_dualpipe/README.md new file mode 100644 index 0000000..5235fc8 --- /dev/null +++ b/examples/moe_dualpipe/README.md @@ -0,0 +1,28 @@ + + + +## Clone the DualPipe & Setup Environment + +```bash +git clone https://github.com/deepseek-ai/DualPipe.git +cd dualpipe +conda create -n dualpipe python=3.10 -y +conda activate dualpipe +pip install -r requirements.txt +pip install -e . +``` + +## Naive Implementation for Single-GPU and Multi-GPU Training of MoE Models +```bash +MASTER_ADDR=localhost MASTER_PORT=12355 WORLD_SIZE=4 python examples/moe_train_basic.py +``` + +### Parameters +- WORLD_SIZE=4: Uses 4 GPUs for pipeline parallelism +- MASTER_ADDR: Master node address +- MASTER_PORT: Communication port +- `test_moe_basic()`: Tests basic functionality of the MoE model + + + + diff --git a/examples/moe_dualpipe/dualpipe/__init__.py b/examples/moe_dualpipe/dualpipe/__init__.py new file mode 100644 index 0000000..de18a21 --- /dev/null +++ b/examples/moe_dualpipe/dualpipe/__init__.py @@ -0,0 +1,17 @@ +__version__ = "1.0.0" + +from dualpipe.dualpipe import DualPipe +from dualpipe.dualpipev import DualPipeV +from dualpipe.comm import ( + set_p2p_tensor_shapes, + set_p2p_tensor_dtype, +) +from dualpipe.utils import WeightGradStore + +__all__ = [ + DualPipe, + DualPipeV, + WeightGradStore, + set_p2p_tensor_shapes, + set_p2p_tensor_dtype, +] diff --git a/examples/moe_dualpipe/dualpipe/comm.py b/examples/moe_dualpipe/dualpipe/comm.py new file mode 100644 index 0000000..e779a77 --- /dev/null +++ b/examples/moe_dualpipe/dualpipe/comm.py @@ -0,0 +1,38 @@ +from typing import List, Tuple + +import torch +import torch.distributed as dist + + +TENSOR_SHAPES: List[Tuple[int]] = None +TENSOR_DTYPE: torch.dtype = None + + +def set_p2p_tensor_shapes(shapes: List[Tuple[int]]): + global TENSOR_SHAPES + TENSOR_SHAPES = shapes + + +def set_p2p_tensor_dtype(dtype: torch.dtype): + global TENSOR_DTYPE + TENSOR_DTYPE = dtype + + +def build_from_tensor_shapes(): + return [torch.empty(s, dtype=TENSOR_DTYPE, device="cuda", requires_grad=True) for s in TENSOR_SHAPES] + + +def append_irecv(ops: List[dist.P2POp], src: int, group: dist.ProcessGroup) -> List[torch.Tensor]: + tensors = build_from_tensor_shapes() + src = dist.distributed_c10d.get_global_rank(group, src) + for tensor in tensors: + if tensor is not None: + ops.append(dist.P2POp(dist.irecv, tensor, src)) + return tensors + + +def append_isend(ops: List[dist.P2POp], tensors: List[torch.Tensor], dst: int, group: dist.ProcessGroup) -> None: + dst = dist.distributed_c10d.get_global_rank(group, dst) + for tensor in tensors: + if tensor is not None: + ops.append(dist.P2POp(dist.isend, tensor, dst)) diff --git a/examples/moe_dualpipe/dualpipe/dualpipe.py b/examples/moe_dualpipe/dualpipe/dualpipe.py new file mode 100644 index 0000000..ae67f4a --- /dev/null +++ b/examples/moe_dualpipe/dualpipe/dualpipe.py @@ -0,0 +1,440 @@ +from typing import Tuple, List, Union, Callable, Optional + +import torch +import torch.nn as nn +import torch.distributed as dist + +import dualpipe.comm as comm +from dualpipe.utils import WeightGradStore, run_backward, scatter, gather + + +class DualPipe(nn.Module): + def __init__( + self, + modules: Tuple[nn.Module, nn.Module], + batch_dim: int = 0, + process_group: Optional[dist.ProcessGroup] = None, + rank_mapping: Optional[List[int]] = None, + ) -> None: + super().__init__() + + assert next(modules[0].parameters()).device == torch.device(torch.cuda.current_device()) + self.module = nn.ModuleList(modules) + self.overlapped_forward_backward = type(modules[0]) == type(modules[1]) and hasattr(type(modules[0]), "overlapped_forward_backward") + self.batch_dim = batch_dim + self.group = process_group or dist.distributed_c10d._get_default_group() + self.num_ranks = self.group.size() + + # rank_mapping: Map rank in process_group to actual pp rank. + # rank_inverse_mapping: Map actual pp rank to rank in process_group. + if rank_mapping is None: + rank_mapping = list(range(self.num_ranks)) + rank_inverse_mapping = [None] * (self.num_ranks + 1) + for i in range(self.num_ranks): + rank_inverse_mapping[rank_mapping[i]] = i + + self.rank = rank_mapping[self.group.rank()] + self.first_rank = rank_inverse_mapping[0] + self.prev_rank = rank_inverse_mapping[self.rank - 1] + self.next_rank = rank_inverse_mapping[self.rank + 1] + self.last_rank = rank_inverse_mapping[self.num_ranks - 1] + + self.is_first_rank = self.rank == 0 + self.is_last_rank = self.rank == self.num_ranks - 1 + self.is_in_second_half = self.rank >= self.num_ranks // 2 + self.is_middle_rank = (self.rank == self.num_ranks // 2 - 1) or (self.rank == self.num_ranks // 2) + + def _reset_states(self) -> None: + WeightGradStore.clear() + + self.input_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], []) + self.output_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], []) + self.input_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], []) + self.output_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], []) + self.labels: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = None + self.loss_chunks: List[torch.Tensor] = [] + self.criterion: Callable = None + + self.current_f_chunk_id: List[int] = [0, 0] + self.current_b_chunk_id: List[int] = [0, 0] + self.current_send_f_chunk_id: List[int] = [0, 0] + self.current_send_b_chunk_id: List[int] = [0, 0] + self.current_recv_f_chunk_id: List[int] = [0, 0] + self.current_recv_b_chunk_id: List[int] = [0, 0] + self.comm_ops: List[dist.P2POp] = [] + self.to_free: List[torch.Tensor] = [] + + def _forward_compute_chunk(self, phase: int) -> None: + phase ^= self.is_in_second_half + chunk_id = self.current_f_chunk_id[phase] + self.current_f_chunk_id[phase] += 1 + inputs = self.input_chunks[phase][chunk_id] + if self.forward_only: + self.input_chunks[phase][chunk_id] = None + + is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0) + + outputs = self.module[phase](*inputs) + outputs = [outputs] if isinstance(outputs, torch.Tensor) else outputs + if is_last_stage and self.criterion is not None: + labels = self.labels[phase][chunk_id] + loss = self.criterion(*outputs, *labels) + self.loss_chunks.append(loss) + + if (not is_last_stage) or self.return_outputs: + self.output_chunks[phase].append(outputs) + + def _backward_compute_chunk(self, phase: int, enable_zb: bool = False) -> None: + if self.forward_only: + return + + phase ^= self.is_in_second_half + chunk_id = self.current_b_chunk_id[phase] + self.current_b_chunk_id[phase] += 1 + + is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0) + + WeightGradStore.enabled = enable_zb + if is_last_stage: + loss = self.loss_chunks[chunk_id] + loss.backward() + loss.detach_() + else: + outputs = self.output_chunks[phase][chunk_id] + if not self.return_outputs: + self.output_chunks[phase][chunk_id] = None + output_grads = self.output_grad_chunks[phase][chunk_id] + self.output_grad_chunks[phase][chunk_id] = None + non_empty = [(t, g) for t, g in zip(outputs, output_grads) if g is not None] + outputs, output_grads = list(zip(*non_empty)) + if len(outputs) > 0: + run_backward(outputs, output_grads) + WeightGradStore.enabled = False + if enable_zb: + WeightGradStore.flush() + + inputs = self.input_chunks[phase][chunk_id] + self.input_chunks[phase][chunk_id] = None + input_grads = [t.grad for t in inputs] + self.input_grad_chunks[phase].append(input_grads) + + def _forward_backward_compute_chunk(self, phase0: int, phase1: int) -> None: + if self.forward_only: + self._forward_compute_chunk(phase0) + return + + if not self.overlapped_forward_backward: + self._forward_compute_chunk(phase0) + self._backward_compute_chunk(phase1) + return + + # pre-forward + phase0 ^= self.is_in_second_half + chunk_id0 = self.current_f_chunk_id[phase0] + self.current_f_chunk_id[phase0] += 1 + module0 = self.module[phase0] + inputs0 = self.input_chunks[phase0][chunk_id0] + is_last_stage0 = (self.is_first_rank and phase0 == 1) or (self.is_last_rank and phase0 == 0) + + if is_last_stage0 and self.criterion is not None: + labels0 = self.labels[phase0][chunk_id0] + criterion0 = self.criterion + else: + labels0 = [] + criterion0 = None + + # pre-backward + phase1 ^= self.is_in_second_half + chunk_id1 = self.current_b_chunk_id[phase1] + self.current_b_chunk_id[phase1] += 1 + module1 = self.module[phase1] + is_last_stage1 = (self.is_first_rank and phase1 == 1) or (self.is_last_rank and phase1 == 0) + + if is_last_stage1: + loss1 = self.loss_chunks[chunk_id1] + outputs1 = [] + output_grads1 = [] + else: + loss1 = None + outputs1 = self.output_chunks[phase1][chunk_id1] + if not self.return_outputs: + self.output_chunks[phase1][chunk_id1] = None + output_grads1 = self.output_grad_chunks[phase1][chunk_id1] + self.output_grad_chunks[phase1][chunk_id1] = None + non_empty = [(t, g) for t, g in zip(outputs1, output_grads1) if g is not None] + outputs1, output_grads1 = list(zip(*non_empty)) + + # forward & backward + outputs0, loss0 = type(module0).overlapped_forward_backward( + module0, inputs0, criterion0, labels0, + module1, loss1, outputs1, output_grads1, + ) + + # post-forward + if (not is_last_stage0) or self.return_outputs: + self.output_chunks[phase0].append(outputs0) + if is_last_stage0 and self.criterion is not None: + self.loss_chunks.append(loss0) + + # post-backward + inputs = self.input_chunks[phase1][chunk_id1] + self.input_chunks[phase1][chunk_id1] = None + input_grads1 = [t.grad for t in inputs] + self.input_grad_chunks[phase1].append(input_grads1) + + def _forward_chunk(self, phase: int, recv: bool = True, send: bool = True) -> None: + if recv: + self._recv_forward(phase) + self._commit_and_wait_comm() + + self._forward_compute_chunk(phase) + + if send: + self._send_forward(phase) + + def _backward_chunk(self, phase: int, enable_zb: bool = False, recv: bool = True, send: bool = True) -> None: + if recv: + self._recv_backward(phase) + self._commit_and_wait_comm() + + self._backward_compute_chunk(phase, enable_zb) + + if send: + self._send_backward(phase) + + def _forward_backward_chunk(self, phase0: int, phase1: int, recv0: bool = True) -> None: + if recv0: + self._recv_forward(phase0) + self._recv_backward(phase1) + self._commit_and_wait_comm() + + self._forward_backward_compute_chunk(phase0, phase1) + + self._send_forward(phase0) + self._send_backward(phase1) + + def _weight_chunk(self) -> None: + if self.forward_only: + return + + self._commit_and_wait_comm() + + # Assume FIFO + WeightGradStore.pop() + + def _free_tensors(self) -> None: + for tensor in self.to_free: + assert tensor._base is None, f"pipeline stage should not return view tensors {dist.get_rank(), tensor.shape}" + tensor.data = torch.Tensor() + self.to_free = [] + + def _recv_forward(self, phase: int) -> None: + phase ^= self.is_in_second_half + is_first_stage = (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1) + if is_first_stage: + return + + self.current_recv_f_chunk_id[phase] += 1 + tensors = comm.append_irecv(self.comm_ops, self.prev_rank if phase == 0 else self.next_rank, self.group) + self.input_chunks[phase].append(tensors) + + def _send_forward(self, phase: int) -> None: + phase ^= self.is_in_second_half + is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0) + if is_last_stage: + return + + chunk_id = self.current_send_f_chunk_id[phase] + self.current_send_f_chunk_id[phase] += 1 + tensors = self.output_chunks[phase][chunk_id] + + comm.append_isend(self.comm_ops, tensors, self.next_rank if phase == 0 else self.prev_rank, self.group) + + if not self.return_outputs: + self.to_free.extend(tensors) + + def _recv_backward(self, phase: int) -> None: + if self.forward_only: + return + + phase ^= self.is_in_second_half + is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0) + if is_last_stage: + return + + self.current_recv_b_chunk_id[phase] += 1 + tensors = comm.append_irecv(self.comm_ops, self.next_rank if phase == 0 else self.prev_rank, self.group) + self.output_grad_chunks[phase].append(tensors) + + def _send_backward(self, phase: int) -> None: + if self.forward_only: + return + + phase ^= self.is_in_second_half + is_first_stage = (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1) + if is_first_stage: + return + + chunk_id = self.current_send_b_chunk_id[phase] + self.current_send_b_chunk_id[phase] += 1 + tensors = self.input_grad_chunks[phase][chunk_id] + self.input_grad_chunks[phase][chunk_id] = None + + comm.append_isend(self.comm_ops, tensors, self.prev_rank if phase == 0 else self.next_rank, self.group) + + def _commit_and_wait_comm(self) -> None: + if not self.comm_ops: + return + reqs = dist.batch_isend_irecv(self.comm_ops) + for req in reqs: + req.wait() + self.comm_ops = [] + self._free_tensors() + + def step( + self, + *inputs: Optional[torch.Tensor], + num_chunks: int = 0, + criterion: Optional[Callable] = None, + labels: List[Optional[torch.Tensor]] = [], + return_outputs: bool = False, + ) -> Tuple[Optional[torch.Tensor], Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]]: + """ + Execute a training or inference step. + + Arguments: + *inputs: Module inputs. Required only on the first/last ranks. + num_chunks: The number of micro-batches. + criterion: Loss function, invoked as ``criterion(*outputs, *labels)``. Required only on the first/last ranks. + labels: Labels of the loss function. Required only on the first/last ranks. + labels on the first rank corresponds to inputs on the last rank. + labels on the last rank corresponds to inputs on the first rank. + return_outputs: Whether to return outputs on the first/last ranks. Default: ``False``. + + Returns: (loss, outputs) + loss: Loss for the batch. + loss on the first rank corresponds to inputs on the last rank. + loss on the last rank corresponds to inputs on the first rank. + Otherwise: ``None``. + outputs: Returned only if ``return_outputs=True``. + outputs on the first rank corresponds to inputs on the last rank. + outputs on the last rank corresponds to inputs on the first rank. + Otherwise: ``None``. + + """ + assert comm.TENSOR_SHAPES is not None and comm.TENSOR_DTYPE is not None, \ + "You need to call set_p2p_tensor_shapes and set_p2p_tensor_dtype before doing a step." + self.forward_only = not torch.is_grad_enabled() + self.return_outputs = return_outputs + + rank = self.rank + num_ranks = self.num_ranks + assert num_ranks % 2 == 0 + assert num_chunks > 0 and num_chunks % 2 == 0 and num_chunks >= num_ranks * 2, f"{num_chunks=}, {num_ranks=}" + num_half_ranks = num_ranks // 2 + half_rank = min(rank, num_ranks - 1 - rank) + half_num_chunks = num_chunks // 2 + self.num_half_ranks = num_half_ranks + self.half_rank = half_rank + + if not self.forward_only and (self.is_first_rank or self.is_last_rank): + assert criterion is not None + + self._reset_states() + + inputs = scatter(inputs, half_num_chunks, self.batch_dim) + labels = scatter(labels, half_num_chunks, self.batch_dim) + if self.is_first_rank: + self.input_chunks = (inputs, []) + self.labels = ([], labels) + elif self.is_last_rank: + self.input_chunks = ([], inputs) + self.labels = (labels, []) + self.criterion = criterion + + # For the first half of the ranks: phase 0 means forward direction, phase 1 means reverse direction. + # For the second half of the ranks: phase 0 means reverse direction, phase 1 means forward direction. + + # Step 1: nF0 + step_1 = (num_half_ranks - half_rank - 1) * 2 + for i in range(step_1): + self._forward_chunk(0) + + # Step 2: nF0F1 + step_2 = half_rank + 1 + self._recv_forward(0) + for i in range(step_2): + self._forward_chunk(0, recv=False, send=self.is_middle_rank) + self._recv_forward(0) + self._forward_chunk(1, send=(not self.is_middle_rank) or (i < step_2 - 1)) + if not self.is_middle_rank: + self._send_forward(0) + + # Step 3: nB1W1F1 (Use zero bubble) + step_3 = num_half_ranks - half_rank - 1 + for i in range(step_3): + self._backward_chunk(1, enable_zb=True) + self._recv_forward(1) + self._weight_chunk() + self._forward_chunk(1, recv=False) + + # Step 4 (Main step): nF0B1F1B0 + step_4 = half_num_chunks - num_ranks + half_rank + 1 + for i in range(step_4): + if i == 0: + if self.is_middle_rank: + # NOTE: We don't overlap these two chunks to further reduce bubble size. + self._forward_chunk(0, recv=False, send=False) + self._send_forward(1) + self._backward_chunk(1, send=False) + self._send_forward(0) + self._send_backward(1) + else: + self._forward_backward_chunk(0, 1, recv0=False) + else: + self._forward_backward_chunk(0, 1) + self._forward_backward_chunk(1, 0) + + # Step 5: nB1F1B0 + step_5 = num_half_ranks - half_rank - 1 + for i in range(step_5): + self._backward_chunk(1) + self._forward_backward_chunk(1, 0) + + # Step 6: nB1B0 (The second half of the chunks use zero bubble) + step_6 = half_rank + 1 + enable_zb = False + for i in range(step_6): + if i == step_6 // 2 and half_rank % 2 == 1: + enable_zb = True + self._backward_chunk(1, enable_zb=enable_zb) + if i == step_6 // 2 and half_rank % 2 == 0: + enable_zb = True + self._backward_chunk(0, enable_zb=enable_zb) + + # Step 7: nWB0 (Use zero bubble) + step_7 = num_half_ranks - half_rank - 1 + for i in range(step_7): + self._weight_chunk() + self._backward_chunk(0, enable_zb=True) + + # Step 8: nW + step_8 = half_rank + 1 + for i in range(step_8): + self._weight_chunk() + assert WeightGradStore.funcs_queue.empty() + + self._commit_and_wait_comm() + + loss, outputs = None, None + if self.is_first_rank or self.is_last_rank: + if criterion is not None: + loss = torch.stack(self.loss_chunks) + if return_outputs: + outputs = gather(self.output_chunks[self.is_first_rank], self.batch_dim) + if len(outputs) == 1: + outputs = outputs[0] + + self._reset_states() + + return loss, outputs \ No newline at end of file diff --git a/examples/moe_dualpipe/dualpipe/dualpipev.py b/examples/moe_dualpipe/dualpipe/dualpipev.py new file mode 100644 index 0000000..cb8dc76 --- /dev/null +++ b/examples/moe_dualpipe/dualpipe/dualpipev.py @@ -0,0 +1,411 @@ +from typing import Tuple, List, Union, Callable, Optional + +import torch +import torch.nn as nn +import torch.distributed as dist + +import dualpipe.comm as comm +from dualpipe.utils import WeightGradStore, run_backward, scatter, gather + + +class DualPipeV(nn.Module): + def __init__( + self, + modules: Tuple[nn.Module, nn.Module], + batch_dim: int = 0, + process_group: Optional[dist.ProcessGroup] = None, + rank_mapping: Optional[List[int]] = None, + ) -> None: + super().__init__() + + assert next(modules[0].parameters()).device == torch.device(torch.cuda.current_device()) + self.module = nn.ModuleList(modules) + self.overlapped_forward_backward = type(modules[0]) == type(modules[1]) and hasattr(type(modules[0]), "overlapped_forward_backward") + self.batch_dim = batch_dim + self.group = process_group or dist.distributed_c10d._get_default_group() + self.num_ranks = self.group.size() + + # rank_mapping: Map rank in process_group to actual pp rank. + # rank_inverse_mapping: Map actual pp rank to rank in process_group. + if rank_mapping is None: + rank_mapping = list(range(self.num_ranks)) + rank_inverse_mapping = [None] * (self.num_ranks + 1) + for i in range(self.num_ranks): + rank_inverse_mapping[rank_mapping[i]] = i + + self.rank = rank_mapping[self.group.rank()] + self.prev_rank = rank_inverse_mapping[self.rank - 1] + self.next_rank = rank_inverse_mapping[self.rank + 1] + + self.is_first_rank = self.rank == 0 + self.is_last_rank = self.rank == self.num_ranks - 1 + + def _reset_states(self) -> None: + WeightGradStore.clear() + + self.input_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], []) + self.output_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], []) + self.input_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], []) + self.output_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], []) + self.labels: List[List[torch.Tensor]] = None + self.loss_chunks: List[torch.Tensor] = [] + self.criterion: Callable = None + + self.current_f_chunk_id: List[int] = [0, 0] + self.current_b_chunk_id: List[int] = [0, 0] + self.current_send_f_chunk_id: List[int] = [0, 0] + self.current_send_b_chunk_id: List[int] = [0, 0] + self.current_recv_f_chunk_id: List[int] = [0, 0] + self.current_recv_b_chunk_id: List[int] = [0, 0] + self.comm_ops: List[dist.P2POp] = [] + self.to_free: List[torch.Tensor] = [] + + def _forward_compute_chunk(self, phase: int) -> None: + chunk_id = self.current_f_chunk_id[phase] + self.current_f_chunk_id[phase] += 1 + inputs = self.input_chunks[phase][chunk_id] + if self.forward_only: + self.input_chunks[phase][chunk_id] = None + + is_last_stage = (self.is_first_rank and phase == 1) + + outputs = self.module[phase](*inputs) + outputs = [outputs] if isinstance(outputs, torch.Tensor) else outputs + if is_last_stage and self.criterion is not None: + labels = self.labels[chunk_id] + loss = self.criterion(*outputs, *labels) + self.loss_chunks.append(loss) + + if self.is_last_rank and phase == 0: + self.input_chunks[1].append([output.detach().requires_grad_() for output in outputs]) + if (not is_last_stage) or self.return_outputs: + self.output_chunks[phase].append(outputs) + + def _backward_compute_chunk(self, phase: int, enable_zb: bool = False) -> None: + if self.forward_only: + return + + chunk_id = self.current_b_chunk_id[phase] + self.current_b_chunk_id[phase] += 1 + + is_last_stage = (self.is_first_rank and phase == 1) + + WeightGradStore.enabled = enable_zb + if is_last_stage: + loss = self.loss_chunks[chunk_id] + loss.backward() + loss.detach_() + else: + outputs = self.output_chunks[phase][chunk_id] + if not self.return_outputs: + self.output_chunks[phase][chunk_id] = None + output_grads = self.output_grad_chunks[phase][chunk_id] + self.output_grad_chunks[phase][chunk_id] = None + non_empty = [(t, g) for t, g in zip(outputs, output_grads) if g is not None] + outputs, output_grads = list(zip(*non_empty)) + if len(outputs) > 0: + run_backward(outputs, output_grads) + WeightGradStore.enabled = False + if enable_zb: + WeightGradStore.flush() + + inputs = self.input_chunks[phase][chunk_id] + self.input_chunks[phase][chunk_id] = None + input_grads = [t.grad for t in inputs] + if self.is_last_rank and phase == 1: + self.output_grad_chunks[0].append(input_grads) + else: + self.input_grad_chunks[phase].append(input_grads) + + def _forward_backward_compute_chunk(self, phase0: int, phase1: int) -> None: + if self.forward_only: + self._forward_compute_chunk(phase0) + return + + if not self.overlapped_forward_backward: + self._forward_compute_chunk(phase0) + self._backward_compute_chunk(phase1) + return + + # pre-forward + chunk_id0 = self.current_f_chunk_id[phase0] + self.current_f_chunk_id[phase0] += 1 + module0 = self.module[phase0] + inputs0 = self.input_chunks[phase0][chunk_id0] + is_last_stage0 = (self.is_first_rank and phase0 == 1) + + if is_last_stage0 and self.criterion is not None: + labels0 = self.labels[chunk_id0] + criterion0 = self.criterion + else: + labels0 = [] + criterion0 = None + + # pre-backward + chunk_id1 = self.current_b_chunk_id[phase1] + self.current_b_chunk_id[phase1] += 1 + module1 = self.module[phase1] + is_last_stage1 = (self.is_first_rank and phase1 == 1) + + if is_last_stage1: + loss1 = self.loss_chunks[chunk_id1] + outputs1 = [] + output_grads1 = [] + else: + loss1 = None + outputs1 = self.output_chunks[phase1][chunk_id1] + if not self.return_outputs: + self.output_chunks[phase1][chunk_id1] = None + output_grads1 = self.output_grad_chunks[phase1][chunk_id1] + self.output_grad_chunks[phase1][chunk_id1] = None + non_empty = [(t, g) for t, g in zip(outputs1, output_grads1) if g is not None] + outputs1, output_grads1 = list(zip(*non_empty)) + + # forward & backward + outputs0, loss0 = type(module0).overlapped_forward_backward( + module0, inputs0, criterion0, labels0, + module1, loss1, outputs1, output_grads1, + ) + + # post-forward + if self.is_last_rank and phase0 == 0: + self.input_chunks[1].append([output.detach().requires_grad_() for output in outputs0]) + if (not is_last_stage0) or self.return_outputs: + self.output_chunks[phase0].append(outputs0) + if is_last_stage0 and self.criterion is not None: + self.loss_chunks.append(loss0) + + # post-backward + inputs = self.input_chunks[phase1][chunk_id1] + self.input_chunks[phase1][chunk_id1] = None + input_grads1 = [t.grad for t in inputs] + if self.is_last_rank and phase1 == 1: + self.output_grad_chunks[0].append(input_grads1) + else: + self.input_grad_chunks[phase1].append(input_grads1) + + def _forward_chunk(self, phase: int, recv: bool = True, send: bool = True) -> None: + if recv: + self._recv_forward(phase) + self._commit_and_wait_comm() + + self._forward_compute_chunk(phase) + + if send: + self._send_forward(phase) + + def _backward_chunk(self, phase: int, enable_zb: bool = False, recv: bool = True, send: bool = True) -> None: + if recv: + self._recv_backward(phase) + self._commit_and_wait_comm() + + self._backward_compute_chunk(phase, enable_zb) + + if send: + self._send_backward(phase) + + def _forward_backward_chunk(self, phase0: int, phase1: int, recv0: bool = True) -> None: + if recv0: + self._recv_forward(phase0) + self._recv_backward(phase1) + self._commit_and_wait_comm() + + self._forward_backward_compute_chunk(phase0, phase1) + + self._send_forward(phase0) + self._send_backward(phase1) + + def _weight_chunk(self) -> None: + if self.forward_only: + return + + self._commit_and_wait_comm() + + # Assume FIFO + WeightGradStore.pop() + + def _free_tensors(self) -> None: + for tensor in self.to_free: + assert tensor._base is None, f"pipeline stage should not return view tensors {dist.get_rank(), tensor.shape}" + tensor.data = torch.Tensor() + self.to_free = [] + + def _recv_forward(self, phase: int) -> None: + if (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1): + return + + self.current_recv_f_chunk_id[phase] += 1 + tensors = comm.append_irecv(self.comm_ops, self.prev_rank if phase == 0 else self.next_rank, self.group) + self.input_chunks[phase].append(tensors) + + def _send_forward(self, phase: int) -> None: + if (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0): + return + + chunk_id = self.current_send_f_chunk_id[phase] + self.current_send_f_chunk_id[phase] += 1 + tensors = self.output_chunks[phase][chunk_id] + + comm.append_isend(self.comm_ops, tensors, self.next_rank if phase == 0 else self.prev_rank, self.group) + + if not self.return_outputs: + self.to_free.extend(tensors) + + def _recv_backward(self, phase: int) -> None: + if self.forward_only: + return + + if (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0): + return + + self.current_recv_b_chunk_id[phase] += 1 + tensors = comm.append_irecv(self.comm_ops, self.next_rank if phase == 0 else self.prev_rank, self.group) + self.output_grad_chunks[phase].append(tensors) + + def _send_backward(self, phase: int) -> None: + if self.forward_only: + return + + if (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1): + return + + chunk_id = self.current_send_b_chunk_id[phase] + self.current_send_b_chunk_id[phase] += 1 + tensors = self.input_grad_chunks[phase][chunk_id] + self.input_grad_chunks[phase][chunk_id] = None + + comm.append_isend(self.comm_ops, tensors, self.prev_rank if phase == 0 else self.next_rank, self.group) + + def _commit_and_wait_comm(self) -> None: + if not self.comm_ops: + return + reqs = dist.batch_isend_irecv(self.comm_ops) + for req in reqs: + req.wait() + self.comm_ops = [] + self._free_tensors() + + def step( + self, + *inputs: Optional[torch.Tensor], + num_chunks: int = 0, + criterion: Optional[Callable] = None, + labels: List[Optional[torch.Tensor]] = [], + return_outputs: bool = False, + ) -> Tuple[Optional[torch.Tensor], Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]]: + """ + Execute a training or inference step. + + Arguments: + *inputs: Module inputs. Required only on the first rank. + num_chunks: The number of micro-batches. + criterion: Loss function, invoked as ``criterion(*outputs, *labels)``. Required only on the first rank. + labels: Labels of the loss function. Required only on the first rank. + return_outputs: Whether to return outputs on the first rank. Default: ``False``. + + Returns: (loss, outputs) + loss: Loss for the batch. Returned only on the first rank. + outputs: Module outputs. Returned only if ``return_outputs=True`` and on the first rank. + + """ + assert comm.TENSOR_SHAPES is not None and comm.TENSOR_DTYPE is not None, \ + "You need to call set_p2p_tensor_shapes and set_p2p_tensor_dtype before executing a step." + self.forward_only = not torch.is_grad_enabled() + self.return_outputs = return_outputs + + rank = self.rank + num_ranks = self.num_ranks + assert num_chunks > 0 and num_chunks >= num_ranks * 2, f"{num_chunks=}, {num_ranks=}" + + if not self.forward_only and self.is_first_rank: + assert criterion is not None + + self._reset_states() + + if self.is_first_rank: + self.input_chunks = (scatter(inputs, num_chunks, self.batch_dim), []) + self.labels = scatter(labels, num_chunks, self.batch_dim) + self.criterion = criterion + + # Step 1: nF0 + step_1 = (num_ranks - rank - 1) * 2 + for i in range(step_1): + self._forward_chunk(0) + + # Step 2: nF0F1 + step_2 = rank + 1 + self._recv_forward(0) + for i in range(step_2): + self._forward_chunk(0, recv=False, send=False) + self._recv_forward(0) + self._forward_chunk(1, send=(not self.is_last_rank) or (i < step_2 - 1)) + self._send_forward(0) + + # Step 3: nB1W1F1 (Use zero bubble) + step_3 = num_ranks - rank - 1 + for i in range(step_3): + self._backward_chunk(1, enable_zb=True) + self._recv_forward(1) + self._weight_chunk() + self._forward_chunk(1, recv=False) + + # Step 4 (Main step): nF0B1F1B0 + step_4 = num_chunks - num_ranks * 2 + rank + 1 + for i in range(step_4): + if i == 0: + if self.is_last_rank: + # NOTE: We don't overlap these two chunks to further reduce bubble size. + self._forward_chunk(0, recv=False, send=False) + self._send_forward(1) + self._backward_chunk(1, send=False) + self._send_forward(0) + self._send_backward(1) + else: + self._forward_backward_chunk(0, 1, recv0=False) + else: + self._forward_backward_chunk(0, 1) + self._forward_backward_chunk(1, 0) + + # Step 5: nB1F1B0 + step_5 = num_ranks - rank - 1 + for i in range(step_5): + self._backward_chunk(1) + self._forward_backward_chunk(1, 0) + + # Step 6: nB1B0 (The second half of the chunks use zero bubble) + step_6 = rank + 1 + enable_zb = False + for i in range(step_6): + if i == step_6 // 2 and rank % 2 == 1: + enable_zb = True + self._backward_chunk(1, enable_zb=enable_zb) + if i == step_6 // 2 and rank % 2 == 0: + enable_zb = True + self._backward_chunk(0, enable_zb=enable_zb) + + # Step 7: nWB0 (Use zero bubble) + step_7 = num_ranks - rank - 1 + for i in range(step_7): + self._weight_chunk() + self._backward_chunk(0, enable_zb=True) + + # Step 8: nW + step_8 = rank + 1 + for i in range(step_8): + self._weight_chunk() + assert WeightGradStore.funcs_queue.empty() + + self._commit_and_wait_comm() + + loss, outputs = None, None + if self.is_first_rank: + if criterion is not None: + loss = torch.stack(self.loss_chunks) + if return_outputs: + outputs = gather(self.output_chunks[1], self.batch_dim) + if len(outputs) == 1: + outputs = outputs[0] + + self._reset_states() + + return loss, outputs diff --git a/examples/moe_dualpipe/dualpipe/utils.py b/examples/moe_dualpipe/dualpipe/utils.py new file mode 100644 index 0000000..cefc52b --- /dev/null +++ b/examples/moe_dualpipe/dualpipe/utils.py @@ -0,0 +1,80 @@ +import queue +from typing import List, Callable + +import torch +from torch.autograd import Variable + + +class WeightGradStore: + + enabled: bool = False + cache: List[Callable] = [] + funcs_queue = queue.Queue() + + @classmethod + def put(cls, func: Callable) -> None: + cls.cache.append(func) + + @classmethod + def flush(cls) -> None: + cls.funcs_queue.put(cls.cache) + cls.cache = [] + + @classmethod + def pop(cls) -> None: + assert not cls.funcs_queue.empty(), "Pop empty queue." + funcs = cls.funcs_queue.get() + for func in funcs: + func() + + @classmethod + def clear(cls) -> None: + cls.cache = [] + cls.funcs_queue = queue.Queue() + + +def run_backward(tensors: List[torch.Tensor], grad_tensors: List[torch.Tensor]) -> None: + kwargs = dict( + keep_graph=False, + create_graph=False, + allow_unreachable=True, + accumulate_grad=True, + ) + Variable._execution_engine.run_backward(tensors, grad_tensors, **kwargs) + + +def chunk_tensor(x, chunks, dim): + if x is None: + return [None for _ in range(chunks)] + return x.tensor_split(chunks, dim=dim) + + +def cat_tensor(x, dim): + if (isinstance(x, tuple) or isinstance(x, list)): + if len(x) == 1: + return x[0] + elif x[0] is None: + assert all(y is None for y in x) + return None + return torch.cat(x, dim=dim) + + +def scatter(inputs, chunks, dim): + assert isinstance(inputs, (torch.Tensor, tuple, list)) + if isinstance(inputs, torch.Tensor): + inputs = (inputs,) + assert all(x is None or isinstance(x, torch.Tensor) for x in inputs) + inputs = [chunk_tensor(x, chunks, dim) for x in inputs] + microbatches = [microbatch for microbatch in zip(*inputs)] + if len(microbatches) == 0: + microbatches = [() for _ in range(chunks)] + return microbatches + + +def gather(micro_outputs, dim): + assert isinstance(micro_outputs[0], (torch.Tensor, tuple, list)) + if isinstance(micro_outputs[0], torch.Tensor): + micro_outputs = [(x,) for x in micro_outputs] + outputs = [x for x in zip(*micro_outputs)] + outputs = tuple(cat_tensor(x, dim=dim) for x in outputs) + return outputs diff --git a/examples/moe_dualpipe/examples/moe_train_basic.py b/examples/moe_dualpipe/examples/moe_train_basic.py new file mode 100644 index 0000000..410be59 --- /dev/null +++ b/examples/moe_dualpipe/examples/moe_train_basic.py @@ -0,0 +1,325 @@ +from typing import List, Optional, Callable, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +import os + +from dualpipe import set_p2p_tensor_shapes, set_p2p_tensor_dtype +from dualpipe.utils import run_backward + +class MoEConfig: + def __init__( + self, + hidden_size: int, + num_experts: int = 4, + top_k: int = 2, + aux_loss_coeff: float = 0.01, + z_loss_coeff: float = 0.0001 + ): + self.hidden_size = hidden_size + self.intermediate_size = hidden_size * 4 + self.num_experts = num_experts + self.top_k = top_k + self.moe_aux_loss_coeff = aux_loss_coeff + self.moe_z_loss_coeff = z_loss_coeff + self.initializer_range = 0.02 + +class MoEAuxLossAutoScaler(torch.autograd.Function): + # Reference: + # https://github.com/intelligent-machine-learning/atorch/blob/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/examples/moe/moe_modules.py#L206 + + main_loss_backward_scale: torch.Tensor = torch.tensor(1.0) + + @staticmethod + def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor): + ctx.save_for_backward(aux_loss) + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + (aux_loss,) = ctx.saved_tensors + aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale + scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale + return grad_output, scaled_aux_loss_grad + + @staticmethod + def set_loss_scale(scale: torch.Tensor): + MoEAuxLossAutoScaler.main_loss_backward_scale = scale + +class MoERouter(nn.Module): + def __init__(self, config: MoEConfig): + super().__init__() + self.classifier = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.num_experts = config.num_experts + self.top_k = config.top_k + self.moe_aux_loss_coeff = config.moe_aux_loss_coeff + self.moe_z_loss_coeff = config.moe_z_loss_coeff + + def apply_z_loss(self, logits): + if self.moe_z_loss_coeff > 0 and self.training: + z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * self.moe_z_loss_coeff + logits = MoEAuxLossAutoScaler.apply(logits, z_loss) + return logits + + def apply_load_balancing_loss(self, router_probs, tokens_per_expert): + if self.moe_aux_loss_coeff > 0 and self.training: + # Calculate load for each expert + num_tokens = router_probs.shape[0] + expert_load = router_probs.sum(0) / num_tokens + # Ideal load + target_load = torch.ones_like(expert_load) / self.num_experts + aux_loss = F.mse_loss(expert_load, target_load) * self.moe_aux_loss_coeff + router_probs = MoEAuxLossAutoScaler.apply(router_probs, aux_loss) + return router_probs + + def forward(self, x: torch.Tensor): + router_logits = self.classifier(x) + router_logits = self.apply_z_loss(router_logits) + + router_probs = F.softmax(router_logits, dim=-1) + + top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1) + + tokens_per_expert = torch.zeros(self.num_experts, device=router_probs.device) + for k in range(self.top_k): + tokens_per_expert.scatter_add_(0, top_k_indices[..., k].view(-1), + torch.ones_like(top_k_indices[..., k].view(-1), dtype=torch.float)) + + + top_k_probs = self.apply_load_balancing_loss(router_probs, tokens_per_expert) + + return top_k_probs, top_k_indices + +class ExpertLinear(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=True) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate = F.silu(self.gate_proj(x)) + up = self.up_proj(x) + activated = gate * up + return self.down_proj(activated) + +class MoEPipelineStage(nn.Module): + # Reference: + # https://github.com/deepseek-ai/DualPipe/blob/3da1bbea53606543d7f5f232338fc58096db30e3/examples/example_dualpipe.py#L55C9-L55C36 + def __init__(self, config: MoEConfig): + super().__init__() + self.config = config + self.router = MoERouter(config) + self.experts = nn.ModuleList([ + ExpertLinear(config.hidden_size, config.intermediate_size) + for _ in range(config.num_experts) + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if isinstance(x, tuple): + x = x[0] + + batch_size, seq_len, hidden_size = x.shape + + router_probs, expert_indices = self.router(x) + + flat_x = x.view(-1, hidden_size) + + combined_output = torch.zeros_like(flat_x) + for k in range(self.config.top_k): + expert_index = expert_indices[..., k] + prob = router_probs[..., k] + + flat_expert_index = expert_index.view(-1) + flat_prob = prob.view(-1) + + for i in range(self.config.num_experts): + mask = (flat_expert_index == i) + if mask.any(): + expert_input = flat_x[mask] + expert_output = self.experts[i](expert_input) + combined_output[mask] += expert_output * flat_prob[mask].unsqueeze(-1) + + output = combined_output.view(batch_size, seq_len, hidden_size) + return output, router_probs.new_zeros(1) + + @classmethod + def overlapped_forward_backward( + cls, + module0: "MoEPipelineStage", + inputs0: List[torch.Tensor], + criterion0: Optional[Callable], + labels0: Optional[List[torch.Tensor]], + module1: "MoEPipelineStage", + loss1: Optional[torch.Tensor], + outputs1: Optional[List[torch.Tensor]], + output_grads1: Optional[List[torch.Tensor]], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + + outputs0, router_loss0 = module0(inputs0[0] if isinstance(inputs0, (list, tuple)) else inputs0) + outputs0 = [outputs0] if isinstance(outputs0, torch.Tensor) else outputs0 + + + if criterion0 is not None: + task_loss0 = criterion0(*outputs0, *labels0) + loss0 = task_loss0 + router_loss0 + else: + loss0 = None + + + if loss1 is not None: + loss1.backward() + loss1.detach_() + else: + run_backward(outputs1, output_grads1) + + return outputs0, loss0 + +def moe_criterion(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + + return F.mse_loss(output, target) + +def ref_step(x, l, model, chunks): + ys, losses = [], [] + for micro_x, micro_l in zip(x.chunk(chunks), l.chunk(chunks)): + + micro_y, router_loss = model(micro_x) + task_loss = moe_criterion(micro_y, micro_l) + loss = task_loss + router_loss + loss.backward() + ys.append(micro_y) + losses.append(loss) + y = torch.cat(ys, 0) + loss = torch.stack(losses) + return loss, y + + +def cal_diff(x: torch.Tensor, y: torch.Tensor) -> float: + x, y = x.double(), y.double() + cos_diff = 1 - 2 * (x * y).sum().item() / (x * x + y * y).sum().item() + return cos_diff + +def main(rank, pp_size): + + is_first_rank = rank == 0 + is_last_rank = rank == pp_size - 1 + dist.init_process_group(backend='nccl', init_method="env://", world_size=pp_size, rank=rank) + torch.cuda.set_device(rank) + torch.set_default_device(f"cuda:{rank}") + torch.manual_seed(233) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + + num_chunks = 20 + micro_batch_size = 3 + seq_len = 256 + hidden_size = 512 + num_experts = 4 + top_k = 2 + + + moe_config = MoEConfig( + hidden_size=hidden_size, + num_experts=num_experts, + top_k=top_k, + aux_loss_coeff=0.01, + z_loss_coeff=0.0001 + ) + + + if is_first_rank: + print(f"{pp_size=}, {num_chunks=}, {seq_len=}, {hidden_size=}, {num_experts=}, {top_k=}", flush=True) + set_p2p_tensor_shapes([(micro_batch_size, seq_len, hidden_size)]) + set_p2p_tensor_dtype(torch.float32) + + + full_modules = nn.Sequential(*[MoEPipelineStage(moe_config) for _ in range(pp_size)]) + full_x = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size) + full_l = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size) + + + loss_ref, output_ref = ref_step(full_x, full_l, full_modules, num_chunks) + + # Add comparison after multi-machine training + if is_first_rank: + # Create single-machine model for comparison + single_model = MoEPipelineStage(moe_config) + single_model.load_state_dict(full_modules[0].state_dict()) + + # Use the same input data - take first micro batch + single_output, single_router_loss = single_model(full_x[:3]) # 3 samples, consistent with micro_batch_size + single_task_loss = moe_criterion(single_output, full_l[:3]) + single_loss = single_task_loss + single_router_loss + + # Compare results + print(f"Multi-machine loss: {loss_ref[0].item():.6f}") + print(f"Single-machine loss: {single_loss.item():.6f}") + print(f"Loss difference: {abs(loss_ref[0].item() - single_loss.item()):.6f}") + + # Compare outputs - take first micro batch from multi-machine output + multi_output = output_ref[:3] # # Take first 3 samples, matching single_output shape + output_diff = torch.abs(multi_output - single_output) + max_diff = torch.max(output_diff).item() + mean_diff = torch.mean(output_diff).item() + print(f"Output max difference: {max_diff:.6f}") + print(f"Output mean difference: {mean_diff:.6f}") + print(f"Output shapes - Multi: {multi_output.shape}, Single: {single_output.shape}") + + + + +def test_multi_gpu(ngpus): + torch.multiprocessing.spawn(main, args=(ngpus,), nprocs=ngpus, daemon=True) + +def test_single_gpu(): + + hidden_size = 512 + batch_size = 4 + seq_len = 32 + + + config = MoEConfig( + hidden_size=hidden_size, + num_experts=4, + top_k=2, + aux_loss_coeff=0.01, + z_loss_coeff=0.0001 + ) + + + model = MoEPipelineStage(config) + model.cuda() + + x = torch.randn(batch_size, seq_len, hidden_size).cuda() + target = torch.randn(batch_size, seq_len, hidden_size).cuda() + + + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + + # Train for a few steps + print("Starting test training...") + for step in range(5): + optimizer.zero_grad() + output, router_loss = model(x) + task_loss = F.mse_loss(output, target) + loss = task_loss + router_loss + loss.backward() + optimizer.step() + print(f"Step {step}, Loss: {loss.item():.4f}, Task Loss: {task_loss.item():.4f}, Router Loss: {router_loss.item():.4f}") + + print("Basic functionality test completed!") + + + +if __name__ == "__main__": + # Run basic test first + # print("Running basic MoE model test...") + # test_single_gpu() + + # 获取可用GPU数量并运行测试 + num_gpus = torch.cuda.device_count() // 2 * 2 + print(f"Testing with {num_gpus} GPUs") + test_multi_gpu(num_gpus) + print(f"Test passed with {num_gpus} GPUs") +