Skip to content

Commit 24f3dfa

Browse files
committed
merge v3
1 parent 6985145 commit 24f3dfa

File tree

1 file changed

+69
-0
lines changed
  • rdagent/scenarios/data_science/proposal/exp_gen

1 file changed

+69
-0
lines changed

rdagent/scenarios/data_science/proposal/exp_gen/merge.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,72 @@ def gen(self, trace: DSTrace) -> DSExperiment:
342342
# return self.merge_exp_gen.gen(trace)
343343
trace.set_current_selection(selection=(-1,))
344344
return self.exp_gen.gen(trace) # continue the last trace, to polish the merged solution
345+
346+
347+
class ExpGen2TraceAndMergeV3(ExpGen):
348+
def __init__(self, *args, **kwargs):
349+
super().__init__(*args, **kwargs)
350+
self.merge_exp_gen = ExpGen2Hypothesis(self.scen)
351+
self.exp_gen = DataScienceRDLoop._get_exp_gen(
352+
"rdagent.scenarios.data_science.proposal.exp_gen.DSExpGen", self.scen
353+
)
354+
self.MAX_TRACE_NUM = DS_RD_SETTING.max_trace_num # maximum number of traces to grow before merging
355+
self.flag_start_merge = False
356+
357+
def gen(self, trace: DSTrace) -> DSExperiment:
358+
timer: RDAgentTimer = RD_Agent_TIMER_wrapper.timer
359+
logger.info(f"Remain time: {timer.remain_time_duration}")
360+
361+
if timer.remain_time_duration >= timedelta(hours=DS_RD_SETTING.merge_hours):
362+
363+
if DS_RD_SETTING.enable_inject_knowledge_at_root:
364+
365+
if len(trace.hist) == 0:
366+
# set the knowledge base option to True for the first trace
367+
DS_RD_SETTING.enable_knowledge_base = True
368+
369+
else:
370+
# set the knowledge base option back to False for the other traces
371+
DS_RD_SETTING.enable_knowledge_base = False
372+
return self.exp_gen.gen(trace)
373+
374+
else:
375+
# disable reset in merging stage
376+
DS_RD_SETTING.coding_fail_reanalyze_threshold = 100000
377+
DS_RD_SETTING.consecutive_errors = 100000
378+
379+
leaves: list[int] = trace.get_leaves()
380+
if len(leaves) < 2:
381+
trace.set_current_selection(selection=(-1,))
382+
return self.exp_gen.gen(trace)
383+
else:
384+
selection = (leaves[0],)
385+
sota_exp_fb = trace.sota_experiment_fb(selection=selection)
386+
if sota_exp_fb is None:
387+
sota_exp_fb = trace.hist[leaves[0]]
388+
exp_to_merge_fb = trace.sota_experiment_fb(selection=(leaves[1],))
389+
if exp_to_merge_fb is None:
390+
exp_to_merge_fb = trace.hist[leaves[1]]
391+
try:
392+
if (
393+
trace.sota_exp_to_submit is not None
394+
and sota_exp_fb[0].result is not None
395+
and exp_to_merge_fb[0].result is not None
396+
):
397+
current_exp_value = exp_to_merge_fb[0].result.loc["ensemble"].iloc[0]
398+
sota_submit_value = trace.sota_exp_to_submit.result.loc["ensemble"].iloc[0]
399+
sota_feedback_value = sota_exp_fb[0].result.loc["ensemble"].iloc[0]
400+
401+
# SOTA experiment value may not be the last value in the trace
402+
logger.info(
403+
f"{leaves[0]} score: {current_exp_value}, {leaves[1]} score: {current_exp_value}, Sota score: {sota_submit_value}"
404+
)
405+
if abs(current_exp_value - sota_submit_value) < abs(current_exp_value - sota_feedback_value):
406+
selection = (leaves[1],)
407+
if sota_exp_fb[0].result is None and exp_to_merge_fb[0].result is not None:
408+
logger.info(f"{leaves[0]} result is None, change selection to {leaves[1]}, result is {exp_to_merge_fb[0].result}")
409+
selection = (leaves[1],)
410+
except Exception as e:
411+
logger.error(f"Get best selection failed: {e}")
412+
trace.set_current_selection(selection)
413+
return self.merge_exp_gen.gen(trace)

0 commit comments

Comments
 (0)