Skip to content

Commit e8f6186

Browse files
committed
update
1 parent f14830b commit e8f6186

File tree

17 files changed

+901
-271
lines changed

17 files changed

+901
-271
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,6 @@ cov.xml
189189
*.dot
190190
.python-version
191191
*.csv
192+
/examples/ags/data/baselines/general
193+
/examples/ags/scripts/optimized/HumanEval/graphs
194+
/examples/ags/scripts/optimized/HumanEval/graphs_test

examples/ags/benchmark/humaneval.py

Lines changed: 18 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
from datetime import datetime
88
from typing import List, Tuple, Callable, Dict, Any, Optional
99

10+
import re
1011
import pandas as pd
1112
from tqdm.asyncio import tqdm_asyncio
1213

1314
from examples.ags.benchmark.utils import generate_random_indices
1415
from examples.ags.benchmark.utils import log_mismatch
16+
from metagpt.actions.code_sanitize import sanitize
1517

1618

1719
async def load_data(file_path: str, samples=1, test=False) -> List[dict]:
@@ -38,58 +40,6 @@ async def load_file_data(file_path: str, specific_indices: List[int] = None) ->
3840

3941
return data
4042

41-
# async def check_solution(solution, test, entry_point):
42-
43-
# print(f"solution: {solution}")
44-
45-
# try:
46-
# # 定义一个包含所有必要模块的全局字典
47-
# global_dict = {
48-
# 'math': __import__('math'),
49-
# 'hashlib': __import__('hashlib'),
50-
# 're': __import__('re'),
51-
# 'List': List,
52-
# 'Dict': Dict,
53-
# 'Tuple': Tuple,
54-
# 'Optional': Optional,
55-
# 'Any': Any
56-
# }
57-
# if entry_point == "decode_cyclic":
58-
# solution = "\n\ndef encode_cyclic(s: str):\n \"\"\"\n returns encoded string by cycling groups of three characters.\n \"\"\"\n # split string to groups. Each of length 3.\n groups = [s[(3 * i):min((3 * i + 3), len(s))] for i in range((len(s) + 2) // 3)]\n # cycle elements in each group. Unless group has fewer elements than 3.\n groups = [(group[1:] + group[0]) if len(group) == 3 else group for group in groups]\n return \"\".join(groups)" + "\n\n" + solution
59-
# elif entry_point == "decode_shift":
60-
# solution = "\n\ndef encode_shift(s: str):\n \"\"\"\n returns encoded string by shifting every character by 5 in the alphabet.\n \"\"\"\n return \"\".join([chr(((ord(ch) + 5 - ord(\"a\")) % 26) + ord(\"a\")) for ch in s])\n\n\n" + solution
61-
# elif entry_point == "find_zero":
62-
# solution = "\n\ndef poly(xs: list, x: float):\n return sum(coeff * (x ** i) for i, coeff in enumerate(xs))\n\n" + solution
63-
# # 执行解决方案
64-
# exec(solution, global_dict)
65-
66-
# # 确保入口点函数已定义
67-
# if entry_point not in global_dict:
68-
# raise ValueError(f"函数 {entry_point} 在解决方案中未定义。")
69-
70-
# # 执行测试用例
71-
# exec(test, global_dict)
72-
73-
# # 获取检查函数
74-
# check = global_dict["check"]
75-
76-
# # 运行检查函数
77-
# result = check(global_dict[entry_point])
78-
79-
# if result is None:
80-
# result = (PASS, "解决方案通过了所有测试用例。")
81-
82-
# except Exception as e:
83-
# # 记录详细的错误信息
84-
# error_message = f"错误: {str(e)}.\n 解决方案: {solution}.\n 测试: {test}"
85-
# result = (FAIL, error_message)
86-
87-
# # 将错误信息写入error.log文件
88-
# with open('error.log', 'a', encoding='utf-8') as log_file:
89-
# log_file.write(f"{time.strftime('%Y-%m-%d %H:%M:%S')} - {error_message}\n")
90-
91-
# return result
92-
9343
PASS = "PASS"
9444
FAIL = "FAIL"
9545

@@ -98,24 +48,33 @@ class TimeoutError(Exception):
9848

9949
def run_with_timeout(func, args, timeout):
10050
result = []
51+
stop_event = threading.Event()
52+
10153
def target():
10254
try:
10355
result.append(func(*args))
10456
except Exception as e:
10557
result.append(e)
58+
finally:
59+
stop_event.set()
10660

10761
thread = threading.Thread(target=target)
10862
thread.start()
109-
thread.join(timeout)
110-
if thread.is_alive():
63+
is_timeout = not stop_event.wait(timeout)
64+
65+
if is_timeout:
66+
# 线程仍在运行,我们无法强制终止它,但至少可以标记超时
11167
raise TimeoutError("Function execution timed out")
68+
69+
if not result:
70+
return None
11271
if isinstance(result[0], Exception):
11372
raise result[0]
11473
return result[0]
11574

11675
def check_solution(solution, test, entry_point):
117-
print(f"solution: {solution}")
11876

77+
solution = sanitize(code=solution, entrypoint=entry_point)
11978
try:
12079
# 定义一个包含所有必要模块的全局字典
12180
global_dict = {
@@ -147,8 +106,8 @@ def check_solution(solution, test, entry_point):
147106
# 获取检查函数
148107
check = global_dict["check"]
149108

150-
# 运行检查函数,设置超时时间为5秒
151-
result = run_with_timeout(check, (global_dict[entry_point],), 120)
109+
# 运行检查函数,设置超时时间为120秒
110+
result = run_with_timeout(check, (global_dict[entry_point],), 15)
152111

153112
if result is None:
154113
result = (PASS, "解决方案通过了所有测试用例。")
@@ -171,13 +130,7 @@ async def evaluate_problem(data: dict, graph: Callable, path) -> Tuple[str, str,
171130
max_retries = 5
172131
retries = 0
173132

174-
# prediction = await graph(data["prompt"], data["entry_point"]) if graph else "None"
175-
# cost = prediction[1]
176-
# solution = prediction[0]
177-
# ret = check_solution(solution, data["test"], data["entry_point"])
178-
# test_case_details = ret[1]
179-
# expected_output = test_case_details + "\nCorrect Solution:\ndef " + data["entry_point"] + "(params you should put here):" + "\n\n" + data["canonical_solution"]
180-
# score = 1 if ret[0] == PASS else 0
133+
expected_output = "\nCorrect Solution:\ndef " + data["entry_point"] + "(params you should put here):" + "\n\n" + data["canonical_solution"]
181134

182135
while retries < max_retries:
183136
try:
@@ -186,7 +139,7 @@ async def evaluate_problem(data: dict, graph: Callable, path) -> Tuple[str, str,
186139
solution = prediction[0]
187140
ret = check_solution(solution, data["test"], data["entry_point"])
188141
test_case_details = ret[1]
189-
expected_output = test_case_details + "\nCorrect Solution:\ndef " + data["entry_point"] + "(params you should put here):" + "\n\n" + data["canonical_solution"]
142+
expected_output = test_case_details + "\nCorrect Solution:\ndef " + data["entry_point"] + "(params you should put here):" + "\n\n" + data["canonical_solution"]
190143
score = 1 if ret[0] == PASS else 0
191144

192145
if score == 0:
@@ -258,8 +211,3 @@ async def optimize_humaneval_evaluation(graph: Callable, file_path: str, path: s
258211
print(f"Total Cost: {total_cost:.5f}")
259212
print(f"Average cost on HumanEval dataset: {average_cost:.5f}")
260213
return average_score, average_cost, total_cost
261-
262-
# TODO HumanEval 主实验后续任务
263-
264-
# 1. 修改optimized中的内容,让优化代码能够跑起来
265-
# 2. 启动主实验

examples/ags/benchmark/mbpp.py

Lines changed: 96 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
import os
12
import json
23
import time
34
import asyncio
45
import aiofiles
6+
import threading
57
import pandas as pd
68
from typing import List, Tuple, Callable, Any, Optional, Dict
7-
from tqdm.asyncio import tqdm_asyncio
9+
from datetime import datetime
810

11+
from tqdm.asyncio import tqdm_asyncio
12+
from examples.ags.benchmark.utils import log_mismatch
13+
from metagpt.actions.code_sanitize import sanitize
914
from examples.ags.benchmark.utils import generate_random_indices
1015

1116
PASS = "pass"
@@ -21,7 +26,41 @@ async def load_data(file_path: str, samples=1, test=False) -> List[dict]:
2126
return data
2227

2328

24-
async def check_solution(solution, test, entry_point):
29+
PASS = "PASS"
30+
FAIL = "FAIL"
31+
32+
class TimeoutError(Exception):
33+
pass
34+
35+
def run_with_timeout(func, args, timeout):
36+
result = []
37+
stop_event = threading.Event()
38+
39+
def target():
40+
try:
41+
result.append(func(*args))
42+
except Exception as e:
43+
result.append(e)
44+
finally:
45+
stop_event.set()
46+
47+
thread = threading.Thread(target=target)
48+
thread.start()
49+
is_timeout = not stop_event.wait(timeout)
50+
51+
if is_timeout:
52+
# 线程仍在运行,我们无法强制终止它,但至少可以标记超时
53+
raise TimeoutError("Function execution timed out")
54+
55+
if not result:
56+
return None
57+
if isinstance(result[0], Exception):
58+
raise result[0]
59+
return result[0]
60+
61+
def check_solution(solution, test, entry_point):
62+
63+
solution = sanitize(code=solution, entrypoint=entry_point)
2564
try:
2665
# 定义一个包含所有必要模块的全局字典
2766
global_dict = {
@@ -47,38 +86,43 @@ async def check_solution(solution, test, entry_point):
4786
# 获取检查函数
4887
check = global_dict["check"]
4988

50-
# 运行检查函数
51-
result = check()
89+
# 运行检查函数,设置超时时间为120秒
90+
result = run_with_timeout(check, (global_dict[entry_point],), 15)
5291

5392
if result is None:
5493
result = (PASS, "解决方案通过了所有测试用例。")
5594

56-
# except ValueError as ve:
57-
# if "函数" in str(ve) and "在解决方案中未定义" in str(ve):
58-
# raise
95+
except TimeoutError:
96+
result = (FAIL, "执行超时。请检查您的解决方案是否包含无限循环或过于耗时的操作。")
5997
except Exception as e:
6098
# 记录详细的错误信息
6199
error_message = f"错误: {str(e)}.\n 解决方案: {solution}.\n 测试: {test}"
62100
result = (FAIL, error_message)
63101

64102
# 将错误信息写入error.log文件
65-
with open('error_mbpp.log', 'a', encoding='utf-8') as log_file:
103+
with open('error.log', 'a', encoding='utf-8') as log_file:
66104
log_file.write(f"{time.strftime('%Y-%m-%d %H:%M:%S')} - {error_message}\n")
67105

68106
return result
69107

70-
async def evaluate_problem(data: dict, graph: Callable) -> Tuple[str, str, str, int, str]:
108+
async def evaluate_problem(data: dict, graph: Callable, path) -> Tuple[str, str, str, int, str]:
71109
max_retries = 5
72110
retries = 0
73111

112+
expected_output = "\nCorrect Solution:\ndef " + data["code"]
113+
74114
while retries < max_retries:
75115
try:
76116
prediction = await graph(data["prompt"], data["entry_point"]) if graph else "None"
77117
cost = prediction[1]
78118
solution = prediction[0]
79119
ret = await check_solution(solution, data["test"], data["entry_point"])
80-
120+
test_case_details = ret[1]
81121
score = 1 if ret[0] == PASS else 0
122+
expected_output = test_case_details + "\nCorrect Solution:" + data["code"]
123+
124+
if score == 0:
125+
log_mismatch(data["prompt"], expected_output, solution, score, path)
82126
break
83127

84128
except Exception as e:
@@ -92,28 +136,55 @@ async def evaluate_problem(data: dict, graph: Callable) -> Tuple[str, str, str,
92136
score = 0
93137
break
94138

95-
return data["prompt"], solution, ret[1], score, cost
139+
return data["prompt"], solution, expected_output, score, cost
96140

97-
async def evaluate_all_problems(data: List[dict], graph: Callable, max_concurrent_tasks: int = 50) -> List[Tuple[str, str, str, int, str]]:
141+
async def evaluate_all_problems(data: List[dict], graph: Callable, path:str="", max_concurrent_tasks: int = 50) -> List[Tuple[str, str, str, int, str]]:
98142
semaphore = asyncio.Semaphore(max_concurrent_tasks)
99143

100144
async def sem_evaluate(problem):
101145
async with semaphore:
102-
return await evaluate_problem(problem, graph)
146+
return await evaluate_problem(problem, graph, path)
103147

104148
tasks = [sem_evaluate(problem) for problem in data]
105149

106150
return await tqdm_asyncio.gather(*tasks, desc="Evaluating MBPP problems", total=len(data))
107151

108-
def save_results_to_csv(results: List[Tuple[str, str, str, int, str]], path: str) -> Tuple[float, float]:
109-
df = pd.DataFrame(results, columns=["question", "prediction", "test_case_details", "score", "cost"])
110-
average_score = df["score"].mean()
111-
total_cost = df["cost"].max()
152+
def save_results_to_csv(results: List[Tuple[str, str, str, int]], path):
153+
# 创建 DataFrame
154+
df = pd.DataFrame(results, columns=["question", "prediction", "expected_output", "score", "cost"])
155+
156+
# 计算统计数据
157+
avg_score = df["score"].mean()
158+
t_cost = df["cost"].max()
159+
a_cost = t_cost / len(df) if len(df) > 0 else 0
160+
161+
# 获取当前时间,格式为 YYYYMMDD_HHMMSS
162+
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
163+
164+
# 生成文件名,包含平均分和当前时间,保留五位小数
165+
filename = f"{avg_score:.5f}_{current_time}.csv"
166+
output_file = os.path.join(path, filename)
112167

113-
output_file = f"{path}/{average_score:.5f}.csv"
168+
# 保存到 CSV
114169
df.to_csv(output_file, index=False)
115170
print(f"Results saved to {output_file}")
116-
return average_score, total_cost
171+
172+
return avg_score, a_cost, t_cost
173+
174+
175+
async def load_file_data(file_path: str, specific_indices: List[int] = None) -> List[dict]:
176+
data = []
177+
# 异步读取文件内容
178+
async with aiofiles.open(file_path, mode="r", encoding='utf-8') as file:
179+
async for line in file:
180+
data.append(json.loads(line))
181+
182+
# 然后在随机选择的样本中基于特定索引列表进行进一步筛选
183+
if specific_indices is not None:
184+
filtered_data = [data[i] for i in specific_indices if i < len(data)]
185+
return filtered_data
186+
187+
return data
117188

118189
async def mbpp_evaluation(graph: Callable, file_path: str, samples: int, path: str, test=False) -> Tuple[float, float]:
119190
data = await load_data(file_path, samples, test)
@@ -124,17 +195,11 @@ async def mbpp_evaluation(graph: Callable, file_path: str, samples: int, path: s
124195
return average_score, total_cost
125196

126197

127-
async def load_file_data(file_path: str) -> List[dict]:
128-
data = []
129-
async with aiofiles.open(file_path, mode="r") as file:
130-
async for line in file:
131-
data.append(json.loads(line))
132-
return data
133-
134-
async def optimize_mbpp_evaluation(graph: Callable, file_path: str, path: str) -> Tuple[float, float]:
135-
data = await load_file_data(file_path)
136-
results = await evaluate_all_problems(data, graph, max_concurrent_tasks=50)
137-
average_score, total_cost = save_results_to_csv(results, path=path)
198+
async def optimize_mbpp_evaluation(graph: Callable, file_path: str, path: str, va_list: List[int]) -> Tuple[float, float]:
199+
data = await load_file_data(file_path, va_list)
200+
results = await evaluate_all_problems(data, graph, path, max_concurrent_tasks=25)
201+
average_score, average_cost, total_cost = save_results_to_csv(results, path=path)
138202
print(f"Average score on MBPP dataset: {average_score:.5f}")
139203
print(f"Total Cost: {total_cost:.5f}")
140-
return average_score, total_cost
204+
print(f"Average cost on MBPP dataset: {average_cost:.5f}")
205+
return average_score, average_cost, total_cost

0 commit comments

Comments
 (0)