@@ -157,7 +157,7 @@ def build_test_list():
157157 "--experimental.pipeline_parallel_degree 2" ,
158158 "--experimental.pipeline_parallel_split_points layers.4" ,
159159 "--experimental.pipeline_parallel_schedule 1f1b" ,
160- "--training.data_parallel_degree 1" ,
160+ "--training.data_parallel_shard_degree 1" ,
161161 ],
162162 ],
163163 "PP 1D test 1f1b" ,
@@ -172,7 +172,7 @@ def build_test_list():
172172 "--experimental.pipeline_parallel_degree 2" ,
173173 "--experimental.pipeline_parallel_split_points layers.4" ,
174174 "--experimental.pipeline_parallel_schedule gpipe" ,
175- "--training.data_parallel_degree 1" ,
175+ "--training.data_parallel_shard_degree 1" ,
176176 ],
177177 ],
178178 "PP 1D test gpipe" ,
@@ -187,7 +187,7 @@ def build_test_list():
187187 "--experimental.pipeline_parallel_degree 2" ,
188188 "--experimental.pipeline_parallel_split_points layers.4" ,
189189 "--experimental.pipeline_parallel_schedule 1f1b" ,
190- "--training.data_parallel_degree 2" ,
190+ "--training.data_parallel_shard_degree 2" ,
191191 ],
192192 ],
193193 "PP+DP 1f1b 2D test" ,
@@ -201,7 +201,7 @@ def build_test_list():
201201 "--experimental.pipeline_parallel_degree 2" ,
202202 "--experimental.pipeline_parallel_split_points layers.4" ,
203203 "--experimental.pipeline_parallel_schedule gpipe" ,
204- "--training.data_parallel_degree 2" ,
204+ "--training.data_parallel_shard_degree 2" ,
205205 ],
206206 ],
207207 "PP+DP gpipe 2D test" ,
@@ -227,15 +227,15 @@ def build_test_list():
227227 "--checkpoint.enable_checkpoint" ,
228228 "--experimental.pipeline_parallel_degree 2" ,
229229 "--experimental.pipeline_parallel_split_points layers.4" ,
230- "--training.data_parallel_degree 2" ,
230+ "--training.data_parallel_shard_degree 2" ,
231231 "--training.tensor_parallel_degree 2" ,
232232 ],
233233 [
234234 "--training.steps 20" ,
235235 "--checkpoint.enable_checkpoint" ,
236236 "--experimental.pipeline_parallel_degree 2" ,
237237 "--experimental.pipeline_parallel_split_points layers.4" ,
238- "--training.data_parallel_degree 2" ,
238+ "--training.data_parallel_shard_degree 2" ,
239239 "--training.tensor_parallel_degree 2" ,
240240 ],
241241 ],
@@ -249,7 +249,7 @@ def build_test_list():
249249 [
250250 "--experimental.pipeline_parallel_degree 2" ,
251251 "--experimental.pipeline_parallel_split_points layers.4" ,
252- "--training.data_parallel_degree 2" ,
252+ "--training.data_parallel_shard_degree 2" ,
253253 "--training.tensor_parallel_degree 2" ,
254254 "--training.compile" ,
255255 ],
@@ -285,13 +285,37 @@ def build_test_list():
285285 OverrideDefinitions (
286286 [
287287 [
288- "--training.data_parallel_type ddp" ,
288+ "--training.data_parallel_shard_degree=1" ,
289+ "--training.data_parallel_replicate_degree=4" ,
289290 ]
290291 ],
291292 "DDP" ,
292293 "ddp" ,
293294 ngpu = 4 ,
294295 ),
296+ OverrideDefinitions (
297+ [
298+ [
299+ "--training.data_parallel_shard_degree=2" ,
300+ "--training.data_parallel_replicate_degree=2" ,
301+ ]
302+ ],
303+ "HSDP" ,
304+ "hsdp" ,
305+ ngpu = 4 ,
306+ ),
307+ OverrideDefinitions (
308+ [
309+ [
310+ "--training.data_parallel_shard_degree=2" ,
311+ "--training.data_parallel_replicate_degree=2" ,
312+ "--training.tensor_parallel_degree=2" ,
313+ ]
314+ ],
315+ "HSDP+TP" ,
316+ "hsdp+tp" ,
317+ ngpu = 8 ,
318+ ),
295319 OverrideDefinitions (
296320 [
297321 [
0 commit comments