Skip to content

Commit

Permalink
AdagradW (pytorch#3605)
Browse files Browse the repository at this point in the history
Summary:

CounterWeightDecayMode.SQRT

Differential Revision: D67625467
  • Loading branch information
minhua-chen authored and facebook-github-bot committed Feb 4, 2025
1 parent bdcce9c commit 9b39781
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
14 changes: 12 additions & 2 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,11 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
const auto iter_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx];
prev_iter[idx] = iter * 1.0;
const auto counter_log_rho = logf(2.0) / counter_halflife;
row_counter[idx] = 1.0 + expf(-iter_delta * counter_log_rho) * row_counter[idx];
if (regularization_mode == 3 && weight_decay_mode == 3) {
tail_id_threshold_val = iter_delta;
} else {
row_counter[idx] = 1.0 + expf(-iter_delta * counter_log_rho) * row_counter[idx];
}
} else if (counter_halflife == 0) { // count only 1 (appear or not)
row_counter[idx] = 1.0;
} else { // count raw appearance without decaying
Expand Down Expand Up @@ -552,7 +556,13 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
exp_reg_correction = 1.0;
if (regularization_mode == 3) { // counter-based regularization (regularization_mode=3)
if (adjustment_enabled) {
if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2)
if (weight_decay_mode == 3) { // AdagradW (weight_decay_mode=3)
tail_id_threshold_val = min(tail_id_threshold_val, iter*1.0 - adjustment_iter);
exp_reg_correction = 1.0 - weight_decay * learning_rate / sqrtf(iter*1.0);
freq = expf(- weight_decay * learning_rate * 2.0 * (sqrtf(iter*1.0) - sqrtf(iter*1.0 - tail_id_threshold_val + 1.0)));
adjusted_multiplier *= freq; // lazy update
exp_reg_correction *= freq;
} else if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2)
exp_reg_correction = 1.0 - freq * weight_decay * learning_rate;
} else if (weight_decay_mode == 1) { // L2 regularization (coupled wd)
exp_reg_correction = 1.0 - freq * weight_decay * multiplier;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class CounterWeightDecayMode(enum.IntEnum):
NONE = 0
L2 = 1
DECOUPLE = 2
SQRT = 3


class StepMode(enum.IntEnum):
Expand Down

0 comments on commit 9b39781

Please sign in to comment.