Skip to content

Moe multipod #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: alanwaketan/moe
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions examples/pytorch/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@

logger = logging.getLogger(__name__)


MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

Expand Down Expand Up @@ -554,7 +553,8 @@ def tokenize_function(examples):
)
return output

with training_args.main_process_first(desc="dataset map tokenization"):
# with training_args.main_process_first(desc="dataset map tokenization"):
if True:
if not data_args.streaming:
tokenized_datasets = raw_datasets.map(
tokenize_function,
Expand Down Expand Up @@ -618,7 +618,8 @@ def group_texts(examples):
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/process#map

with training_args.main_process_first(desc="grouping texts together"):
# with training_args.main_process_first(desc="grouping texts together"):
if True:
if not data_args.streaming:
lm_datasets = tokenized_datasets.map(
group_texts,
Expand Down Expand Up @@ -724,19 +725,16 @@ def compute_metrics(eval_preds):
# Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=default_data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_xla_available()
else None,
)

# Training
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
# import pdb; pdb.set_trace()
train_result = trainer.train()
trainer.save_model() # Saves the tokenizer too for easy upload

metrics = train_result.metrics
Expand Down
8 changes: 8 additions & 0 deletions fsdp-config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"fsdp_transformer_layer_cls_to_wrap": [
"MixtralDecoderLayer"
],
"xla": true,
"xla_fsdp_v2": true,
"xla_fsdp_grad_ckpt": true
}
29 changes: 29 additions & 0 deletions mixtral_8x22b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"architectures": [
"MixtralForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 6144,
"initializer_range": 0.02,
"intermediate_size": 16384,
"max_position_embeddings": 65536,
"model_type": "mixtral",
"num_attention_heads": 48,
"num_experts_per_tok": 2,
"num_hidden_layers": 56,
"num_key_value_heads": 8,
"num_local_experts": 8,
"output_router_logits": false,
"rms_norm_eps": 1e-05,
"rope_theta": 1000000.0,
"router_aux_loss_coef": 0.001,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.38.0",
"use_cache": false,
"vocab_size": 32000
}
105 changes: 68 additions & 37 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@
import torch_xla.distributed.spmd as xs
import torch_xla.core.xla_model as xm
import torch_xla
import os


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "MixtralConfig"
NUM_TPU_SLICE = int(os.environ.get('NUM_TPU_SLICE', 1))


def load_balancing_loss_func(
Expand Down Expand Up @@ -395,8 +397,12 @@ def forward(
query_states /= math.sqrt(self.head_dim)
partition_spec = None
if xs.get_global_mesh() is not None:
partition_spec = ('fsdp', 'tensor', None, None)
if NUM_TPU_SLICE == 1:
partition_spec = ('fsdp', 'tensor', None, None)
else:
partition_spec = (('dcn','fsdp'), 'tensor', None, None)
attn_output = flash_attention(query_states, key_states, value_states, causal=True, partition_spec=partition_spec)
# attn_output = FlashAttention.apply(query_states, key_states, value_states, True, None, None, 1.0, None, partition_spec, None)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
Expand Down Expand Up @@ -865,7 +871,6 @@ def _eager_gmm_backward(grad_output, lhs, rhs, group_sizes):
start += size
return torch.cat(grad_lhs), torch.stack(grad_rhs)


@staticmethod
@xp.trace_me("gmm_forward")
def forward(ctx, hidden_states: torch.Tensor, top_ks: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, w3: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -894,11 +899,19 @@ def forward(ctx, hidden_states: torch.Tensor, top_ks: torch.Tensor, w1: torch.Te

# Enter manual sharding zone
if xs.get_global_mesh() is not None:
hidden_states = xs.enable_manual_sharding(hidden_states, ('fsdp', None)).global_tensor
top_ks = xs.enable_manual_sharding(top_ks, ('fsdp', None)).global_tensor
w1 = xs.enable_manual_sharding(full_w1, (None, None, 'tensor')).global_tensor
w2 = xs.enable_manual_sharding(full_w2, (None, 'tensor', None)).global_tensor
w3 = xs.enable_manual_sharding(full_w3, (None, None, 'tensor')).global_tensor
if NUM_TPU_SLICE == 1:
hidden_states = xs.enable_manual_sharding(hidden_states, ('fsdp', None)).global_tensor
top_ks = xs.enable_manual_sharding(top_ks, ('fsdp', None)).global_tensor
w1 = xs.enable_manual_sharding(full_w1, (None, None, 'tensor')).global_tensor
w2 = xs.enable_manual_sharding(full_w2, (None, 'tensor', None)).global_tensor
w3 = xs.enable_manual_sharding(full_w3, (None, None, 'tensor')).global_tensor
else:
hidden_states = xs.enable_manual_sharding(hidden_states, (('dcn', 'fsdp'), None)).global_tensor
top_ks = xs.enable_manual_sharding(top_ks, (('dcn', 'fsdp'), None)).global_tensor
w1 = xs.enable_manual_sharding(full_w1, (None, None, 'tensor')).global_tensor
w2 = xs.enable_manual_sharding(full_w2, (None, 'tensor', None)).global_tensor
w3 = xs.enable_manual_sharding(full_w3, (None, None, 'tensor')).global_tensor


# We want to create one big batch of tokens that has all top-k choices in it.
# Our tokens will thus be duplicated k-times in the batch. To do this we,
Expand All @@ -917,12 +930,12 @@ def forward(ctx, hidden_states: torch.Tensor, top_ks: torch.Tensor, w1: torch.Te

# Replicated MixtralBlockSparseTop2MLP.forward
# Here we just use silu and ignore the configuration given we need to manually write the backward pass.
gmm1 = gmm(hidden_states_sorted, w1, group_sizes)
gmm3 = gmm(hidden_states_sorted, w3, group_sizes)
gmm1 = gmm(hidden_states_sorted, w1, group_sizes, tiling=(512, 1024, 1024))
gmm3 = gmm(hidden_states_sorted, w3, group_sizes, tiling=(512, 1024, 1024))
# Should I save silu activations?
silu = F.silu(gmm1)
sgmm = silu * gmm3
gmm2 = gmm(sgmm, w2, group_sizes)
gmm2 = gmm(sgmm, w2, group_sizes, tiling=(512, 1024, 1024))
current_hidden_states = gmm2[hidden_states_reverse_order].reshape(-1, k, n)

# Exit manual sharding zone
Expand All @@ -939,15 +952,21 @@ def forward(ctx, hidden_states: torch.Tensor, top_ks: torch.Tensor, w1: torch.Te
# Only reduce-scatter along tensor axis.
current_hidden_states = torch_xla.torch_xla._XLAC._xla_spmd_reduce_scatter(xm.REDUCE_SUM, current_hidden_states, 1.0, -1, device_ids.shape[-1], device_ids.tolist())


current_hidden_states = xs.disable_manual_sharding(current_hidden_states, ('fsdp', None, 'tensor'), (m, k, n)).global_tensor

# Checkpoints for backward
hidden_states_sorted = xs.disable_manual_sharding(hidden_states_sorted, ('fsdp', None), (m * k, n)).global_tensor
gmm1 = xs.disable_manual_sharding(gmm1, ('fsdp', 'tensor'), (m * k, l)).global_tensor
gmm3 = xs.disable_manual_sharding(gmm3, ('fsdp', 'tensor'), (m * k, l)).global_tensor
silu = xs.disable_manual_sharding(silu, ('fsdp', 'tensor'), (m * k, l)).global_tensor
sgmm = xs.disable_manual_sharding(sgmm, ('fsdp', 'tensor'), (m * k, l)).global_tensor
if NUM_TPU_SLICE == 1:
current_hidden_states = xs.disable_manual_sharding(current_hidden_states, ('fsdp', None, 'tensor'), (m, k, n)).global_tensor
hidden_states_sorted = xs.disable_manual_sharding(hidden_states_sorted, ('fsdp', None), (m * k, n)).global_tensor
gmm1 = xs.disable_manual_sharding(gmm1, ('fsdp', 'tensor'), (m * k, l)).global_tensor
gmm3 = xs.disable_manual_sharding(gmm3, ('fsdp', 'tensor'), (m * k, l)).global_tensor
silu = xs.disable_manual_sharding(silu, ('fsdp', 'tensor'), (m * k, l)).global_tensor
sgmm = xs.disable_manual_sharding(sgmm, ('fsdp', 'tensor'), (m * k, l)).global_tensor
else:
current_hidden_states = xs.disable_manual_sharding(current_hidden_states, (('dcn','fsdp'), None, 'tensor'), (m, k, n)).global_tensor
# Checkpoints for backward
hidden_states_sorted = xs.disable_manual_sharding(hidden_states_sorted, (('dcn', 'fsdp'), None), (m * k, n)).global_tensor
gmm1 = xs.disable_manual_sharding(gmm1, (('dcn', 'fsdp'), 'tensor'), (m * k, l)).global_tensor
gmm3 = xs.disable_manual_sharding(gmm3, (('dcn', 'fsdp'), 'tensor'), (m * k, l)).global_tensor
silu = xs.disable_manual_sharding(silu, (('dcn', 'fsdp'), 'tensor'), (m * k, l)).global_tensor
sgmm = xs.disable_manual_sharding(sgmm, (('dcn', 'fsdp'), 'tensor'), (m * k, l)).global_tensor

# Save for backward
ctx.save_for_backward(hidden_states_sorted, full_w1, full_w2, full_w3, gmm1, gmm3, silu, sgmm, hidden_states_order, hidden_states_reverse_order, group_sizes)
Expand All @@ -960,7 +979,6 @@ def forward(ctx, hidden_states: torch.Tensor, top_ks: torch.Tensor, w1: torch.Te
@xp.trace_me("gmm_backward")
def backward(ctx, grad_output):
from torch_xla.experimental.custom_kernel import _histogram, gmm_backward

device = grad_output.device
if device == torch.device('cpu'):
gmm_backward = Gmm._eager_gmm_backward
Expand All @@ -978,46 +996,61 @@ def backward(ctx, grad_output):

# Enter manual sharding zone
if xs.get_global_mesh() is not None:
hidden_states_sorted = xs.enable_manual_sharding(hidden_states_sorted, ('fsdp', None)).global_tensor
if NUM_TPU_SLICE == 1:
hidden_states_sorted = xs.enable_manual_sharding(hidden_states_sorted, ('fsdp', None)).global_tensor
else:
hidden_states_sorted = xs.enable_manual_sharding(hidden_states_sorted, (('dcn', 'fsdp'), None)).global_tensor
w1 = xs.enable_manual_sharding(full_w1, (None, None, 'tensor')).global_tensor
w2 = xs.enable_manual_sharding(full_w2, (None, 'tensor', None)).global_tensor
w3 = xs.enable_manual_sharding(full_w3, (None, None, 'tensor')).global_tensor
gmm1 = xs.enable_manual_sharding(gmm1, ('fsdp', 'tensor')).global_tensor
gmm3 = xs.enable_manual_sharding(gmm3, ('fsdp', 'tensor')).global_tensor
silu = xs.enable_manual_sharding(silu, ('fsdp', 'tensor')).global_tensor
sgmm = xs.enable_manual_sharding(sgmm, ('fsdp', 'tensor')).global_tensor
grad_output = xs.enable_manual_sharding(grad_output, ('fsdp', None, None)).global_tensor
temp_sharding_spec = ('fsdp', 'tensor') if NUM_TPU_SLICE == 1 else (('dcn', 'fsdp'), 'tensor')
gmm1 = xs.enable_manual_sharding(gmm1, temp_sharding_spec).global_tensor
gmm3 = xs.enable_manual_sharding(gmm3, temp_sharding_spec).global_tensor
silu = xs.enable_manual_sharding(silu, temp_sharding_spec).global_tensor
sgmm = xs.enable_manual_sharding(sgmm, temp_sharding_spec).global_tensor
if NUM_TPU_SLICE == 1:
grad_output = xs.enable_manual_sharding(grad_output, ('fsdp', None, None)).global_tensor
else:
grad_output = xs.enable_manual_sharding(grad_output, (('dcn', 'fsdp'), None, None)).global_tensor


grad_output_sorted = grad_output.reshape(-1, n)[hidden_states_order]
grad_output, grad_w2 = gmm_backward(grad_output_sorted, sgmm, w2, group_sizes)
grad_output, grad_w2 = gmm_backward(grad_output_sorted, sgmm, w2, group_sizes, tiling=(512, 1024, 1024))

grad_gmm1 = gmm3 * grad_output
grad_gmm1 = torch.ops.aten.silu_backward(grad_gmm1, gmm1)
grad_gmm1, grad_w1 = gmm_backward(grad_gmm1, hidden_states_sorted, w1, group_sizes)

grad_gmm1, grad_w1 = gmm_backward(grad_gmm1, hidden_states_sorted, w1, group_sizes, tiling=(512, 1024, 1024))

grad_gmm3 = silu * grad_output
grad_gmm3, grad_w3 = gmm_backward(grad_gmm3, hidden_states_sorted, w3, group_sizes)
grad_gmm3, grad_w3 = gmm_backward(grad_gmm3, hidden_states_sorted, w3, group_sizes, tiling=(512, 1024, 1024))

grad_output = grad_gmm1 + grad_gmm3

grad_output = grad_output[hidden_states_reverse_order]
grad_output = grad_output.reshape(-1, k, grad_output.shape[-1]).sum(dim=1)

# Exit manual sharding zone
if xs.get_global_mesh() is not None:
if not hasattr(ctx, "device_ids"):
# Here we do a manual reduce scatter as SPMD will not be able to infer this after the manual sharding zone.
groups = [xs.get_global_mesh().device_ids] # a single group across the whole world
if NUM_TPU_SLICE == 1:
groups = [xs.get_global_mesh().device_ids] # a single group across the whole world
else:
groups = [list(range(i*256, (i+1)*256)) for i in range(NUM_TPU_SLICE)]
world_size = len(groups[0])
grad_w1 = torch_xla.torch_xla._XLAC._xla_spmd_reduce_scatter(xm.REDUCE_SUM, grad_w1, 1 / world_size, -1, world_size, groups)
grad_w2 = torch_xla.torch_xla._XLAC._xla_spmd_reduce_scatter(xm.REDUCE_SUM, grad_w2, 1 / world_size, -2, world_size, groups)
grad_w3 = torch_xla.torch_xla._XLAC._xla_spmd_reduce_scatter(xm.REDUCE_SUM, grad_w3, 1 / world_size, -1, world_size, groups)

grad_output = xs.disable_manual_sharding(grad_output, (0, None), (m, n)).global_tensor
if NUM_TPU_SLICE == 1:
grad_output = xs.disable_manual_sharding(grad_output, ('fsdp', None), (m, n)).global_tensor
else:
grad_output = xs.disable_manual_sharding(grad_output, (('dcn', 'fsdp'), None), (m, n)).global_tensor
# TODO: make the 0s more programmatic.
grad_w1 = xs.disable_manual_sharding(grad_w1, (None, None, 0), w1.shape).global_tensor
grad_w2 = xs.disable_manual_sharding(grad_w2, (None, 0, None), w2.shape).global_tensor
grad_w3 = xs.disable_manual_sharding(grad_w3, (None, None, 0), w3.shape).global_tensor
# grad_w* sharding isn't affected by multipod.
grad_w1 = xs.disable_manual_sharding(grad_w1, (None, None, 'fsdp'), w1.shape).global_tensor
grad_w2 = xs.disable_manual_sharding(grad_w2, (None, 'fsdp', None), w2.shape).global_tensor
grad_w3 = xs.disable_manual_sharding(grad_w3, (None, None, 'fsdp'), w3.shape).global_tensor
else: # 2d sharding
device_ids = ctx.device_ids

Expand All @@ -1036,7 +1069,6 @@ def backward(ctx, grad_output):
grad_w1 = xs.disable_manual_sharding(grad_w1, (None, 'fsdp', 'tensor'), full_w1.shape).global_tensor
grad_w2 = xs.disable_manual_sharding(grad_w2, (None, 'tensor', 'fsdp'), full_w2.shape).global_tensor
grad_w3 = xs.disable_manual_sharding(grad_w3, (None, 'fsdp', 'tensor'), full_w3.shape).global_tensor

return grad_output, None, grad_w1, grad_w2, grad_w3


Expand Down Expand Up @@ -1091,7 +1123,6 @@ def __init__(self, config):

# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

if not self.gmm or self.gmm_stack:
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
else:
Expand Down
Loading