Skip to content

Commit bf1a615

Browse files
R. Xiajamie-mcg
authored andcommitted
Add option batch_speedup
1 parent 067293c commit bf1a615

File tree

3 files changed

+21
-18
lines changed

3 files changed

+21
-18
lines changed

examples/autoencoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ def draw(model, data):
418418
sgd_lr=eta_sgd,
419419
initialization="normal",
420420
device=device,
421+
batch_speedup=False
421422
)
422423

423424
print(opt.__dict__["fish_lr"])

src/optim/FishLeg/fishleg.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,14 @@ def __init__(
133133
initialization: str = "uniform",
134134
device: str = "cpu",
135135
num_steps = None,
136-
para_name: str = ''
136+
para_name: str = '',
137+
batch_speedup: bool = False
137138
) -> None:
138139
self.model = model
139140
self.sgd_lr = sgd_lr
140141
self.fish_lr = fish_lr
141142
self.device = device
143+
self.batch_speedup = batch_speedup
142144
self.para_name = para_name
143145
self.initialization = initialization
144146

@@ -163,19 +165,24 @@ def __init__(
163165
"gradbar": [
164166
torch.zeros_like(params[name]) for name in module.order
165167
],
166-
"theta0": [params[name].data.clone() for name in module.order],
167-
"grad": [torch.zeros_like(params[name]) for name in module.order],
168-
"Qg": module.Qg,
168+
"theta0": [
169+
params[name].data.clone() for name in module.order
170+
],
171+
"grad": [
172+
torch.zeros_like(params[name]) for name in module.order
173+
],
174+
"Qv": module.Qg if self.batch_speedup else module.Qv,
169175
"order": module.order,
170176
"name": module_name,
171177
"module": module,
172178
}
173179
param_groups.append(g)
174180

175181
# Register hooks on trainable modules
176-
module.register_forward_pre_hook(self._save_input)
177-
module.register_full_backward_hook(self._save_grad_output)
178-
182+
if self.batch_speedup:
183+
module.register_forward_pre_hook(self._save_input)
184+
module.register_full_backward_hook(self._save_grad_output)
185+
179186
likelihood_params = self.likelihood.get_parameters()
180187
if len(likelihood_params) > 0:
181188
self.likelihood.init_aux(init_scale=np.sqrt(self.sgd_lr / self.fish_lr))
@@ -328,7 +335,7 @@ def update_aux(self) -> None:
328335
name = group["name"]
329336

330337
grad_norm = [grad/g_norm for grad in group['grad']]
331-
qg = group["Qg"]()
338+
qg = group["Qv"]() if self.batch_speedup else group["Qv"](group['grad'])
332339

333340
for p, g, d_p in zip(
334341
group['params'], grad_norm, qg
@@ -352,10 +359,9 @@ def update_aux(self) -> None:
352359

353360
def step(self) -> None:
354361
"""Performes a single optimization step of FishLeg."""
355-
362+
self.updated = False
356363
if self.step_t == 0:
357364
self.step_t += 1
358-
print("== pretraining==")
359365
aux_losses = []
360366
aux = 0
361367
for pre in range(self.pre_aux_training):
@@ -366,27 +372,23 @@ def step(self) -> None:
366372
aux_loss, linear_term, quad_term, reg_term, g2 = self.update_aux()
367373
aux += aux_loss
368374

369-
if pre % 10 == 0 and pre != 0:
370-
print('aux_loss: {:.4f}, \t linear: {:.4f}, quad: {:.4f}, reg: {:.4f} g2: {:.4}'.format(
371-
aux/10, linear_term, quad_term, reg_term, g2
372-
))
373-
aux = 0
374-
aux_losses.append(aux_loss.detach().cpu().numpy())
375375
return aux_losses
376376

377377
if self.update_aux_every > 0:
378378
if self.step_t % self.update_aux_every == 0:
379379
aux_loss, linear_term, quad_term, reg_term, g2 = self.update_aux()
380+
self.updated = True
380381
elif self.update_aux_every < 0:
381382
for _ in range(-self.update_aux_every):
382383
self.update_aux()
384+
self.updated = True
383385

384386
self.step_t += 1
385387

386388
for group in self.param_groups:
387389
name = group["name"]
388390
with torch.no_grad():
389-
nat_grad = group["Qg"]()
391+
nat_grad = group["Qv"]() if self.batch_speedup else group["Qv"](group['grad'] if self.updated else [p.grad.data for p in group['params']])
390392

391393
for p, d_p, gbar, p0 in zip(
392394
group["params"], nat_grad, group["gradbar"], group["theta0"]

src/optim/FishLeg/fishleg_layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def Qv(self, v: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
123123
R = torch.sqrt(self.fishleg_aux["scale"]) * self.fishleg_aux["R"]
124124
# print("u", v[0].shape, v[1][:, None].shape)
125125
u = torch.cat([v[0], v[1][:, None]], dim=-1)
126-
z = torch.linalg.multi_dot((R, R.T, u, L, L.T))
126+
z = torch.linalg.multi_dot((R.T, R, u, L, L.T))
127127
return (z[:, :-1], z[:, -1])
128128

129129
def diagQ(self) -> Tensor:

0 commit comments

Comments
 (0)