|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +import torch |
| 7 | +import torch.optim |
| 8 | +import math |
| 9 | + |
| 10 | +class AdamWScheduleFree(torch.optim.Optimizer): |
| 11 | + r""" |
| 12 | + Schedule-Free AdamW |
| 13 | + As the name suggests, no scheduler is needed with this optimizer. |
| 14 | + To add warmup, rather than using a learning rate schedule you can just |
| 15 | + set the warmup_steps parameter. |
| 16 | +
|
| 17 | + This optimizer requires that .train() and .val() be called before the |
| 18 | + beginning of training and evaluation respectively. |
| 19 | +
|
| 20 | + Arguments: |
| 21 | + params (iterable): |
| 22 | + Iterable of parameters to optimize or dicts defining |
| 23 | + parameter groups. |
| 24 | + lr (float): |
| 25 | + Learning rate parameter (default 0.0025) |
| 26 | + betas (Tuple[float, float], optional): coefficients used for computing |
| 27 | + running averages of gradient and its square (default: (0.9, 0.999)). |
| 28 | + eps (float): |
| 29 | + Term added to the denominator outside of the root operation to |
| 30 | + improve numerical stability. (default: 1e-8). |
| 31 | + weight_decay (float): |
| 32 | + Weight decay, i.e. a L2 penalty (default: 0). |
| 33 | + warmup_steps (int): Enables a linear learning rate warmup (default 0). |
| 34 | + r (float): Use polynomial weighting in the average |
| 35 | + with power r (default 0). |
| 36 | + weight_lr_power (float): During warmup, the weights in the average will |
| 37 | + be equal to lr raised to this power. Set to 0 for no weighting |
| 38 | + (default 2.0). |
| 39 | + """ |
| 40 | + def __init__(self, |
| 41 | + params, |
| 42 | + lr=0.0025, |
| 43 | + betas=(0.9, 0.999), |
| 44 | + eps=1e-8, |
| 45 | + weight_decay=0, |
| 46 | + warmup_steps=0, |
| 47 | + r=0.0, |
| 48 | + weight_lr_power=2.0, |
| 49 | + ): |
| 50 | + |
| 51 | + defaults = dict(lr=lr, |
| 52 | + betas=betas, |
| 53 | + eps=eps, |
| 54 | + r=r, |
| 55 | + k=0, |
| 56 | + warmup_steps=warmup_steps, |
| 57 | + train_mode = True, |
| 58 | + weight_sum=0.0, |
| 59 | + lr_max=-1.0, |
| 60 | + weight_lr_power=weight_lr_power, |
| 61 | + weight_decay=weight_decay) |
| 62 | + super().__init__(params, defaults) |
| 63 | + |
| 64 | + def eval(self): |
| 65 | + for group in self.param_groups: |
| 66 | + train_mode = group['train_mode'] |
| 67 | + beta1, _ = group['betas'] |
| 68 | + if train_mode: |
| 69 | + for p in group['params']: |
| 70 | + state = self.state[p] |
| 71 | + if 'z' in state: |
| 72 | + # Set p.data to x |
| 73 | + p.data.lerp_(end=state['z'], weight=1-1/beta1) |
| 74 | + group['train_mode'] = False |
| 75 | + |
| 76 | + def train(self): |
| 77 | + for group in self.param_groups: |
| 78 | + train_mode = group['train_mode'] |
| 79 | + beta1, _ = group['betas'] |
| 80 | + if not train_mode: |
| 81 | + for p in group['params']: |
| 82 | + state = self.state[p] |
| 83 | + if 'z' in state: |
| 84 | + # Set p.data to y |
| 85 | + p.data.lerp_(end=state['z'], weight=1-beta1) |
| 86 | + group['train_mode'] = True |
| 87 | + |
| 88 | + def step(self, closure=None): |
| 89 | + """Performs a single optimization step. |
| 90 | +
|
| 91 | + Arguments: |
| 92 | + closure (callable, optional): A closure that reevaluates the model |
| 93 | + and returns the loss. |
| 94 | + """ |
| 95 | + |
| 96 | + loss = None |
| 97 | + if closure is not None: |
| 98 | + loss = closure() |
| 99 | + |
| 100 | + for group in self.param_groups: |
| 101 | + eps = group['eps'] |
| 102 | + beta1, beta2 = group['betas'] |
| 103 | + decay = group['weight_decay'] |
| 104 | + k = group['k'] |
| 105 | + r = group['r'] |
| 106 | + warmup_steps = group['warmup_steps'] |
| 107 | + weight_lr_power = group['weight_lr_power'] |
| 108 | + |
| 109 | + if k < warmup_steps: |
| 110 | + sched = (k+1) / warmup_steps |
| 111 | + else: |
| 112 | + sched = 1.0 |
| 113 | + |
| 114 | + bias_correction2 = 1 - beta2 ** (k+1) |
| 115 | + lr = group['lr']*sched*math.sqrt(bias_correction2) |
| 116 | + |
| 117 | + lr_max = group['lr_max'] = max(lr, group['lr_max']) |
| 118 | + |
| 119 | + weight = ((k+1)**r) * (lr_max**weight_lr_power) |
| 120 | + weight_sum = group['weight_sum'] = group['weight_sum'] + weight |
| 121 | + |
| 122 | + ckp1 = weight/weight_sum |
| 123 | + |
| 124 | + if not group['train_mode']: |
| 125 | + raise Exception("Not in train mode!") |
| 126 | + |
| 127 | + for p in group['params']: |
| 128 | + if p.grad is None: |
| 129 | + continue |
| 130 | + |
| 131 | + y = p.data # Notation to match theory |
| 132 | + grad = p.grad.data |
| 133 | + |
| 134 | + state = self.state[p] |
| 135 | + |
| 136 | + if 'z' not in state: |
| 137 | + state['z'] = torch.clone(y) |
| 138 | + state['exp_avg_sq'] = torch.zeros_like(p.data) |
| 139 | + |
| 140 | + z = state['z'] |
| 141 | + exp_avg_sq = state['exp_avg_sq'] |
| 142 | + |
| 143 | + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) |
| 144 | + denom = exp_avg_sq.sqrt().add_(eps) |
| 145 | + |
| 146 | + # Reuse grad buffer for memory efficiency |
| 147 | + grad_normalized = grad.div_(denom) |
| 148 | + |
| 149 | + # Weight decay calculated at y |
| 150 | + if decay != 0: |
| 151 | + grad_normalized.add_(y, alpha=decay) |
| 152 | + |
| 153 | + # These operations update y in-place, |
| 154 | + # without computing x explicitly. |
| 155 | + y.lerp_(end=z, weight=ckp1) |
| 156 | + y.add_(grad_normalized, alpha=lr*(beta1*(1-ckp1)-1)) |
| 157 | + |
| 158 | + # z step |
| 159 | + z.sub_(grad_normalized, alpha=lr) |
| 160 | + |
| 161 | + group['k'] = k+1 |
| 162 | + return loss |
0 commit comments