Skip to content

Commit da230f2

Browse files
committed
make gradient clipping configurable on restart
1 parent 6feb7d3 commit da230f2

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

megatron/checkpointing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,11 @@ def load_checkpoint(
402402
load_module_strict=neox_args.train_impl != "rm",
403403
)
404404

405+
# respect values passed in the config ....
406+
print_rank_0(f"Overwriting ckpt grad clip of {model.optimizer.clip_grad} to config value of {neox_args.gradient_clipping}")
407+
model.optimizer.clip_grad = neox_args.gradient_clipping
408+
print_rank_0(f"Value successfully set to {model.optimizer.clip_grad}")
409+
405410
if checkpoint_name is None:
406411
# if an iteration is specified, we want to raise an error here rather than
407412
# continuing silently, since we are trying to load a specific checkpoint

0 commit comments

Comments
 (0)