1+ import os
12import json
23import time
34import asyncio
45import aiofiles
6+ import threading
57import pandas as pd
68from typing import List , Tuple , Callable , Any , Optional , Dict
7- from tqdm . asyncio import tqdm_asyncio
9+ from datetime import datetime
810
11+ from tqdm .asyncio import tqdm_asyncio
12+ from examples .ags .benchmark .utils import log_mismatch
13+ from metagpt .actions .code_sanitize import sanitize
914from examples .ags .benchmark .utils import generate_random_indices
1015
1116PASS = "pass"
@@ -21,7 +26,41 @@ async def load_data(file_path: str, samples=1, test=False) -> List[dict]:
2126 return data
2227
2328
24- async def check_solution (solution , test , entry_point ):
29+ PASS = "PASS"
30+ FAIL = "FAIL"
31+
32+ class TimeoutError (Exception ):
33+ pass
34+
35+ def run_with_timeout (func , args , timeout ):
36+ result = []
37+ stop_event = threading .Event ()
38+
39+ def target ():
40+ try :
41+ result .append (func (* args ))
42+ except Exception as e :
43+ result .append (e )
44+ finally :
45+ stop_event .set ()
46+
47+ thread = threading .Thread (target = target )
48+ thread .start ()
49+ is_timeout = not stop_event .wait (timeout )
50+
51+ if is_timeout :
52+ # 线程仍在运行,我们无法强制终止它,但至少可以标记超时
53+ raise TimeoutError ("Function execution timed out" )
54+
55+ if not result :
56+ return None
57+ if isinstance (result [0 ], Exception ):
58+ raise result [0 ]
59+ return result [0 ]
60+
61+ def check_solution (solution , test , entry_point ):
62+
63+ solution = sanitize (code = solution , entrypoint = entry_point )
2564 try :
2665 # 定义一个包含所有必要模块的全局字典
2766 global_dict = {
@@ -47,38 +86,43 @@ async def check_solution(solution, test, entry_point):
4786 # 获取检查函数
4887 check = global_dict ["check" ]
4988
50- # 运行检查函数
51- result = check ( )
89+ # 运行检查函数,设置超时时间为120秒
90+ result = run_with_timeout ( check , ( global_dict [ entry_point ],), 15 )
5291
5392 if result is None :
5493 result = (PASS , "解决方案通过了所有测试用例。" )
5594
56- # except ValueError as ve:
57- # if "函数" in str(ve) and "在解决方案中未定义" in str(ve):
58- # raise
95+ except TimeoutError :
96+ result = (FAIL , "执行超时。请检查您的解决方案是否包含无限循环或过于耗时的操作。" )
5997 except Exception as e :
6098 # 记录详细的错误信息
6199 error_message = f"错误: { str (e )} .\n 解决方案: { solution } .\n 测试: { test } "
62100 result = (FAIL , error_message )
63101
64102 # 将错误信息写入error.log文件
65- with open ('error_mbpp .log' , 'a' , encoding = 'utf-8' ) as log_file :
103+ with open ('error .log' , 'a' , encoding = 'utf-8' ) as log_file :
66104 log_file .write (f"{ time .strftime ('%Y-%m-%d %H:%M:%S' )} - { error_message } \n " )
67105
68106 return result
69107
70- async def evaluate_problem (data : dict , graph : Callable ) -> Tuple [str , str , str , int , str ]:
108+ async def evaluate_problem (data : dict , graph : Callable , path ) -> Tuple [str , str , str , int , str ]:
71109 max_retries = 5
72110 retries = 0
73111
112+ expected_output = "\n Correct Solution:\n def " + data ["code" ]
113+
74114 while retries < max_retries :
75115 try :
76116 prediction = await graph (data ["prompt" ], data ["entry_point" ]) if graph else "None"
77117 cost = prediction [1 ]
78118 solution = prediction [0 ]
79119 ret = await check_solution (solution , data ["test" ], data ["entry_point" ])
80-
120+ test_case_details = ret [ 1 ]
81121 score = 1 if ret [0 ] == PASS else 0
122+ expected_output = test_case_details + "\n Correct Solution:" + data ["code" ]
123+
124+ if score == 0 :
125+ log_mismatch (data ["prompt" ], expected_output , solution , score , path )
82126 break
83127
84128 except Exception as e :
@@ -92,28 +136,55 @@ async def evaluate_problem(data: dict, graph: Callable) -> Tuple[str, str, str,
92136 score = 0
93137 break
94138
95- return data ["prompt" ], solution , ret [ 1 ] , score , cost
139+ return data ["prompt" ], solution , expected_output , score , cost
96140
97- async def evaluate_all_problems (data : List [dict ], graph : Callable , max_concurrent_tasks : int = 50 ) -> List [Tuple [str , str , str , int , str ]]:
141+ async def evaluate_all_problems (data : List [dict ], graph : Callable , path : str = "" , max_concurrent_tasks : int = 50 ) -> List [Tuple [str , str , str , int , str ]]:
98142 semaphore = asyncio .Semaphore (max_concurrent_tasks )
99143
100144 async def sem_evaluate (problem ):
101145 async with semaphore :
102- return await evaluate_problem (problem , graph )
146+ return await evaluate_problem (problem , graph , path )
103147
104148 tasks = [sem_evaluate (problem ) for problem in data ]
105149
106150 return await tqdm_asyncio .gather (* tasks , desc = "Evaluating MBPP problems" , total = len (data ))
107151
108- def save_results_to_csv (results : List [Tuple [str , str , str , int , str ]], path : str ) -> Tuple [float , float ]:
109- df = pd .DataFrame (results , columns = ["question" , "prediction" , "test_case_details" , "score" , "cost" ])
110- average_score = df ["score" ].mean ()
111- total_cost = df ["cost" ].max ()
152+ def save_results_to_csv (results : List [Tuple [str , str , str , int ]], path ):
153+ # 创建 DataFrame
154+ df = pd .DataFrame (results , columns = ["question" , "prediction" , "expected_output" , "score" , "cost" ])
155+
156+ # 计算统计数据
157+ avg_score = df ["score" ].mean ()
158+ t_cost = df ["cost" ].max ()
159+ a_cost = t_cost / len (df ) if len (df ) > 0 else 0
160+
161+ # 获取当前时间,格式为 YYYYMMDD_HHMMSS
162+ current_time = datetime .now ().strftime ("%Y%m%d_%H%M%S" )
163+
164+ # 生成文件名,包含平均分和当前时间,保留五位小数
165+ filename = f"{ avg_score :.5f} _{ current_time } .csv"
166+ output_file = os .path .join (path , filename )
112167
113- output_file = f" { path } / { average_score :.5f } .csv"
168+ # 保存到 CSV
114169 df .to_csv (output_file , index = False )
115170 print (f"Results saved to { output_file } " )
116- return average_score , total_cost
171+
172+ return avg_score , a_cost , t_cost
173+
174+
175+ async def load_file_data (file_path : str , specific_indices : List [int ] = None ) -> List [dict ]:
176+ data = []
177+ # 异步读取文件内容
178+ async with aiofiles .open (file_path , mode = "r" , encoding = 'utf-8' ) as file :
179+ async for line in file :
180+ data .append (json .loads (line ))
181+
182+ # 然后在随机选择的样本中基于特定索引列表进行进一步筛选
183+ if specific_indices is not None :
184+ filtered_data = [data [i ] for i in specific_indices if i < len (data )]
185+ return filtered_data
186+
187+ return data
117188
118189async def mbpp_evaluation (graph : Callable , file_path : str , samples : int , path : str , test = False ) -> Tuple [float , float ]:
119190 data = await load_data (file_path , samples , test )
@@ -124,17 +195,11 @@ async def mbpp_evaluation(graph: Callable, file_path: str, samples: int, path: s
124195 return average_score , total_cost
125196
126197
127- async def load_file_data (file_path : str ) -> List [dict ]:
128- data = []
129- async with aiofiles .open (file_path , mode = "r" ) as file :
130- async for line in file :
131- data .append (json .loads (line ))
132- return data
133-
134- async def optimize_mbpp_evaluation (graph : Callable , file_path : str , path : str ) -> Tuple [float , float ]:
135- data = await load_file_data (file_path )
136- results = await evaluate_all_problems (data , graph , max_concurrent_tasks = 50 )
137- average_score , total_cost = save_results_to_csv (results , path = path )
198+ async def optimize_mbpp_evaluation (graph : Callable , file_path : str , path : str , va_list : List [int ]) -> Tuple [float , float ]:
199+ data = await load_file_data (file_path , va_list )
200+ results = await evaluate_all_problems (data , graph , path , max_concurrent_tasks = 25 )
201+ average_score , average_cost , total_cost = save_results_to_csv (results , path = path )
138202 print (f"Average score on MBPP dataset: { average_score :.5f} " )
139203 print (f"Total Cost: { total_cost :.5f} " )
140- return average_score , total_cost
204+ print (f"Average cost on MBPP dataset: { average_cost :.5f} " )
205+ return average_score , average_cost , total_cost
0 commit comments