-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathevaluation.py
More file actions
144 lines (125 loc) · 4.86 KB
/
evaluation.py
File metadata and controls
144 lines (125 loc) · 4.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import json
import os
import argparse
import multiprocessing
from tqdm import tqdm
from loguru import logger
from utils.config import EvaluationConfig
from utils.llm import LLM
from utils.evaluation_utils import (
extract_cot_answers,
extract_pot_answers,
get_acc,
)
from multiprocessing import Process, Queue
def get_statistics(data):
total_acc = 0
total_execution = 0
total_tokens = 0
for record in data:
total_acc += record["result"]["acc"]
total_execution += record["result"]["execution_rate"]
total_tokens += record["completion_tokens"]
return {
"avg_accuracy": round(total_acc / len(data) * 100, 2),
"avg_execution_rate": round(total_execution / len(data) * 100, 2),
"total_tokens": total_tokens
}
def eval_cot(
data,
ans_extract_model: LLM,
eval_data = None,
force_extract_answer: bool = False
):
def vailid(output: str):
return output and "none" not in output.lower()
extract_answer = force_extract_answer or eval_data is None
if extract_answer:
responses = extract_cot_answers(data, ans_extract_model)
to_retry = [
idx for idx, response in enumerate(responses)
if not vailid(response['output'])
]
if len(to_retry) > 0:
retry_responses = extract_cot_answers([data[i] for i in to_retry], ans_extract_model)
for idx, response in enumerate(retry_responses):
responses[to_retry[idx]] = response
for idx, record in tqdm(enumerate(data), desc="Evaluating COT"):
if extract_answer:
extracted_answer = responses[idx]['output']
else:
assert eval_data[idx]["question_id"] == record["question_id"]
extracted_answer = eval_data[idx]['result']['extracted_answer']
eval_result = { "execution_rate": 0, "acc": 0, "extracted_answer": None }
if vailid(extracted_answer):
eval_result = {
"execution_rate": 1,
"acc": get_acc(extracted_answer, record["ground_truth"]),
"extracted_answer": extracted_answer
}
record["result"] = eval_result
statistics = get_statistics(data)
return data, statistics
def empty_print(*args, **kwargs):
pass
def run_code_in_process(code, result_queue):
try:
namespace = {"print": empty_print}
exec(code, namespace)
result = namespace["solution"]()
result_queue.put(("success", result))
except Exception as e:
result_queue.put(("error", str(e)))
def exec_code_with_timeout(code, timeout_duration):
result_queue = Queue()
process = Process(target=run_code_in_process, args=(code, result_queue))
process.start()
try:
status, result = result_queue.get(timeout=timeout_duration)
if status == "error":
raise Exception(result)
return result
except multiprocessing.queues.Empty:
raise Exception("Code execution took too long!")
except Exception : raise
finally: process.kill()
def eval_pot( data, timeout_duration: int):
for record in tqdm(data, desc="Evaluating POT"):
code = extract_pot_answers(record['output'])
record["result"] = { "acc": 0, "execution_rate": 0, "executed_result": None }
try:
executed_result = exec_code_with_timeout(code, timeout_duration)
except Exception as e:
logger.warning(f"Error while executing code: {e}")
continue
record["result"] = {
"acc": get_acc(executed_result, record["ground_truth"]),
"execution_rate": 1,
"executed_result": str(executed_result)
}
statistics = get_statistics(data)
return data, statistics
def make_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
return parser.parse_args()
def main():
args = make_args()
config = EvaluationConfig.from_yaml(args.config)
data = json.load(open(config.inference_file, "r", encoding="utf-8"))
if 'cot' in config.prompt_type:
eval_data = None
if os.path.exists(config.evaluation_file):
print(f"Loading evaluation data from {config.evaluation_file}")
with open(config.evaluation_file, "r", encoding="utf-8") as f:
eval_data = json.load(f)
ans_extract_model = LLM(config.llms[config.ans_extract_model_name])
force_extract_answer = config.force_extract_answer
data, statistics = eval_cot(data, ans_extract_model, eval_data, force_extract_answer)
elif 'pot' in config.prompt_type:
data, statistics = eval_pot(data, config.timeout_duration)
logger.info(f"Statistics: {statistics}")
with open(config.evaluation_file, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
if __name__ == "__main__":
main()