@@ -41,15 +41,19 @@ def sample_from_logits(
4141
4242
4343def obj_func_scaler (
44- obj_func : Callable , exp_offset_and_scale : Optional [Tuple [float , float ]]
45- ) -> Callable :
44+ obj_func : Optional [Callable [[Dict [str , torch .Tensor ]], torch .Tensor ]],
45+ exp_offset_and_scale : Optional [Tuple [float , float ]],
46+ ) -> Optional [Callable ]:
4647 """
4748 Scale objective functions to make optimizers get out of local minima more easily.
4849
4950 The scaling formula is: exp((reward - offset) / scale)
5051
5152 if obj_exp_offset_scale is None, do not scale the obj_function (i.e., reward == scaled_reward)
5253 """
54+ if obj_func is None :
55+ return None
56+
5357 if exp_offset_and_scale is not None :
5458 offset , scale = exp_offset_and_scale
5559
@@ -103,7 +107,7 @@ class ComboOptimizerBase:
103107 def __init__ (
104108 self ,
105109 param : ng .p .Dict ,
106- obj_func : Callable ,
110+ obj_func : Optional [ Callable [[ Dict [ str , torch . Tensor ]], torch . Tensor ]] = None ,
107111 batch_size : int = BATCH_SIZE ,
108112 obj_exp_offset_scale : Optional [Tuple [float , float ]] = None ,
109113 ) -> None :
@@ -123,6 +127,11 @@ def _init(self) -> None:
123127 pass
124128
125129 def optimize_step (self ) -> Tuple :
130+ assert self .obj_func is not None , (
131+ "obj_func not provided. Can't call optimize_step() for optimization. "
132+ "You have to perform manual optimization, i.e., call sample_internal() then update_params()"
133+ )
134+
126135 all_results = self ._optimize_step ()
127136 sampled_solutions , sampled_reward = all_results [0 ], all_results [1 ]
128137 self ._maintain_best_solutions (sampled_solutions , sampled_reward )
@@ -249,7 +258,7 @@ class RandomSearchOptimizer(ComboOptimizerBase):
249258 def __init__ (
250259 self ,
251260 param : ng .p .Dict ,
252- obj_func : Callable ,
261+ obj_func : Optional [ Callable [[ Dict [ str , torch . Tensor ]], torch . Tensor ]] = None ,
253262 batch_size : int = BATCH_SIZE ,
254263 sampling_weights : Optional [Dict [str , np .ndarray ]] = None ,
255264 ) -> None :
@@ -304,16 +313,16 @@ class NeverGradOptimizer(ComboOptimizerBase):
304313 Args:
305314 param (ng.p.Dict): a nevergrad dictionary for specifying input choices
306315
316+ estimated_budgets (int): estimated number of budgets (objective evaluation
317+ times) for nevergrad to perform auto tuning.
318+
307319 obj_func (Callable[[Dict[str, torch.Tensor]], torch.Tensor]):
308320 a function which consumes sampled solutions and returns
309321 rewards as tensors of shape (batch_size, 1).
310322
311323 The input dictionary has choice names as the key and sampled choice
312324 indices as the value (of shape (batch_size, ))
313325
314- estimated_budgets (int): estimated number of budgets (objective evaluation
315- times) for nevergrad to perform auto tuning.
316-
317326 optimizer_name (Optional[str]): ng optimizer to be used specifically
318327 All possible nevergrad optimizers are available at:
319328 https://facebookresearch.github.io/nevergrad/optimization.html#choosing-an-optimizer.
@@ -331,8 +340,9 @@ class NeverGradOptimizer(ComboOptimizerBase):
331340 ... reward[i, 0] = 0.0
332341 ... return reward
333342 ...
343+ >>> estimated_budgets = 40
334344 >>> optimizer = NeverGradOptimizer(
335- ... ng_param, obj_func, batch_size=BATCH_SIZE, estimated_budgets=40
345+ ... ng_param, estimated_budgets, obj_func, batch_size=BATCH_SIZE,
336346 ... )
337347 >>>
338348 >>> for i in range(10):
@@ -346,8 +356,8 @@ class NeverGradOptimizer(ComboOptimizerBase):
346356 def __init__ (
347357 self ,
348358 param : ng .p .Dict ,
349- obj_func : Callable ,
350359 estimated_budgets : int ,
360+ obj_func : Optional [Callable [[Dict [str , torch .Tensor ]], torch .Tensor ]] = None ,
351361 batch_size : int = BATCH_SIZE ,
352362 optimizer_name : Optional [str ] = None ,
353363 ) -> None :
@@ -422,9 +432,9 @@ class LogitBasedComboOptimizerBase(ComboOptimizerBase):
422432 def __init__ (
423433 self ,
424434 param : ng .p .Dict ,
425- obj_func : Callable ,
426435 start_temp : float ,
427436 min_temp : float ,
437+ obj_func : Optional [Callable [[Dict [str , torch .Tensor ]], torch .Tensor ]] = None ,
428438 learning_rate : float = LEARNING_RATE ,
429439 anneal_rate : float = ANNEAL_RATE ,
430440 batch_size : int = BATCH_SIZE ,
@@ -510,7 +520,7 @@ class GumbelSoftmaxOptimizer(LogitBasedComboOptimizerBase):
510520 ... ng_param, obj_func, anneal_rate=0.9, batch_size=BATCH_SIZE, learning_rate=0.1
511521 ... )
512522 ...
513- >>> for i in range(20 ):
523+ >>> for i in range(30 ):
514524 ... res = optimizer.optimize_step()
515525 ...
516526 >>> assert optimizer.sample(1)['choice1'] == 2
@@ -519,7 +529,7 @@ class GumbelSoftmaxOptimizer(LogitBasedComboOptimizerBase):
519529 def __init__ (
520530 self ,
521531 param : ng .p .Dict ,
522- obj_func : Callable [[Dict [str , torch .Tensor ]], torch .Tensor ],
532+ obj_func : Optional [ Callable [[Dict [str , torch .Tensor ]], torch .Tensor ]] = None ,
523533 start_temp : float = 1.0 ,
524534 min_temp : float = 0.1 ,
525535 learning_rate : float = LEARNING_RATE ,
@@ -530,9 +540,9 @@ def __init__(
530540 self .update_params_within_optimizer = update_params_within_optimizer
531541 super ().__init__ (
532542 param ,
533- obj_func ,
534543 start_temp ,
535544 min_temp ,
545+ obj_func ,
536546 learning_rate ,
537547 anneal_rate ,
538548 batch_size ,
@@ -621,7 +631,7 @@ class PolicyGradientOptimizer(LogitBasedComboOptimizerBase):
621631 def __init__ (
622632 self ,
623633 param : ng .p .Dict ,
624- obj_func : Callable [[Dict [str , torch .Tensor ]], torch .Tensor ],
634+ obj_func : Optional [ Callable [[Dict [str , torch .Tensor ]], torch .Tensor ]] = None ,
625635 # default (start_temp=min_temp=1.0): no temperature change for policy gradient
626636 start_temp : float = 1.0 ,
627637 min_temp : float = 1.0 ,
@@ -632,9 +642,9 @@ def __init__(
632642 ) -> None :
633643 super ().__init__ (
634644 param ,
635- obj_func ,
636645 start_temp ,
637646 min_temp ,
647+ obj_func ,
638648 learning_rate ,
639649 anneal_rate ,
640650 batch_size ,
@@ -756,7 +766,7 @@ class QLearningOptimizer(ComboOptimizerBase):
756766 def __init__ (
757767 self ,
758768 param : ng .p .Dict ,
759- obj_func : Callable [[Dict [str , torch .Tensor ]], torch .Tensor ],
769+ obj_func : Optional [ Callable [[Dict [str , torch .Tensor ]], torch .Tensor ]] = None ,
760770 start_temp : float = 1.0 ,
761771 min_temp : float = 0.1 ,
762772 learning_rate : float = LEARNING_RATE ,
0 commit comments