diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index 8a1797dbc0bc..30345828f139 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -65,7 +65,6 @@ logger = logging.getLogger(__name__) - MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @@ -554,7 +553,8 @@ def tokenize_function(examples): ) return output - with training_args.main_process_first(desc="dataset map tokenization"): + # with training_args.main_process_first(desc="dataset map tokenization"): + if True: if not data_args.streaming: tokenized_datasets = raw_datasets.map( tokenize_function, @@ -618,7 +618,8 @@ def group_texts(examples): # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/process#map - with training_args.main_process_first(desc="grouping texts together"): + # with training_args.main_process_first(desc="grouping texts together"): + if True: if not data_args.streaming: lm_datasets = tokenized_datasets.map( group_texts, @@ -724,11 +725,7 @@ def compute_metrics(eval_preds): # Data collator will default to DataCollatorWithPadding, so we change it. data_collator=default_data_collator, compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None, - preprocess_logits_for_metrics=preprocess_logits_for_metrics - if training_args.do_eval and not is_torch_xla_available() - else None, ) - # Training if training_args.do_train: checkpoint = None @@ -736,7 +733,8 @@ def compute_metrics(eval_preds): checkpoint = training_args.resume_from_checkpoint elif last_checkpoint is not None: checkpoint = last_checkpoint - train_result = trainer.train(resume_from_checkpoint=checkpoint) + # import pdb; pdb.set_trace() + train_result = trainer.train() trainer.save_model() # Saves the tokenizer too for easy upload metrics = train_result.metrics diff --git a/fsdp-config.json b/fsdp-config.json new file mode 100644 index 000000000000..19a19159b29e --- /dev/null +++ b/fsdp-config.json @@ -0,0 +1,8 @@ +{ + "fsdp_transformer_layer_cls_to_wrap": [ + "MixtralDecoderLayer" + ], + "xla": true, + "xla_fsdp_v2": true, + "xla_fsdp_grad_ckpt": true +} diff --git a/mixtral_8x22b.json b/mixtral_8x22b.json new file mode 100644 index 000000000000..740a84517f44 --- /dev/null +++ b/mixtral_8x22b.json @@ -0,0 +1,29 @@ +{ + "architectures": [ + "MixtralForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 6144, + "initializer_range": 0.02, + "intermediate_size": 16384, + "max_position_embeddings": 65536, + "model_type": "mixtral", + "num_attention_heads": 48, + "num_experts_per_tok": 2, + "num_hidden_layers": 56, + "num_key_value_heads": 8, + "num_local_experts": 8, + "output_router_logits": false, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000.0, + "router_aux_loss_coef": 0.001, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.0", + "use_cache": false, + "vocab_size": 32000 + } \ No newline at end of file diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index bb4727f66a4e..010f386da981 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -72,11 +72,13 @@ import torch_xla.distributed.spmd as xs import torch_xla.core.xla_model as xm import torch_xla +import os logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MixtralConfig" +NUM_TPU_SLICE = int(os.environ.get('NUM_TPU_SLICE', 1)) def load_balancing_loss_func( @@ -395,8 +397,12 @@ def forward( query_states /= math.sqrt(self.head_dim) partition_spec = None if xs.get_global_mesh() is not None: - partition_spec = ('fsdp', 'tensor', None, None) + if NUM_TPU_SLICE == 1: + partition_spec = ('fsdp', 'tensor', None, None) + else: + partition_spec = (('dcn','fsdp'), 'tensor', None, None) attn_output = flash_attention(query_states, key_states, value_states, causal=True, partition_spec=partition_spec) + # attn_output = FlashAttention.apply(query_states, key_states, value_states, True, None, None, 1.0, None, partition_spec, None) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( @@ -865,7 +871,6 @@ def _eager_gmm_backward(grad_output, lhs, rhs, group_sizes): start += size return torch.cat(grad_lhs), torch.stack(grad_rhs) - @staticmethod @xp.trace_me("gmm_forward") def forward(ctx, hidden_states: torch.Tensor, top_ks: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, w3: torch.Tensor) -> torch.Tensor: @@ -894,11 +899,19 @@ def forward(ctx, hidden_states: torch.Tensor, top_ks: torch.Tensor, w1: torch.Te # Enter manual sharding zone if xs.get_global_mesh() is not None: - hidden_states = xs.enable_manual_sharding(hidden_states, ('fsdp', None)).global_tensor - top_ks = xs.enable_manual_sharding(top_ks, ('fsdp', None)).global_tensor - w1 = xs.enable_manual_sharding(full_w1, (None, None, 'tensor')).global_tensor - w2 = xs.enable_manual_sharding(full_w2, (None, 'tensor', None)).global_tensor - w3 = xs.enable_manual_sharding(full_w3, (None, None, 'tensor')).global_tensor + if NUM_TPU_SLICE == 1: + hidden_states = xs.enable_manual_sharding(hidden_states, ('fsdp', None)).global_tensor + top_ks = xs.enable_manual_sharding(top_ks, ('fsdp', None)).global_tensor + w1 = xs.enable_manual_sharding(full_w1, (None, None, 'tensor')).global_tensor + w2 = xs.enable_manual_sharding(full_w2, (None, 'tensor', None)).global_tensor + w3 = xs.enable_manual_sharding(full_w3, (None, None, 'tensor')).global_tensor + else: + hidden_states = xs.enable_manual_sharding(hidden_states, (('dcn', 'fsdp'), None)).global_tensor + top_ks = xs.enable_manual_sharding(top_ks, (('dcn', 'fsdp'), None)).global_tensor + w1 = xs.enable_manual_sharding(full_w1, (None, None, 'tensor')).global_tensor + w2 = xs.enable_manual_sharding(full_w2, (None, 'tensor', None)).global_tensor + w3 = xs.enable_manual_sharding(full_w3, (None, None, 'tensor')).global_tensor + # We want to create one big batch of tokens that has all top-k choices in it. # Our tokens will thus be duplicated k-times in the batch. To do this we, @@ -917,12 +930,12 @@ def forward(ctx, hidden_states: torch.Tensor, top_ks: torch.Tensor, w1: torch.Te # Replicated MixtralBlockSparseTop2MLP.forward # Here we just use silu and ignore the configuration given we need to manually write the backward pass. - gmm1 = gmm(hidden_states_sorted, w1, group_sizes) - gmm3 = gmm(hidden_states_sorted, w3, group_sizes) + gmm1 = gmm(hidden_states_sorted, w1, group_sizes, tiling=(512, 1024, 1024)) + gmm3 = gmm(hidden_states_sorted, w3, group_sizes, tiling=(512, 1024, 1024)) # Should I save silu activations? silu = F.silu(gmm1) sgmm = silu * gmm3 - gmm2 = gmm(sgmm, w2, group_sizes) + gmm2 = gmm(sgmm, w2, group_sizes, tiling=(512, 1024, 1024)) current_hidden_states = gmm2[hidden_states_reverse_order].reshape(-1, k, n) # Exit manual sharding zone @@ -939,15 +952,21 @@ def forward(ctx, hidden_states: torch.Tensor, top_ks: torch.Tensor, w1: torch.Te # Only reduce-scatter along tensor axis. current_hidden_states = torch_xla.torch_xla._XLAC._xla_spmd_reduce_scatter(xm.REDUCE_SUM, current_hidden_states, 1.0, -1, device_ids.shape[-1], device_ids.tolist()) - - current_hidden_states = xs.disable_manual_sharding(current_hidden_states, ('fsdp', None, 'tensor'), (m, k, n)).global_tensor - - # Checkpoints for backward - hidden_states_sorted = xs.disable_manual_sharding(hidden_states_sorted, ('fsdp', None), (m * k, n)).global_tensor - gmm1 = xs.disable_manual_sharding(gmm1, ('fsdp', 'tensor'), (m * k, l)).global_tensor - gmm3 = xs.disable_manual_sharding(gmm3, ('fsdp', 'tensor'), (m * k, l)).global_tensor - silu = xs.disable_manual_sharding(silu, ('fsdp', 'tensor'), (m * k, l)).global_tensor - sgmm = xs.disable_manual_sharding(sgmm, ('fsdp', 'tensor'), (m * k, l)).global_tensor + if NUM_TPU_SLICE == 1: + current_hidden_states = xs.disable_manual_sharding(current_hidden_states, ('fsdp', None, 'tensor'), (m, k, n)).global_tensor + hidden_states_sorted = xs.disable_manual_sharding(hidden_states_sorted, ('fsdp', None), (m * k, n)).global_tensor + gmm1 = xs.disable_manual_sharding(gmm1, ('fsdp', 'tensor'), (m * k, l)).global_tensor + gmm3 = xs.disable_manual_sharding(gmm3, ('fsdp', 'tensor'), (m * k, l)).global_tensor + silu = xs.disable_manual_sharding(silu, ('fsdp', 'tensor'), (m * k, l)).global_tensor + sgmm = xs.disable_manual_sharding(sgmm, ('fsdp', 'tensor'), (m * k, l)).global_tensor + else: + current_hidden_states = xs.disable_manual_sharding(current_hidden_states, (('dcn','fsdp'), None, 'tensor'), (m, k, n)).global_tensor + # Checkpoints for backward + hidden_states_sorted = xs.disable_manual_sharding(hidden_states_sorted, (('dcn', 'fsdp'), None), (m * k, n)).global_tensor + gmm1 = xs.disable_manual_sharding(gmm1, (('dcn', 'fsdp'), 'tensor'), (m * k, l)).global_tensor + gmm3 = xs.disable_manual_sharding(gmm3, (('dcn', 'fsdp'), 'tensor'), (m * k, l)).global_tensor + silu = xs.disable_manual_sharding(silu, (('dcn', 'fsdp'), 'tensor'), (m * k, l)).global_tensor + sgmm = xs.disable_manual_sharding(sgmm, (('dcn', 'fsdp'), 'tensor'), (m * k, l)).global_tensor # Save for backward ctx.save_for_backward(hidden_states_sorted, full_w1, full_w2, full_w3, gmm1, gmm3, silu, sgmm, hidden_states_order, hidden_states_reverse_order, group_sizes) @@ -960,7 +979,6 @@ def forward(ctx, hidden_states: torch.Tensor, top_ks: torch.Tensor, w1: torch.Te @xp.trace_me("gmm_backward") def backward(ctx, grad_output): from torch_xla.experimental.custom_kernel import _histogram, gmm_backward - device = grad_output.device if device == torch.device('cpu'): gmm_backward = Gmm._eager_gmm_backward @@ -978,46 +996,61 @@ def backward(ctx, grad_output): # Enter manual sharding zone if xs.get_global_mesh() is not None: - hidden_states_sorted = xs.enable_manual_sharding(hidden_states_sorted, ('fsdp', None)).global_tensor + if NUM_TPU_SLICE == 1: + hidden_states_sorted = xs.enable_manual_sharding(hidden_states_sorted, ('fsdp', None)).global_tensor + else: + hidden_states_sorted = xs.enable_manual_sharding(hidden_states_sorted, (('dcn', 'fsdp'), None)).global_tensor w1 = xs.enable_manual_sharding(full_w1, (None, None, 'tensor')).global_tensor w2 = xs.enable_manual_sharding(full_w2, (None, 'tensor', None)).global_tensor w3 = xs.enable_manual_sharding(full_w3, (None, None, 'tensor')).global_tensor - gmm1 = xs.enable_manual_sharding(gmm1, ('fsdp', 'tensor')).global_tensor - gmm3 = xs.enable_manual_sharding(gmm3, ('fsdp', 'tensor')).global_tensor - silu = xs.enable_manual_sharding(silu, ('fsdp', 'tensor')).global_tensor - sgmm = xs.enable_manual_sharding(sgmm, ('fsdp', 'tensor')).global_tensor - grad_output = xs.enable_manual_sharding(grad_output, ('fsdp', None, None)).global_tensor + temp_sharding_spec = ('fsdp', 'tensor') if NUM_TPU_SLICE == 1 else (('dcn', 'fsdp'), 'tensor') + gmm1 = xs.enable_manual_sharding(gmm1, temp_sharding_spec).global_tensor + gmm3 = xs.enable_manual_sharding(gmm3, temp_sharding_spec).global_tensor + silu = xs.enable_manual_sharding(silu, temp_sharding_spec).global_tensor + sgmm = xs.enable_manual_sharding(sgmm, temp_sharding_spec).global_tensor + if NUM_TPU_SLICE == 1: + grad_output = xs.enable_manual_sharding(grad_output, ('fsdp', None, None)).global_tensor + else: + grad_output = xs.enable_manual_sharding(grad_output, (('dcn', 'fsdp'), None, None)).global_tensor + grad_output_sorted = grad_output.reshape(-1, n)[hidden_states_order] - grad_output, grad_w2 = gmm_backward(grad_output_sorted, sgmm, w2, group_sizes) + grad_output, grad_w2 = gmm_backward(grad_output_sorted, sgmm, w2, group_sizes, tiling=(512, 1024, 1024)) grad_gmm1 = gmm3 * grad_output grad_gmm1 = torch.ops.aten.silu_backward(grad_gmm1, gmm1) - grad_gmm1, grad_w1 = gmm_backward(grad_gmm1, hidden_states_sorted, w1, group_sizes) + + grad_gmm1, grad_w1 = gmm_backward(grad_gmm1, hidden_states_sorted, w1, group_sizes, tiling=(512, 1024, 1024)) grad_gmm3 = silu * grad_output - grad_gmm3, grad_w3 = gmm_backward(grad_gmm3, hidden_states_sorted, w3, group_sizes) + grad_gmm3, grad_w3 = gmm_backward(grad_gmm3, hidden_states_sorted, w3, group_sizes, tiling=(512, 1024, 1024)) grad_output = grad_gmm1 + grad_gmm3 grad_output = grad_output[hidden_states_reverse_order] grad_output = grad_output.reshape(-1, k, grad_output.shape[-1]).sum(dim=1) - # Exit manual sharding zone if xs.get_global_mesh() is not None: if not hasattr(ctx, "device_ids"): # Here we do a manual reduce scatter as SPMD will not be able to infer this after the manual sharding zone. - groups = [xs.get_global_mesh().device_ids] # a single group across the whole world + if NUM_TPU_SLICE == 1: + groups = [xs.get_global_mesh().device_ids] # a single group across the whole world + else: + groups = [list(range(i*256, (i+1)*256)) for i in range(NUM_TPU_SLICE)] world_size = len(groups[0]) grad_w1 = torch_xla.torch_xla._XLAC._xla_spmd_reduce_scatter(xm.REDUCE_SUM, grad_w1, 1 / world_size, -1, world_size, groups) grad_w2 = torch_xla.torch_xla._XLAC._xla_spmd_reduce_scatter(xm.REDUCE_SUM, grad_w2, 1 / world_size, -2, world_size, groups) grad_w3 = torch_xla.torch_xla._XLAC._xla_spmd_reduce_scatter(xm.REDUCE_SUM, grad_w3, 1 / world_size, -1, world_size, groups) - grad_output = xs.disable_manual_sharding(grad_output, (0, None), (m, n)).global_tensor + if NUM_TPU_SLICE == 1: + grad_output = xs.disable_manual_sharding(grad_output, ('fsdp', None), (m, n)).global_tensor + else: + grad_output = xs.disable_manual_sharding(grad_output, (('dcn', 'fsdp'), None), (m, n)).global_tensor # TODO: make the 0s more programmatic. - grad_w1 = xs.disable_manual_sharding(grad_w1, (None, None, 0), w1.shape).global_tensor - grad_w2 = xs.disable_manual_sharding(grad_w2, (None, 0, None), w2.shape).global_tensor - grad_w3 = xs.disable_manual_sharding(grad_w3, (None, None, 0), w3.shape).global_tensor + # grad_w* sharding isn't affected by multipod. + grad_w1 = xs.disable_manual_sharding(grad_w1, (None, None, 'fsdp'), w1.shape).global_tensor + grad_w2 = xs.disable_manual_sharding(grad_w2, (None, 'fsdp', None), w2.shape).global_tensor + grad_w3 = xs.disable_manual_sharding(grad_w3, (None, None, 'fsdp'), w3.shape).global_tensor else: # 2d sharding device_ids = ctx.device_ids @@ -1036,7 +1069,6 @@ def backward(ctx, grad_output): grad_w1 = xs.disable_manual_sharding(grad_w1, (None, 'fsdp', 'tensor'), full_w1.shape).global_tensor grad_w2 = xs.disable_manual_sharding(grad_w2, (None, 'tensor', 'fsdp'), full_w2.shape).global_tensor grad_w3 = xs.disable_manual_sharding(grad_w3, (None, 'fsdp', 'tensor'), full_w3.shape).global_tensor - return grad_output, None, grad_w1, grad_w2, grad_w3 @@ -1091,7 +1123,6 @@ def __init__(self, config): # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - if not self.gmm or self.gmm_stack: self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) else: diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c818bc814f91..bcf3d6bc702b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -238,6 +238,7 @@ if is_accelerate_available("0.28.0"): from accelerate.utils import DataLoaderConfiguration +NUM_TPU_SLICE = int(os.environ.get('NUM_TPU_SLICE', 1)) def _is_peft_model(model): if is_peft_available(): @@ -681,7 +682,20 @@ def __init__( # Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper. # Tensor axis is just a placeholder where it will not be used in FSDPv2. num_devices = xr.global_runtime_device_count() - xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor"))) + if NUM_TPU_SLICE == 1: + mesh_shape = (num_devices, 1) + device_ids = np.array(range(num_devices)) + # To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on. + mesh = xs.Mesh(device_ids, mesh_shape, ('fsdp', 'tensor')) + xs.set_global_mesh(mesh) + elif NUM_TPU_SLICE > 1: + dcn_axis = NUM_TPU_SLICE + ici_mesh_shape = (1, num_devices // dcn_axis, 1) + dcn_mesh_shape = (dcn_axis, 1, 1) + mesh = xs.HybridMesh(ici_mesh_shape=ici_mesh_shape, dcn_mesh_shape=dcn_mesh_shape, axis_names=('dcn', 'fsdp', 'tensor')) + xs.set_global_mesh(mesh) + else: + raise ValueError("Expected NUM_TPU_SLICE > 0") def _activate_neftune(self, model): r""" @@ -1712,7 +1726,11 @@ def shard_output(output, mesh): if real_output is None: raise ValueError("Something went wrong, the output of the model shouldn't be `None`") - xs.mark_sharding(real_output, mesh, ("fsdp", None, None)) + if NUM_TPU_SLICE == 1: + xs.mark_sharding(real_output, mesh, ("fsdp", None, None)) + else: + xs.mark_sharding(real_output, mesh, (("dcn", "fsdp"), None, None)) + self.model = model = FSDPv2( model, @@ -1796,6 +1814,7 @@ def train( args = self.args self.is_in_train = True + os.environ['USE_SINGLE_SLICE']= 'true' # Attach NEFTune hooks if necessary if self.neftune_noise_alpha is not None: @@ -1896,7 +1915,12 @@ def _inner_training_loop( # Data loader and number of training steps train_dataloader = self.get_train_dataloader() if self.is_fsdp_xla_v2_enabled: - train_dataloader = tpu_spmd_dataloader(train_dataloader) + # train_dataloader = tpu_spmd_dataloader(train_dataloader) + if NUM_TPU_SLICE == 1: + sharding_spec = xs.ShardingSpec(xs.get_global_mesh(), ("fsdp", None)) + else: + sharding_spec = xs.ShardingSpec(xs.get_global_mesh(), (("dcn", "fsdp"), None)) + train_dataloader._parallel_loader_kwargs["input_sharding"] = sharding_spec # Setting up training control variables: # number of training epochs: num_train_epochs @@ -2224,6 +2248,9 @@ def _inner_training_loop( # with self.accelerator.accumulate(model): tr_loss_step = self.training_step(model, inputs) + if step % 10 == 0: + print(f"Training step {step} : Loss: {tr_loss_step}") + if ( args.logging_nan_inf_filter and not is_torch_xla_available() @@ -4467,4 +4494,4 @@ def _fsdp_qlora_plugin_updates(self): ): fsdp_plugin.set_mixed_precision( self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True - ) + ) \ No newline at end of file