diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index 4a83564ce7..a462b094f9 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -56,7 +56,8 @@ def _get_param_groups( Creates parameter groups based on weight decay condition (regularized vs non regularized), learning rate scale condition (lr vs lr_mult * lr), and whether it is expert parameters. scale_lr_cond is used during finetuning - where head of the network requires a scaled version of the base learning rate. + where head of the network can have a scaled version of the base learning rate or + during pre-training where down-projection layer (linear_fc2) can have a lower learning rate. Args: model_chunks (List[MegatronModule]): model chunks to create parameter diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 8b374ca4be..3b660deb01 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -59,7 +59,7 @@ class TransformerConfig(ModelParallelConfig): # @jcasper should we keep this option? apply_residual_connection_post_layernorm: bool = False - """If True, uses the original BERT residule connection ordering.""" + """If True, uses the original BERT residual connection ordering.""" layernorm_epsilon: float = 1e-5 """Epsilon value for any LayerNorm operations.""" diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 64c92ea3cd..e79b5d3f2d 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1369,6 +1369,12 @@ def _add_learning_rate_args(parser): group.add_argument('--decoupled-min-lr', type=float, default=None, help='Minimum value for learning rate for the input and output layer. The scheduler' 'clip values below this threshold') + group.add_argument('--scale-lr-layer', type=str, default=None, + help='Scale learning rate for the specified layer.' + 'E.g. --scale-lr-layer "linear_fc2" to scale lr for down-proj layer (during pretraining or finetuning).' + 'Or, --scale-lr-layer "head" to scale lr for lm-head (during pretraining or finetuning).') + group.add_argument('--lr-multiplier', type=float, default=1.0, + help='Learning rate multiplier for the specified layer in scale-lr-layer.') return parser @@ -1821,8 +1827,6 @@ def _add_vision_args(parser): group.add_argument('--no-data-sharding', action='store_false', help='Disable data sharding.', dest='data_sharding') - group.add_argument('--head-lr-mult', type=float, default=1.0, - help='learning rate multiplier for head during finetuning') # pretraining type and backbone selection` group.add_argument('--vision-pretraining', action='store_true', diff --git a/megatron/training/training.py b/megatron/training/training.py index d5ee16be5f..349e99b76f 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -299,7 +299,10 @@ def pretrain( timers('model-and-optimizer-setup', log_level=0).start(barrier=True) app_metrics['app_build_optimizer_start_time'] = one_logger_utils.get_timestamp_in_ms() model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - model_provider, model_type, checkpointing_context=checkpointing_context) + model_provider, model_type, + scale_lr_cond=(lambda name, param: args.scale_lr_layer in name) if args.scale_lr_layer else None, + lr_mult=args.lr_multiplier, + checkpointing_context=checkpointing_context) timers('model-and-optimizer-setup').stop() print_datetime('after model, optimizer, and learning rate ' diff --git a/tasks/vision/finetune_utils.py b/tasks/vision/finetune_utils.py index ced2e674e6..2f55ebf4a4 100644 --- a/tasks/vision/finetune_utils.py +++ b/tasks/vision/finetune_utils.py @@ -224,8 +224,8 @@ def finetune( setup_model_and_optimizer( model_provider, model_type, - scale_lr_cond=lambda name, param: ".head." in name, - lr_mult=args.head_lr_mult) + scale_lr_cond=(lambda name, param: args.scale_lr_layer in name) if args.scale_lr_layer else None, + lr_mult=args.lr_multiplier) timers("model and optimizer").stop() # If pretrained checkpoint is provided and we have not trained for