Skip to content

Commit a3d5473

Browse files
authored
feat: trace merging (microsoft#836)
* feat: runnalbe -- add exp_gen_cls param, get_leaves and merge exp gen functionalities * fix: remove unused scenario_desc and update YAML task labels * feat: override selection and update merge task description * lint * lint * lint * lint * lint * fix: log competition setting to enable mle_summary * fix name error
1 parent b2eec32 commit a3d5473

File tree

11 files changed

+126
-15
lines changed

11 files changed

+126
-15
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ celerybeat.pid
112112

113113
# Environments
114114
.env*
115+
*.env
115116
.venv
116117
^env/
117118
venv/

rdagent/app/data_science/loop.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def record(self, prev_out: dict[str, Any]):
166166
self.trace = DSTrace(scen=self.trace.scen, knowledge_base=self.trace.knowledge_base)
167167
logger.log_object(self.trace, tag="trace")
168168
logger.log_object(self.trace.sota_experiment(), tag="SOTA experiment")
169+
169170
if DS_RD_SETTING.enable_knowledge_base and DS_RD_SETTING.knowledge_base_version == "v1":
170171
logger.log_object(self.trace.knowledge_base, tag="knowledge_base")
171172
self.trace.knowledge_base.dump()
@@ -228,6 +229,7 @@ def load(
228229
replace_timer: bool = True,
229230
) -> "LoopBase":
230231
session = super().load(path, output_path, do_truncate, replace_timer)
232+
logger.log_object(DS_RD_SETTING.competition, tag="competition") # NOTE: necessary to make mle_summary work.
231233
if DS_RD_SETTING.enable_knowledge_base and DS_RD_SETTING.knowledge_base_version == "v1":
232234
session.trace.knowledge_base = DSKnowledgeBase(
233235
path=DS_RD_SETTING.knowledge_base_path, idea_pool_json_path=DS_RD_SETTING.idea_pool_json_path
@@ -257,6 +259,7 @@ def main(
257259
do_truncate=True,
258260
timeout=None,
259261
replace_timer=True,
262+
exp_gen_cls: str | None = None,
260263
):
261264
"""
262265
@@ -275,6 +278,10 @@ def main(
275278
competition :
276279
do_truncate :
277280
If set to True, the logger will truncate the future log messages by calling `logger.storage.truncate`.
281+
replace_timer :
282+
If session is loaded, should we replace the timer with session.timer
283+
exp_gen_cls :
284+
When we have different stages, we can replace the exp_gen with the new proposal
278285
279286
280287
Auto R&D Evolving loop for models in a Kaggle scenario.
@@ -300,6 +307,11 @@ def main(
300307
kaggle_loop = DataScienceRDLoop(DS_RD_SETTING)
301308
else:
302309
kaggle_loop = DataScienceRDLoop.load(path, output_path, do_truncate, replace_timer)
310+
311+
# replace exp_gen if we have new class
312+
if exp_gen_cls is not None:
313+
kaggle_loop.exp_gen = import_class(exp_gen_cls)(kaggle_loop.exp_gen.scen)
314+
303315
kaggle_loop.run(step_n=step_n, loop_n=loop_n, all_duration=timeout)
304316

305317

rdagent/scenarios/data_science/experiment/experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from rdagent.core.experiment import Experiment, FBWorkspace, Task
77

8-
COMPONENT = Literal["DataLoadSpec", "FeatureEng", "Model", "Ensemble", "Workflow"]
8+
COMPONENT = Literal["DataLoadSpec", "FeatureEng", "Model", "Ensemble", "Workflow", "Pipeline"]
99

1010

1111
class DSExperiment(Experiment[Task, FBWorkspace, FBWorkspace]):

rdagent/scenarios/data_science/proposal/exp_gen/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from rdagent.core.proposal import ExpGen
33
from rdagent.core.utils import import_class
44
from rdagent.scenarios.data_science.experiment.experiment import DSExperiment
5-
from rdagent.scenarios.data_science.proposal.exp_gen.base import DSTrace
5+
from rdagent.scenarios.data_science.proposal.exp_gen.base import DSHypothesis, DSTrace
66
from rdagent.scenarios.data_science.proposal.exp_gen.draft import DSDraftExpGen
77
from rdagent.scenarios.data_science.proposal.exp_gen.proposal import (
88
DSProposalV1ExpGen,

rdagent/scenarios/data_science/proposal/exp_gen/base.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,21 @@ def get_current_selection(self) -> tuple[int, ...]:
6868
def set_current_selection(self, selection: tuple[int, ...]) -> None:
6969
self.current_selection = selection
7070

71+
def get_leaves(self) -> list[int, ...]:
72+
"""
73+
Get the indices of nodes (in hist) that have no children—i.e., "leaves" of current DAG.
74+
Returns:
75+
tuple of ints: Indices of leaf nodes.
76+
- Leaves with lower index comes first.
77+
"""
78+
# Build a set of all parent indices found in dag_parent (skip empty tuples which represent roots)
79+
parent_indices = set(idx for parents in self.dag_parent for idx in parents)
80+
# All node indices
81+
all_indices = set(range(len(self.hist)))
82+
# The leaf nodes have no children, so they are not present as parents of any other node
83+
leaves = list(sorted(all_indices - parent_indices))
84+
return leaves
85+
7186
def sync_dag_parent_and_hist(
7287
self,
7388
) -> None:
@@ -90,7 +105,9 @@ def sync_dag_parent_and_hist(
90105
self.dag_parent.append((current_node_idx,))
91106

92107
def retrieve_search_list(
93-
self, search_type: Literal["all", "ancestors"] = "ancestors"
108+
self,
109+
search_type: Literal["all", "ancestors"] = "ancestors",
110+
selection: tuple[int, ...] | None = None,
94111
) -> list[tuple[DSExperiment, ExperimentFeedback]]:
95112
"""
96113
Retrieve the search list based on the selection and search_type.
@@ -108,7 +125,9 @@ def retrieve_search_list(
108125
The search list.
109126
"""
110127

111-
selection = self.get_current_selection()
128+
if selection is None:
129+
selection = self.get_current_selection()
130+
112131
if selection is None:
113132
# selection is None, which means we switch to a new trace, which is not implemented yet
114133
return []
@@ -175,11 +194,12 @@ def experiment_and_feedback_list_after_init(
175194
self,
176195
return_type: Literal["sota", "failed", "all"],
177196
search_type: Literal["all", "ancestors"] = "all",
197+
selection: tuple[int, ...] | None = None,
178198
) -> list[tuple[DSExperiment, ExperimentFeedback]]:
179199
"""
180200
Retrieve a list of experiments and feedbacks based on the return_type.
181201
"""
182-
search_list = self.retrieve_search_list(search_type)
202+
search_list = self.retrieve_search_list(search_type, selection=selection)
183203

184204
final_component = self.COMPLETE_ORDER[-1]
185205
has_final_component = True if DS_RD_SETTING.coder_on_whole_pipeline else False
@@ -199,6 +219,7 @@ def experiment_and_feedback_list_after_init(
199219
def sota_experiment(
200220
self,
201221
search_type: Literal["all", "ancestors"] = "ancestors",
222+
selection: tuple[int, ...] | None = None,
202223
) -> DSExperiment | None:
203224
"""
204225
@@ -207,7 +228,7 @@ def sota_experiment(
207228
Experiment or None
208229
The experiment result if found, otherwise None.
209230
"""
210-
search_list = self.retrieve_search_list(search_type)
231+
search_list = self.retrieve_search_list(search_type, selection=selection)
211232

212233
if DS_RD_SETTING.coder_on_whole_pipeline or self.next_incomplete_component() is None:
213234
for exp, ef in search_list[::-1]:
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""Merge the version in different traces"""
2+
3+
from rdagent.components.coder.data_science.pipeline.exp import PipelineTask
4+
from rdagent.core.proposal import ExpGen
5+
from rdagent.scenarios.data_science.experiment.experiment import DSExperiment
6+
from rdagent.scenarios.data_science.proposal.exp_gen.base import DSHypothesis, DSTrace
7+
from rdagent.utils.agent.tpl import T
8+
9+
10+
class MergeExpGen(ExpGen):
11+
def gen(self, trace: DSTrace, selection: tuple[int, ...] = (-1,)) -> DSExperiment:
12+
# Ignore the selection argument and use all leaves instead.
13+
leaves: list[int] = trace.get_leaves()
14+
trace.set_current_selection((leaves[0],)) # override the current selection.
15+
16+
# assuming merging the first and sencond trace.
17+
sota_exp = trace.sota_experiment(selection=(leaves[0],))
18+
exp_to_merge = trace.sota_experiment(selection=(leaves[1],))
19+
20+
# scenario_desc = trace.scen.get_scenario_all_desc()
21+
# scenario_desc is not needed in task description. So we have to do it.
22+
23+
sota_exp_desc = T("scenarios.data_science.share:describe.exp").r(
24+
exp=sota_exp,
25+
heading="Best of previous exploration of the scenario",
26+
)
27+
exp_to_merge_desc = T("scenarios.data_science.share:describe.exp").r(
28+
exp=exp_to_merge,
29+
heading="A solution that to be merged into previous best solution",
30+
)
31+
32+
exp_and_feedback_list_desc = T("scenarios.data_science.share:describe.trace").r(
33+
exp_and_feedback_list=trace.experiment_and_feedback_list_after_init(
34+
return_type="sota", selection=(leaves[1],)
35+
),
36+
type="success",
37+
)
38+
39+
task = PipelineTask(
40+
description=T("scenarios.data_science.proposal.exp_gen.merge:task").r(
41+
sota_exp_desc=sota_exp_desc,
42+
exp_to_merge_desc=exp_to_merge_desc,
43+
exp_and_feedback_list_desc=exp_and_feedback_list_desc,
44+
)
45+
)
46+
47+
exp = DSExperiment(
48+
pending_tasks_list=[[task]],
49+
hypothesis=DSHypothesis(
50+
component="Pipeline",
51+
hypothesis="Merging two different versions of solutions would get the best of both sides and result in a better solution",
52+
),
53+
)
54+
55+
if sota_exp is not None:
56+
exp.experiment_workspace.inject_code_from_file_dict(sota_exp.experiment_workspace)
57+
return exp
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
task: |-
2+
{% include "scenarios.data_science.share:scen.role" %}
3+
4+
The user is improving a Kaggle competition implementation iteratively.
5+
Your task is to merge two solutions to create a better version. We expect the merged version to perform better than both given solutions.
6+
7+
You will be given:
8+
1) Previous Main Solution: this is the main solution you will build on to create an improved version;
9+
2) Solution to be merged: another solution that you will combine with the previous main solution.
10+
- Solution: the approach or method used in this solution.
11+
- Successful iterations: the steps or changes that led to the success of `Solution to be merged`.
12+
13+
# Previous Main Solution
14+
{{ sota_exp_desc }}
15+
16+
# Solution to be merged
17+
## Solution Descrioption:
18+
{{ exp_to_merge_desc }}
19+
## Successful iterations:
20+
{{ exp_and_feedback_list_desc }}

rdagent/scenarios/data_science/proposal/exp_gen/naive.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
The most naive way to design experiments
33
"""
44

5-
from rdagent.app.data_science.conf import DS_RD_SETTING
65
from rdagent.components.coder.data_science.pipeline.exp import PipelineTask
76
from rdagent.core.proposal import ExpGen
87
from rdagent.scenarios.data_science.experiment.experiment import DSExperiment

rdagent/scenarios/data_science/proposal/exp_gen/proposal.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
from rdagent.oai.llm_utils import APIBackend, md5_hash
1515
from rdagent.scenarios.data_science.experiment.experiment import DSExperiment
1616
from rdagent.scenarios.data_science.proposal.exp_gen.base import DSHypothesis, DSTrace
17-
from rdagent.scenarios.data_science.proposal.exp_gen.idea_pool import (
18-
DSIdea,
19-
)
17+
from rdagent.scenarios.data_science.proposal.exp_gen.idea_pool import DSIdea
2018
from rdagent.utils.agent.tpl import T
2119
from rdagent.utils.repo.diff import generate_diff_from_dict
2220
from rdagent.utils.workflow import wait_retry

rdagent/scenarios/data_science/share.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,4 +332,4 @@ component_spec:
332332
333333
guidelines:
334334
coding: |-
335-
You might receive exploratory data analysis (EDA) details about the source data. Do not use this EDA information to create assertions or raise errors. We might generate sample data for quick coding (so your code may run on sample data which is part of the full-size data), but remember that the EDA details are based on the full-size data.
335+
You might receive exploratory data analysis (EDA) details about the source data. Do not use this EDA information to create assertions or raise errors. We might generate sample data for quick coding (so your code may run on sample data which is part of the full-size data), but remember that the EDA details are based on the full-size data.

0 commit comments

Comments
 (0)