Skip to content

Commit

Permalink
Update on "Fix 1D PP tracer test, add 2D test"
Browse files Browse the repository at this point in the history
forgot to enable tracer for tracer test in the last PR

[ghstack-poisoned]
  • Loading branch information
kwen2501 committed May 29, 2024
2 parents cf70283 + e76e0f7 commit 435b3ca
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
# apply AC + torch.compile
ac_config = job_config.activation_checkpoint
enable_compile = job_config.training.compile
for layer_id, transformer_block in model.layers.items():
for layer_id, transformer_block in model.layers.named_children():
if ac_config.mode in ("full", "selective"):
transformer_block = checkpoint_wrapper(transformer_block, ac_config)
if enable_compile:
Expand All @@ -379,7 +379,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
# compile time.
# torch._dynamo.config.inline_inbuilt_nn_modules = True
transformer_block = torch.compile(transformer_block, dynamic=False)
model.layers[layer_id] = transformer_block
model.layers.register_module(layer_id, transformer_block)

if ac_config.mode in ("full", "selective"):
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
Expand Down

0 comments on commit 435b3ca

Please sign in to comment.