Skip to content

Commit da9886a

Browse files
committed
support multipod
1 parent 3dd4ec1 commit da9886a

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

src/transformers/models/llama/modeling_llama.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"""PyTorch LLaMA model."""
2121

2222
import math
23+
import os
2324
import warnings
2425
from typing import List, Optional, Tuple, Union
2526

@@ -61,6 +62,7 @@
6162

6263
_CONFIG_FOR_DOC = "LlamaConfig"
6364

65+
NUM_SLICE=int(os.getenv('NUM_SLICE', 1))
6466

6567
def _get_unpad_data(attention_mask):
6668
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@@ -387,7 +389,11 @@ def forward(
387389
# Integrated with PyTorch/XLA Pallas Flash Attention:
388390
from torch_xla.experimental.custom_kernel import flash_attention
389391
query_states /= math.sqrt(self.head_dim)
390-
attn_output = flash_attention(query_states, key_states, value_states, causal=True, partition_spec=('fsdp', 'tensor', None, None))
392+
if NUM_SLICE == 1:
393+
attn_output = flash_attention(query_states, key_states, value_states, causal=True, partition_spec=('fsdp', 'tensor', None, None))
394+
else:
395+
attn_output = flash_attention(query_states, key_states, value_states, causal=True, partition_spec=(('dcn', 'fsdp'), None, None, None))
396+
391397

392398
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
393399
raise ValueError(

src/transformers/trainer.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
set_seed,
128128
speed_metrics,
129129
)
130+
import torch_xla.distributed.parallel_loader as pl
130131
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
131132
from .utils import (
132133
ADAPTER_CONFIG_NAME,
@@ -264,6 +265,7 @@ def _get_fsdp_ckpt_kwargs():
264265

265266
logger = logging.get_logger(__name__)
266267

268+
NUM_SLICE=int(os.getenv('NUM_SLICE', 1))
267269

268270
# Name of the files used for checkpointing
269271
TRAINING_ARGS_NAME = "training_args.bin"
@@ -381,6 +383,7 @@ def __init__(
381383
args = TrainingArguments(output_dir=output_dir)
382384
self.args = args
383385
# Seed must be set before instantiating the model when using model
386+
set_seed(self.args.seed)
384387
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
385388
self.hp_name = None
386389
self.deepspeed = None
@@ -679,6 +682,18 @@ def __init__(
679682
# Tensor axis is just a placeholder where it will not be used in FSDPv2.
680683
num_devices = xr.global_runtime_device_count()
681684
xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
685+
if NUM_SLICE==1:
686+
mesh_shape = (num_devices, 1)
687+
device_ids = np.array(range(num_devices))
688+
# To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on.
689+
mesh = xs.Mesh(device_ids, mesh_shape, ('fsdp', 'tensor'))
690+
xs.set_global_mesh(mesh)
691+
else:
692+
dcn_axis = NUM_SLICE
693+
ici_mesh_shape = (1, num_devices // dcn_axis, 1)
694+
dcn_mesh_shape = (dcn_axis, 1, 1)
695+
mesh = xs.HybridMesh(ici_mesh_shape=ici_mesh_shape, dcn_mesh_shape=dcn_mesh_shape, axis_names=('dcn', 'fsdp', 'tensor'))
696+
xs.set_global_mesh(mesh)
682697

683698
def _activate_neftune(self, model):
684699
r"""
@@ -877,6 +892,24 @@ def get_train_dataloader(self) -> DataLoader:
877892
dataloader_params["worker_init_fn"] = seed_worker
878893
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
879894

895+
896+
if is_torch_xla_available():
897+
torch_dataloader = DataLoader(train_dataset, **dataloader_params)
898+
device = xm.xla_device()
899+
if NUM_SLICE==1:
900+
mp_device_loader = pl.MpDeviceLoader(
901+
torch_dataloader,
902+
device,
903+
input_sharding=xs.ShardingSpec(xs.get_global_mesh(), ("fsdp", None)),
904+
)
905+
else:
906+
mp_device_loader = pl.MpDeviceLoader(
907+
torch_dataloader,
908+
device,
909+
input_sharding=xs.ShardingSpec(xs.get_global_mesh(), (("dcn", "fsdp"), None)),
910+
)
911+
return mp_device_loader
912+
880913
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
881914

882915
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
@@ -1681,7 +1714,6 @@ def _wrap_model(self, model, training=True, dataloader=None):
16811714
# Transformer layer class to wrap
16821715
transformer_layer_cls=transformer_cls_to_wrap,
16831716
)
1684-
fsdp_kwargs = self.args.xla_fsdp_config
16851717
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
16861718
if model.config.use_cache:
16871719
logger.warning_once(
@@ -1709,7 +1741,11 @@ def shard_output(output, mesh):
17091741

17101742
if real_output is None:
17111743
raise ValueError("Something went wrong, the output of the model shouldn't be `None`")
1712-
xs.mark_sharding(real_output, mesh, ("fsdp", None, None))
1744+
1745+
if NUM_SLICE==1:
1746+
xs.mark_sharding(real_output, mesh, ("fsdp", None, None))
1747+
else:
1748+
xs.mark_sharding(real_output, mesh, (("dcn", "fsdp"), None, None))
17131749

17141750
self.model = model = FSDPv2(
17151751
model,
@@ -1718,10 +1754,12 @@ def shard_output(output, mesh):
17181754
auto_wrapper_callable=auto_wrapper_callable,
17191755
)
17201756
else:
1757+
fsdp_kwargs = self.args.xla_fsdp_config
17211758
self.model = model = FSDP(
17221759
model,
17231760
auto_wrap_policy=auto_wrap_policy,
17241761
auto_wrapper_callable=auto_wrapper_callable,
1762+
reshard_after_forward=False,
17251763
**fsdp_kwargs,
17261764
)
17271765

@@ -1854,6 +1892,7 @@ def train(
18541892
# Disable progress bars when uploading models during checkpoints to avoid polluting stdout
18551893
hf_hub_utils.disable_progress_bars()
18561894
return inner_training_loop(
1895+
batch_size=self._train_batch_size,
18571896
args=args,
18581897
resume_from_checkpoint=resume_from_checkpoint,
18591898
trial=trial,
@@ -1863,6 +1902,7 @@ def train(
18631902
hf_hub_utils.enable_progress_bars()
18641903
else:
18651904
return inner_training_loop(
1905+
batch_size=self._train_batch_size,
18661906
args=args,
18671907
resume_from_checkpoint=resume_from_checkpoint,
18681908
trial=trial,
@@ -1892,8 +1932,8 @@ def _inner_training_loop(
18921932
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
18931933
# Data loader and number of training steps
18941934
train_dataloader = self.get_train_dataloader()
1895-
if self.is_fsdp_xla_v2_enabled:
1896-
train_dataloader = tpu_spmd_dataloader(train_dataloader)
1935+
# if self.is_fsdp_xla_v2_enabled:
1936+
# train_dataloader = tpu_spmd_dataloader(train_dataloader)
18971937

18981938
# Setting up training control variables:
18991939
# number of training epochs: num_train_epochs
@@ -4454,3 +4494,4 @@ def _fsdp_qlora_plugin_updates(self):
44544494
fsdp_plugin.set_mixed_precision(
44554495
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
44564496
)
4497+

0 commit comments

Comments
 (0)