11# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.
22
3- # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
3+ # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
44# and/or
55# https://github.com/lessw2020/Best-Deep-Learning-Optimizers
66
77# Ranger has now been used to capture 12 records on the FastAI leaderboard.
88
9- # This version = 20.4.11
9+ # This version = 20.4.11
1010
1111# Credits:
1212# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization
1313# RAdam --> https://github.com/LiyuanLucasLiu/RAdam
1414# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
1515# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
1616
17- # summary of changes:
18- # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
19- # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
17+ # summary of changes:
18+ # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
19+ # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
2020# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
21- # changes 8/31/19 - fix references to *self*.N_sma_threshold;
21+ # changes 8/31/19 - fix references to *self*.N_sma_threshold;
2222# changed eps to 1e-5 as better default than 1e-8.
2323
2424import math
2525import torch
2626from torch .optim .optimizer import Optimizer , required
2727
2828
29-
3029class Ranger (Optimizer ):
3130
3231 def __init__ (self , params , lr = 1e-3 , # lr
33- alpha = 0.5 , k = 6 , N_sma_threshhold = 5 , # Ranger options
34- betas = (.95 ,0.999 ), eps = 1e-5 , weight_decay = 0 , # Adam options
35- use_gc = True , gc_conv_only = False # Gradient centralization on or off, applied to conv layers only or conv + fc layers
36- ):
32+ alpha = 0.5 , k = 6 , N_sma_threshhold = 5 , # Ranger options
33+ betas = (.95 , 0.999 ), eps = 1e-5 , weight_decay = 0 , # Adam options
34+ # Gradient centralization on or off, applied to conv layers only or conv + fc layers
35+ use_gc = True , gc_conv_only = False
36+ ):
3737
38- #parameter checks
38+ # parameter checks
3939 if not 0.0 <= alpha <= 1.0 :
4040 raise ValueError (f'Invalid slow update rate: { alpha } ' )
4141 if not 1 <= k :
@@ -45,58 +45,53 @@ def __init__(self, params, lr=1e-3, # lr
4545 if not eps > 0 :
4646 raise ValueError (f'Invalid eps: { eps } ' )
4747
48- #parameter comments:
48+ # parameter comments:
4949 # beta1 (momentum) of .95 seems to work better than .90...
50- #N_sma_threshold of 5 seems better in testing than 4.
51- #In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
50+ # N_sma_threshold of 5 seems better in testing than 4.
51+ # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
5252
53- #prep defaults and init torch.optim base
54- defaults = dict (lr = lr , alpha = alpha , k = k , step_counter = 0 , betas = betas , N_sma_threshhold = N_sma_threshhold , eps = eps , weight_decay = weight_decay )
55- super ().__init__ (params ,defaults )
53+ # prep defaults and init torch.optim base
54+ defaults = dict (lr = lr , alpha = alpha , k = k , step_counter = 0 , betas = betas ,
55+ N_sma_threshhold = N_sma_threshhold , eps = eps , weight_decay = weight_decay )
56+ super ().__init__ (params , defaults )
5657
57- #adjustable threshold
58+ # adjustable threshold
5859 self .N_sma_threshhold = N_sma_threshhold
5960
60-
61- #look ahead params
61+ # look ahead params
6262
6363 self .alpha = alpha
64- self .k = k
64+ self .k = k
65+
66+ # radam buffer for state
67+ self .radam_buffer = [[None , None , None ] for ind in range (10 )]
6568
66- #radam buffer for state
67- self .radam_buffer = [[ None , None , None ] for ind in range ( 10 )]
69+ # gc on or off
70+ self .use_gc = use_gc
6871
69- #gc on or off
70- self .use_gc = use_gc
71-
72- #level of gradient centralization
72+ # level of gradient centralization
7373 self .gc_gradient_threshold = 3 if gc_conv_only else 1
74-
75-
76- print ( f"Ranger optimizer loaded. \n Gradient Centralization usage = { self .use_gc } " )
77- if (self .use_gc and self .gc_gradient_threshold == 1 ):
74+
75+ print (
76+ f"Ranger optimizer loaded. \n Gradient Centralization usage = { self .use_gc } " )
77+ if (self .use_gc and self .gc_gradient_threshold == 1 ):
7878 print (f"GC applied to both conv and fc layers" )
79- elif (self .use_gc and self .gc_gradient_threshold == 3 ):
79+ elif (self .use_gc and self .gc_gradient_threshold == 3 ):
8080 print (f"GC applied to conv layers only" )
81-
82-
83-
84-
8581
8682 def __setstate__ (self , state ):
8783 print ("set state called" )
8884 super (Ranger , self ).__setstate__ (state )
8985
90-
9186 def step (self , closure = None ):
9287 loss = None
93- #note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
94- #Uncomment if you need to use the actual closure...
88+ # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
89+ # Uncomment if you need to use the actual closure...
9590
96- #if closure is not None:
97- #loss = closure()
91+ # if closure is not None:
92+ #loss = closure()
9893
99- #Evaluate averages and grad, update param tensors
94+ # Evaluate averages and grad, update param tensors
10095 for group in self .param_groups :
10196
10297 for p in group ['params' ]:
@@ -105,84 +100,85 @@ def step(self, closure=None):
105100 grad = p .grad .data .float ()
106101
107102 if grad .is_sparse :
108- raise RuntimeError ('Ranger optimizer does not support sparse gradients' )
103+ raise RuntimeError (
104+ 'Ranger optimizer does not support sparse gradients' )
109105
110106 p_data_fp32 = p .data .float ()
111107
112- state = self .state [p ] #get state dict for this param
108+ state = self .state [p ] # get state dict for this param
113109
114- if len (state ) == 0 : # if first time to run...init dictionary with our desired entries
115- #if self.first_run_check==0:
116- # self.first_run_check=1
117- #print("Initializing slow buffer...should not see this at load from saved model!")
110+ if len (state ) == 0 : # if first time to run...init dictionary with our desired entries
111+ # if self.first_run_check==0:
112+ # self.first_run_check=1
113+ #print("Initializing slow buffer...should not see this at load from saved model!")
118114 state ['step' ] = 0
119115 state ['exp_avg' ] = torch .zeros_like (p_data_fp32 )
120116 state ['exp_avg_sq' ] = torch .zeros_like (p_data_fp32 )
121117
122- #look ahead weight storage now in state dict
118+ # look ahead weight storage now in state dict
123119 state ['slow_buffer' ] = torch .empty_like (p .data )
124120 state ['slow_buffer' ].copy_ (p .data )
125121
126122 else :
127123 state ['exp_avg' ] = state ['exp_avg' ].type_as (p_data_fp32 )
128- state ['exp_avg_sq' ] = state ['exp_avg_sq' ].type_as (p_data_fp32 )
124+ state ['exp_avg_sq' ] = state ['exp_avg_sq' ].type_as (
125+ p_data_fp32 )
129126
130- #begin computations
127+ # begin computations
131128 exp_avg , exp_avg_sq = state ['exp_avg' ], state ['exp_avg_sq' ]
132129 beta1 , beta2 = group ['betas' ]
133-
134-
135- #GC operation for Conv layers and FC layers
136- if grad .dim () > self .gc_gradient_threshold :
137- grad .add_ (- grad .mean (dim = tuple (range (1 ,grad .dim ())), keepdim = True ))
138-
139-
130+
131+ # GC operation for Conv layers and FC layers
132+ if grad .dim () > self .gc_gradient_threshold :
133+ grad .add_ (- grad .mean (dim = tuple (range (1 , grad .dim ())), keepdim = True ))
140134
141135 state ['step' ] += 1
142136
143- #compute variance mov avg
137+ # compute variance mov avg
144138 exp_avg_sq .mul_ (beta2 ).addcmul_ (1 - beta2 , grad , grad )
145- #compute mean moving avg
139+ # compute mean moving avg
146140 exp_avg .mul_ (beta1 ).add_ (1 - beta1 , grad )
147141
148-
149-
150-
151-
152142 buffered = self .radam_buffer [int (state ['step' ] % 10 )]
153-
143+
154144 if state ['step' ] == buffered [0 ]:
155145 N_sma , step_size = buffered [1 ], buffered [2 ]
156146 else :
157147 buffered [0 ] = state ['step' ]
158148 beta2_t = beta2 ** state ['step' ]
159149 N_sma_max = 2 / (1 - beta2 ) - 1
160- N_sma = N_sma_max - 2 * state ['step' ] * beta2_t / (1 - beta2_t )
150+ N_sma = N_sma_max - 2 * \
151+ state ['step' ] * beta2_t / (1 - beta2_t )
161152 buffered [1 ] = N_sma
162153 if N_sma > self .N_sma_threshhold :
163- step_size = math .sqrt ((1 - beta2_t ) * (N_sma - 4 ) / (N_sma_max - 4 ) * (N_sma - 2 ) / N_sma * N_sma_max / (N_sma_max - 2 )) / (1 - beta1 ** state ['step' ])
154+ step_size = math .sqrt ((1 - beta2_t ) * (N_sma - 4 ) / (N_sma_max - 4 ) * (
155+ N_sma - 2 ) / N_sma * N_sma_max / (N_sma_max - 2 )) / (1 - beta1 ** state ['step' ])
164156 else :
165157 step_size = 1.0 / (1 - beta1 ** state ['step' ])
166158 buffered [2 ] = step_size
167159
168-
169160 if group ['weight_decay' ] != 0 :
170- p_data_fp32 .add_ (- group ['weight_decay' ] * group ['lr' ], p_data_fp32 )
161+ p_data_fp32 .add_ (- group ['weight_decay' ]
162+ * group ['lr' ], p_data_fp32 )
171163
172164 # apply lr
173165 if N_sma > self .N_sma_threshhold :
174166 denom = exp_avg_sq .sqrt ().add_ (group ['eps' ])
175- p_data_fp32 .addcdiv_ (- step_size * group ['lr' ], exp_avg , denom )
167+ p_data_fp32 .addcdiv_ (- step_size *
168+ group ['lr' ], exp_avg , denom )
176169 else :
177170 p_data_fp32 .add_ (- step_size * group ['lr' ], exp_avg )
178171
179172 p .data .copy_ (p_data_fp32 )
180173
181- #integrated look ahead...
182- #we do it at the param level instead of group level
174+ # integrated look ahead...
175+ # we do it at the param level instead of group level
183176 if state ['step' ] % group ['k' ] == 0 :
184- slow_p = state ['slow_buffer' ] #get access to slow param tensor
185- slow_p .add_ (self .alpha , p .data - slow_p ) #(fast weights - slow weights) * alpha
186- p .data .copy_ (slow_p ) #copy interpolated weights to RAdam param tensor
177+ # get access to slow param tensor
178+ slow_p = state ['slow_buffer' ]
179+ # (fast weights - slow weights) * alpha
180+ slow_p .add_ (self .alpha , p .data - slow_p )
181+ # copy interpolated weights to RAdam param tensor
182+ p .data .copy_ (slow_p )
187183
188184 return loss
0 commit comments