@@ -133,12 +133,14 @@ def __init__(
133133 initialization : str = "uniform" ,
134134 device : str = "cpu" ,
135135 num_steps = None ,
136- para_name : str = ''
136+ para_name : str = '' ,
137+ batch_speedup : bool = False
137138 ) -> None :
138139 self .model = model
139140 self .sgd_lr = sgd_lr
140141 self .fish_lr = fish_lr
141142 self .device = device
143+ self .batch_speedup = batch_speedup
142144 self .para_name = para_name
143145 self .initialization = initialization
144146
@@ -163,19 +165,24 @@ def __init__(
163165 "gradbar" : [
164166 torch .zeros_like (params [name ]) for name in module .order
165167 ],
166- "theta0" : [params [name ].data .clone () for name in module .order ],
167- "grad" : [torch .zeros_like (params [name ]) for name in module .order ],
168- "Qg" : module .Qg ,
168+ "theta0" : [
169+ params [name ].data .clone () for name in module .order
170+ ],
171+ "grad" : [
172+ torch .zeros_like (params [name ]) for name in module .order
173+ ],
174+ "Qv" : module .Qg if self .batch_speedup else module .Qv ,
169175 "order" : module .order ,
170176 "name" : module_name ,
171177 "module" : module ,
172178 }
173179 param_groups .append (g )
174180
175181 # Register hooks on trainable modules
176- module .register_forward_pre_hook (self ._save_input )
177- module .register_full_backward_hook (self ._save_grad_output )
178-
182+ if self .batch_speedup :
183+ module .register_forward_pre_hook (self ._save_input )
184+ module .register_full_backward_hook (self ._save_grad_output )
185+
179186 likelihood_params = self .likelihood .get_parameters ()
180187 if len (likelihood_params ) > 0 :
181188 self .likelihood .init_aux (init_scale = np .sqrt (self .sgd_lr / self .fish_lr ))
@@ -328,7 +335,7 @@ def update_aux(self) -> None:
328335 name = group ["name" ]
329336
330337 grad_norm = [grad / g_norm for grad in group ['grad' ]]
331- qg = group ["Qg " ]()
338+ qg = group ["Qv " ]() if self . batch_speedup else group [ "Qv" ]( group [ 'grad' ] )
332339
333340 for p , g , d_p in zip (
334341 group ['params' ], grad_norm , qg
@@ -352,10 +359,9 @@ def update_aux(self) -> None:
352359
353360 def step (self ) -> None :
354361 """Performes a single optimization step of FishLeg."""
355-
362+ self . updated = False
356363 if self .step_t == 0 :
357364 self .step_t += 1
358- print ("== pretraining==" )
359365 aux_losses = []
360366 aux = 0
361367 for pre in range (self .pre_aux_training ):
@@ -366,27 +372,23 @@ def step(self) -> None:
366372 aux_loss , linear_term , quad_term , reg_term , g2 = self .update_aux ()
367373 aux += aux_loss
368374
369- if pre % 10 == 0 and pre != 0 :
370- print ('aux_loss: {:.4f}, \t linear: {:.4f}, quad: {:.4f}, reg: {:.4f} g2: {:.4}' .format (
371- aux / 10 , linear_term , quad_term , reg_term , g2
372- ))
373- aux = 0
374- aux_losses .append (aux_loss .detach ().cpu ().numpy ())
375375 return aux_losses
376376
377377 if self .update_aux_every > 0 :
378378 if self .step_t % self .update_aux_every == 0 :
379379 aux_loss , linear_term , quad_term , reg_term , g2 = self .update_aux ()
380+ self .updated = True
380381 elif self .update_aux_every < 0 :
381382 for _ in range (- self .update_aux_every ):
382383 self .update_aux ()
384+ self .updated = True
383385
384386 self .step_t += 1
385387
386388 for group in self .param_groups :
387389 name = group ["name" ]
388390 with torch .no_grad ():
389- nat_grad = group ["Qg " ]()
391+ nat_grad = group ["Qv " ]() if self . batch_speedup else group [ "Qv" ]( group [ 'grad' ] if self . updated else [ p . grad . data for p in group [ 'params' ]] )
390392
391393 for p , d_p , gbar , p0 in zip (
392394 group ["params" ], nat_grad , group ["gradbar" ], group ["theta0" ]
0 commit comments