Skip to content

Commit 02d0540

Browse files
authored
adds gradient_centralization2, fixes addcmul_ warning
1 parent 81d386b commit 02d0540

File tree

1 file changed

+208
-0
lines changed

1 file changed

+208
-0
lines changed

ranger/ranger2020.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.
2+
3+
# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
4+
# and/or
5+
# https://github.com/lessw2020/Best-Deep-Learning-Optimizers
6+
7+
# Ranger has been used to capture 12 records on the FastAI leaderboard.
8+
9+
# This version = 2020.9.4
10+
11+
12+
# Credits:
13+
# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization
14+
# RAdam --> https://github.com/LiyuanLucasLiu/RAdam
15+
# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
16+
# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
17+
18+
# summary of changes:
19+
# 9/4/20 - updated addcmul_ signature to avoid warning. Integrates latest changes from GC developer (he did the work for this), and verified on performance on private dataset.
20+
# 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
21+
# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
22+
# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
23+
# changes 8/31/19 - fix references to *self*.N_sma_threshold;
24+
# changed eps to 1e-5 as better default than 1e-8.
25+
26+
import math
27+
import torch
28+
from torch.optim.optimizer import Optimizer, required
29+
30+
31+
def centralized_gradient(x, use_gc=True, gc_conv_only=False):
32+
'''credit - https://github.com/Yonghongwei/Gradient-Centralization '''
33+
if use_gc:
34+
if gc_conv_only:
35+
if len(list(x.size())) > 3:
36+
x.add_(-x.mean(dim=tuple(range(1, len(list(x.size())))), keepdim=True))
37+
else:
38+
if len(list(x.size())) > 1:
39+
x.add_(-x.mean(dim=tuple(range(1, len(list(x.size())))), keepdim=True))
40+
return x
41+
42+
43+
class Ranger(Optimizer):
44+
45+
def __init__(self, params, lr=1e-3, # lr
46+
alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options
47+
betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options
48+
# Gradient centralization on or off, applied to conv layers only or conv + fc layers
49+
use_gc=True, gc_conv_only=False, gc_loc=True
50+
):
51+
52+
# parameter checks
53+
if not 0.0 <= alpha <= 1.0:
54+
raise ValueError(f'Invalid slow update rate: {alpha}')
55+
if not 1 <= k:
56+
raise ValueError(f'Invalid lookahead steps: {k}')
57+
if not lr > 0:
58+
raise ValueError(f'Invalid Learning Rate: {lr}')
59+
if not eps > 0:
60+
raise ValueError(f'Invalid eps: {eps}')
61+
62+
# parameter comments:
63+
# beta1 (momentum) of .95 seems to work better than .90...
64+
# N_sma_threshold of 5 seems better in testing than 4.
65+
# In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
66+
67+
# prep defaults and init torch.optim base
68+
defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas,
69+
N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay)
70+
super().__init__(params, defaults)
71+
72+
# adjustable threshold
73+
self.N_sma_threshhold = N_sma_threshhold
74+
75+
# look ahead params
76+
77+
self.alpha = alpha
78+
self.k = k
79+
80+
# radam buffer for state
81+
self.radam_buffer = [[None, None, None] for ind in range(10)]
82+
83+
# gc on or off
84+
self.gc_loc = gc_loc
85+
self.use_gc = use_gc
86+
self.gc_conv_only = gc_conv_only
87+
# level of gradient centralization
88+
#self.gc_gradient_threshold = 3 if gc_conv_only else 1
89+
90+
print(
91+
f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}")
92+
if (self.use_gc and self.gc_conv_only == False):
93+
print(f"GC applied to both conv and fc layers")
94+
elif (self.use_gc and self.gc_conv_only == True):
95+
print(f"GC applied to conv layers only")
96+
97+
def __setstate__(self, state):
98+
print("set state called")
99+
super(Ranger, self).__setstate__(state)
100+
101+
def step(self, closure=None):
102+
loss = None
103+
# note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
104+
# Uncomment if you need to use the actual closure...
105+
106+
# if closure is not None:
107+
#loss = closure()
108+
109+
# Evaluate averages and grad, update param tensors
110+
for group in self.param_groups:
111+
112+
for p in group['params']:
113+
if p.grad is None:
114+
continue
115+
grad = p.grad.data.float()
116+
117+
if grad.is_sparse:
118+
raise RuntimeError(
119+
'Ranger optimizer does not support sparse gradients')
120+
121+
p_data_fp32 = p.data.float()
122+
123+
state = self.state[p] # get state dict for this param
124+
125+
if len(state) == 0: # if first time to run...init dictionary with our desired entries
126+
# if self.first_run_check==0:
127+
# self.first_run_check=1
128+
#print("Initializing slow buffer...should not see this at load from saved model!")
129+
state['step'] = 0
130+
state['exp_avg'] = torch.zeros_like(p_data_fp32)
131+
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
132+
133+
# look ahead weight storage now in state dict
134+
state['slow_buffer'] = torch.empty_like(p.data)
135+
state['slow_buffer'].copy_(p.data)
136+
137+
else:
138+
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
139+
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(
140+
p_data_fp32)
141+
142+
# begin computations
143+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
144+
beta1, beta2 = group['betas']
145+
146+
# GC operation for Conv layers and FC layers
147+
# if grad.dim() > self.gc_gradient_threshold:
148+
# grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
149+
if self.gc_loc:
150+
grad = centralized_gradient(grad, use_gc=self.use_gc, gc_conv_only=self.gc_conv_only)
151+
152+
state['step'] += 1
153+
154+
# compute variance mov avg
155+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
156+
157+
# compute mean moving avg
158+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
159+
160+
buffered = self.radam_buffer[int(state['step'] % 10)]
161+
162+
if state['step'] == buffered[0]:
163+
N_sma, step_size = buffered[1], buffered[2]
164+
else:
165+
buffered[0] = state['step']
166+
beta2_t = beta2 ** state['step']
167+
N_sma_max = 2 / (1 - beta2) - 1
168+
N_sma = N_sma_max - 2 * \
169+
state['step'] * beta2_t / (1 - beta2_t)
170+
buffered[1] = N_sma
171+
if N_sma > self.N_sma_threshhold:
172+
step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (
173+
N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
174+
else:
175+
step_size = 1.0 / (1 - beta1 ** state['step'])
176+
buffered[2] = step_size
177+
178+
# if group['weight_decay'] != 0:
179+
# p_data_fp32.add_(-group['weight_decay']
180+
# * group['lr'], p_data_fp32)
181+
182+
# apply lr
183+
if N_sma > self.N_sma_threshhold:
184+
denom = exp_avg_sq.sqrt().add_(group['eps'])
185+
G_grad = exp_avg / denom
186+
else:
187+
G_grad = exp_avg
188+
189+
if group['weight_decay'] != 0:
190+
G_grad.add_(p_data_fp32, alpha=group['weight_decay'])
191+
# GC operation
192+
if self.gc_loc == False:
193+
G_grad = centralized_gradient(G_grad, use_gc=self.use_gc, gc_conv_only=self.gc_conv_only)
194+
195+
p_data_fp32.add_(G_grad, alpha=-step_size * group['lr'])
196+
p.data.copy_(p_data_fp32)
197+
198+
# integrated look ahead...
199+
# we do it at the param level instead of group level
200+
if state['step'] % group['k'] == 0:
201+
# get access to slow param tensor
202+
slow_p = state['slow_buffer']
203+
# (fast weights - slow weights) * alpha
204+
slow_p.add_(p.data - slow_p, alpha=self.alpha)
205+
# copy interpolated weights to RAdam param tensor
206+
p.data.copy_(slow_p)
207+
208+
return loss

0 commit comments

Comments
 (0)