Skip to content

Commit f97c298

Browse files
authored
now adds Gradient Centralization
1 parent a0154c8 commit f97c298

File tree

1 file changed

+193
-165
lines changed

1 file changed

+193
-165
lines changed

ranger/ranger.py

Lines changed: 193 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -1,165 +1,193 @@
1-
#Ranger deep learning optimizer - RAdam + Lookahead combined.
2-
#https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
3-
4-
#Ranger has now been used to capture 12 records on the FastAI leaderboard.
5-
6-
#This version = 9.3.19
7-
8-
#Credits:
9-
#RAdam --> https://github.com/LiyuanLucasLiu/RAdam
10-
#Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
11-
#Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
12-
13-
#summary of changes:
14-
#full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
15-
#supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
16-
#changes 8/31/19 - fix references to *self*.N_sma_threshold;
17-
#changed eps to 1e-5 as better default than 1e-8.
18-
19-
import math
20-
import torch
21-
from torch.optim.optimizer import Optimizer, required
22-
import itertools as it
23-
24-
25-
26-
class Ranger(Optimizer):
27-
28-
def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95,0.999), eps=1e-5, weight_decay=0):
29-
#parameter checks
30-
if not 0.0 <= alpha <= 1.0:
31-
raise ValueError(f'Invalid slow update rate: {alpha}')
32-
if not 1 <= k:
33-
raise ValueError(f'Invalid lookahead steps: {k}')
34-
if not lr > 0:
35-
raise ValueError(f'Invalid Learning Rate: {lr}')
36-
if not eps > 0:
37-
raise ValueError(f'Invalid eps: {eps}')
38-
39-
#parameter comments:
40-
# beta1 (momentum) of .95 seems to work better than .90...
41-
#N_sma_threshold of 5 seems better in testing than 4.
42-
#In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
43-
44-
#prep defaults and init torch.optim base
45-
defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay)
46-
super().__init__(params,defaults)
47-
48-
#adjustable threshold
49-
self.N_sma_threshhold = N_sma_threshhold
50-
51-
#now we can get to work...
52-
#removed as we now use step from RAdam...no need for duplicate step counting
53-
#for group in self.param_groups:
54-
# group["step_counter"] = 0
55-
#print("group step counter init")
56-
57-
#look ahead params
58-
self.alpha = alpha
59-
self.k = k
60-
61-
#radam buffer for state
62-
self.radam_buffer = [[None,None,None] for ind in range(10)]
63-
64-
#self.first_run_check=0
65-
66-
#lookahead weights
67-
#9/2/19 - lookahead param tensors have been moved to state storage.
68-
#This should resolve issues with load/save where weights were left in GPU memory from first load, slowing down future runs.
69-
70-
#self.slow_weights = [[p.clone().detach() for p in group['params']]
71-
# for group in self.param_groups]
72-
73-
#don't use grad for lookahead weights
74-
#for w in it.chain(*self.slow_weights):
75-
# w.requires_grad = False
76-
77-
def __setstate__(self, state):
78-
print("set state called")
79-
super(Ranger, self).__setstate__(state)
80-
81-
82-
def step(self, closure=None):
83-
loss = None
84-
#note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
85-
#Uncomment if you need to use the actual closure...
86-
87-
#if closure is not None:
88-
#loss = closure()
89-
90-
#Evaluate averages and grad, update param tensors
91-
for group in self.param_groups:
92-
93-
for p in group['params']:
94-
if p.grad is None:
95-
continue
96-
grad = p.grad.data.float()
97-
if grad.is_sparse:
98-
raise RuntimeError('Ranger optimizer does not support sparse gradients')
99-
100-
p_data_fp32 = p.data.float()
101-
102-
state = self.state[p] #get state dict for this param
103-
104-
if len(state) == 0: #if first time to run...init dictionary with our desired entries
105-
#if self.first_run_check==0:
106-
#self.first_run_check=1
107-
#print("Initializing slow buffer...should not see this at load from saved model!")
108-
state['step'] = 0
109-
state['exp_avg'] = torch.zeros_like(p_data_fp32)
110-
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
111-
112-
#look ahead weight storage now in state dict
113-
state['slow_buffer'] = torch.empty_like(p.data)
114-
state['slow_buffer'].copy_(p.data)
115-
116-
else:
117-
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
118-
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
119-
120-
#begin computations
121-
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
122-
beta1, beta2 = group['betas']
123-
124-
#compute variance mov avg
125-
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
126-
#compute mean moving avg
127-
exp_avg.mul_(beta1).add_(1 - beta1, grad)
128-
129-
state['step'] += 1
130-
131-
132-
buffered = self.radam_buffer[int(state['step'] % 10)]
133-
if state['step'] == buffered[0]:
134-
N_sma, step_size = buffered[1], buffered[2]
135-
else:
136-
buffered[0] = state['step']
137-
beta2_t = beta2 ** state['step']
138-
N_sma_max = 2 / (1 - beta2) - 1
139-
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
140-
buffered[1] = N_sma
141-
if N_sma > self.N_sma_threshhold:
142-
step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
143-
else:
144-
step_size = 1.0 / (1 - beta1 ** state['step'])
145-
buffered[2] = step_size
146-
147-
if group['weight_decay'] != 0:
148-
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
149-
150-
if N_sma > self.N_sma_threshhold:
151-
denom = exp_avg_sq.sqrt().add_(group['eps'])
152-
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
153-
else:
154-
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
155-
156-
p.data.copy_(p_data_fp32)
157-
158-
#integrated look ahead...
159-
#we do it at the param level instead of group level
160-
if state['step'] % group['k'] == 0:
161-
slow_p = state['slow_buffer'] #get access to slow param tensor
162-
slow_p.add_(self.alpha, p.data - slow_p) #(fast weights - slow weights) * alpha
163-
p.data.copy_(slow_p) #copy interpolated weights to RAdam param tensor
164-
165-
return loss
1+
# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.
2+
3+
# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
4+
# and/or
5+
# https://github.com/lessw2020/Best-Deep-Learning-Optimizers
6+
7+
# Ranger has now been used to capture 12 records on the FastAI leaderboard.
8+
9+
# This version = 20.4.11
10+
11+
# Credits:
12+
# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization
13+
# RAdam --> https://github.com/LiyuanLucasLiu/RAdam
14+
# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
15+
# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
16+
17+
# summary of changes:
18+
# 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
19+
# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
20+
# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
21+
# changes 8/31/19 - fix references to *self*.N_sma_threshold;
22+
# changed eps to 1e-5 as better default than 1e-8.
23+
24+
import math
25+
import torch
26+
from torch.optim.optimizer import Optimizer, required
27+
28+
29+
30+
class Ranger(Optimizer):
31+
32+
def __init__(self, params, lr=1e-3, # lr
33+
alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options
34+
betas=(.95,0.999), eps=1e-5, weight_decay=0, # Adam options
35+
use_gc=True, gc_conv_only=False # Gradient centralization on or off, applied to conv layers only or conv + fc layers
36+
):
37+
38+
#parameter checks
39+
if not 0.0 <= alpha <= 1.0:
40+
raise ValueError(f'Invalid slow update rate: {alpha}')
41+
if not 1 <= k:
42+
raise ValueError(f'Invalid lookahead steps: {k}')
43+
if not lr > 0:
44+
raise ValueError(f'Invalid Learning Rate: {lr}')
45+
if not eps > 0:
46+
raise ValueError(f'Invalid eps: {eps}')
47+
48+
#parameter comments:
49+
# beta1 (momentum) of .95 seems to work better than .90...
50+
#N_sma_threshold of 5 seems better in testing than 4.
51+
#In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
52+
53+
#prep defaults and init torch.optim base
54+
defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay)
55+
super().__init__(params,defaults)
56+
57+
#adjustable threshold
58+
self.N_sma_threshhold = N_sma_threshhold
59+
60+
61+
#look ahead params
62+
63+
self.alpha = alpha
64+
self.k = k
65+
66+
#radam buffer for state
67+
self.radam_buffer = [[None,None,None] for ind in range(10)]
68+
69+
#gc on or off
70+
self.use_gc=use_gc
71+
72+
#level of gradient centralization
73+
self.gc_gradient_threshold = 3 if gc_conv_only else 1
74+
75+
76+
print(f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}")
77+
if (self.use_gc and self.gc_gradient_threshold==1):
78+
print(f"GC applied to both conv and fc layers")
79+
elif (self.use_gc and self.gc_gradient_threshold==3):
80+
print(f"GC applied to conv layers only")
81+
82+
83+
84+
85+
86+
def __setstate__(self, state):
87+
print("set state called")
88+
super(Ranger, self).__setstate__(state)
89+
90+
91+
def step(self, closure=None):
92+
loss = None
93+
#note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
94+
#Uncomment if you need to use the actual closure...
95+
96+
#if closure is not None:
97+
#loss = closure()
98+
99+
#Evaluate averages and grad, update param tensors
100+
for group in self.param_groups:
101+
102+
for p in group['params']:
103+
if p.grad is None:
104+
continue
105+
grad = p.grad.data.float()
106+
107+
if grad.is_sparse:
108+
raise RuntimeError('Ranger optimizer does not support sparse gradients')
109+
110+
p_data_fp32 = p.data.float()
111+
112+
state = self.state[p] #get state dict for this param
113+
114+
if len(state) == 0: #if first time to run...init dictionary with our desired entries
115+
#if self.first_run_check==0:
116+
#self.first_run_check=1
117+
#print("Initializing slow buffer...should not see this at load from saved model!")
118+
state['step'] = 0
119+
state['exp_avg'] = torch.zeros_like(p_data_fp32)
120+
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
121+
122+
#look ahead weight storage now in state dict
123+
state['slow_buffer'] = torch.empty_like(p.data)
124+
state['slow_buffer'].copy_(p.data)
125+
126+
else:
127+
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
128+
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
129+
130+
#begin computations
131+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
132+
beta1, beta2 = group['betas']
133+
134+
135+
136+
#gc work #3 = conv only, 1 = conv +fc
137+
if len(list(grad.size()))>self.gc_gradient_threshold:
138+
grad.add_(-grad.mean(dim = tuple(range(1,len(list(grad.size())))), keepdim = True))
139+
140+
#GC operation for Conv layers and FC layers
141+
if len(list(grad.size()))>self.gc_gradient_threshold:
142+
grad.add_(-grad.mean(dim = tuple(range(1,len(list(grad.size())))), keepdim = True))
143+
144+
145+
146+
state['step'] += 1
147+
148+
#compute variance mov avg
149+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
150+
#compute mean moving avg
151+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
152+
153+
154+
155+
156+
157+
buffered = self.radam_buffer[int(state['step'] % 10)]
158+
159+
if state['step'] == buffered[0]:
160+
N_sma, step_size = buffered[1], buffered[2]
161+
else:
162+
buffered[0] = state['step']
163+
beta2_t = beta2 ** state['step']
164+
N_sma_max = 2 / (1 - beta2) - 1
165+
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
166+
buffered[1] = N_sma
167+
if N_sma > self.N_sma_threshhold:
168+
step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
169+
else:
170+
step_size = 1.0 / (1 - beta1 ** state['step'])
171+
buffered[2] = step_size
172+
173+
174+
if group['weight_decay'] != 0:
175+
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
176+
177+
# apply lr
178+
if N_sma > self.N_sma_threshhold:
179+
denom = exp_avg_sq.sqrt().add_(group['eps'])
180+
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
181+
else:
182+
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
183+
184+
p.data.copy_(p_data_fp32)
185+
186+
#integrated look ahead...
187+
#we do it at the param level instead of group level
188+
if state['step'] % group['k'] == 0:
189+
slow_p = state['slow_buffer'] #get access to slow param tensor
190+
slow_p.add_(self.alpha, p.data - slow_p) #(fast weights - slow weights) * alpha
191+
p.data.copy_(slow_p) #copy interpolated weights to RAdam param tensor
192+
193+
return loss

0 commit comments

Comments
 (0)