Skip to content

Commit 22e8bb6

Browse files
authored
Update ranger.py
made a new parameter for N_sma_threshhold. Testing on FastAI datasets shows best results with N_sma_threshhold=5 (default). Paper uses 4, but please test with your dataset!
1 parent 44f66a5 commit 22e8bb6

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

ranger.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
class Ranger(Optimizer):
1212

13-
def __init__(self, params, lr=1e-3, alpha=0.5, k=6, betas=(.9,0.999), eps=1e-8, weight_decay=0):
13+
def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.9,0.999), eps=1e-8, weight_decay=0):
1414
#parameter checks
1515
if not 0.0 <= alpha <= 1.0:
1616
raise ValueError(f'Invalid slow update rate: {alpha}')
@@ -25,6 +25,9 @@ def __init__(self, params, lr=1e-3, alpha=0.5, k=6, betas=(.9,0.999), eps=1e-8,
2525
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
2626
super().__init__(params,defaults)
2727

28+
#adjustable threshold
29+
self.N_sma_threshhold = N_sma_threshhold
30+
2831
#now we can get to work...
2932
for group in self.param_groups:
3033
group["step_counter"] = 0
@@ -96,7 +99,7 @@ def step(self, closure=None):
9699
N_sma_max = 2 / (1 - beta2) - 1
97100
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
98101
buffered[1] = N_sma
99-
if N_sma > 4:
102+
if N_sma > self.N_sma_threshhold:
100103
step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
101104
else:
102105
step_size = group['lr'] / (1 - beta1 ** state['step'])

0 commit comments

Comments
 (0)