Skip to content

Commit

Permalink
AdagradW (pytorch#3605)
Browse files Browse the repository at this point in the history
Summary:

CounterWeightDecayMode.SQRT

Differential Revision: D67625467
  • Loading branch information
minhua-chen authored and facebook-github-bot committed Jan 31, 2025
1 parent 3266957 commit 240f006
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
12 changes: 10 additions & 2 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,14 +463,15 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
split_precomputation = """
at::acc_type<cache_t, true> freq = 1.0;
at::acc_type<cache_t, true> tail_id_threshold_val = tail_id_threshold;
at::acc_type<cache_t, true> iter_delta = 1.0;
CUDA_KERNEL_ASSERT(max_counter != 0.0); // avoid divide by zero error
if (is_tail_id_thresh_ratio == 1){
tail_id_threshold_val = floorf(tail_id_threshold * max_counter);
}
if (threadIdx.x == 0) {
if (counter_halflife > 0) { // decay based on counter_halflife
// if id occurs multiple times in a batch, iter_delta=1
const auto iter_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx];
iter_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx];
prev_iter[idx] = iter * 1.0;
const auto counter_log_rho = logf(2.0) / counter_halflife;
row_counter[idx] = 1.0 + expf(-iter_delta * counter_log_rho) * row_counter[idx];
Expand All @@ -483,6 +484,7 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
}
freq = SHFL_SYNC(freq, 0);
tail_id_threshold_val = SHFL_SYNC(tail_id_threshold_val, 0);
iter_delta = SHFL_SYNC(iter_delta, 0);
at::acc_type<cache_t, true> g_local_sum_square = 0.0;
at::acc_type<cache_t, true> w_local_sum_square = 0.0;
Expand Down Expand Up @@ -552,7 +554,13 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
exp_reg_correction = 1.0;
if (regularization_mode == 3) { // counter-based regularization (regularization_mode=3)
if (adjustment_enabled) {
if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2)
if (weight_decay_mode == 3) { // Decoupled weight decay with SQRT (weight_decay_mode=3)
iter_delta = min(iter_delta, iter*1.0 - adjustment_iter);
exp_reg_correction = 1.0 - weight_decay * learning_rate / sqrtf(iter*1.0);
tail_id_threshold_val = expf(- weight_decay * learning_rate * 2.0 * (sqrtf(iter*1.0) - sqrtf(iter*1.0 - iter_delta + 1.0)));
adjusted_multiplier *= tail_id_threshold_val; // lazy update
exp_reg_correction *= tail_id_threshold_val;
} else if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2)
exp_reg_correction = 1.0 - freq * weight_decay * learning_rate;
} else if (weight_decay_mode == 1) { // L2 regularization (coupled wd)
exp_reg_correction = 1.0 - freq * weight_decay * multiplier;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class CounterWeightDecayMode(enum.IntEnum):
NONE = 0
L2 = 1
DECOUPLE = 2
SQRT = 3


class StepMode(enum.IntEnum):
Expand Down

0 comments on commit 240f006

Please sign in to comment.