From 45679da64a7cb8f6c2daeb7e2895f68db19030ac Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 11 Dec 2024 16:44:54 -0800 Subject: [PATCH] [MoE][PoC] model code ghstack-source-id: 43510c8443b3fcd2261ec6175216db0453c47bac Pull Request resolved: https://github.com/pytorch/torchtitan/pull/730 --- torchtitan/models/llama/model.py | 72 ++++++++-- torchtitan/models/llama/moe_layer.py | 204 +++++++++++++++++++++++++++ 2 files changed, 268 insertions(+), 8 deletions(-) create mode 100644 torchtitan/models/llama/moe_layer.py diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 641ef6de9..c1413700d 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -34,6 +34,13 @@ class ModelArgs: depth_init: bool = True norm_type: str = "rmsnorm" + # MoE args + enable_moe: bool = True + num_experts: int = 8 + capacity_factor: float = 1.0 + use_shared_expert: bool = True + auto_scale_hidden_dim: bool = True + def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: """ @@ -283,12 +290,55 @@ def __init__(self, layer_id: int, model_args: ModelArgs): self.n_heads = model_args.n_heads self.dim = model_args.dim self.attention = Attention(model_args) - self.feed_forward = FeedForward( - dim=model_args.dim, - hidden_dim=4 * model_args.dim, - multiple_of=model_args.multiple_of, - ffn_dim_multiplier=model_args.ffn_dim_multiplier, - ) + self.enable_moe = model_args.enable_moe + + if not self.enable_moe: + self.feed_forward = FeedForward( + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + else: + from torchtitan.models.llama.moe_layer import ( + ExpertChoiceTopKRouter, + GroupedExperts, + MoE, + ) + + hidden_dim_denom = 1 + if model_args.auto_scale_hidden_dim: + hidden_dim_denom = model_args.capacity_factor + int( + model_args.use_shared_expert + ) + + dim = model_args.dim + hidden_dim = 4 * model_args.dim + hidden_dim = int(2 * hidden_dim / 3) + if model_args.ffn_dim_multiplier is not None: + hidden_dim = int(model_args.ffn_dim_multiplier * hidden_dim) + if model_args.auto_scale_hidden_dim: + hidden_dim = int(hidden_dim / hidden_dim_denom) + hidden_dim += -hidden_dim % model_args.multiple_of + + num_experts = model_args.num_experts + self.moe = MoE( + experts=GroupedExperts( + dim_in=dim, dim_out=hidden_dim, num_experts=num_experts + ), + router=ExpertChoiceTopKRouter( + gate=nn.Linear(dim, num_experts, bias=False), + dim=dim, + num_experts=num_experts, + capacity_factor=model_args.capacity_factor, + ), + shared_expert=( + GroupedExperts(dim_in=dim, dim_out=hidden_dim, num_experts=1) + if model_args.use_shared_expert + else None + ), + ) + self.layer_id = layer_id self.num_layers = model_args.n_layers @@ -321,14 +371,20 @@ def forward( """ h = x + self.attention(self.attention_norm(x), freqs_cis) - out = h + self.feed_forward(self.ffn_norm(h)) + if not self.enable_moe: + out = h + self.feed_forward(self.ffn_norm(h)) + else: + out = h + self.moe(self.ffn_norm(h)) return out def init_weights(self): for norm in (self.attention_norm, self.ffn_norm): norm.reset_parameters() self.attention.init_weights(self.weight_init_std) - self.feed_forward.init_weights(self.weight_init_std) + if not self.enable_moe: + self.feed_forward.init_weights(self.weight_init_std) + else: + self.moe.init_weights(self.weight_init_std) class Transformer(nn.Module): diff --git a/torchtitan/models/llama/moe_layer.py b/torchtitan/models/llama/moe_layer.py new file mode 100644 index 000000000..b44bc6391 --- /dev/null +++ b/torchtitan/models/llama/moe_layer.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +import torch +import torch.nn.functional as F +from torch import nn + + +class GroupedExperts(nn.Module): + """This class implements the grouped experts layer used in Mixture of Experts. Each expert + is a variant of the Gated Linear Units network. See more details in https://arxiv.org/pdf/2002.05202. + + Args: + dim_in (int): Input dimension. + dim_out (int): Output dimension. + num_experts (int): Number of experts in this grouped experts layer. Default is 1. + swiglu (bool): Whether to use gated linear unit. Default is True. + activation (nn.Module): Activation function to use. Default is F.silu. + """ + + def __init__( + self, + *, + dim_in: int, + dim_out: int, + num_experts: int = 1, + swiglu: bool = True, + activation: Callable = F.silu, + ): + super().__init__() + self.dim_in = dim_in + self.num_experts = num_experts + self.gate_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out)) + self.down_proj = nn.Parameter(torch.empty(num_experts, dim_out, dim_in)) + if swiglu: + self.up_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out)) + self.act_fn = F.silu + else: + self.up_proj = None + self.act_fn = activation + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): with shape (num_experts, tokens_per_expert, dim_in) for Expert Choice(EC). + + Returns: + torch.Tensor: with shape (num_experts, tokens_per_expert, dim_in) for Expert Choice(EC). + """ + # Expert Choice(EC) forward + # x shape (num_experts, tokens_per_expert, dim_in) + h = self.act_fn(torch.bmm(x, self.gate_proj)) + if self.up_proj is not None: + h = h * torch.bmm(x, self.up_proj) + # out shape (num_experts, tokens_per_expert, dim_out) + out = torch.bmm(h, self.down_proj) + return out + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.gate_proj, mean=0.0, std=0.02) + if self.up_proj is not None: + nn.init.trunc_normal_(self.up_proj, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.down_proj, mean=0.0, std=init_std) + + +class ExpertChoiceTopKRouter(nn.Module): + """This class implements experts choice routing. Each experts will select it's top K tokens based on + the router scores. Refer to more details in https://arxiv.org/abs/2202.09368 + + Args: + gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts). + dim (int): Dimension of input tokens. + num_experts (int): Number of experts in each moe layer. + capacity_factor (float): Capacity factor determines how many tokens each expert can choose. + expert capacity = (number of tokens * capacity factor) / number of experts. + use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False. + """ + + def __init__( + self, + *, + gate: nn.Module, + dim: int, + num_experts: int, + capacity_factor: float, + use_sigmoid: bool = True, + ): + super().__init__() + self.gate = gate + self.dim = dim + self.num_experts = num_experts + self.capacity_factor = capacity_factor + self.use_sigmoid = use_sigmoid + + def forward(self, x: torch.Tensor): + """ + Args: + x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. + + Returns: + routed_input (torch.Tensor): input tokens grouped together by experts indices with shape + ``(num_experts*tokens_per_expert, dim)``. + token_indices (torch.Tensor): token indices for routed_input. Shape ``(num_experts*tokens_per_expert,)``. + """ + # scores shape (num_experts, bs*slen) + scores = self.gate(x).transpose(0, 1) + # By default, we perform sigmoid and softmax in float32 to avoid loss explosion. + if self.use_sigmoid: + scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype) + else: + scores = F.softmax(scores.to(torch.float32), dim=0).to(x.dtype) + tokens_per_expert = int(x.shape[0] * self.capacity_factor / self.num_experts) + tokens_per_expert += -tokens_per_expert % 8 + # Take the smaller of tokens_per_expert and the number of tokens + tokens_per_expert = min(tokens_per_expert, x.shape[0]) + # top_scores shape (num_experts, tokens_per_expert) + top_scores, selected_token_indices = torch.topk( + scores, k=tokens_per_expert, dim=1 + ) + + return top_scores, selected_token_indices + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) + + +class MoE(nn.Module): + """This class implements the moe layer which is Mixture of Experts. Mixture of Experts + typically consists of a set of expert networks, alongside with a router, which directs input tokens + to the appropriate experts. See more details in https://arxiv.org/pdf/2407.06204. + + Args: + experts (nn.Module): experts module. + router (nn.Module): router module. + shared_expert (Optional[nn.Module]): shared expert module. Default is None. + """ + + def __init__( + self, + *, + experts: nn.Module, + router: nn.Module, + shared_expert: Optional[nn.Module] = None, + ): + super().__init__() + self.experts = experts + self.router = router + self.shared_expert = shared_expert + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor with shape ``(bz, slen, dim)``. + + Returns: + out (torch.Tensor): Output tensor with shape ``(bz, slen, dim)``. + """ + bz, slen, dim = x.shape + + # routed_input shape (num_experts*tokens_per_expert, dim) for EC + x = x.reshape(bz * slen, dim) + top_scores, selected_token_indices = self.router(x) + num_experts, _ = top_scores.shape + + # token_indices shape (num_experts*tokens_per_expert, dim) + token_indices = selected_token_indices.reshape(-1, 1).expand(-1, dim) + # routed_input shape (num_experts*tokens_per_expert, dim) + routed_input = torch.gather(x, dim=0, index=token_indices) + routed_input = routed_input * top_scores.reshape(-1, 1) + + # routed_input shape (num_experts, tokens_per_expert, dim_in) + routed_input = routed_input.reshape(num_experts, -1, dim) + # routed_output shape (num_experts, tokens_per_expert, dim_out) + routed_output = self.experts(routed_input) + # routed_output shape (num_experts*tokens_per_expert, dim_out) + routed_output = routed_output.reshape(-1, dim) + + # shared expert + if self.shared_expert is not None: + out = self.shared_expert(x.reshape(1, bz * slen, dim)).reshape( + bz * slen, dim + ) + else: + out = torch.zeros_like(x.reshape(bz * slen, dim)) + + # add experts output + # doing in in place might be faster + out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + out = out.reshape(bz, slen, dim) + return out + + def init_weights(self, init_std: float): + self.experts.init_weights(init_std) + self.router.init_weights(init_std) + if self.shared_expert is not None: + self.shared_expert.init_weights(init_std)