Skip to content

Commit 8f8afea

Browse files
authored
chore: more optional parameters for running benchmark analysis (microsoft#431)
* set title and round * decision from multiple types * check if decision is true * reformat * remove unused file
1 parent b427960 commit 8f8afea

File tree

2 files changed

+12
-19
lines changed

2 files changed

+12
-19
lines changed

rdagent/app/benchmark/factor/analysis.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,28 +178,32 @@ def change_fs(font_size):
178178
plt.rc("figure", titlesize=font_size)
179179

180180
@staticmethod
181-
def plot_data(data, file_name):
181+
def plot_data(data, file_name, title):
182182
plt.figure(figsize=(10, 6))
183183
sns.barplot(x="index", y="b", hue="a", data=data)
184184
plt.xlabel("Method")
185185
plt.ylabel("Value")
186-
plt.title("Comparison of Different Methods")
186+
plt.title(title)
187187
plt.savefig(file_name)
188188

189189

190-
def main(path="git_ignore_folder/eval_results/res_promptV220240724-060037.pkl"):
190+
def main(
191+
path="git_ignore_folder/eval_results/res_promptV220240724-060037.pkl",
192+
round=1,
193+
title="Comparison of Different Methods",
194+
):
191195
settings = BenchmarkSettings()
192196
benchmark = BenchmarkAnalyzer(settings)
193197
results = {
194-
"1 round experiment": path,
198+
f"{round} round experiment": path,
195199
}
196200
final_results = benchmark.process_results(results)
197201
final_results_df = pd.DataFrame(final_results)
198202

199203
Plotter.change_fs(20)
200204
plot_data = final_results_df.drop(["max. accuracy", "avg. accuracy"], axis=0).T
201205
plot_data = plot_data.reset_index().melt("index", var_name="a", value_name="b")
202-
Plotter.plot_data(plot_data, "./comparison_plot.png")
206+
Plotter.plot_data(plot_data, "./comparison_plot.png", title)
203207

204208

205209
if __name__ == "__main__":

rdagent/components/coder/factor_coder/CoSTEER/evaluators.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,7 @@ def evaluate(
195195
user_prompt=gen_df_info_str, system_prompt=system_prompt, json_mode=True
196196
)
197197
resp_dict = json.loads(resp)
198-
199-
if isinstance(resp_dict["output_format_decision"], str) and resp_dict[
200-
"output_format_decision"
201-
].lower() in (
202-
"true",
203-
"false",
204-
):
205-
resp_dict["output_format_decision"] = resp_dict["output_format_decision"].lower() == "true"
198+
resp_dict["output_format_decision"] = str(resp_dict["output_format_decision"]).lower() in ["true", "1"]
206199

207200
return (
208201
resp_dict["output_format_feedback"],
@@ -243,7 +236,7 @@ def evaluate(
243236
False,
244237
)
245238

246-
time_diff = gen_df.index.get_level_values("datetime").to_series().diff().dropna().unique()
239+
time_diff = pd.to_datetime(gen_df.index.get_level_values("datetime")).to_series().diff().dropna().unique()
247240
if pd.Timedelta(minutes=1) in time_diff:
248241
return (
249242
"The generated dataframe is not daily. The implementation is definitely wrong. Please check the implementation.",
@@ -548,11 +541,7 @@ def evaluate(
548541
final_decision = final_evaluation_dict["final_decision"]
549542
final_feedback = final_evaluation_dict["final_feedback"]
550543

551-
if isinstance(final_decision, str) and final_decision.lower() in ("true", "false"):
552-
final_decision = final_decision.lower() == "true"
553-
elif isinstance(final_decision, int) and final_decision in (0, 1):
554-
final_decision = bool(final_decision)
555-
544+
final_decision = str(final_decision).lower() in ["true", "1"]
556545
return final_decision, final_feedback
557546

558547
except json.JSONDecodeError as e:

0 commit comments

Comments
 (0)