Skip to content

Commit f760a3e

Browse files
authored
fix: fix the bug of Exceed-LLM-Context in online merge of multi-tarce (#892)
* set constrains on max_sota_retrieved, fix logis on identical problem * fix: only Auto SOTA selector use max_sota_retrieved_num * set max_sota_retrieved_num=10 by default * minor update * auto lint
1 parent 12c9ef4 commit f760a3e

File tree

6 files changed

+43
-10
lines changed

6 files changed

+43
-10
lines changed

rdagent/app/data_science/conf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,10 @@ class DataScienceBasePropSetting(KaggleBasePropSetting):
9898
merge_hours: int = 2
9999
"""The time for merge"""
100100

101+
#### multi-trace: max SOTA-retrieved number, used in AutoSOTAexpSelector
102+
# constrains the number of SOTA experiments to retrieve, otherwise too many SOTA experiments to retrieve will cause the exceed of the context window of LLM
103+
max_sota_retrieved_num: int = 10
104+
"""The maximum number of SOTA experiments to retrieve in a LLM call"""
105+
101106

102107
DS_RD_SETTING = DataScienceBasePropSetting()

rdagent/log/mle_summary.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,7 @@
1010
from rdagent.core.experiment import FBWorkspace
1111
from rdagent.core.proposal import ExperimentFeedback
1212
from rdagent.log.storage import FileStorage
13-
from rdagent.log.utils import (
14-
extract_json,
15-
extract_loopid_func_name,
16-
is_valid_session,
17-
)
13+
from rdagent.log.utils import extract_json, extract_loopid_func_name, is_valid_session
1814
from rdagent.log.utils.folder import get_first_session_file_after_duration
1915
from rdagent.scenarios.data_science.experiment.experiment import DSExperiment
2016
from rdagent.scenarios.data_science.test_eval import (

rdagent/scenarios/data_science/proposal/exp_gen/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,15 @@ def experiment_and_feedback_list_after_init(
212212
return_type: Literal["sota", "failed", "all"],
213213
search_type: Literal["all", "ancestors"] = "all",
214214
selection: tuple[int, ...] | None = None,
215+
max_retrieve_num: int | None = None,
215216
) -> list[tuple[DSExperiment, ExperimentFeedback]]:
216217
"""
217218
Retrieve a list of experiments and feedbacks based on the return_type.
218219
"""
219220
search_list = self.retrieve_search_list(search_type, selection=selection)
221+
if max_retrieve_num is not None and len(search_list) > 0:
222+
retrieve_num = min(max_retrieve_num, len(search_list))
223+
search_list = search_list[:retrieve_num]
220224

221225
final_component = self.COMPLETE_ORDER[-1]
222226
has_final_component = True if DS_RD_SETTING.coder_on_whole_pipeline else False

rdagent/scenarios/data_science/proposal/exp_gen/merge.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,12 @@ def gen(self, trace: DSTrace, selection: tuple[int, ...] = (-1,)) -> DSExperimen
147147
)
148148

149149
success_fb_list = trace.experiment_and_feedback_list_after_init(
150-
return_type="sota", search_type="ancestors", selection=(leaves[i],)
150+
return_type="sota",
151+
search_type="ancestors",
152+
selection=(leaves[i],),
151153
)
152154
if len(success_fb_list) > 0:
155+
153156
exp_to_merge_fb_desc = T("scenarios.data_science.share:describe.trace").r(
154157
exp_and_feedback_list=success_fb_list,
155158
type="success",

rdagent/scenarios/data_science/proposal/exp_gen/proposal.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -694,8 +694,9 @@ def gen(self, trace: DSTrace, pipeline: bool = False) -> DSExperiment:
694694
inject_diverse = False
695695

696696
# Step 1: Identify problems
697+
current_sub_trace = trace.collect_all_ancestors(selection=(-1,))
697698
all_problems = {}
698-
if len(trace.hist) >= 3:
699+
if len(current_sub_trace) >= 3:
699700
fb_problems = self.identify_feedback_problem(
700701
scenario_desc=scenario_desc,
701702
exp_feedback_list_desc=exp_feedback_list_desc,
@@ -706,7 +707,7 @@ def gen(self, trace: DSTrace, pipeline: bool = False) -> DSExperiment:
706707
fb_problems[problem_name]["label"] = "FEEDBACK_PROBLEM"
707708
all_problems[problem_name] = fb_problems[problem_name]
708709

709-
if len(trace.hist) < 9:
710+
if len(current_sub_trace) < 9:
710711
scen_problems = self.identify_scenario_problem(
711712
scenario_desc=scenario_desc,
712713
sota_exp_desc=sota_exp_desc,

rdagent/scenarios/data_science/proposal/exp_gen/sota_exp_select.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def __init__(
4444
def get_sota_exp_to_submit(self, trace: Trace) -> DSExperiment | None:
4545
# retrieve all SOTA experiments from the trace
4646

47-
sota_exp_fb_list = trace.experiment_and_feedback_list_after_init(return_type="sota", search_type="all")
47+
sota_exp_fb_list = trace.experiment_and_feedback_list_after_init(
48+
return_type="sota", search_type="all", max_retrieve_num=DS_RD_SETTING.max_sota_retrieved_num
49+
)
4850

4951
if len(sota_exp_fb_list) == 0:
5052
logger.info("Auto SOTA selector: No SOTA in trace yet")
@@ -58,10 +60,32 @@ def get_sota_exp_to_submit(self, trace: Trace) -> DSExperiment | None:
5860
return sota_exp_fb_list[0][0]
5961

6062
else:
61-
logger.info("Auto SOTA selector: Multiple SOTA in trace, calling LLM to select the best one")
63+
logger.info(
64+
f"Auto SOTA selector: Multiple SOTA in trace, calling LLM to select the best one in {DS_RD_SETTING.max_sota_retrieved_num} SOTA experiments"
65+
)
6266

6367
SOAT_exp_with_desc_and_scores = "Historical SOTA experiments:\n\n"
6468

69+
leaves: list[int] = trace.get_leaves()
70+
71+
if len(leaves) >= 2:
72+
# multiple trace case, collect the latest SOTA experiments from each trace
73+
new_sota_exp_fb_list: list[tuple[DSExperiment, ExperimentFeedback]] = []
74+
# calculate the number of SOTA experiments to retrieve from each trace
75+
max_sota_retrieved_num_per_trace = DS_RD_SETTING.max_sota_retrieved_num // len(leaves)
76+
# recall, due to the integer division, the final number of SOTA experiments to retrieve may be different
77+
for leaf in leaves:
78+
sota_exp_fb_list_per_trace = trace.experiment_and_feedback_list_after_init(
79+
return_type="sota",
80+
search_type="ancestors",
81+
selection=(leaf,),
82+
max_retrieve_num=max_sota_retrieved_num_per_trace,
83+
)
84+
85+
new_sota_exp_fb_list.extend(sota_exp_fb_list_per_trace)
86+
87+
sota_exp_fb_list = new_sota_exp_fb_list
88+
6589
for i, (exp, ef) in enumerate(sota_exp_fb_list):
6690
if exp:
6791
current_final_score = pd.DataFrame(exp.result).loc["ensemble"].iloc[0]

0 commit comments

Comments
 (0)