1+ import os
2+ import json
3+ import spacy
4+ from spacy .tokens import Doc
5+ from const import ROLE , All_Valid_EntTypes
6+ import torch
7+ from torchmetrics .classification import MulticlassCalibrationError
8+
9+
10+ class WhitespaceTokenizer :
11+ def __init__ (self , vocab ):
12+ self .vocab = vocab
13+
14+ def __call__ (self , text ):
15+ words = text .split (" " )
16+ return Doc (self .vocab , words = words )
17+
18+
19+ nlp = spacy .load ('en_core_web_sm' )
20+ nlp .tokenizer = WhitespaceTokenizer (nlp .vocab )
21+
22+
23+ def safe_div (num , denom ):
24+ if denom > 0 :
25+ return num / denom
26+ else :
27+ return 0
28+
29+ def compute_f1 (predicted , gold , matched ):
30+ precision = safe_div (matched , predicted )
31+ recall = safe_div (matched , gold )
32+ f1 = safe_div (2 * precision * recall , precision + recall )
33+ return precision , recall , f1
34+
35+
36+ def find_head (arg_start , arg_end , doc ):
37+ cur_i = arg_start
38+ while doc [cur_i ].head .i >= arg_start and doc [cur_i ].head .i <= arg_end :
39+ if doc [cur_i ].head .i == cur_i :
40+ # self is the head
41+ break
42+ else :
43+ cur_i = doc [cur_i ].head .i
44+
45+ arg_head = cur_i
46+
47+ return (arg_head , arg_head )
48+
49+
50+ def clean_span (tokens , start , end ):
51+ if tokens [start ].lower () in {'the' , 'an' , 'a' }:
52+ if start != end :
53+ return (start + 1 , end )
54+ return start , end
55+
56+
57+ def evaluate (preds , gold , only_head = False ):
58+ for example_id in gold :
59+ for argument in gold [example_id ]:
60+ if only_head :
61+ words = argument .pop (4 )
62+ doc = nlp (' ' .join (words ))
63+ argument [3 ] -= 1
64+ argument [2 ], argument [3 ] = clean_span (words , argument [2 ], argument [3 ])
65+ argument [2 ], argument [3 ] = find_head (argument [2 ], argument [3 ], doc )
66+ assert argument [2 ] == argument [3 ]
67+ argument [1 ] = words [argument [2 ]]
68+ argument [3 ] += 1
69+ else :
70+ argument .pop (4 )
71+
72+
73+ for example_id in preds :
74+ preds [example_id ] = list (set ([tuple (i ) for i in preds [example_id ]]))
75+ gold [example_id ] = [tuple (i ) for i in gold [example_id ]]
76+
77+ pred_arg_num , gold_arg_num = 0 , 0
78+ arg_idn_num , arg_class_num , arg_ic_num = 0 , 0 , 0
79+
80+ for example_id in preds :
81+ pred_arg_num += len (preds [example_id ])
82+ gold_arg_num += len (gold [example_id ])
83+
84+ correct_confidence = 0
85+ incorrect_confidence = 0
86+ if_reasonable_num = 0
87+ calibrate_record = []
88+ invalid_role = 0
89+ for example_id in preds :
90+ for pred_arg in preds [example_id ]:
91+ role , span , arg_start , arg_end , confidence , if_reasonable = pred_arg
92+ gold_idn = {item for item in gold [example_id ] if item [2 ] == arg_start and item [3 ] == arg_end }
93+ gold_ic = [item for item in gold [example_id ] if item [0 ] == role and item [1 ] == span ]
94+ if gold_ic :
95+ arg_ic_num += 1
96+ correct_confidence += confidence
97+ if if_reasonable :
98+ if_reasonable_num += 1
99+ if confidence == 0 :
100+ invalid_role += 1
101+ continue
102+ gold_label_idx = LABEL2ID [gold_ic [0 ][0 ]]
103+ pred_label_idx = LABEL2ID [role ]
104+ calibrate_record .append ([gold_label_idx , pred_label_idx , confidence / 100 ])
105+ else :
106+ incorrect_confidence += confidence
107+ if role not in LABEL2ID :
108+ invalid_role += 1
109+ continue
110+ if confidence == 0 :
111+ invalid_role += 1
112+ continue
113+ gold_label_idx = LABEL2ID ['None' ]
114+ pred_label_idx = LABEL2ID [role ]
115+ calibrate_record .append ([gold_label_idx , pred_label_idx , confidence / 100 ])
116+ if gold_idn :
117+ arg_idn_num += 1
118+ gold_class = {item for item in gold_idn if item [0 ] == role }
119+ if gold_class :
120+ arg_class_num += 1
121+
122+ print (f"gold_arg_num: { gold_arg_num } , pred_arg_num: { pred_arg_num } , arg_idn_num: { arg_idn_num } , arg_class_num: { arg_class_num } , arg_ic_num: { arg_ic_num } " )
123+
124+ role_id_prec , role_id_rec , role_id_f = compute_f1 (pred_arg_num , gold_arg_num , arg_idn_num )
125+ role_prec , role_rec , role_f = compute_f1 (pred_arg_num , gold_arg_num , arg_class_num )
126+ role_ic_prec , role_ic_rec , role_ic_f = compute_f1 (pred_arg_num , gold_arg_num , arg_ic_num )
127+ print ('Role identification: P: {:.2f}, R: {:.2f}, F: {:.2f}' .format (role_id_prec * 100.0 , role_id_rec * 100.0 , role_id_f * 100.0 ))
128+ print ('Role: P: {:.2f}, R: {:.2f}, F: {:.2f}' .format (role_prec * 100.0 , role_rec * 100.0 , role_f * 100.0 ))
129+ print ('Role ic: P: {:.2f}, R: {:.2f}, F: {:.2f}' .format (role_ic_prec * 100.0 , role_ic_rec * 100.0 , role_ic_f * 100.0 ))
130+ print ('(Role ic) Correct Mean Confidence: {:.2f}, Incorrect Mean Confidence: {:.2f}' .format (correct_confidence / arg_ic_num , incorrect_confidence / (pred_arg_num - arg_ic_num )))
131+ print (f' Auto Rate: { if_reasonable_num / arg_ic_num } ' )
132+
133+ # Compute Expected Calibration Error (ECE)
134+ assert len (calibrate_record ) == (pred_arg_num - invalid_role )
135+ print (invalid_role , len (calibrate_record ))
136+ label_idx , pred_idx , prob = zip (* calibrate_record )
137+ labels = torch .tensor (label_idx )
138+ preds = torch .zeros (len (calibrate_record ), len (LABEL2ID ), dtype = torch .float32 )
139+ preds [range (len (calibrate_record )), pred_idx ] = torch .tensor (prob )
140+ metric = MulticlassCalibrationError (num_classes = 23 , n_bins = 50 , norm = 'l1' )
141+ result = metric (preds , labels )
142+ print ('Expected Calibration Error: {:.5f}' .format (result ))
143+
144+
145+ def read_question (question_path ):
146+ Event2Query = {}
147+ with open (question_path , 'r' ) as f :
148+ lines = f .readlines ()
149+ for line in lines :
150+ line = line .strip ()
151+ event_arg , query = line .split ("," )
152+ event , arg = event_arg .split ("_" )
153+ if event not in Event2Query :
154+ Event2Query [event ] = []
155+ Event2Query [event ].append ((arg , query ))
156+
157+ return Event2Query
158+
159+
160+ def filter_invalid_answer (preds , only_head = False ):
161+ "过滤掉不合法的输出,比如索引为负数的答案;end_index加1(因为Prompt中明确说明了end_word_index应该是inclusive的)"
162+
163+ def if_invalid (argument ):
164+ filter_words = ['unknown' , 'Unknown' , 'unspecified' , 'not specified' , 'not mentioned' , 'None' , 'none' , 'not mentioned' , 'not applicable' , 'N/A' ]
165+ if not isinstance (argument [1 ], str ):
166+ return True
167+ elif not isinstance (argument [0 ], str ):
168+ return True
169+ elif not (isinstance (argument [2 ], int ) and isinstance (argument [3 ], int )):
170+ return True
171+ elif not (argument [2 ]>= 0 and argument [3 ]>= 0 ):
172+ return True
173+ elif [i for i in filter_words if i in argument [1 ]]:
174+ return True
175+ return False
176+
177+ count = 0
178+ for example_id in preds :
179+ for argument in preds [example_id ][::- 1 ]:
180+ if if_invalid (argument ):
181+ preds [example_id ].remove (argument )
182+ count += 1
183+ for example_id in preds :
184+ for argument in preds [example_id ]:
185+
186+ if only_head :
187+ words = argument .pop (4 )
188+
189+ if argument [2 ]>= len (words ):
190+ argument [3 ] += 1
191+ continue
192+
193+ doc = nlp (' ' .join (words ))
194+ argument [2 ], argument [3 ] = find_head (argument [2 ], argument [3 ], doc )
195+ assert argument [2 ] == argument [3 ]
196+ argument [1 ] = words [argument [2 ]]
197+ else :
198+ argument .pop (4 )
199+ argument [3 ] += 1
200+
201+ return count
202+
203+
204+ def join (word_list ):
205+ res = ''
206+ for idx , word in enumerate (word_list ):
207+ if idx == 0 :
208+ res = word
209+ else :
210+ if "'" in word :
211+ res += word
212+ elif word == '-' :
213+ res += word
214+ elif word_list [idx - 1 ] == '-' :
215+ res += word
216+ else :
217+ res += (' ' + word )
218+ return res
219+
220+
221+ def read_gold_example (path ):
222+ gold = {}
223+ gold_event = {}
224+ gold_words = {}
225+ with open (path ) as f :
226+ lines = f .readlines ()
227+ for line in lines :
228+ line = json .loads (line )
229+ event_type = line ['event' ]['event_type' ]
230+ gold [line ['id' ]] = []
231+ gold_event [line ['id' ]] = event_type
232+ gold_words [line ['id' ]] = line ['words' ]
233+ for arg in line ['event' ]['argument' ]:
234+ role = arg ['role' ]
235+ span = arg ['text' ]
236+ # print(span)
237+ gold [line ['id' ]].append ([role , span , arg ['start' ] , arg ['end' ], line ['words' ]])
238+
239+ return gold , gold_event , gold_words
240+
241+
242+ def get_vocab ():
243+ all_labels = ['None' ]
244+ for label in ROLE :
245+ all_labels .append (label )
246+ label2id = {label : idx for idx , label in enumerate (all_labels )}
247+ id2label = {idx : label for idx , label in enumerate (all_labels )}
248+ return label2id , id2label
249+
250+
251+ def main (result_dir , gold_path , question_path , only_head = False ):
252+
253+ Event2Query = read_question (question_path )
254+ gold , gold_event , gold_words = read_gold_example (gold_path )
255+
256+ preds = {}
257+ for file in os .listdir (result_dir ):
258+ example_id = file [:- 5 ]
259+ preds [example_id ] = []
260+ file_path = os .path .join (result_dir , file )
261+ # print(file)
262+ with open (file_path , 'r' ) as f :
263+ res = json .load (f )
264+ question_num = len (res )
265+ questions = [f"Question{ i + 1 } " for i in range (question_num )]
266+ event_type = gold_event [example_id ]
267+ words = gold_words [example_id ]
268+ all_role = [i [0 ] for i in Event2Query [event_type ]]
269+ assert len (all_role ) == question_num
270+ for idx , ques in enumerate (questions ):
271+ answers = res [ques ]
272+ for ans in answers :
273+ role = all_role [idx ]
274+ span = ans ['span' ]
275+ confidence = ans ['confidence' ]
276+ if_reasonable = ans ['if_reasonable' ]
277+ preds [example_id ].append ([role , span , ans ['start_word_index' ] , ans ['end_word_index' ], words , confidence , if_reasonable ])
278+
279+
280+ invalid_arg_num = filter_invalid_answer (preds , only_head )
281+ print (invalid_arg_num )
282+ evaluate (preds , gold , only_head )
283+
284+
285+ LABEL2ID ,ID2LABEL = get_vocab ()
286+ if __name__ == "__main__" :
287+ result_dir = './Output/EAE/Full_Testset/EAE_E+_Closed'
288+ question_path = './Code/description_queries_new.csv'
289+ gold_path = "./data/ACE05-E+/EAE_E+_gold.json"
290+ only_head = False
291+ main (result_dir , gold_path , question_path , only_head )
292+
0 commit comments