diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index ff570b6e51..eeef2aaadd 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -463,15 +463,16 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]: split_precomputation = """ at::acc_type freq = 1.0; at::acc_type tail_id_threshold_val = tail_id_threshold; + at::acc_type iter_delta = 1.0; CUDA_KERNEL_ASSERT(max_counter != 0.0); // avoid divide by zero error if (is_tail_id_thresh_ratio == 1){ tail_id_threshold_val = floorf(tail_id_threshold * max_counter); } if (threadIdx.x == 0) { + iter_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx]; + prev_iter[idx] = iter * 1.0; if (counter_halflife > 0) { // decay based on counter_halflife // if id occurs multiple times in a batch, iter_delta=1 - const auto iter_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx]; - prev_iter[idx] = iter * 1.0; const auto counter_log_rho = logf(2.0) / counter_halflife; row_counter[idx] = 1.0 + expf(-iter_delta * counter_log_rho) * row_counter[idx]; } else if (counter_halflife == 0) { // count only 1 (appear or not) @@ -483,6 +484,7 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]: } freq = SHFL_SYNC(freq, 0); tail_id_threshold_val = SHFL_SYNC(tail_id_threshold_val, 0); + iter_delta = SHFL_SYNC(iter_delta, 0); at::acc_type g_local_sum_square = 0.0; at::acc_type w_local_sum_square = 0.0; @@ -552,7 +554,13 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]: exp_reg_correction = 1.0; if (regularization_mode == 3) { // counter-based regularization (regularization_mode=3) if (adjustment_enabled) { - if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2) + if (weight_decay_mode == 3) { // Decoupled weight decay with rate SQRT(iter) (weight_decay_mode=3) + iter_delta = min(iter_delta, iter*1.0 - adjustment_iter); + exp_reg_correction = 1.0 - weight_decay * learning_rate / sqrtf(iter*1.0); + tail_id_threshold_val = expf(- weight_decay * learning_rate * 2.0 * (sqrtf(iter*1.0) - sqrtf(iter*1.0 - iter_delta + 1.0))); + adjusted_multiplier *= tail_id_threshold_val; // lazy update + exp_reg_correction *= tail_id_threshold_val; // lazy update + } else if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2) exp_reg_correction = 1.0 - freq * weight_decay * learning_rate; } else if (weight_decay_mode == 1) { // L2 regularization (coupled wd) exp_reg_correction = 1.0 - freq * weight_decay * multiplier; diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 87d9437be7..8f3c98a353 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -93,6 +93,7 @@ class CounterWeightDecayMode(enum.IntEnum): NONE = 0 L2 = 1 DECOUPLE = 2 + SQRT = 3 class StepMode(enum.IntEnum):