Skip to content

Commit 07d5f1f

Browse files
authored
fix: move next_component_required logic to DSTrace class and accurate implement (microsoft#612)
* refactor: Move next_component_required logic to DSTrace class * lint
1 parent f900ec5 commit 07d5f1f

File tree

3 files changed

+24
-30
lines changed

3 files changed

+24
-30
lines changed

rdagent/app/data_science/loop.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def coding(self, prev_out: dict[str, Any]):
9393

9494
def running(self, prev_out: dict[str, Any]):
9595
exp: DSExperiment = prev_out["coding"]
96-
if exp.next_component_required() is None:
96+
if self.trace.next_component_required() is None:
9797
new_exp = self.runner.develop(exp)
9898
logger.log_object(new_exp)
9999
return new_exp
@@ -102,7 +102,7 @@ def running(self, prev_out: dict[str, Any]):
102102

103103
def feedback(self, prev_out: dict[str, Any]) -> ExperimentFeedback:
104104
exp: DSExperiment = prev_out["running"]
105-
if exp.next_component_required() is None:
105+
if self.trace.next_component_required() is None:
106106
feedback = self.summarizer.generate_feedback(exp, self.trace)
107107
else:
108108
feedback = ExperimentFeedback(
@@ -124,15 +124,11 @@ def record(self, prev_out: dict[str, Any]):
124124
)
125125
)
126126
if self.trace.sota_experiment() is None and len(self.trace.hist) >= DS_RD_SETTING.consecutive_errors:
127-
trace_exp_next_component_list = [
128-
type(exp.pending_tasks_list[0][0])
129-
for exp, _ in self.trace.hist[-DS_RD_SETTING.consecutive_errors :]
130-
]
131-
last_successful_exp = self.trace.last_successful_exp()
132-
if (
133-
last_successful_exp not in [exp for exp, _ in self.trace.hist[-DS_RD_SETTING.consecutive_errors :]]
134-
and len(set(trace_exp_next_component_list)) == 1
135-
):
127+
# if {in inital/drafting stage} and {tried enough times}
128+
for _, fb in self.trace.hist[-DS_RD_SETTING.consecutive_errors :]:
129+
if fb:
130+
break # any success will stop restarting.
131+
else: # otherwise restart it
136132
logger.error("Consecutive errors reached the limit. Dumping trace.")
137133
logger.log_object(self.trace, tag="trace before restart")
138134
self.trace = DSTrace(scen=self.trace.scen, knowledge_base=self.trace.knowledge_base)

rdagent/scenarios/data_science/experiment/experiment.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,3 @@ def __init__(self, pending_tasks_list: list, *args, **kwargs) -> None:
1919
self.experiment_workspace = FBWorkspace()
2020
self.pending_tasks_list = pending_tasks_list
2121
self.format_check_result = None
22-
23-
def next_component_required(self) -> COMPONENT | None:
24-
files = list(self.experiment_workspace.file_dict.keys())
25-
if "load_data.py" not in files:
26-
return "DataLoadSpec"
27-
if "feature.py" not in files:
28-
return "FeatureEng"
29-
if not any(re.match(r"model.*\.py", file) for file in files):
30-
return "Model"
31-
if "ensemble.py" not in files:
32-
return "Ensemble"
33-
if "main.py" not in files:
34-
return "Workflow"
35-
return None

rdagent/scenarios/data_science/proposal/exp_gen.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,21 @@ def __init__(self, scen: DataScienceScen, knowledge_base: KnowledgeBase | None =
9090
self.hist: list[tuple[DSExperiment, ExperimentFeedback]] = []
9191
self.knowledge_base = knowledge_base
9292

93+
COMPLETE_ORDER = ("DataLoadSpec", "FeatureEng", "Model", "Ensemble", "Workflow")
94+
95+
def next_component_required(self) -> COMPONENT | None:
96+
for c in self.COMPLETE_ORDER:
97+
if not self.has_compponent(c):
98+
return c
99+
return None
100+
101+
def has_compponent(self, component: COMPONENT) -> bool:
102+
for exp, fb in self.hist:
103+
assert isinstance(exp.hypothesis, DSHypothesis), "Hypothesis should be DSHypothesis (and not None)"
104+
if exp.hypothesis.component == component and fb:
105+
return True
106+
return False
107+
93108
def sota_experiment(self, last_n: int = -1) -> DSExperiment | None:
94109
"""
95110
Access the last experiment result.
@@ -108,7 +123,7 @@ def sota_experiment(self, last_n: int = -1) -> DSExperiment | None:
108123
assert last_n < 0
109124
for exp, ef in self.hist[::-1]:
110125
# the sota exp should be accepted decision and all required components are completed.
111-
if ef.decision and exp.next_component_required() is None:
126+
if ef.decision and self.next_component_required() is None:
112127
last_n += 1
113128
if last_n == 0:
114129
return exp
@@ -237,10 +252,7 @@ def gen(self, trace: DSTrace) -> DSExperiment:
237252
scenario_desc = trace.scen.get_scenario_all_desc()
238253
last_successful_exp = trace.last_successful_exp()
239254

240-
if len(trace.hist) == 0 or last_successful_exp is None:
241-
next_missing_component = "DataLoadSpec"
242-
else:
243-
next_missing_component = last_successful_exp.next_component_required()
255+
next_missing_component = trace.next_component_required()
244256

245257
init_component_config = {
246258
"DataLoadSpec": {"task_cls": DataLoaderTask, "spec_file": None, "component_prompt_key": "data_loader"},

0 commit comments

Comments
 (0)