Skip to content

Commit eae3514

Browse files
committed
Update AFlow
1 parent 040a732 commit eae3514

File tree

89 files changed

+2309
-395689
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+2309
-395689
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from scipy.optimize import linear_sum_assignment
1010
from tqdm.asyncio import tqdm_asyncio
1111

12-
from examples.ags.benchmark.utils import generate_random_indices
12+
from examples.aflow.benchmark.utils import generate_random_indices
1313

1414
global cost
1515
cost = 0
Lines changed: 36 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,17 @@
1010
import pandas as pd
1111
from typing import Optional, List, Tuple, Callable, Any
1212
from tqdm.asyncio import tqdm_asyncio
13+
import os
14+
import time
15+
from datetime import datetime
1316

14-
from examples.ags.benchmark.utils import generate_random_indices, log_mismatch
17+
from examples.aflow.benchmark.utils import generate_random_indices, log_mismatch
1518

1619
def extract_number(text: str) -> Optional[float]:
1720
"""Clean text and extract a single number"""
18-
matches = re.findall(r"[-+]?\d+(?:,\d{3})*(?:\.\d+)?|\d+\.\d+", text)
21+
print(f"text: {text}")
22+
matches = re.findall(r"[-+]?\d+(?:,\d{3})*(?:\.\d+)?|\d+\.\d+", str(text))
23+
print(f"matches: {matches}")
1924
if matches:
2025
last_number = matches[-1].replace(",", "")
2126
try:
@@ -25,78 +30,77 @@ def extract_number(text: str) -> Optional[float]:
2530
else:
2631
return None
2732

28-
def loose_match_score(expected_output: str, prediction: str, tolerance: float = 1e-6) -> int:
29-
"""Loose match score calculation function"""
30-
expected_number = extract_number(expected_output)
31-
predicted_number = extract_number(prediction)
32-
33-
if expected_number is None or predicted_number is None:
33+
def loose_match_score(expected_output: float, prediction: float, tolerance: float = 1e-6) -> int:
34+
if prediction is None:
3435
return 0
35-
36-
if abs(expected_number - predicted_number) <= tolerance:
36+
37+
if abs(expected_output - prediction) <= tolerance:
3738
return 1
3839
else:
3940
return 0
4041

41-
def save_results_to_csv(results: List[Tuple[str, str, str, int, str]], path: str) -> Tuple[float, float]:
42-
"""Save results to CSV file"""
42+
def save_results_to_csv(results: List[Tuple[str, str, str, int, str]], path: str) -> Tuple[float, float, float]:
4343
df = pd.DataFrame(results, columns=["question", "prediction", "expected_output", "score", "cost"])
44-
average_score = df["score"].mean()
45-
total_cost = df["cost"].max()
46-
average_cost = total_cost / len(df) if len(df) > 0 else 0
44+
avg_score = df["score"].mean()
45+
t_cost = df["cost"].max()
46+
a_cost = t_cost / len(df) if len(df) > 0 else 0
47+
48+
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
49+
filename = f"{avg_score:.5f}_{current_time}.csv"
50+
output_file = os.path.join(path, filename)
4751

48-
output_file = f"{path}/{average_score:.5f}.csv"
4952
df.to_csv(output_file, index=False)
5053
print(f"Results saved to {output_file}")
51-
return average_score, average_cost, total_cost
54+
return avg_score, a_cost, t_cost
5255

5356
async def evaluate_problem(input: str, graph: Callable, expected_output: str, path: str = None) -> Tuple[str, str, str, int, str]:
54-
"""Evaluate a single problem"""
55-
max_retries = 5
57+
max_retries = 10
5658
retries = 0
59+
uni_score = 0
60+
5761
while retries < max_retries:
5862
try:
5963
prediction = await graph(input) if graph else None
6064
cost = prediction[1]
6165
output = prediction[0]
62-
66+
6367
if output is not None:
6468
predicted_number = extract_number(output)
65-
expected_output = extract_number(expected_output)
69+
expected_number = extract_number(expected_output)
6670
else:
6771
predicted_number = None
72+
expected_number = extract_number(expected_output)
6873

69-
uni_score = loose_match_score(expected_output, predicted_number)
74+
print(f"predicted_number: {predicted_number}, expected_number: {expected_number}")
75+
uni_score = loose_match_score(expected_number, predicted_number)
7076

7177
if uni_score == 0 and path is not None:
7278
log_mismatch(input, expected_output, output, predicted_number, path)
73-
else:
74-
pass
7579

7680
break
7781

7882
except Exception as e:
7983
retries += 1
8084
print(f"Error generating prediction: {e}. Retrying... ({retries}/{max_retries})")
85+
time.sleep(5 * retries)
8186

8287
if retries == max_retries:
8388
print("Maximum retries reached. Skipping this sample.")
8489
output = str(e)
8590
cost = None
86-
score = 0
91+
uni_score = 0
8792
break
8893

89-
return input, output, expected_output, score, cost
94+
return input, output, expected_output, uni_score, cost
9095

91-
async def evaluate_all_problems(data: List[dict], graph: Callable, max_concurrent_tasks: int = 20) -> List[Tuple[str, str, str, int, str]]:
92-
"""Evaluate all problems"""
96+
async def evaluate_all_problems(data: List[dict], graph: Callable, path, max_concurrent_tasks: int = 20) -> List[Tuple[str, str, str, int, str]]:
9397
semaphore = asyncio.Semaphore(max_concurrent_tasks)
9498

9599
async def sem_evaluate(problem):
96100
async with semaphore:
97101
input_text = problem["question"]
98102
expected_output = problem["answer"]
99-
return await evaluate_problem(input_text, graph, expected_output)
103+
return await evaluate_problem(input_text, graph, expected_output, path)
100104

101105
tasks = [sem_evaluate(problem) for problem in data]
102106

@@ -113,38 +117,28 @@ async def load_data(file_path: str, samples=1, test=False) -> List[dict]:
113117

114118
async def load_file_data(file_path: str, specific_indices: List[int] = None) -> List[dict]:
115119
data = []
116-
# 异步读取文件内容
117120
async with aiofiles.open(file_path, mode="r", encoding='utf-8') as file:
118121
async for line in file:
119122
data.append(json.loads(line))
120123

121-
# 然后在随机选择的样本中基于特定索引列表进行进一步筛选
122124
if specific_indices is not None:
123125
filtered_data = [data[i] for i in specific_indices if i < len(data)]
124126
return filtered_data
125127

126128
return data
127129

128130
async def gsm8k_evaluation(graph: Callable, file_path: str, samples: int, path: str, test=False) -> Tuple[float, float]:
129-
"""GSM8K evaluation main function"""
130131
data = await load_data(file_path, samples, test=test)
131-
results = await evaluate_all_problems(data, graph, max_concurrent_tasks=10)
132+
results = await evaluate_all_problems(data, graph, path, max_concurrent_tasks=20)
132133
average_score, average_cost, total_cost = save_results_to_csv(results, path=path)
133134
print(f"Average score: {average_score:.5f}")
134135
print(f"Total Cost: {total_cost:.5f}")
135136
return average_score, total_cost
136137

137138
async def optimize_gsm8k_evaluation(graph: Callable, file_path: str, path: str, va_list: list) -> Tuple[Any, Any, Any]:
138-
"""Optimize GSM8K evaluation main function"""
139139
data = await load_file_data(file_path, va_list)
140-
results = await evaluate_all_problems(data, graph, path, max_concurrent_tasks=50)
140+
results = await evaluate_all_problems(data, graph, path, max_concurrent_tasks=8)
141141
average_score, average_cost, total_cost = save_results_to_csv(results, path=path)
142142
print(f"Average score: {average_score:.5f}")
143143
print(f"Total Cost: {total_cost:.5f}")
144-
return average_score, average_cost, total_cost
145-
146-
# TODO Benchmark 与 Evaluator 中主要修改四个地方
147-
# 1. Evaluator.py 之中添加 val list
148-
# 2. load_data 函数修改
149-
# 3. result_to_csv 函数需要给 avg return
150-
# 4. evaluate_problem 中添加log.json
144+
return average_score, average_cost, total_cost
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import re
1212

1313

14-
from examples.ags.benchmark.utils import generate_random_indices
14+
from examples.aflow.benchmark.utils import generate_random_indices
1515

1616
global cost
1717
cost = 0
Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import pandas as pd
1212
from tqdm.asyncio import tqdm_asyncio
1313

14-
from examples.ags.benchmark.utils import generate_random_indices
15-
from examples.ags.benchmark.utils import log_mismatch
14+
from examples.aflow.benchmark.utils import generate_random_indices
15+
from examples.aflow.benchmark.utils import log_mismatch
1616
from metagpt.actions.code_sanitize import sanitize
1717

1818

@@ -134,7 +134,7 @@ async def evaluate_problem(data: dict, graph: Callable, path) -> Tuple[str, str,
134134

135135
while retries < max_retries:
136136
try:
137-
prediction = await graph(data["prompt"], data["entry_point"]) if graph else "None"
137+
prediction = await asyncio.wait_for(graph(data["prompt"], data["entry_point"]), timeout=60) if graph else "None"
138138
cost = prediction[1]
139139
solution = prediction[0]
140140
ret = check_solution(solution, data["test"], data["entry_point"])
@@ -145,6 +145,13 @@ async def evaluate_problem(data: dict, graph: Callable, path) -> Tuple[str, str,
145145
if score == 0:
146146
log_mismatch(data["prompt"], expected_output, solution, score, path)
147147
break
148+
149+
except TimeoutError:
150+
solution = None
151+
ret = (FAIL, ["超时"])
152+
score = 0
153+
cost = 0
154+
break
148155

149156
except Exception as e:
150157
retries += 1
@@ -195,7 +202,7 @@ def save_results_to_csv(results: List[Tuple[str, str, str, int]], path):
195202

196203
async def humaneval_evaluation(graph: Callable, file_path: str, samples: int, path: str, test=False) -> Tuple[float, float]:
197204
data = await load_data(file_path, samples, test=test)
198-
results = await evaluate_all_problems(data, graph, max_concurrent_tasks=50)
205+
results = await evaluate_all_problems(data, graph, max_concurrent_tasks=5)
199206
average_score, average_cost, total_cost = save_results_to_csv(results, path=path)
200207
print(f"Average score on HumanEval dataset: {average_score:.5f}")
201208
print(f"Total Cost: {total_cost:.5f}")
@@ -205,7 +212,7 @@ async def humaneval_evaluation(graph: Callable, file_path: str, samples: int, pa
205212

206213
async def optimize_humaneval_evaluation(graph: Callable, file_path: str, path: str, va_list: List[int]) -> Tuple[float, float, float]:
207214
data = await load_file_data(file_path, va_list)
208-
results = await evaluate_all_problems(data, graph, path, max_concurrent_tasks=25)
215+
results = await evaluate_all_problems(data, graph, path, max_concurrent_tasks=10)
209216
average_score, average_cost, total_cost = save_results_to_csv(results, path=path)
210217
print(f"Average score on HumanEval dataset: {average_score:.5f}")
211218
print(f"Total Cost: {total_cost:.5f}")
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import Optional, List, Tuple, Callable, Union
1313
from tqdm.asyncio import tqdm_asyncio
1414

15-
from examples.ags.benchmark.utils import generate_random_indices
15+
from examples.aflow.benchmark.utils import generate_random_indices
1616

1717
def extract_model_answer(text: str) -> str:
1818
# 提取最后一个 \boxed{...}
Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from datetime import datetime
1010

1111
from tqdm.asyncio import tqdm_asyncio
12-
from examples.ags.benchmark.utils import log_mismatch
12+
from examples.aflow.benchmark.utils import log_mismatch
1313
from metagpt.actions.code_sanitize import sanitize
14-
from examples.ags.benchmark.utils import generate_random_indices
14+
from examples.aflow.benchmark.utils import generate_random_indices
1515

1616
PASS = "pass"
1717
FAIL = "fail"
@@ -32,13 +32,13 @@ async def load_data(file_path: str, samples=1, test=False) -> List[dict]:
3232
class TimeoutError(Exception):
3333
pass
3434

35-
def run_with_timeout(func, args, timeout):
35+
def run_with_timeout(func, timeout):
3636
result = []
3737
stop_event = threading.Event()
3838

3939
def target():
4040
try:
41-
result.append(func(*args))
41+
result.append(func())
4242
except Exception as e:
4343
result.append(e)
4444
finally:
@@ -61,6 +61,7 @@ def target():
6161
def check_solution(solution, test, entry_point):
6262

6363
solution = sanitize(code=solution, entrypoint=entry_point)
64+
print(test)
6465
try:
6566
# 定义一个包含所有必要模块的全局字典
6667
global_dict = {
@@ -87,7 +88,7 @@ def check_solution(solution, test, entry_point):
8788
check = global_dict["check"]
8889

8990
# 运行检查函数,设置超时时间为120秒
90-
result = run_with_timeout(check, (global_dict[entry_point],), 15)
91+
result = run_with_timeout(check, 15)
9192

9293
if result is None:
9394
result = (PASS, "解决方案通过了所有测试用例。")
@@ -110,16 +111,15 @@ async def evaluate_problem(data: dict, graph: Callable, path) -> Tuple[str, str,
110111
retries = 0
111112

112113
expected_output = "\nCorrect Solution:\ndef " + data["code"]
113-
114114
while retries < max_retries:
115115
try:
116116
prediction = await graph(data["prompt"], data["entry_point"]) if graph else "None"
117117
cost = prediction[1]
118118
solution = prediction[0]
119-
ret = await check_solution(solution, data["test"], data["entry_point"])
119+
ret = check_solution(solution, data["test"], data["entry_point"])
120120
test_case_details = ret[1]
121-
score = 1 if ret[0] == PASS else 0
122-
expected_output = test_case_details + "\nCorrect Solution:" + data["code"]
121+
expected_output = test_case_details + "\nCorrect Solution:" + data["code"]
122+
score = 1 if ret[0] == PASS else 0
123123

124124
if score == 0:
125125
log_mismatch(data["prompt"], expected_output, solution, score, path)
@@ -134,6 +134,7 @@ async def evaluate_problem(data: dict, graph: Callable, path) -> Tuple[str, str,
134134
solution = None
135135
ret = (FAIL, [])
136136
score = 0
137+
cost = 0
137138
break
138139

139140
return data["prompt"], solution, expected_output, score, cost

0 commit comments

Comments
 (0)