Skip to content

Commit

Permalink
Mixture of Experts (#115)
Browse files Browse the repository at this point in the history
* update

* works on one gpu

* added moe params

* works on multiple gpus

* update

* eval works

* update

* update

* update

* update

* update

* removed experiment dir

* removed experiments dir

* removed custom fsdp

* update

* update

* added expert gradient norm

* update

* update

* Update model.py

* added load balancing loss

* update

* update

* update

* update

* update

* update

* update

* removed unnecessary dir

* Remove the now ignored directory experiments

* update

* update

* update

* update

* update

* update

* update

* black formatting

* removed custom moe dir

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: Suchin Gururangan <[email protected]>
Co-authored-by: Suchin Gururangan <[email protected]>
Co-authored-by: Suchin Gururangan <[email protected]>
Co-authored-by: Suchin Gururangan <[email protected]>
Co-authored-by: Suchin Gururangan <[email protected]>
Co-authored-by: Suchin Gururangan <[email protected]>
  • Loading branch information
7 people authored Dec 18, 2023
1 parent f159628 commit 5610963
Show file tree
Hide file tree
Showing 10 changed files with 306 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,4 @@ out*
tests/assets/*
.vscode/
checkpoints/
experiments/
114 changes: 114 additions & 0 deletions MOE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Mixture of Experts Language Models

## Dependencies

Our implementation of mixture of experts depends on [megablocks](https://github.com/stanford-futuredata/megablocks) and the version of xformers which is compatible with torch 2.1:

```
pip install megablocks
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121
```

## Train MoE

To train an MoE, add the `--moe-X` related arguments to the training command:

```
torchrun --nproc-per-node 8 -m open_lm.main \
--train-num-samples 10000000000 \
--workers 2 \
--dataset-manifest "s3://laion-west/rpj_tokenized_upsampled_eleutherai/manifest.jsonl" "s3://laion-west/2T_no_rpj_tokenized_upsampled_25k_shards/manifest.jsonl" \
--train-data-mix-weights 0.725 0.275 \
--precision amp_bfloat16 \
--batch-size 8 \
--accum-freq 4 \
--log-every-n-steps 20 \
--grad-clip-norm 1 \
--lr 5e-4 \
--warmup 200 \
--model open_lm_41m \
--wd 0.1 \
--beta2 0.95 \
--epochs 50 \
--report-to wandb \
--moe-freq 2 \
--moe-num-experts 8 \
--moe-top-k 2 \
--moe-capacity-factor 1.25 --moe-loss-weight 0.1 \
--disable-meta-device \
--wandb-project-name moe \
--name test$RANDOM \
--logs /fsx/home-$USER/experiments/moe \
--resume latest \
--seed 124 \
--data-key 'json' \
--fsdp --fsdp-amp \
--model-norm gain_only_layer_norm \
--lr-scheduler cosine \
--lr-cooldown-end 0.00001
```

The above command will add an MoE FFN layer to every other Transformer block. You can use an arbitrary number of experts; you are only limited by total RAM across all GPUs.


You can also add the `moe_expert_model_parallelism` which will distribute experts across different GPUs. However, if the number of GPUs is larger than number of experts, an additional num_gpu/num_expert tensor parallelism is applied. Currently this is not eval-friendly though, so I would not recommend using it yet.

You can evaluate the MoE in the same way as dense models:

```
torchrun --nproc-per-node 8 -m open_lm.main \
--val-data "pipe:aws s3 cp s3://laion-west/lmdata/validation_data_tokenized/open_lm//shard_00000000.tar -" \
--workers 6 \
--precision amp_bfloat16 \
--batch-size 8 \
--log-every-n-steps 1 \
--model open_lm_41m \
--fsdp --fsdp-amp \
--moe-num-experts 64 --moe-freq 2 \
--data-key json \
--train-num-samples 1000000000 \
--model-norm gain_only_layer_norm \
--name $RANDOM \
--resume /fsx/home-suching/experiments/mix_wo/test8086/checkpoints/epoch_1.pt \
--logs /fsx/home-$USER/experiments/eval
```


## Benchmarking

To benchmark your results, here are perplexities we obtain with our implementation across a number of compute budgets and model sizes on our A100 cluster:

### Compute budgets

| Compute type | 41M | 87M | 160M | 410M | 830M |
|--------------|------|------|------|------|------|
| Number of nodes | 1 | 1 | 1 | 2 | 4 |
| Number of tokens | 20.0B | 20.0B | 20.0B | 20.0B | 20.0B |

### Perplexity
| Number of Experts | 41M | 87M | 160M | 410M | 830M |
|--------------|------|------|------|------|------|
| 1 | 27.61 | 18.68 | 14.87 | 10.54 | 9.39 |
| 8 | 19.85 | 14.66 | 12.26 | 9.82 | 8.84 |
| 32 | 20.55 | 15.28 |14.62 | | |


### Tokens/sec/GPU

| Number of Experts | 41M | 87M | 160M | 410M | 830M |
|--------------|------|------|------|------|------|
| 1 | 141.2K | 106.0K | 95.5K | 30.3K | 16.0K |
| 8 | 69.5K | 66.6K | 66.2K | 18.5K | 9.2K |

### Training Parameters

| Number of Experts | 41M | 87M | 160M | 410M | 830M |
|--------------|------|------|------|------|------|
| 8 experts | 68.9M | 165.4M | 360.6M | 1.1B | 2.4B |
| 32 experts | 164.5M | 439.9M | 1.0B | 3.5B | 7.9B |

### Inference Parameters

| Number of Experts | 41M | 87M | 160M | 410M | 830M |
|--------------|------|------|------|------|------|
| 2 experts | 45.0M | 96.8M | 190.7M | 509.2M | 1.1B |
54 changes: 52 additions & 2 deletions open_lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@
from open_lm.positional_embedding.rotary import RotaryWithCast
from open_lm.positional_embedding.llama_rotary import LLaMARotaryWithCast


# from open_lm.moe.mixture_of_experts import MoE
try:
from megablocks.layers.moe import MoE
from megablocks.layers.arguments import Arguments as MoEArgs
except ImportError:
import logging

logging.warning(f"Megablocks not installed. To train MoE, install with pip install megablocks.")

try: # optional import
from mamba_ssm import MambaLMHeadModel
except ImportError:
Expand Down Expand Up @@ -77,6 +87,13 @@ class Params:
weight_tying: bool = False
norm_type: nn.Module = nn.LayerNorm
apply_qk_norm: bool = False
moe_loss_weight: float = 0.1
moe_capacity_factor: float = 1.25
moe_expert_model_parallelism: bool = False
moe_weight_parallelism: bool = False
moe_num_experts: int = 8
moe_top_k: int = 2
moe_freq: int = 0
positional_embedding_type: str = "rotary"
ffn_type: str = "swiglu"

Expand Down Expand Up @@ -237,10 +254,10 @@ def __init__(self, layer_id, args: Params):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim

self.head_dim = args.dim // args.n_heads
self.attention = CustomAttn(layer_id, args)
self._ffn_type = args.ffn_type

if args.ffn_type == "swiglu":
# this follows llama / lit llama -- go to multiple of 256
self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256)
Expand All @@ -251,6 +268,21 @@ def __init__(self, layer_id, args: Params):
self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False)
self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False)
self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2)
elif args.ffn_type == "moe":
moe_args = MoEArgs(
hidden_size=args.dim,
ffn_hidden_size=args.dim * 4,
moe_num_experts=args.moe_num_experts,
moe_weight_parallelism=args.moe_weight_parallelism,
moe_expert_model_parallelism=args.moe_expert_model_parallelism,
moe_top_k=args.moe_top_k,
moe_capacity_factor=args.moe_capacity_factor,
moe_loss_weight=args.moe_loss_weight,
device=torch.cuda.current_device(),
bf16=False,
fp16=False,
)
self.feed_forward = MoE(moe_args)

self.layer_id = layer_id
self.attention_norm = args.norm_type(
Expand Down Expand Up @@ -289,7 +321,11 @@ def forward(self, x, past_key_value=None, use_cache=False):
use_cache=use_cache,
)
h = x + h
out = h + self.feed_forward(self.ffn_norm(h))
if self._ffn_type == "moe":
ffn_out, _ = self.feed_forward(self.ffn_norm(h))
else:
ffn_out = self.feed_forward(self.ffn_norm(h))
out = h + ffn_out
return out, past_key_value


Expand All @@ -298,8 +334,10 @@ def __init__(self, params):
super().__init__()
# for convenience we often share param names with llama
self.params = params
self.dim = params.dim
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.moe_num_experts = params.moe_num_experts
self.seq_len = params.seq_len
self.post_embed_norm = (
params.norm_type(
Expand All @@ -314,7 +352,12 @@ def __init__(self, params):
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)

self.layers = torch.nn.ModuleList()
ffn_type_ = params.ffn_type
for layer_id in range(params.n_layers):
if params.moe_freq > 0 and layer_id % params.moe_freq == 0:
params.ffn_type = "moe"
else:
params.ffn_type = ffn_type_
self.layers.append(Block(layer_id, params))

# get class for normalization layers
Expand Down Expand Up @@ -405,6 +448,13 @@ def create_params(args):
apply_qk_norm=cfg.get("qk_norm", args.qk_norm),
positional_embedding_type=cfg.get("positional_embedding_type", args.positional_embedding_type),
ffn_type=cfg.get("ffn_type", args.ffn_type),
moe_num_experts=cfg.get("moe_num_experts", args.moe_num_experts),
moe_loss_weight=cfg.get("moe_loss_weight", args.moe_loss_weight),
moe_expert_model_parallelism=cfg.get("moe_expert_model_parallelism", args.moe_expert_model_parallelism),
moe_weight_parallelism=cfg.get("moe_weight_parallelism", args.moe_weight_parallelism),
moe_capacity_factor=cfg.get("moe_capacity_factor", args.moe_capacity_factor),
moe_freq=cfg.get("moe_freq", args.moe_freq),
moe_top_k=cfg.get("moe_top_k", args.moe_top_k),
)


Expand Down
44 changes: 44 additions & 0 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,50 @@ def add_model_args(parser):
default="rotary",
help="Type of positional embedding to use. This might be overridden by the model config.",
)
parser.add_argument(
"--moe-freq",
type=int,
default=0,
help="if set > 0, we will add MoE layer to every moe_freq layer.",
)
parser.add_argument(
"--moe-num-experts",
type=int,
default=None,
help="Number of experts for MoE",
)

parser.add_argument(
"--moe-weight-parallelism",
action="store_true",
help="Add weight parallelism to MoE",
)

parser.add_argument(
"--moe-expert-model-parallelism",
action="store_true",
help="Add expert model parallelism to MoE",
)

parser.add_argument(
"--moe-capacity-factor",
type=float,
default=1.25,
help="MoE capacity factor",
)

parser.add_argument(
"--moe-loss-weight",
type=float,
default=0.1,
help="MoE loss weight",
)
parser.add_argument(
"--moe-top-k",
type=int,
default=2,
help="MoE top k experts",
)


def check_replacement_type(replacement, original):
Expand Down
Loading

0 comments on commit 5610963

Please sign in to comment.