|
| 1 | +# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. |
| 2 | + |
| 3 | +# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer |
| 4 | +# and/or |
| 5 | +# https://github.com/lessw2020/Best-Deep-Learning-Optimizers |
| 6 | + |
| 7 | +# Ranger has been used to capture 12 records on the FastAI leaderboard. |
| 8 | + |
| 9 | +# This version = 2020.9.4 |
| 10 | + |
| 11 | + |
| 12 | +# Credits: |
| 13 | +# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization |
| 14 | +# RAdam --> https://github.com/LiyuanLucasLiu/RAdam |
| 15 | +# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. |
| 16 | +# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 |
| 17 | + |
| 18 | +# summary of changes: |
| 19 | +# 9/4/20 - updated addcmul_ signature to avoid warning. Integrates latest changes from GC developer (he did the work for this), and verified on performance on private dataset. |
| 20 | +# 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. |
| 21 | +# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), |
| 22 | +# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. |
| 23 | +# changes 8/31/19 - fix references to *self*.N_sma_threshold; |
| 24 | +# changed eps to 1e-5 as better default than 1e-8. |
| 25 | + |
| 26 | +import math |
| 27 | +import torch |
| 28 | +from torch.optim.optimizer import Optimizer, required |
| 29 | + |
| 30 | + |
| 31 | +def centralized_gradient(x, use_gc=True, gc_conv_only=False): |
| 32 | + '''credit - https://github.com/Yonghongwei/Gradient-Centralization ''' |
| 33 | + if use_gc: |
| 34 | + if gc_conv_only: |
| 35 | + if len(list(x.size())) > 3: |
| 36 | + x.add_(-x.mean(dim=tuple(range(1, len(list(x.size())))), keepdim=True)) |
| 37 | + else: |
| 38 | + if len(list(x.size())) > 1: |
| 39 | + x.add_(-x.mean(dim=tuple(range(1, len(list(x.size())))), keepdim=True)) |
| 40 | + return x |
| 41 | + |
| 42 | + |
| 43 | +class Ranger(Optimizer): |
| 44 | + |
| 45 | + def __init__(self, params, lr=1e-3, # lr |
| 46 | + alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options |
| 47 | + betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options |
| 48 | + # Gradient centralization on or off, applied to conv layers only or conv + fc layers |
| 49 | + use_gc=True, gc_conv_only=False, gc_loc=True |
| 50 | + ): |
| 51 | + |
| 52 | + # parameter checks |
| 53 | + if not 0.0 <= alpha <= 1.0: |
| 54 | + raise ValueError(f'Invalid slow update rate: {alpha}') |
| 55 | + if not 1 <= k: |
| 56 | + raise ValueError(f'Invalid lookahead steps: {k}') |
| 57 | + if not lr > 0: |
| 58 | + raise ValueError(f'Invalid Learning Rate: {lr}') |
| 59 | + if not eps > 0: |
| 60 | + raise ValueError(f'Invalid eps: {eps}') |
| 61 | + |
| 62 | + # parameter comments: |
| 63 | + # beta1 (momentum) of .95 seems to work better than .90... |
| 64 | + # N_sma_threshold of 5 seems better in testing than 4. |
| 65 | + # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. |
| 66 | + |
| 67 | + # prep defaults and init torch.optim base |
| 68 | + defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, |
| 69 | + N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay) |
| 70 | + super().__init__(params, defaults) |
| 71 | + |
| 72 | + # adjustable threshold |
| 73 | + self.N_sma_threshhold = N_sma_threshhold |
| 74 | + |
| 75 | + # look ahead params |
| 76 | + |
| 77 | + self.alpha = alpha |
| 78 | + self.k = k |
| 79 | + |
| 80 | + # radam buffer for state |
| 81 | + self.radam_buffer = [[None, None, None] for ind in range(10)] |
| 82 | + |
| 83 | + # gc on or off |
| 84 | + self.gc_loc = gc_loc |
| 85 | + self.use_gc = use_gc |
| 86 | + self.gc_conv_only = gc_conv_only |
| 87 | + # level of gradient centralization |
| 88 | + #self.gc_gradient_threshold = 3 if gc_conv_only else 1 |
| 89 | + |
| 90 | + print( |
| 91 | + f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}") |
| 92 | + if (self.use_gc and self.gc_conv_only == False): |
| 93 | + print(f"GC applied to both conv and fc layers") |
| 94 | + elif (self.use_gc and self.gc_conv_only == True): |
| 95 | + print(f"GC applied to conv layers only") |
| 96 | + |
| 97 | + def __setstate__(self, state): |
| 98 | + print("set state called") |
| 99 | + super(Ranger, self).__setstate__(state) |
| 100 | + |
| 101 | + def step(self, closure=None): |
| 102 | + loss = None |
| 103 | + # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure. |
| 104 | + # Uncomment if you need to use the actual closure... |
| 105 | + |
| 106 | + # if closure is not None: |
| 107 | + #loss = closure() |
| 108 | + |
| 109 | + # Evaluate averages and grad, update param tensors |
| 110 | + for group in self.param_groups: |
| 111 | + |
| 112 | + for p in group['params']: |
| 113 | + if p.grad is None: |
| 114 | + continue |
| 115 | + grad = p.grad.data.float() |
| 116 | + |
| 117 | + if grad.is_sparse: |
| 118 | + raise RuntimeError( |
| 119 | + 'Ranger optimizer does not support sparse gradients') |
| 120 | + |
| 121 | + p_data_fp32 = p.data.float() |
| 122 | + |
| 123 | + state = self.state[p] # get state dict for this param |
| 124 | + |
| 125 | + if len(state) == 0: # if first time to run...init dictionary with our desired entries |
| 126 | + # if self.first_run_check==0: |
| 127 | + # self.first_run_check=1 |
| 128 | + #print("Initializing slow buffer...should not see this at load from saved model!") |
| 129 | + state['step'] = 0 |
| 130 | + state['exp_avg'] = torch.zeros_like(p_data_fp32) |
| 131 | + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) |
| 132 | + |
| 133 | + # look ahead weight storage now in state dict |
| 134 | + state['slow_buffer'] = torch.empty_like(p.data) |
| 135 | + state['slow_buffer'].copy_(p.data) |
| 136 | + |
| 137 | + else: |
| 138 | + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) |
| 139 | + state['exp_avg_sq'] = state['exp_avg_sq'].type_as( |
| 140 | + p_data_fp32) |
| 141 | + |
| 142 | + # begin computations |
| 143 | + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] |
| 144 | + beta1, beta2 = group['betas'] |
| 145 | + |
| 146 | + # GC operation for Conv layers and FC layers |
| 147 | + # if grad.dim() > self.gc_gradient_threshold: |
| 148 | + # grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) |
| 149 | + if self.gc_loc: |
| 150 | + grad = centralized_gradient(grad, use_gc=self.use_gc, gc_conv_only=self.gc_conv_only) |
| 151 | + |
| 152 | + state['step'] += 1 |
| 153 | + |
| 154 | + # compute variance mov avg |
| 155 | + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) |
| 156 | + |
| 157 | + # compute mean moving avg |
| 158 | + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) |
| 159 | + |
| 160 | + buffered = self.radam_buffer[int(state['step'] % 10)] |
| 161 | + |
| 162 | + if state['step'] == buffered[0]: |
| 163 | + N_sma, step_size = buffered[1], buffered[2] |
| 164 | + else: |
| 165 | + buffered[0] = state['step'] |
| 166 | + beta2_t = beta2 ** state['step'] |
| 167 | + N_sma_max = 2 / (1 - beta2) - 1 |
| 168 | + N_sma = N_sma_max - 2 * \ |
| 169 | + state['step'] * beta2_t / (1 - beta2_t) |
| 170 | + buffered[1] = N_sma |
| 171 | + if N_sma > self.N_sma_threshhold: |
| 172 | + step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * ( |
| 173 | + N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) |
| 174 | + else: |
| 175 | + step_size = 1.0 / (1 - beta1 ** state['step']) |
| 176 | + buffered[2] = step_size |
| 177 | + |
| 178 | + # if group['weight_decay'] != 0: |
| 179 | + # p_data_fp32.add_(-group['weight_decay'] |
| 180 | + # * group['lr'], p_data_fp32) |
| 181 | + |
| 182 | + # apply lr |
| 183 | + if N_sma > self.N_sma_threshhold: |
| 184 | + denom = exp_avg_sq.sqrt().add_(group['eps']) |
| 185 | + G_grad = exp_avg / denom |
| 186 | + else: |
| 187 | + G_grad = exp_avg |
| 188 | + |
| 189 | + if group['weight_decay'] != 0: |
| 190 | + G_grad.add_(p_data_fp32, alpha=group['weight_decay']) |
| 191 | + # GC operation |
| 192 | + if self.gc_loc == False: |
| 193 | + G_grad = centralized_gradient(G_grad, use_gc=self.use_gc, gc_conv_only=self.gc_conv_only) |
| 194 | + |
| 195 | + p_data_fp32.add_(G_grad, alpha=-step_size * group['lr']) |
| 196 | + p.data.copy_(p_data_fp32) |
| 197 | + |
| 198 | + # integrated look ahead... |
| 199 | + # we do it at the param level instead of group level |
| 200 | + if state['step'] % group['k'] == 0: |
| 201 | + # get access to slow param tensor |
| 202 | + slow_p = state['slow_buffer'] |
| 203 | + # (fast weights - slow weights) * alpha |
| 204 | + slow_p.add_(p.data - slow_p, alpha=self.alpha) |
| 205 | + # copy interpolated weights to RAdam param tensor |
| 206 | + p.data.copy_(slow_p) |
| 207 | + |
| 208 | + return loss |
0 commit comments