Skip to content

Commit 1923ce4

Browse files
authored
Revert "merge upstream changes" (#570)
Reverts #569 sorry accidental commit things to the wrong fork
1 parent a09cde3 commit 1923ce4

File tree

11 files changed

+21
-392
lines changed

11 files changed

+21
-392
lines changed

.gitignore

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,3 @@ torchtitan/datasets/**/*.model
1717
*.log
1818
error.json
1919
_remote_module_non_scriptable.py
20-
21-
# torch compile debug related
22-
torch_compile_debug/*

benchmark.py

Lines changed: 0 additions & 232 deletions
This file was deleted.

run_benchmark_train.sh

Lines changed: 0 additions & 24 deletions
This file was deleted.

run_llama_train.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ if [ $# -ne 0 ]; then
1919
overrides="$*"
2020
fi
2121

22-
# TORCH_TRACE="./outputs/trace" \
23-
TORCH_NCCL_AVOID_RECORD_STREAMS=1 \
2422
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
2523
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
2624
train.py --job.config_file ${CONFIG_FILE} $overrides

torchtitan/config_manager.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,14 +241,6 @@ def __init__(self):
241241
action="store_true",
242242
help="Whether to apply loss parallel when sequence parallel is enabled",
243243
)
244-
245-
# experimental configs
246-
self.parser.add_argument(
247-
"--experimental.torch_spmd",
248-
default=False,
249-
action="store_true",
250-
help="Whether to use the experimental torch_spmd style parallelism",
251-
)
252244
self.parser.add_argument(
253245
"--experimental.enable_async_tensor_parallel",
254246
default=False,

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 4 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -29,55 +29,7 @@
2929
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3030
from torchtitan.logging import logger
3131
from torchtitan.parallelisms.parallel_dims import ParallelDims
32-
33-
34-
# NOTE(lty): experimental for the PT-D 24 research internship project
35-
def torch_spmd_parallelize(
36-
model: nn.Module,
37-
world_mesh: DeviceMesh,
38-
parallel_dims: ParallelDims,
39-
job_config: JobConfig,
40-
):
41-
torch._inductor.config.simplefsdp.enable_reorder = True
42-
torch._inductor.config.simplefsdp.enable_bucket = True
43-
44-
if parallel_dims.tp_enabled:
45-
apply_tp(
46-
model,
47-
world_mesh["tp"],
48-
loss_parallel=parallel_dims.loss_parallel_enabled,
49-
enable_float8=job_config.float8.enable_float8_linear,
50-
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
51-
)
52-
53-
ac_config = job_config.activation_checkpoint
54-
if ac_config.mode != "none":
55-
apply_ac(model, ac_config)
56-
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
57-
58-
if parallel_dims.dp_enabled:
59-
from torch_spmd.data_parallel import data_parallel, MixedPrecisionPolicy
60-
61-
mp_policy = MixedPrecisionPolicy(
62-
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
63-
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
64-
)
65-
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
66-
67-
model = data_parallel(
68-
model,
69-
dp_mesh,
70-
mode="fully_shard",
71-
ac_mode=ac_config.mode,
72-
mp_policy=mp_policy,
73-
)
74-
logger.info("Applied Simple FSDP to the model")
75-
76-
if job_config.training.compile:
77-
model = torch.compile(model, fullgraph=True)
78-
logger.info("Compiling with torch.compile")
79-
80-
return model
32+
from torchtitan.parallelisms.utils import check_strided_sharding_enabled
8133

8234

8335
def parallelize_llama(
@@ -93,9 +45,6 @@ def parallelize_llama(
9345
NOTE: The passed-in model preferably should be on meta device. Otherwise,
9446
the model must fit on GPU or CPU memory.
9547
"""
96-
# NOTE(lty): experimental for the PT-D 24 research internship project
97-
if job_config.experimental.torch_spmd:
98-
return torch_spmd_parallelize(model, world_mesh, parallel_dims, job_config)
9948

10049
if parallel_dims.tp_enabled:
10150
if (
@@ -351,12 +300,11 @@ def apply_fsdp(
351300
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
352301
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
353302

354-
# TODO(lty): the check below requires the latest PyTorch nightly; remove for now
355303
# TODO: remove this check once PyTorch 2.5 is released. We can safely assume
356304
# that users won't use a nightly build which is older than 20240809 by then.
357-
# if tp_enabled:
358-
# # check if strided sharding is enabled, which is necessary for 2D/3D DCP
359-
# check_strided_sharding_enabled()
305+
if tp_enabled:
306+
# check if strided sharding is enabled, which is necessary for 2D/3D DCP
307+
check_strided_sharding_enabled()
360308

361309
for layer_id, transformer_block in model.layers.items():
362310
if pp_enabled:

0 commit comments

Comments
 (0)