2929from torchtitan .config_manager import JobConfig , TORCH_DTYPE_MAP
3030from torchtitan .logging import logger
3131from torchtitan .parallelisms .parallel_dims import ParallelDims
32-
33-
34- # NOTE(lty): experimental for the PT-D 24 research internship project
35- def torch_spmd_parallelize (
36- model : nn .Module ,
37- world_mesh : DeviceMesh ,
38- parallel_dims : ParallelDims ,
39- job_config : JobConfig ,
40- ):
41- torch ._inductor .config .simplefsdp .enable_reorder = True
42- torch ._inductor .config .simplefsdp .enable_bucket = True
43-
44- if parallel_dims .tp_enabled :
45- apply_tp (
46- model ,
47- world_mesh ["tp" ],
48- loss_parallel = parallel_dims .loss_parallel_enabled ,
49- enable_float8 = job_config .float8 .enable_float8_linear ,
50- enable_async_tp = job_config .experimental .enable_async_tensor_parallel ,
51- )
52-
53- ac_config = job_config .activation_checkpoint
54- if ac_config .mode != "none" :
55- apply_ac (model , ac_config )
56- logger .info (f"Applied { ac_config .mode } activation checkpointing to the model" )
57-
58- if parallel_dims .dp_enabled :
59- from torch_spmd .data_parallel import data_parallel , MixedPrecisionPolicy
60-
61- mp_policy = MixedPrecisionPolicy (
62- param_dtype = TORCH_DTYPE_MAP [job_config .training .mixed_precision_param ],
63- reduce_dtype = TORCH_DTYPE_MAP [job_config .training .mixed_precision_reduce ],
64- )
65- dp_mesh = world_mesh ["dp" ] if world_mesh .ndim > 1 else world_mesh
66-
67- model = data_parallel (
68- model ,
69- dp_mesh ,
70- mode = "fully_shard" ,
71- ac_mode = ac_config .mode ,
72- mp_policy = mp_policy ,
73- )
74- logger .info ("Applied Simple FSDP to the model" )
75-
76- if job_config .training .compile :
77- model = torch .compile (model , fullgraph = True )
78- logger .info ("Compiling with torch.compile" )
79-
80- return model
32+ from torchtitan .parallelisms .utils import check_strided_sharding_enabled
8133
8234
8335def parallelize_llama (
@@ -93,9 +45,6 @@ def parallelize_llama(
9345 NOTE: The passed-in model preferably should be on meta device. Otherwise,
9446 the model must fit on GPU or CPU memory.
9547 """
96- # NOTE(lty): experimental for the PT-D 24 research internship project
97- if job_config .experimental .torch_spmd :
98- return torch_spmd_parallelize (model , world_mesh , parallel_dims , job_config )
9948
10049 if parallel_dims .tp_enabled :
10150 if (
@@ -351,12 +300,11 @@ def apply_fsdp(
351300 mp_policy = MixedPrecisionPolicy (param_dtype = param_dtype , reduce_dtype = reduce_dtype )
352301 fsdp_config = {"mesh" : dp_mesh , "mp_policy" : mp_policy }
353302
354- # TODO(lty): the check below requires the latest PyTorch nightly; remove for now
355303 # TODO: remove this check once PyTorch 2.5 is released. We can safely assume
356304 # that users won't use a nightly build which is older than 20240809 by then.
357- # if tp_enabled:
358- # # check if strided sharding is enabled, which is necessary for 2D/3D DCP
359- # check_strided_sharding_enabled()
305+ if tp_enabled :
306+ # check if strided sharding is enabled, which is necessary for 2D/3D DCP
307+ check_strided_sharding_enabled ()
360308
361309 for layer_id , transformer_block in model .layers .items ():
362310 if pp_enabled :
0 commit comments