diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d067e1ddfb6a..dafc14b2b898 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -20,6 +20,7 @@ """PyTorch LLaMA model.""" import math +import os import warnings from typing import List, Optional, Tuple, Union @@ -61,6 +62,7 @@ _CONFIG_FOR_DOC = "LlamaConfig" +NUM_SLICE=int(os.getenv('NUM_SLICE', 1)) def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) @@ -387,7 +389,11 @@ def forward( # Integrated with PyTorch/XLA Pallas Flash Attention: from torch_xla.experimental.custom_kernel import flash_attention query_states /= math.sqrt(self.head_dim) - attn_output = flash_attention(query_states, key_states, value_states, causal=True, partition_spec=('fsdp', 'tensor', None, None)) + if NUM_SLICE == 1: + attn_output = flash_attention(query_states, key_states, value_states, causal=True, partition_spec=('fsdp', 'tensor', None, None)) + else: + attn_output = flash_attention(query_states, key_states, value_states, causal=True, partition_spec=(('dcn', 'fsdp'), None, None, None)) + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e2ddc3fd74e4..d78929e5e2a3 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -127,6 +127,7 @@ set_seed, speed_metrics, ) +import torch_xla.distributed.parallel_loader as pl from .training_args import OptimizerNames, ParallelMode, TrainingArguments from .utils import ( ADAPTER_CONFIG_NAME, @@ -264,6 +265,7 @@ def _get_fsdp_ckpt_kwargs(): logger = logging.get_logger(__name__) +NUM_SLICE=int(os.getenv('NUM_SLICE', 1)) # Name of the files used for checkpointing TRAINING_ARGS_NAME = "training_args.bin" @@ -381,6 +383,7 @@ def __init__( args = TrainingArguments(output_dir=output_dir) self.args = args # Seed must be set before instantiating the model when using model + set_seed(self.args.seed) enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) self.hp_name = None self.deepspeed = None @@ -679,6 +682,18 @@ def __init__( # Tensor axis is just a placeholder where it will not be used in FSDPv2. num_devices = xr.global_runtime_device_count() xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor"))) + if NUM_SLICE==1: + mesh_shape = (num_devices, 1) + device_ids = np.array(range(num_devices)) + # To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on. + mesh = xs.Mesh(device_ids, mesh_shape, ('fsdp', 'tensor')) + xs.set_global_mesh(mesh) + else: + dcn_axis = NUM_SLICE + ici_mesh_shape = (1, num_devices // dcn_axis, 1) + dcn_mesh_shape = (dcn_axis, 1, 1) + mesh = xs.HybridMesh(ici_mesh_shape=ici_mesh_shape, dcn_mesh_shape=dcn_mesh_shape, axis_names=('dcn', 'fsdp', 'tensor')) + xs.set_global_mesh(mesh) def _activate_neftune(self, model): r""" @@ -877,6 +892,24 @@ def get_train_dataloader(self) -> DataLoader: dataloader_params["worker_init_fn"] = seed_worker dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + if is_torch_xla_available(): + torch_dataloader = DataLoader(train_dataset, **dataloader_params) + device = xm.xla_device() + if NUM_SLICE==1: + mp_device_loader = pl.MpDeviceLoader( + torch_dataloader, + device, + input_sharding=xs.ShardingSpec(xs.get_global_mesh(), ("fsdp", None)), + ) + else: + mp_device_loader = pl.MpDeviceLoader( + torch_dataloader, + device, + input_sharding=xs.ShardingSpec(xs.get_global_mesh(), (("dcn", "fsdp"), None)), + ) + return mp_device_loader + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: @@ -1681,7 +1714,6 @@ def _wrap_model(self, model, training=True, dataloader=None): # Transformer layer class to wrap transformer_layer_cls=transformer_cls_to_wrap, ) - fsdp_kwargs = self.args.xla_fsdp_config if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: if model.config.use_cache: logger.warning_once( @@ -1709,7 +1741,11 @@ def shard_output(output, mesh): if real_output is None: raise ValueError("Something went wrong, the output of the model shouldn't be `None`") - xs.mark_sharding(real_output, mesh, ("fsdp", None, None)) + + if NUM_SLICE==1: + xs.mark_sharding(real_output, mesh, ("fsdp", None, None)) + else: + xs.mark_sharding(real_output, mesh, (("dcn", "fsdp"), None, None)) self.model = model = FSDPv2( model, @@ -1718,10 +1754,12 @@ def shard_output(output, mesh): auto_wrapper_callable=auto_wrapper_callable, ) else: + fsdp_kwargs = self.args.xla_fsdp_config self.model = model = FSDP( model, auto_wrap_policy=auto_wrap_policy, auto_wrapper_callable=auto_wrapper_callable, + reshard_after_forward=False, **fsdp_kwargs, ) @@ -1854,6 +1892,7 @@ def train( # Disable progress bars when uploading models during checkpoints to avoid polluting stdout hf_hub_utils.disable_progress_bars() return inner_training_loop( + batch_size=self._train_batch_size, args=args, resume_from_checkpoint=resume_from_checkpoint, trial=trial, @@ -1863,6 +1902,7 @@ def train( hf_hub_utils.enable_progress_bars() else: return inner_training_loop( + batch_size=self._train_batch_size, args=args, resume_from_checkpoint=resume_from_checkpoint, trial=trial, @@ -1892,8 +1932,8 @@ def _inner_training_loop( logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") # Data loader and number of training steps train_dataloader = self.get_train_dataloader() - if self.is_fsdp_xla_v2_enabled: - train_dataloader = tpu_spmd_dataloader(train_dataloader) + # if self.is_fsdp_xla_v2_enabled: + # train_dataloader = tpu_spmd_dataloader(train_dataloader) # Setting up training control variables: # number of training epochs: num_train_epochs @@ -4454,3 +4494,4 @@ def _fsdp_qlora_plugin_updates(self): fsdp_plugin.set_mixed_precision( self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True ) +