Skip to content

Commit

Permalink
[MoE][PoC] model code
Browse files Browse the repository at this point in the history
ghstack-source-id: 43510c8443b3fcd2261ec6175216db0453c47bac
Pull Request resolved: #730
  • Loading branch information
tianyu-l committed Dec 12, 2024
1 parent 0186284 commit 45679da
Show file tree
Hide file tree
Showing 2 changed files with 268 additions and 8 deletions.
72 changes: 64 additions & 8 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ class ModelArgs:
depth_init: bool = True
norm_type: str = "rmsnorm"

# MoE args
enable_moe: bool = True
num_experts: int = 8
capacity_factor: float = 1.0
use_shared_expert: bool = True
auto_scale_hidden_dim: bool = True


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
"""
Expand Down Expand Up @@ -283,12 +290,55 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
self.n_heads = model_args.n_heads
self.dim = model_args.dim
self.attention = Attention(model_args)
self.feed_forward = FeedForward(
dim=model_args.dim,
hidden_dim=4 * model_args.dim,
multiple_of=model_args.multiple_of,
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
)
self.enable_moe = model_args.enable_moe

if not self.enable_moe:
self.feed_forward = FeedForward(
dim=model_args.dim,
hidden_dim=4 * model_args.dim,
multiple_of=model_args.multiple_of,
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
)
else:
from torchtitan.models.llama.moe_layer import (
ExpertChoiceTopKRouter,
GroupedExperts,
MoE,
)

hidden_dim_denom = 1
if model_args.auto_scale_hidden_dim:
hidden_dim_denom = model_args.capacity_factor + int(
model_args.use_shared_expert
)

dim = model_args.dim
hidden_dim = 4 * model_args.dim
hidden_dim = int(2 * hidden_dim / 3)
if model_args.ffn_dim_multiplier is not None:
hidden_dim = int(model_args.ffn_dim_multiplier * hidden_dim)
if model_args.auto_scale_hidden_dim:
hidden_dim = int(hidden_dim / hidden_dim_denom)
hidden_dim += -hidden_dim % model_args.multiple_of

num_experts = model_args.num_experts
self.moe = MoE(
experts=GroupedExperts(
dim_in=dim, dim_out=hidden_dim, num_experts=num_experts
),
router=ExpertChoiceTopKRouter(
gate=nn.Linear(dim, num_experts, bias=False),
dim=dim,
num_experts=num_experts,
capacity_factor=model_args.capacity_factor,
),
shared_expert=(
GroupedExperts(dim_in=dim, dim_out=hidden_dim, num_experts=1)
if model_args.use_shared_expert
else None
),
)

self.layer_id = layer_id
self.num_layers = model_args.n_layers

Expand Down Expand Up @@ -321,14 +371,20 @@ def forward(
"""
h = x + self.attention(self.attention_norm(x), freqs_cis)
out = h + self.feed_forward(self.ffn_norm(h))
if not self.enable_moe:
out = h + self.feed_forward(self.ffn_norm(h))
else:
out = h + self.moe(self.ffn_norm(h))
return out

def init_weights(self):
for norm in (self.attention_norm, self.ffn_norm):
norm.reset_parameters()
self.attention.init_weights(self.weight_init_std)
self.feed_forward.init_weights(self.weight_init_std)
if not self.enable_moe:
self.feed_forward.init_weights(self.weight_init_std)
else:
self.moe.init_weights(self.weight_init_std)


class Transformer(nn.Module):
Expand Down
204 changes: 204 additions & 0 deletions torchtitan/models/llama/moe_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Optional

import torch
import torch.nn.functional as F
from torch import nn


class GroupedExperts(nn.Module):
"""This class implements the grouped experts layer used in Mixture of Experts. Each expert
is a variant of the Gated Linear Units network. See more details in https://arxiv.org/pdf/2002.05202.
Args:
dim_in (int): Input dimension.
dim_out (int): Output dimension.
num_experts (int): Number of experts in this grouped experts layer. Default is 1.
swiglu (bool): Whether to use gated linear unit. Default is True.
activation (nn.Module): Activation function to use. Default is F.silu.
"""

def __init__(
self,
*,
dim_in: int,
dim_out: int,
num_experts: int = 1,
swiglu: bool = True,
activation: Callable = F.silu,
):
super().__init__()
self.dim_in = dim_in
self.num_experts = num_experts
self.gate_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out))
self.down_proj = nn.Parameter(torch.empty(num_experts, dim_out, dim_in))
if swiglu:
self.up_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out))
self.act_fn = F.silu
else:
self.up_proj = None
self.act_fn = activation

def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x (torch.Tensor): with shape (num_experts, tokens_per_expert, dim_in) for Expert Choice(EC).
Returns:
torch.Tensor: with shape (num_experts, tokens_per_expert, dim_in) for Expert Choice(EC).
"""
# Expert Choice(EC) forward
# x shape (num_experts, tokens_per_expert, dim_in)
h = self.act_fn(torch.bmm(x, self.gate_proj))
if self.up_proj is not None:
h = h * torch.bmm(x, self.up_proj)
# out shape (num_experts, tokens_per_expert, dim_out)
out = torch.bmm(h, self.down_proj)
return out

def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.gate_proj, mean=0.0, std=0.02)
if self.up_proj is not None:
nn.init.trunc_normal_(self.up_proj, mean=0.0, std=init_std)
nn.init.trunc_normal_(self.down_proj, mean=0.0, std=init_std)


class ExpertChoiceTopKRouter(nn.Module):
"""This class implements experts choice routing. Each experts will select it's top K tokens based on
the router scores. Refer to more details in https://arxiv.org/abs/2202.09368
Args:
gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts).
dim (int): Dimension of input tokens.
num_experts (int): Number of experts in each moe layer.
capacity_factor (float): Capacity factor determines how many tokens each expert can choose.
expert capacity = (number of tokens * capacity factor) / number of experts.
use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False.
"""

def __init__(
self,
*,
gate: nn.Module,
dim: int,
num_experts: int,
capacity_factor: float,
use_sigmoid: bool = True,
):
super().__init__()
self.gate = gate
self.dim = dim
self.num_experts = num_experts
self.capacity_factor = capacity_factor
self.use_sigmoid = use_sigmoid

def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``.
Returns:
routed_input (torch.Tensor): input tokens grouped together by experts indices with shape
``(num_experts*tokens_per_expert, dim)``.
token_indices (torch.Tensor): token indices for routed_input. Shape ``(num_experts*tokens_per_expert,)``.
"""
# scores shape (num_experts, bs*slen)
scores = self.gate(x).transpose(0, 1)
# By default, we perform sigmoid and softmax in float32 to avoid loss explosion.
if self.use_sigmoid:
scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype)
else:
scores = F.softmax(scores.to(torch.float32), dim=0).to(x.dtype)
tokens_per_expert = int(x.shape[0] * self.capacity_factor / self.num_experts)
tokens_per_expert += -tokens_per_expert % 8
# Take the smaller of tokens_per_expert and the number of tokens
tokens_per_expert = min(tokens_per_expert, x.shape[0])
# top_scores shape (num_experts, tokens_per_expert)
top_scores, selected_token_indices = torch.topk(
scores, k=tokens_per_expert, dim=1
)

return top_scores, selected_token_indices

def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)


class MoE(nn.Module):
"""This class implements the moe layer which is Mixture of Experts. Mixture of Experts
typically consists of a set of expert networks, alongside with a router, which directs input tokens
to the appropriate experts. See more details in https://arxiv.org/pdf/2407.06204.
Args:
experts (nn.Module): experts module.
router (nn.Module): router module.
shared_expert (Optional[nn.Module]): shared expert module. Default is None.
"""

def __init__(
self,
*,
experts: nn.Module,
router: nn.Module,
shared_expert: Optional[nn.Module] = None,
):
super().__init__()
self.experts = experts
self.router = router
self.shared_expert = shared_expert

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): Input tensor with shape ``(bz, slen, dim)``.
Returns:
out (torch.Tensor): Output tensor with shape ``(bz, slen, dim)``.
"""
bz, slen, dim = x.shape

# routed_input shape (num_experts*tokens_per_expert, dim) for EC
x = x.reshape(bz * slen, dim)
top_scores, selected_token_indices = self.router(x)
num_experts, _ = top_scores.shape

# token_indices shape (num_experts*tokens_per_expert, dim)
token_indices = selected_token_indices.reshape(-1, 1).expand(-1, dim)
# routed_input shape (num_experts*tokens_per_expert, dim)
routed_input = torch.gather(x, dim=0, index=token_indices)
routed_input = routed_input * top_scores.reshape(-1, 1)

# routed_input shape (num_experts, tokens_per_expert, dim_in)
routed_input = routed_input.reshape(num_experts, -1, dim)
# routed_output shape (num_experts, tokens_per_expert, dim_out)
routed_output = self.experts(routed_input)
# routed_output shape (num_experts*tokens_per_expert, dim_out)
routed_output = routed_output.reshape(-1, dim)

# shared expert
if self.shared_expert is not None:
out = self.shared_expert(x.reshape(1, bz * slen, dim)).reshape(
bz * slen, dim
)
else:
out = torch.zeros_like(x.reshape(bz * slen, dim))

# add experts output
# doing in in place might be faster
out = out.scatter_add(dim=0, index=token_indices, src=routed_output)
out = out.reshape(bz, slen, dim)
return out

def init_weights(self, init_std: float):
self.experts.init_weights(init_std)
self.router.init_weights(init_std)
if self.shared_expert is not None:
self.shared_expert.init_weights(init_std)

0 comments on commit 45679da

Please sign in to comment.