-
Notifications
You must be signed in to change notification settings - Fork 432
[MoE][PoC] model code #730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/tianyu-l/24/base
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,13 @@ class ModelArgs: | |
depth_init: bool = True | ||
norm_type: str = "rmsnorm" | ||
|
||
# MoE args | ||
enable_moe: bool = True | ||
num_experts: int = 8 | ||
capacity_factor: float = 1.0 | ||
use_shared_expert: bool = True | ||
auto_scale_hidden_dim: bool = True | ||
|
||
|
||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: | ||
""" | ||
|
@@ -283,12 +290,55 @@ def __init__(self, layer_id: int, model_args: ModelArgs): | |
self.n_heads = model_args.n_heads | ||
self.dim = model_args.dim | ||
self.attention = Attention(model_args) | ||
self.feed_forward = FeedForward( | ||
dim=model_args.dim, | ||
hidden_dim=4 * model_args.dim, | ||
multiple_of=model_args.multiple_of, | ||
ffn_dim_multiplier=model_args.ffn_dim_multiplier, | ||
) | ||
self.enable_moe = model_args.enable_moe | ||
|
||
if not self.enable_moe: | ||
self.feed_forward = FeedForward( | ||
dim=model_args.dim, | ||
hidden_dim=4 * model_args.dim, | ||
multiple_of=model_args.multiple_of, | ||
ffn_dim_multiplier=model_args.ffn_dim_multiplier, | ||
) | ||
else: | ||
from torchtitan.models.llama.moe_layer import ( | ||
ExpertChoiceTopKRouter, | ||
GroupedExperts, | ||
MoE, | ||
) | ||
|
||
hidden_dim_denom = 1 | ||
if model_args.auto_scale_hidden_dim: | ||
hidden_dim_denom = model_args.capacity_factor + int( | ||
model_args.use_shared_expert | ||
) | ||
|
||
dim = model_args.dim | ||
hidden_dim = 4 * model_args.dim | ||
hidden_dim = int(2 * hidden_dim / 3) | ||
if model_args.ffn_dim_multiplier is not None: | ||
hidden_dim = int(model_args.ffn_dim_multiplier * hidden_dim) | ||
if model_args.auto_scale_hidden_dim: | ||
hidden_dim = int(hidden_dim / hidden_dim_denom) | ||
hidden_dim += -hidden_dim % model_args.multiple_of | ||
|
||
num_experts = model_args.num_experts | ||
self.moe = MoE( | ||
experts=GroupedExperts( | ||
dim_in=dim, dim_out=hidden_dim, num_experts=num_experts | ||
), | ||
router=ExpertChoiceTopKRouter( | ||
gate=nn.Linear(dim, num_experts, bias=False), | ||
dim=dim, | ||
num_experts=num_experts, | ||
capacity_factor=model_args.capacity_factor, | ||
), | ||
shared_expert=( | ||
GroupedExperts(dim_in=dim, dim_out=hidden_dim, num_experts=1) | ||
if model_args.use_shared_expert | ||
else None | ||
), | ||
) | ||
|
||
self.layer_id = layer_id | ||
self.num_layers = model_args.n_layers | ||
|
||
|
@@ -321,14 +371,20 @@ def forward( | |
|
||
""" | ||
h = x + self.attention(self.attention_norm(x), freqs_cis) | ||
out = h + self.feed_forward(self.ffn_norm(h)) | ||
if not self.enable_moe: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wrt to above, this is what I mean about implying it's dense but not really making it clear....is_dense_model is more precise. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I modified it to"
in my copy of your PR..mostly showing how it reads with the is_blah_model approach. |
||
out = h + self.feed_forward(self.ffn_norm(h)) | ||
else: | ||
out = h + self.moe(self.ffn_norm(h)) | ||
return out | ||
|
||
def init_weights(self): | ||
for norm in (self.attention_norm, self.ffn_norm): | ||
norm.reset_parameters() | ||
self.attention.init_weights(self.weight_init_std) | ||
self.feed_forward.init_weights(self.weight_init_std) | ||
if not self.enable_moe: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above |
||
self.feed_forward.init_weights(self.weight_init_std) | ||
else: | ||
self.moe.init_weights(self.weight_init_std) | ||
|
||
|
||
class Transformer(nn.Module): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Callable, Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
|
||
class GroupedExperts(nn.Module): | ||
"""This class implements the grouped experts layer used in Mixture of Experts. Each expert | ||
is a variant of the Gated Linear Units network. See more details in https://arxiv.org/pdf/2002.05202. | ||
|
||
Args: | ||
dim_in (int): Input dimension. | ||
dim_out (int): Output dimension. | ||
num_experts (int): Number of experts in this grouped experts layer. Default is 1. | ||
swiglu (bool): Whether to use gated linear unit. Default is True. | ||
activation (nn.Module): Activation function to use. Default is F.silu. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
dim_in: int, | ||
dim_out: int, | ||
num_experts: int = 1, | ||
swiglu: bool = True, | ||
activation: Callable = F.silu, | ||
): | ||
super().__init__() | ||
self.dim_in = dim_in | ||
self.num_experts = num_experts | ||
self.gate_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out)) | ||
self.down_proj = nn.Parameter(torch.empty(num_experts, dim_out, dim_in)) | ||
if swiglu: | ||
self.up_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out)) | ||
self.act_fn = F.silu | ||
else: | ||
self.up_proj = None | ||
self.act_fn = activation | ||
|
||
def forward( | ||
self, | ||
x: torch.Tensor, | ||
) -> torch.Tensor: | ||
""" | ||
Args: | ||
x (torch.Tensor): with shape (num_experts, tokens_per_expert, dim_in) for Expert Choice(EC). | ||
|
||
Returns: | ||
torch.Tensor: with shape (num_experts, tokens_per_expert, dim_in) for Expert Choice(EC). | ||
""" | ||
# Expert Choice(EC) forward | ||
# x shape (num_experts, tokens_per_expert, dim_in) | ||
h = self.act_fn(torch.bmm(x, self.gate_proj)) | ||
if self.up_proj is not None: | ||
h = h * torch.bmm(x, self.up_proj) | ||
# out shape (num_experts, tokens_per_expert, dim_out) | ||
out = torch.bmm(h, self.down_proj) | ||
return out | ||
|
||
def init_weights(self, init_std: float): | ||
nn.init.trunc_normal_(self.gate_proj, mean=0.0, std=0.02) | ||
if self.up_proj is not None: | ||
nn.init.trunc_normal_(self.up_proj, mean=0.0, std=init_std) | ||
nn.init.trunc_normal_(self.down_proj, mean=0.0, std=init_std) | ||
|
||
|
||
class ExpertChoiceTopKRouter(nn.Module): | ||
"""This class implements experts choice routing. Each experts will select it's top K tokens based on | ||
the router scores. Refer to more details in https://arxiv.org/abs/2202.09368 | ||
|
||
Args: | ||
gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts). | ||
dim (int): Dimension of input tokens. | ||
num_experts (int): Number of experts in each moe layer. | ||
capacity_factor (float): Capacity factor determines how many tokens each expert can choose. | ||
expert capacity = (number of tokens * capacity factor) / number of experts. | ||
use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
gate: nn.Module, | ||
dim: int, | ||
num_experts: int, | ||
capacity_factor: float, | ||
use_sigmoid: bool = True, | ||
): | ||
super().__init__() | ||
self.gate = gate | ||
self.dim = dim | ||
self.num_experts = num_experts | ||
self.capacity_factor = capacity_factor | ||
self.use_sigmoid = use_sigmoid | ||
|
||
def forward(self, x: torch.Tensor): | ||
""" | ||
Args: | ||
x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. | ||
|
||
Returns: | ||
routed_input (torch.Tensor): input tokens grouped together by experts indices with shape | ||
``(num_experts*tokens_per_expert, dim)``. | ||
token_indices (torch.Tensor): token indices for routed_input. Shape ``(num_experts*tokens_per_expert,)``. | ||
""" | ||
# scores shape (num_experts, bs*slen) | ||
scores = self.gate(x).transpose(0, 1) | ||
# By default, we perform sigmoid and softmax in float32 to avoid loss explosion. | ||
if self.use_sigmoid: | ||
scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype) | ||
else: | ||
scores = F.softmax(scores.to(torch.float32), dim=0).to(x.dtype) | ||
tokens_per_expert = int(x.shape[0] * self.capacity_factor / self.num_experts) | ||
tokens_per_expert += -tokens_per_expert % 8 | ||
# Take the smaller of tokens_per_expert and the number of tokens | ||
tokens_per_expert = min(tokens_per_expert, x.shape[0]) | ||
# top_scores shape (num_experts, tokens_per_expert) | ||
top_scores, selected_token_indices = torch.topk( | ||
scores, k=tokens_per_expert, dim=1 | ||
) | ||
|
||
return top_scores, selected_token_indices | ||
|
||
def init_weights(self, init_std: float): | ||
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) | ||
|
||
|
||
class MoE(nn.Module): | ||
"""This class implements the moe layer which is Mixture of Experts. Mixture of Experts | ||
typically consists of a set of expert networks, alongside with a router, which directs input tokens | ||
to the appropriate experts. See more details in https://arxiv.org/pdf/2407.06204. | ||
|
||
Args: | ||
experts (nn.Module): experts module. | ||
router (nn.Module): router module. | ||
shared_expert (Optional[nn.Module]): shared expert module. Default is None. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
experts: nn.Module, | ||
router: nn.Module, | ||
shared_expert: Optional[nn.Module] = None, | ||
): | ||
super().__init__() | ||
self.experts = experts | ||
self.router = router | ||
self.shared_expert = shared_expert | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Args: | ||
x (torch.Tensor): Input tensor with shape ``(bz, slen, dim)``. | ||
|
||
Returns: | ||
out (torch.Tensor): Output tensor with shape ``(bz, slen, dim)``. | ||
""" | ||
bz, slen, dim = x.shape | ||
|
||
# routed_input shape (num_experts*tokens_per_expert, dim) for EC | ||
x = x.reshape(bz * slen, dim) | ||
top_scores, selected_token_indices = self.router(x) | ||
num_experts, _ = top_scores.shape | ||
|
||
# token_indices shape (num_experts*tokens_per_expert, dim) | ||
token_indices = selected_token_indices.reshape(-1, 1).expand(-1, dim) | ||
# routed_input shape (num_experts*tokens_per_expert, dim) | ||
routed_input = torch.gather(x, dim=0, index=token_indices) | ||
routed_input = routed_input * top_scores.reshape(-1, 1) | ||
|
||
# routed_input shape (num_experts, tokens_per_expert, dim_in) | ||
routed_input = routed_input.reshape(num_experts, -1, dim) | ||
# routed_output shape (num_experts, tokens_per_expert, dim_out) | ||
routed_output = self.experts(routed_input) | ||
# routed_output shape (num_experts*tokens_per_expert, dim_out) | ||
routed_output = routed_output.reshape(-1, dim) | ||
|
||
# shared expert | ||
if self.shared_expert is not None: | ||
out = self.shared_expert(x.reshape(1, bz * slen, dim)).reshape( | ||
bz * slen, dim | ||
) | ||
else: | ||
out = torch.zeros_like(x.reshape(bz * slen, dim)) | ||
|
||
# add experts output | ||
# doing in in place might be faster | ||
out = out.scatter_add(dim=0, index=token_indices, src=routed_output) | ||
out = out.reshape(bz, slen, dim) | ||
return out | ||
|
||
def init_weights(self, init_std: float): | ||
self.experts.init_weights(init_std) | ||
self.router.init_weights(init_std) | ||
if self.shared_expert is not None: | ||
self.shared_expert.init_weights(init_std) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor nit - I think this is expressed better as self.is_dense_model and self.is_moe_model.
'Enable' is not the same thing as 'is' so I think cleaner to express it's state of being via 'is'.
In addition, I think the later checks is cleaner expressed as self.is_dense_model/layer makes for easier to read the checks vs 'not self.enable_moe', which may not later really mean it's a dense model/layer if we start enabling other arches.