Skip to content

Commit 7fbbaec

Browse files
authored
QH = Quasi-Hyperbolic
uses the Quasi Hyperbolic Momentum from https://arxiv.org/abs/1810.06801v4. Defaults of nus = (.7, 1.0) work pretty well but still testing and tuning. It is not quite as fast as original Ranger but training curves are smoother and appear to be better for extended training durations. Was able to get a new high for one set of 80 epochs on ImageWoof but still training/tuning.
1 parent 47ba917 commit 7fbbaec

File tree

1 file changed

+182
-0
lines changed

1 file changed

+182
-0
lines changed

rangerqh.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# RangerQH - @lessw2020 github
2+
# Combines Quasi Hyperbolic momentum with Hinton Lookahead.
3+
4+
# https://arxiv.org/abs/1810.06801v4 (QH paper)
5+
# #Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
6+
7+
8+
9+
10+
# Some portions = Copyright (c) Facebook, Inc. and its affiliates.
11+
#
12+
# This source code is licensed under the MIT license found in the
13+
# LICENSE file in the root directory of this source tree.
14+
15+
import torch
16+
from torch.optim.optimizer import Optimizer
17+
18+
#from ..common import param_conv
19+
20+
21+
class RangerQH(Optimizer):
22+
r"""Implements the QHAdam optimization algorithm `(Ma and Yarats, 2019)`_.
23+
Along with Hinton/Zhang Lookahead.
24+
Args:
25+
params (iterable):
26+
iterable of parameters to optimize or dicts defining parameter
27+
groups
28+
lr (float, optional): learning rate (:math:`\alpha` from the paper)
29+
(default: 1e-3)
30+
betas (Tuple[float, float], optional): coefficients used for computing
31+
running averages of the gradient and its square
32+
(default: (0.9, 0.999))
33+
nus (Tuple[float, float], optional): immediate discount factors used to
34+
estimate the gradient and its square
35+
(default: (1.0, 1.0))
36+
eps (float, optional): term added to the denominator to improve
37+
numerical stability
38+
(default: 1e-8)
39+
weight_decay (float, optional): weight decay (default: 0.0)
40+
decouple_weight_decay (bool, optional): whether to decouple the weight
41+
decay from the gradient-based optimization step
42+
(default: False)
43+
Example:
44+
>>> optimizer = qhoptim.pyt.QHAdam(
45+
... model.parameters(),
46+
... lr=3e-4, nus=(0.8, 1.0), betas=(0.99, 0.999))
47+
>>> optimizer.zero_grad()
48+
>>> loss_fn(model(input), target).backward()
49+
>>> optimizer.step()
50+
.. _`(Ma and Yarats, 2019)`: https://arxiv.org/abs/1810.06801
51+
"""
52+
53+
def __init__(
54+
self,
55+
params,
56+
lr=1e-3,
57+
betas=(0.9, 0.999),
58+
nus=(.7, 1.0),
59+
weight_decay=0.0,
60+
k=6,
61+
alpha=.5,
62+
decouple_weight_decay=False,
63+
eps=1e-8,
64+
):
65+
if not 0.0 <= lr:
66+
raise ValueError("Invalid learning rate: {}".format(lr))
67+
if not 0.0 <= eps:
68+
raise ValueError("Invalid epsilon value: {}".format(eps))
69+
if not 0.0 <= betas[0] < 1.0:
70+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
71+
if not 0.0 <= betas[1] < 1.0:
72+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
73+
if weight_decay < 0.0:
74+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
75+
76+
defaults = {
77+
"lr": lr,
78+
"betas": betas,
79+
"nus": nus,
80+
"weight_decay": weight_decay,
81+
"decouple_weight_decay": decouple_weight_decay,
82+
"eps": eps,
83+
}
84+
super().__init__(params, defaults)
85+
86+
#look ahead params
87+
self.alpha = alpha
88+
self.k = k
89+
90+
91+
def step(self, closure=None):
92+
"""Performs a single optimization step.
93+
Args:
94+
closure (callable, optional):
95+
A closure that reevaluates the model and returns the loss.
96+
"""
97+
loss = None
98+
if closure is not None:
99+
loss = closure()
100+
101+
for group in self.param_groups:
102+
lr = group["lr"]
103+
beta1, beta2 = group["betas"]
104+
nu1, nu2 = group["nus"]
105+
weight_decay = group["weight_decay"]
106+
decouple_weight_decay = group["decouple_weight_decay"]
107+
eps = group["eps"]
108+
109+
for p in group["params"]:
110+
if p.grad is None:
111+
continue
112+
113+
d_p = p.grad.data
114+
if d_p.is_sparse:
115+
raise RuntimeError("QHAdam does not support sparse gradients")
116+
117+
118+
119+
if weight_decay != 0:
120+
if decouple_weight_decay:
121+
p.data.mul_(1 - lr * weight_decay)
122+
else:
123+
d_p.add_(weight_decay, p.data)
124+
125+
d_p_sq = d_p.mul(d_p)
126+
127+
#prep for saved param loading
128+
param_state = self.state[p]
129+
130+
if len(param_state) == 0:
131+
param_state["beta1_weight"] = 0.0
132+
param_state["beta2_weight"] = 0.0
133+
param_state['step'] = 0
134+
param_state["exp_avg"] = torch.zeros_like(p.data)
135+
param_state["exp_avg_sq"] = torch.zeros_like(p.data)
136+
#look ahead weight storage now in state dict
137+
param_state['slow_buffer'] = torch.empty_like(p.data)
138+
param_state['slow_buffer'].copy_(p.data)
139+
140+
141+
param_state['step'] += 1
142+
143+
param_state["beta1_weight"] = 1.0 + beta1 * param_state["beta1_weight"]
144+
param_state["beta2_weight"] = 1.0 + beta2 * param_state["beta2_weight"]
145+
146+
beta1_weight = param_state["beta1_weight"]
147+
beta2_weight = param_state["beta2_weight"]
148+
exp_avg = param_state["exp_avg"]
149+
exp_avg_sq = param_state["exp_avg_sq"]
150+
151+
beta1_adj = 1.0 - (1.0 / beta1_weight)
152+
beta2_adj = 1.0 - (1.0 / beta2_weight)
153+
exp_avg.mul_(beta1_adj).add_(1.0 - beta1_adj, d_p)
154+
exp_avg_sq.mul_(beta2_adj).add_(1.0 - beta2_adj, d_p_sq)
155+
156+
avg_grad = exp_avg.mul(nu1)
157+
if nu1 != 1.0:
158+
avg_grad.add_(1.0 - nu1, d_p)
159+
160+
avg_grad_rms = exp_avg_sq.mul(nu2)
161+
if nu2 != 1.0:
162+
avg_grad_rms.add_(1.0 - nu2, d_p_sq)
163+
avg_grad_rms.sqrt_()
164+
if eps != 0.0:
165+
avg_grad_rms.add_(eps)
166+
167+
p.data.addcdiv_(-lr, avg_grad, avg_grad_rms)
168+
169+
#integrated look ahead...
170+
#we do it at the param level instead of group level
171+
if param_state['step'] % self.k ==0: #group['k'] == 0:
172+
slow_p = param_state['slow_buffer'] #get access to slow param tensor
173+
slow_p.add_(self.alpha, p.data - slow_p) #(fast weights - slow weights) * alpha
174+
p.data.copy_(slow_p) #copy interpolated weights to RAdam param tensor
175+
176+
177+
return loss
178+
179+
@classmethod
180+
def _params_to_dict(cls, params):
181+
return {"lr": params.alpha, "nus": (params.nu1, params.nu2), "betas": (params.beta1, params.beta2)}
182+

0 commit comments

Comments
 (0)