127
127
set_seed ,
128
128
speed_metrics ,
129
129
)
130
+ import torch_xla .distributed .parallel_loader as pl
130
131
from .training_args import OptimizerNames , ParallelMode , TrainingArguments
131
132
from .utils import (
132
133
ADAPTER_CONFIG_NAME ,
@@ -264,6 +265,7 @@ def _get_fsdp_ckpt_kwargs():
264
265
265
266
logger = logging .get_logger (__name__ )
266
267
268
+ NUM_SLICE = int (os .getenv ('NUM_SLICE' , 1 ))
267
269
268
270
# Name of the files used for checkpointing
269
271
TRAINING_ARGS_NAME = "training_args.bin"
@@ -381,6 +383,7 @@ def __init__(
381
383
args = TrainingArguments (output_dir = output_dir )
382
384
self .args = args
383
385
# Seed must be set before instantiating the model when using model
386
+ set_seed (self .args .seed )
384
387
enable_full_determinism (self .args .seed ) if self .args .full_determinism else set_seed (self .args .seed )
385
388
self .hp_name = None
386
389
self .deepspeed = None
@@ -679,6 +682,18 @@ def __init__(
679
682
# Tensor axis is just a placeholder where it will not be used in FSDPv2.
680
683
num_devices = xr .global_runtime_device_count ()
681
684
xs .set_global_mesh (xs .Mesh (np .array (range (num_devices )), (num_devices , 1 ), axis_names = ("fsdp" , "tensor" )))
685
+ if NUM_SLICE == 1 :
686
+ mesh_shape = (num_devices , 1 )
687
+ device_ids = np .array (range (num_devices ))
688
+ # To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on.
689
+ mesh = xs .Mesh (device_ids , mesh_shape , ('fsdp' , 'tensor' ))
690
+ xs .set_global_mesh (mesh )
691
+ else :
692
+ dcn_axis = NUM_SLICE
693
+ ici_mesh_shape = (1 , num_devices // dcn_axis , 1 )
694
+ dcn_mesh_shape = (dcn_axis , 1 , 1 )
695
+ mesh = xs .HybridMesh (ici_mesh_shape = ici_mesh_shape , dcn_mesh_shape = dcn_mesh_shape , axis_names = ('dcn' , 'fsdp' , 'tensor' ))
696
+ xs .set_global_mesh (mesh )
682
697
683
698
def _activate_neftune (self , model ):
684
699
r"""
@@ -877,6 +892,24 @@ def get_train_dataloader(self) -> DataLoader:
877
892
dataloader_params ["worker_init_fn" ] = seed_worker
878
893
dataloader_params ["prefetch_factor" ] = self .args .dataloader_prefetch_factor
879
894
895
+
896
+ if is_torch_xla_available ():
897
+ torch_dataloader = DataLoader (train_dataset , ** dataloader_params )
898
+ device = xm .xla_device ()
899
+ if NUM_SLICE == 1 :
900
+ mp_device_loader = pl .MpDeviceLoader (
901
+ torch_dataloader ,
902
+ device ,
903
+ input_sharding = xs .ShardingSpec (xs .get_global_mesh (), ("fsdp" , None )),
904
+ )
905
+ else :
906
+ mp_device_loader = pl .MpDeviceLoader (
907
+ torch_dataloader ,
908
+ device ,
909
+ input_sharding = xs .ShardingSpec (xs .get_global_mesh (), (("dcn" , "fsdp" ), None )),
910
+ )
911
+ return mp_device_loader
912
+
880
913
return self .accelerator .prepare (DataLoader (train_dataset , ** dataloader_params ))
881
914
882
915
def _get_eval_sampler (self , eval_dataset : Dataset ) -> Optional [torch .utils .data .Sampler ]:
@@ -1681,7 +1714,6 @@ def _wrap_model(self, model, training=True, dataloader=None):
1681
1714
# Transformer layer class to wrap
1682
1715
transformer_layer_cls = transformer_cls_to_wrap ,
1683
1716
)
1684
- fsdp_kwargs = self .args .xla_fsdp_config
1685
1717
if self .args .fsdp_config ["xla_fsdp_grad_ckpt" ]:
1686
1718
if model .config .use_cache :
1687
1719
logger .warning_once (
@@ -1709,7 +1741,11 @@ def shard_output(output, mesh):
1709
1741
1710
1742
if real_output is None :
1711
1743
raise ValueError ("Something went wrong, the output of the model shouldn't be `None`" )
1712
- xs .mark_sharding (real_output , mesh , ("fsdp" , None , None ))
1744
+
1745
+ if NUM_SLICE == 1 :
1746
+ xs .mark_sharding (real_output , mesh , ("fsdp" , None , None ))
1747
+ else :
1748
+ xs .mark_sharding (real_output , mesh , (("dcn" , "fsdp" ), None , None ))
1713
1749
1714
1750
self .model = model = FSDPv2 (
1715
1751
model ,
@@ -1718,10 +1754,12 @@ def shard_output(output, mesh):
1718
1754
auto_wrapper_callable = auto_wrapper_callable ,
1719
1755
)
1720
1756
else :
1757
+ fsdp_kwargs = self .args .xla_fsdp_config
1721
1758
self .model = model = FSDP (
1722
1759
model ,
1723
1760
auto_wrap_policy = auto_wrap_policy ,
1724
1761
auto_wrapper_callable = auto_wrapper_callable ,
1762
+ reshard_after_forward = False ,
1725
1763
** fsdp_kwargs ,
1726
1764
)
1727
1765
@@ -1854,6 +1892,7 @@ def train(
1854
1892
# Disable progress bars when uploading models during checkpoints to avoid polluting stdout
1855
1893
hf_hub_utils .disable_progress_bars ()
1856
1894
return inner_training_loop (
1895
+ batch_size = self ._train_batch_size ,
1857
1896
args = args ,
1858
1897
resume_from_checkpoint = resume_from_checkpoint ,
1859
1898
trial = trial ,
@@ -1863,6 +1902,7 @@ def train(
1863
1902
hf_hub_utils .enable_progress_bars ()
1864
1903
else :
1865
1904
return inner_training_loop (
1905
+ batch_size = self ._train_batch_size ,
1866
1906
args = args ,
1867
1907
resume_from_checkpoint = resume_from_checkpoint ,
1868
1908
trial = trial ,
@@ -1892,8 +1932,8 @@ def _inner_training_loop(
1892
1932
logger .debug (f"Currently training with a batch size of: { self ._train_batch_size } " )
1893
1933
# Data loader and number of training steps
1894
1934
train_dataloader = self .get_train_dataloader ()
1895
- if self .is_fsdp_xla_v2_enabled :
1896
- train_dataloader = tpu_spmd_dataloader (train_dataloader )
1935
+ # if self.is_fsdp_xla_v2_enabled:
1936
+ # train_dataloader = tpu_spmd_dataloader(train_dataloader)
1897
1937
1898
1938
# Setting up training control variables:
1899
1939
# number of training epochs: num_train_epochs
@@ -4454,3 +4494,4 @@ def _fsdp_qlora_plugin_updates(self):
4454
4494
fsdp_plugin .set_mixed_precision (
4455
4495
self .model .hf_quantizer .quantization_config .bnb_4bit_quant_storage , override = True
4456
4496
)
4497
+
0 commit comments