1+ import os
2+ import json
3+ from const import ROLE , EVENT_TYPE
4+ import torch
5+ from torchmetrics .classification import MulticlassCalibrationError
6+
7+
8+ def safe_div (num , denom ):
9+ if denom > 0 :
10+ return num / denom
11+ else :
12+ return 0
13+
14+ def compute_f1 (predicted , gold , matched ):
15+ precision = safe_div (matched , predicted )
16+ recall = safe_div (matched , gold )
17+ f1 = safe_div (2 * precision * recall , precision + recall )
18+ return precision , recall , f1
19+
20+
21+ def evaluate (preds_event , preds_arg , gold_event , gold_arg ):
22+ assert len (preds_event ) == len (gold_event )
23+ assert len (preds_arg ) == len (gold_arg )
24+
25+ # trigger
26+ for example_id in preds_event :
27+ preds_event [example_id ] = list (set ([tuple (i ) for i in preds_event [example_id ]]))
28+ gold_event [example_id ] = [tuple (i ) for i in gold_event [example_id ]]
29+
30+
31+ pred_tri_num , gold_tri_num = 0 , 0
32+ match_idn_num , match_cls_num , match_word_num = 0 , 0 , 0
33+
34+ tri_correct_confidence = 0
35+ tri_incorrect_confidence = 0
36+ tri_if_reasonable_num = 0
37+ for example_id in preds_event :
38+ pred_tri_num += len (preds_event [example_id ])
39+ gold_tri_num += len (gold_event [example_id ])
40+
41+ calibrate_record = []
42+ invalid_num = 0
43+ for example_id in preds_event :
44+ for pred_tri in preds_event [example_id ]:
45+ trigger_start , trigger_end , event_type , trigger_word , tri_confidence , tri_if_reasonable = pred_tri
46+ match_idn = {item for item in gold_event [example_id ] if item [0 ] == trigger_start and item [1 ] == trigger_end }
47+ match_word = [item for item in gold_event [example_id ] if item [3 ] == trigger_word and item [2 ]== event_type ]
48+ if match_word :
49+ match_word_num += 1
50+ tri_correct_confidence += tri_confidence
51+ if tri_if_reasonable :
52+ tri_if_reasonable_num += 1
53+ if tri_confidence == 0 :
54+ invalid_num += 1
55+ continue
56+ gold_label_idx = LABEL2ID [match_word [0 ][2 ]]
57+ pred_label_idx = LABEL2ID [event_type ]
58+ calibrate_record .append ([gold_label_idx , pred_label_idx , tri_confidence / 100 ])
59+ else :
60+ tri_incorrect_confidence += tri_confidence
61+ if event_type not in LABEL2ID :
62+ invalid_num += 1
63+ continue
64+ if tri_confidence == 0 :
65+ invalid_num += 1
66+ continue
67+ gold_label_idx = LABEL2ID ['None' ]
68+ pred_label_idx = LABEL2ID [event_type ]
69+ calibrate_record .append ([gold_label_idx , pred_label_idx , tri_confidence / 100 ])
70+ if match_idn :
71+ match_idn_num += 1
72+ match_cls = {item for item in match_idn if item [2 ] == event_type }
73+ if match_cls :
74+ match_cls_num += 1
75+
76+ print (f"gold_tri_num: { gold_tri_num } , pred_tri_num: { pred_tri_num } , match_idn_num: { match_idn_num } , match_cls_num: { match_cls_num } , match_word_num: { match_word_num } " )
77+
78+ tri_id_prec , tri_id_rec , tri_id_f = compute_f1 (pred_tri_num , gold_tri_num , match_idn_num )
79+ tri_cls_prec , tri_cls_rec , tri_cls_f = compute_f1 (pred_tri_num , gold_tri_num , match_cls_num )
80+ tri_word_prec , tri_word_rec , tri_word_f = compute_f1 (pred_tri_num , gold_tri_num , match_word_num )
81+ print ('Trigger Identification: P: {:.2f}, R: {:.2f}, F: {:.2f}' .format (tri_id_prec * 100.0 , tri_id_rec * 100.0 , tri_id_f * 100.0 ))
82+ print ('Trigger Classification: P: {:.2f}, R: {:.2f}, F: {:.2f}' .format (tri_cls_prec * 100.0 , tri_cls_rec * 100.0 , tri_cls_f * 100.0 ))
83+ print ('Trigger Word Cls: P: {:.2f}, R: {:.2f}, F: {:.2f}' .format (tri_word_prec * 100.0 , tri_word_rec * 100.0 , tri_word_f * 100.0 ))
84+ print ('(Trigger Word Cls) Correct Mean Confidence: {:.2f}, Incorrect Mean Confidence: {:.2f}' .format (tri_correct_confidence / match_word_num , tri_incorrect_confidence / (pred_tri_num - match_word_num )))
85+ print (f' Auto Rate: { tri_if_reasonable_num / match_word_num } ' )
86+
87+ # argument
88+ for example_id in preds_arg :
89+ preds_arg [example_id ] = list (set ([tuple (i ) for i in preds_arg [example_id ]]))
90+ gold_arg [example_id ] = [tuple (i ) for i in gold_arg [example_id ]]
91+
92+
93+ pred_arg_num , gold_arg_num = 0 , 0
94+ arg_idn_num , arg_class_num , arg_ic_num = 0 , 0 , 0
95+
96+ arg_correct_confidence = 0
97+ arg_incorrect_confidence = 0
98+ arg_if_reasonable_num = 0
99+ for example_id in preds_arg :
100+ pred_arg_num += len (preds_arg [example_id ])
101+ gold_arg_num += len (gold_arg [example_id ])
102+
103+ for example_id in preds_arg :
104+ for pred_arg in preds_arg [example_id ]:
105+ start , end , event_type , role , text , arg_confidence , arg_if_reasonable = pred_arg
106+ gold_idn = {item for item in gold_arg [example_id ] if item [0 ] == start and item [1 ] == end }
107+ gold_ic = [item for item in gold_arg [example_id ] if item [2 ] == event_type and item [3 ] == role and item [4 ] == text ]
108+ if gold_ic :
109+ arg_ic_num += 1
110+ arg_correct_confidence += arg_confidence
111+ if arg_if_reasonable :
112+ arg_if_reasonable_num += 1
113+ if arg_confidence == 0 :
114+ invalid_num += 1
115+ continue
116+ gold_label_idx = LABEL2ID [gold_ic [0 ][3 ]]
117+ pred_label_idx = LABEL2ID [role ]
118+ calibrate_record .append ([gold_label_idx , pred_label_idx , arg_confidence / 100 ])
119+ else :
120+ arg_incorrect_confidence += arg_confidence
121+ if role not in LABEL2ID :
122+ invalid_num += 1
123+ continue
124+ if arg_confidence == 0 :
125+ invalid_num += 1
126+ continue
127+ gold_label_idx = LABEL2ID ['None' ]
128+ pred_label_idx = LABEL2ID [role ]
129+ calibrate_record .append ([gold_label_idx , pred_label_idx , arg_confidence / 100 ])
130+ if gold_idn :
131+ arg_idn_num += 1
132+ gold_class = {item for item in gold_idn if item [2 ] == event_type and item [3 ] == role }
133+ if gold_class :
134+ arg_class_num += 1
135+
136+ print (f"gold_arg_num: { gold_arg_num } , pred_arg_num: { pred_arg_num } , arg_idn_num: { arg_idn_num } , arg_class_num: { arg_class_num } , arg_ic_num: { arg_ic_num } " )
137+
138+ role_id_prec , role_id_rec , role_id_f = compute_f1 (pred_arg_num , gold_arg_num , arg_idn_num )
139+ role_prec , role_rec , role_f = compute_f1 (pred_arg_num , gold_arg_num , arg_class_num )
140+ role_ic_prec , role_ic_rec , role_ic_f = compute_f1 (pred_arg_num , gold_arg_num , arg_ic_num )
141+ print ('Role identification: P: {:.2f}, R: {:.2f}, F: {:.2f}' .format (role_id_prec * 100.0 , role_id_rec * 100.0 , role_id_f * 100.0 ))
142+ print ('Role: P: {:.2f}, R: {:.2f}, F: {:.2f}' .format (role_prec * 100.0 , role_rec * 100.0 , role_f * 100.0 ))
143+ print ('Role ic: P: {:.2f}, R: {:.2f}, F: {:.2f}' .format (role_ic_prec * 100.0 , role_ic_rec * 100.0 , role_ic_f * 100.0 ))
144+ print ('(Role ic) Correct Mean Confidence: {:.2f}, Incorrect Mean Confidence: {:.2f}' .format (arg_correct_confidence / arg_ic_num , arg_incorrect_confidence / (pred_arg_num - arg_ic_num )))
145+ print (f' Auto Rate: { arg_if_reasonable_num / arg_ic_num } ' )
146+
147+ # Compute Expected Calibration Error (ECE)
148+ print (invalid_num , len (calibrate_record ))
149+ assert len (calibrate_record ) == (pred_tri_num + pred_arg_num - invalid_num )
150+ label_idx , pred_idx , prob = zip (* calibrate_record )
151+ labels = torch .tensor (label_idx )
152+ preds = torch .zeros (len (calibrate_record ), len (LABEL2ID ), dtype = torch .float32 )
153+ preds [range (len (calibrate_record )), pred_idx ] = torch .tensor (prob )
154+ metric = MulticlassCalibrationError (num_classes = 56 , n_bins = 50 , norm = 'l1' )
155+ result = metric (preds , labels )
156+ print ('Expected Calibration Error: {:.5f}' .format (result ))
157+
158+
159+ def filter_invalid_answer (preds_event , preds_arg ):
160+ "过滤掉不合法的输出,比如索引为负数的答案"
161+
162+ def if_invalid_event (record ):
163+ filter_words = ['unknown' , 'Unknown' , 'unspecified' , 'not specified' , 'not mentioned' , 'None' , 'none' , 'NONE' , 'not mentioned' , 'not applicable' , 'N/A' ]
164+ if not isinstance (record [2 ], str ):
165+ return True
166+ elif not isinstance (record [3 ], str ):
167+ return True
168+ elif not isinstance (record [0 ], int ):
169+ return True
170+ elif not isinstance (record [1 ], int ):
171+ return True
172+ elif not record [0 ]>= 0 :
173+ return True
174+ elif not record [1 ]>= 0 :
175+ return True
176+ elif record [2 ] in filter_words :
177+ return True
178+ elif record [3 ] in filter_words :
179+ return True
180+ return False
181+
182+ def if_invalid_arg (argument ):
183+ filter_words = ['None' , 'NONE' ,"N/A" ]
184+ if not isinstance (argument [2 ], str ):
185+ return True
186+ elif not isinstance (argument [3 ], str ):
187+ return True
188+ elif not isinstance (argument [4 ], str ):
189+ return True
190+ elif not (isinstance (argument [0 ], int ) and isinstance (argument [1 ], int )):
191+ return True
192+ elif not (argument [0 ]>= 0 and argument [1 ]>= 0 ):
193+ return True
194+ elif argument [2 ] in filter_words :
195+ return True
196+ elif argument [3 ] in filter_words :
197+ return True
198+ elif argument [4 ] in filter_words :
199+ return True
200+ return False
201+
202+ count_event , count_arg = 0 , 0
203+
204+ for example_id in preds_event :
205+ for record in preds_event [example_id ][::- 1 ]:
206+ if if_invalid_event (record ):
207+ preds_event [example_id ].remove (record )
208+ count_event += 1
209+ for example_id in preds_event :
210+ for record in preds_event [example_id ]:
211+ record [1 ] += 1 # end + 1
212+
213+ count_arg = 0
214+ for example_id in preds_arg :
215+ for argument in preds_arg [example_id ][::- 1 ]:
216+ if if_invalid_arg (argument ):
217+ preds_arg [example_id ].remove (argument )
218+ count_arg += 1
219+ for example_id in preds_arg :
220+ for argument in preds_arg [example_id ]:
221+ argument [1 ] += 1
222+
223+
224+ return count_event , count_arg
225+
226+
227+ def read_gold_example (path ):
228+ gold_event , gold_arg = {}, {}
229+ with open (path ) as f :
230+ lines = f .readlines ()
231+ for line in lines :
232+ line = json .loads (line )
233+ events = line ['event' ]
234+ gold_event [line ['id' ]] = []
235+ gold_arg [line ['id' ]] = []
236+ for event in events :
237+ event_type = event ['event_type' ]
238+
239+ trigger_start = event ['trigger' ]['start' ]
240+ trigger_end = event ['trigger' ]['end' ]
241+ trigger_word = event ['trigger' ]['text' ]
242+ gold_event [line ['id' ]].append ([trigger_start , trigger_end , event_type , trigger_word ])
243+
244+ for arg in event ['arguments' ]:
245+ start = arg ['start' ]
246+ end = arg ['end' ]
247+ role = arg ['role' ]
248+ text = arg ['text' ]
249+ # print(text)
250+ gold_arg [line ['id' ]].append ([start , end , event_type , role , text ])
251+
252+ return gold_event , gold_arg
253+
254+
255+ def get_vocab ():
256+ all_labels = ['None' ]
257+ for label in EVENT_TYPE + ROLE :
258+ all_labels .append (label )
259+ label2id = {label : idx for idx , label in enumerate (all_labels )}
260+ id2label = {idx : label for idx , label in enumerate (all_labels )}
261+ return label2id , id2label
262+
263+
264+ def main (result_dir , gold_path ):
265+ gold_event , gold_arg = read_gold_example (gold_path )
266+
267+
268+ preds_event , preds_arg = {}, {}
269+ # print(len(os.listdir(result_dir)))
270+ for file in os .listdir (result_dir ):
271+ example_id = file [:- 5 ]
272+ preds_event [example_id ] = []
273+ preds_arg [example_id ] = []
274+ file_path = os .path .join (result_dir , file )
275+ print (file )
276+ with open (file_path , 'r' , encoding = 'utf-8' ) as f :
277+ res = json .load (f )
278+ for event in res :
279+ event_type = event ['event_type' ]
280+ trigger_start = event ['start_word_index' ]
281+ trigger_end = event ['end_word_index' ]
282+ trigger_word = event ['trigger' ]
283+ tri_confidence = event ['confidence' ]
284+ if 'if_reasonable' not in event :
285+ event ['if_reasonable' ] = 0
286+ else :
287+ tri_if_reasonable = event ['if_reasonable' ]
288+ preds_event [example_id ].append ([trigger_start , trigger_end , event_type , trigger_word , tri_confidence , tri_if_reasonable ])
289+
290+ for arg in event ['participants' ]:
291+ start = arg ['start_word_index' ]
292+ end = arg ['end_word_index' ]
293+ role = arg ['role' ]
294+ text = arg ['span' ]
295+ arg_confidence = arg ['confidence' ]
296+ if 'if_reasonable' not in arg :
297+ arg ['if_reasonable' ] = 0
298+ else :
299+ arg_if_reasonable = arg ['if_reasonable' ]
300+ preds_arg [example_id ].append ([start , end , event_type , role , text , arg_confidence , arg_if_reasonable ])
301+
302+ count_event , count_arg = filter_invalid_answer (preds_event , preds_arg )
303+ print (count_event , count_arg )
304+ evaluate (preds_event , preds_arg , gold_event , gold_arg )
305+
306+ LABEL2ID ,ID2LABEL = get_vocab ()
307+ if __name__ == "__main__" :
308+ result_dir = './Output/EE/Full_Testset/EE_E+_Closed'
309+ gold_path = './data/ACE05-E+/EE_E+_gold.json'
310+ main (result_dir , gold_path )
311+
0 commit comments