Skip to content

Commit 8cf6ca5

Browse files
authored
Add files via upload
1 parent b318d63 commit 8cf6ca5

File tree

3 files changed

+627
-0
lines changed

3 files changed

+627
-0
lines changed

Code/EAE/const.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
EVENT_TYPE = [
2+
"Movement.Transport",
3+
"Personnel.Elect",
4+
"Personnel.Start-Position",
5+
"Personnel.Nominate",
6+
"Personnel.End-Position",
7+
"Conflict.Attack",
8+
"Conflict.Demonstrate",
9+
"Contact.Phone-Write",
10+
"Contact.Meet",
11+
"Transaction.Transfer-Money",
12+
"Transaction.Transfer-Ownership",
13+
"Business.Start-Org",
14+
"Business.Merge-Org",
15+
"Business.Declare-Bankruptcy",
16+
"Business.End-Org",
17+
"Life.Be-Born",
18+
"Life.Injure",
19+
"Life.Die",
20+
"Life.Marry",
21+
"Life.Divorce",
22+
"Justice.Sue",
23+
"Justice.Arrest-Jail",
24+
"Justice.Execute",
25+
"Justice.Charge-Indict",
26+
"Justice.Convict",
27+
"Justice.Trial-Hearing",
28+
"Justice.Sentence",
29+
"Justice.Release-Parole",
30+
"Justice.Fine",
31+
"Justice.Pardon",
32+
"Justice.Appeal",
33+
"Justice.Extradite",
34+
"Justice.Acquit"
35+
]
36+
37+
ENTITY = ['PER', 'ORG', 'GPE', 'LOC', 'FAC', 'VEH', 'WEA']
38+
39+
ROLE = ['Org', 'Place', 'Instrument', 'Vehicle', 'Attacker', 'Prosecutor', 'Agent', 'Victim', 'Origin', 'Target', 'Giver', 'Seller', 'Defendant', 'Recipient', 'Entity', 'Plaintiff', 'Person', 'Artifact', 'Destination', 'Adjudicator', 'Beneficiary', 'Buyer']
40+
41+
42+
All_Valid_EntTypes = {
43+
('Movement.Transport', 'Vehicle'): {'VEH'},
44+
('Movement.Transport', 'Artifact'): {'VEH', 'PER', 'WEA'},
45+
('Movement.Transport', 'Destination'): {'GPE', 'FAC', 'LOC'},
46+
('Personnel.Elect', 'Person'): {'PER'},
47+
('Movement.Transport', 'Agent'): {'ORG', 'GPE', 'PER'},
48+
('Personnel.Start-Position', 'Person'): {'PER'},
49+
('Personnel.Start-Position', 'Entity'): {'ORG', 'GPE'},
50+
('Personnel.Nominate', 'Person'): {'PER'},
51+
('Conflict.Attack', 'Place'): {'GPE', 'FAC', 'LOC'},
52+
('Personnel.End-Position', 'Entity'): {'ORG', 'GPE'},
53+
('Personnel.End-Position', 'Person'): {'PER'},
54+
('Contact.Meet', 'Entity'): {'ORG', 'GPE', 'PER'},
55+
('Contact.Meet', 'Place'): {'GPE', 'FAC', 'LOC'},
56+
('Life.Marry', 'Person'): {'PER'},
57+
('Personnel.Elect', 'Entity'): {'ORG', 'GPE', 'PER'},
58+
('Conflict.Attack', 'Target'): {'ORG', 'PER', 'VEH', 'FAC', 'LOC', 'WEA'},
59+
('Conflict.Attack', 'Attacker'): {'ORG', 'GPE', 'PER'},
60+
('Transaction.Transfer-Money', 'Giver'): {'ORG', 'GPE', 'PER'},
61+
('Transaction.Transfer-Money', 'Recipient'): {'ORG', 'GPE', 'PER'},
62+
('Conflict.Demonstrate', 'Entity'): {'ORG', 'PER'},
63+
('Conflict.Demonstrate', 'Place'): {'GPE', 'FAC', 'LOC'},
64+
('Business.End-Org', 'Place'): {'GPE', 'FAC'},
65+
('Justice.Sue', 'Plaintiff'): {'ORG', 'PER'},
66+
('Life.Injure', 'Victim'): {'PER'},
67+
('Life.Injure', 'Agent'): {'GPE', 'PER'},
68+
('Life.Die', 'Victim'): {'PER'},
69+
('Life.Die', 'Agent'): {'ORG', 'GPE', 'PER'},
70+
('Personnel.Start-Position', 'Place'): {'GPE', 'FAC'},
71+
('Life.Divorce', 'Place'): {'GPE', 'FAC', 'LOC'},
72+
('Life.Die', 'Place'): {'GPE', 'FAC', 'LOC'},
73+
('Justice.Arrest-Jail', 'Person'): {'PER'},
74+
('Justice.Arrest-Jail', 'Agent'): {'ORG', 'GPE', 'PER'},
75+
('Personnel.End-Position', 'Place'): {'GPE', 'FAC'},
76+
('Contact.Phone-Write', 'Entity'): {'ORG', 'PER'}, ('Life.Injure', 'Place'): {'GPE', 'FAC', 'LOC'}, ('Transaction.Transfer-Ownership', 'Buyer'): {'ORG', 'GPE', 'PER'}, ('Transaction.Transfer-Ownership', 'Artifact'): {'ORG', 'VEH', 'FAC', 'WEA'}, ('Transaction.Transfer-Ownership', 'Seller'): {'ORG', 'GPE', 'PER'}, ('Conflict.Attack', 'Instrument'): {'VEH', 'WEA'}, ('Life.Die', 'Instrument'): {'VEH', 'WEA'}, ('Justice.Arrest-Jail', 'Place'): {'GPE', 'FAC'}, ('Movement.Transport', 'Origin'): {'GPE', 'FAC', 'LOC'}, ('Business.End-Org', 'Org'): {'ORG'}, ('Life.Injure', 'Instrument'): {'VEH', 'WEA'}, ('Transaction.Transfer-Ownership', 'Place'): {'GPE', 'FAC', 'LOC'}, ('Transaction.Transfer-Ownership', 'Beneficiary'): {'GPE', 'PER'}, ('Justice.Execute', 'Place'): {'GPE', 'FAC'}, ('Justice.Execute', 'Agent'): {'ORG', 'GPE', 'PER'}, ('Conflict.Attack', 'Victim'): {'PER'}, ('Contact.Phone-Write', 'Place'): {'GPE', 'FAC', 'LOC'}, ('Justice.Trial-Hearing', 'Defendant'): {'ORG', 'PER'}, ('Justice.Execute', 'Person'): {'PER'}, ('Movement.Transport', 'Place'): {'GPE'}, ('Personnel.Elect', 'Place'): {'GPE', 'LOC'}, ('Life.Be-Born', 'Place'): {'GPE', 'FAC', 'LOC'}, ('Justice.Charge-Indict', 'Adjudicator'): {'ORG', 'PER'}, ('Business.Start-Org', 'Org'): {'ORG'}, ('Business.Start-Org', 'Place'): {'GPE', 'FAC'}, ('Justice.Convict', 'Defendant'): {'ORG', 'PER'}, ('Justice.Convict', 'Adjudicator'): {'ORG'}, ('Justice.Sentence', 'Defendant'): {'ORG', 'PER'}, ('Justice.Sentence', 'Adjudicator'): {'ORG', 'GPE', 'PER'}, ('Business.Declare-Bankruptcy', 'Org'): {'ORG', 'PER'}, ('Justice.Release-Parole', 'Entity'): {'ORG', 'GPE', 'PER'}, ('Justice.Release-Parole', 'Person'): {'PER'}, ('Justice.Charge-Indict', 'Defendant'): {'ORG', 'PER'}, ('Justice.Trial-Hearing', 'Place'): {'GPE', 'FAC', 'LOC'}, ('Justice.Trial-Hearing', 'Adjudicator'): {'ORG', 'PER'}, ('Justice.Trial-Hearing', 'Prosecutor'): {'ORG', 'PER'}, ('Justice.Charge-Indict', 'Prosecutor'): {'ORG', 'GPE', 'PER'}, ('Justice.Fine', 'Entity'): {'ORG', 'GPE', 'PER'}, ('Business.Start-Org', 'Agent'): {'ORG', 'GPE', 'PER'}, ('Justice.Pardon', 'Adjudicator'): {'ORG', 'PER'}, ('Justice.Charge-Indict', 'Place'): {'GPE', 'FAC', 'LOC'}, ('Justice.Appeal', 'Adjudicator'): {'ORG', 'PER'}, ('Justice.Appeal', 'Plaintiff'): {'ORG', 'GPE', 'PER'}, ('Justice.Sentence', 'Place'): {'GPE', 'FAC'}, ('Life.Die', 'Person'): {'PER'}, ('Life.Be-Born', 'Person'): {'PER'}, ('Justice.Release-Parole', 'Place'): {'GPE', 'FAC'}, ('Justice.Sue', 'Defendant'): {'ORG', 'GPE', 'PER'}, ('Transaction.Transfer-Money', 'Beneficiary'): {'ORG', 'GPE', 'PER'}, ('Justice.Convict', 'Place'): {'GPE'}, ('Justice.Extradite', 'Origin'): {'GPE', 'FAC'}, ('Justice.Extradite', 'Destination'): {'GPE'}, ('Justice.Appeal', 'Place'): {'GPE', 'FAC'}, ('Business.Declare-Bankruptcy', 'Place'): {'GPE'}, ('Justice.Fine', 'Adjudicator'): {'PER'}, ('Life.Marry', 'Place'): {'GPE', 'FAC'}, ('Life.Divorce', 'Person'): {'PER'}, ('Personnel.Nominate', 'Agent'): {'ORG', 'GPE', 'PER'}, ('Business.Merge-Org', 'Org'): {'ORG'}, ('Justice.Acquit', 'Defendant'): {'PER'}, ('Justice.Sue', 'Adjudicator'): {'ORG', 'PER'}, ('Justice.Sue', 'Place'): {'GPE', 'LOC'}, ('Justice.Fine', 'Place'): {'GPE', 'FAC'}, ('Justice.Pardon', 'Place'): {'GPE'}, ('Justice.Pardon', 'Defendant'): {'PER'}, ('Justice.Acquit', 'Adjudicator'): {'PER'}, ('Transaction.Transfer-Money', 'Place'): {'GPE', 'LOC'}, ('Justice.Extradite', 'Agent'): {'ORG'}}

Code/EAE/score_EAE_E+.py

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
import os
2+
import json
3+
import spacy
4+
from spacy.tokens import Doc
5+
from const import ROLE, All_Valid_EntTypes
6+
import torch
7+
from torchmetrics.classification import MulticlassCalibrationError
8+
9+
10+
class WhitespaceTokenizer:
11+
def __init__(self, vocab):
12+
self.vocab = vocab
13+
14+
def __call__(self, text):
15+
words = text.split(" ")
16+
return Doc(self.vocab, words=words)
17+
18+
19+
nlp = spacy.load('en_core_web_sm')
20+
nlp.tokenizer = WhitespaceTokenizer(nlp.vocab)
21+
22+
23+
def safe_div(num, denom):
24+
if denom > 0:
25+
return num / denom
26+
else:
27+
return 0
28+
29+
def compute_f1(predicted, gold, matched):
30+
precision = safe_div(matched, predicted)
31+
recall = safe_div(matched, gold)
32+
f1 = safe_div(2 * precision * recall, precision + recall)
33+
return precision, recall, f1
34+
35+
36+
def find_head(arg_start, arg_end, doc):
37+
cur_i = arg_start
38+
while doc[cur_i].head.i >= arg_start and doc[cur_i].head.i <=arg_end:
39+
if doc[cur_i].head.i == cur_i:
40+
# self is the head
41+
break
42+
else:
43+
cur_i = doc[cur_i].head.i
44+
45+
arg_head = cur_i
46+
47+
return (arg_head, arg_head)
48+
49+
50+
def clean_span(tokens, start, end):
51+
if tokens[start].lower() in {'the', 'an', 'a'}:
52+
if start!=end:
53+
return (start+1, end)
54+
return start, end
55+
56+
57+
def evaluate(preds, gold, only_head=False):
58+
for example_id in gold:
59+
for argument in gold[example_id]:
60+
if only_head:
61+
words = argument.pop(4)
62+
doc = nlp(' '.join(words))
63+
argument[3] -= 1
64+
argument[2], argument[3] = clean_span(words, argument[2], argument[3])
65+
argument[2], argument[3] = find_head(argument[2], argument[3], doc)
66+
assert argument[2] == argument[3]
67+
argument[1] = words[argument[2]]
68+
argument[3] += 1
69+
else:
70+
argument.pop(4)
71+
72+
73+
for example_id in preds:
74+
preds[example_id] = list(set([tuple(i) for i in preds[example_id]]))
75+
gold[example_id] = [tuple(i) for i in gold[example_id]]
76+
77+
pred_arg_num, gold_arg_num = 0, 0
78+
arg_idn_num, arg_class_num, arg_ic_num = 0, 0, 0
79+
80+
for example_id in preds:
81+
pred_arg_num += len(preds[example_id])
82+
gold_arg_num += len(gold[example_id])
83+
84+
correct_confidence = 0
85+
incorrect_confidence = 0
86+
if_reasonable_num = 0
87+
calibrate_record = []
88+
invalid_role = 0
89+
for example_id in preds:
90+
for pred_arg in preds[example_id]:
91+
role, span, arg_start, arg_end, confidence, if_reasonable = pred_arg
92+
gold_idn = {item for item in gold[example_id] if item[2] == arg_start and item[3] == arg_end}
93+
gold_ic = [item for item in gold[example_id] if item[0] == role and item[1] == span]
94+
if gold_ic:
95+
arg_ic_num += 1
96+
correct_confidence += confidence
97+
if if_reasonable:
98+
if_reasonable_num += 1
99+
if confidence == 0:
100+
invalid_role += 1
101+
continue
102+
gold_label_idx = LABEL2ID[gold_ic[0][0]]
103+
pred_label_idx = LABEL2ID[role]
104+
calibrate_record.append([gold_label_idx, pred_label_idx, confidence / 100])
105+
else:
106+
incorrect_confidence += confidence
107+
if role not in LABEL2ID:
108+
invalid_role += 1
109+
continue
110+
if confidence == 0:
111+
invalid_role += 1
112+
continue
113+
gold_label_idx = LABEL2ID['None']
114+
pred_label_idx = LABEL2ID[role]
115+
calibrate_record.append([gold_label_idx, pred_label_idx, confidence / 100])
116+
if gold_idn:
117+
arg_idn_num += 1
118+
gold_class = {item for item in gold_idn if item[0] == role}
119+
if gold_class:
120+
arg_class_num += 1
121+
122+
print(f"gold_arg_num: {gold_arg_num}, pred_arg_num: {pred_arg_num}, arg_idn_num: {arg_idn_num}, arg_class_num: {arg_class_num}, arg_ic_num: {arg_ic_num}")
123+
124+
role_id_prec, role_id_rec, role_id_f = compute_f1(pred_arg_num, gold_arg_num, arg_idn_num)
125+
role_prec, role_rec, role_f = compute_f1(pred_arg_num, gold_arg_num, arg_class_num)
126+
role_ic_prec, role_ic_rec, role_ic_f = compute_f1(pred_arg_num, gold_arg_num, arg_ic_num)
127+
print('Role identification: P: {:.2f}, R: {:.2f}, F: {:.2f}'.format(role_id_prec * 100.0, role_id_rec * 100.0, role_id_f * 100.0))
128+
print('Role: P: {:.2f}, R: {:.2f}, F: {:.2f}'.format(role_prec * 100.0, role_rec * 100.0, role_f * 100.0))
129+
print('Role ic: P: {:.2f}, R: {:.2f}, F: {:.2f}'.format(role_ic_prec * 100.0, role_ic_rec * 100.0, role_ic_f * 100.0))
130+
print('(Role ic) Correct Mean Confidence: {:.2f}, Incorrect Mean Confidence: {:.2f}'.format(correct_confidence / arg_ic_num, incorrect_confidence / (pred_arg_num - arg_ic_num)))
131+
print(f' Auto Rate: {if_reasonable_num/arg_ic_num}')
132+
133+
# Compute Expected Calibration Error (ECE)
134+
assert len(calibrate_record) == (pred_arg_num - invalid_role)
135+
print(invalid_role, len(calibrate_record))
136+
label_idx, pred_idx, prob = zip(*calibrate_record)
137+
labels = torch.tensor(label_idx)
138+
preds = torch.zeros(len(calibrate_record), len(LABEL2ID), dtype=torch.float32)
139+
preds[range(len(calibrate_record)), pred_idx] = torch.tensor(prob)
140+
metric = MulticlassCalibrationError(num_classes=23, n_bins=50, norm='l1')
141+
result = metric(preds, labels)
142+
print('Expected Calibration Error: {:.5f}'.format(result))
143+
144+
145+
def read_question(question_path):
146+
Event2Query = {}
147+
with open(question_path, 'r') as f:
148+
lines = f.readlines()
149+
for line in lines:
150+
line = line.strip()
151+
event_arg, query = line.split(",")
152+
event, arg = event_arg.split("_")
153+
if event not in Event2Query:
154+
Event2Query[event] = []
155+
Event2Query[event].append((arg, query))
156+
157+
return Event2Query
158+
159+
160+
def filter_invalid_answer(preds, only_head=False):
161+
"过滤掉不合法的输出,比如索引为负数的答案;end_index加1(因为Prompt中明确说明了end_word_index应该是inclusive的)"
162+
163+
def if_invalid(argument):
164+
filter_words = ['unknown', 'Unknown', 'unspecified', 'not specified', 'not mentioned', 'None', 'none', 'not mentioned', 'not applicable', 'N/A']
165+
if not isinstance(argument[1], str):
166+
return True
167+
elif not isinstance(argument[0], str):
168+
return True
169+
elif not (isinstance(argument[2], int) and isinstance(argument[3], int)):
170+
return True
171+
elif not (argument[2]>=0 and argument[3]>=0):
172+
return True
173+
elif [i for i in filter_words if i in argument[1]]:
174+
return True
175+
return False
176+
177+
count = 0
178+
for example_id in preds:
179+
for argument in preds[example_id][::-1]:
180+
if if_invalid(argument):
181+
preds[example_id].remove(argument)
182+
count += 1
183+
for example_id in preds:
184+
for argument in preds[example_id]:
185+
186+
if only_head:
187+
words = argument.pop(4)
188+
189+
if argument[2]>=len(words):
190+
argument[3] += 1
191+
continue
192+
193+
doc = nlp(' '.join(words))
194+
argument[2], argument[3] = find_head(argument[2], argument[3], doc)
195+
assert argument[2] == argument[3]
196+
argument[1] = words[argument[2]]
197+
else:
198+
argument.pop(4)
199+
argument[3] += 1
200+
201+
return count
202+
203+
204+
def join(word_list):
205+
res = ''
206+
for idx, word in enumerate(word_list):
207+
if idx==0:
208+
res = word
209+
else:
210+
if "'" in word:
211+
res += word
212+
elif word == '-':
213+
res += word
214+
elif word_list[idx-1] == '-':
215+
res += word
216+
else:
217+
res += (' '+word)
218+
return res
219+
220+
221+
def read_gold_example(path):
222+
gold = {}
223+
gold_event = {}
224+
gold_words = {}
225+
with open(path) as f:
226+
lines = f.readlines()
227+
for line in lines:
228+
line = json.loads(line)
229+
event_type = line['event']['event_type']
230+
gold[line['id']] = []
231+
gold_event[line['id']] = event_type
232+
gold_words[line['id']] = line['words']
233+
for arg in line['event']['argument']:
234+
role = arg['role']
235+
span = arg['text']
236+
# print(span)
237+
gold[line['id']].append([role, span, arg['start'] , arg['end'], line['words']])
238+
239+
return gold, gold_event, gold_words
240+
241+
242+
def get_vocab():
243+
all_labels = ['None']
244+
for label in ROLE:
245+
all_labels.append(label)
246+
label2id = {label: idx for idx, label in enumerate(all_labels)}
247+
id2label = {idx: label for idx, label in enumerate(all_labels)}
248+
return label2id, id2label
249+
250+
251+
def main(result_dir, gold_path, question_path, only_head=False):
252+
253+
Event2Query = read_question(question_path)
254+
gold, gold_event, gold_words = read_gold_example(gold_path)
255+
256+
preds = {}
257+
for file in os.listdir(result_dir):
258+
example_id = file[:-5]
259+
preds[example_id] = []
260+
file_path = os.path.join(result_dir, file)
261+
# print(file)
262+
with open(file_path, 'r') as f:
263+
res = json.load(f)
264+
question_num = len(res)
265+
questions = [f"Question{i+1}" for i in range(question_num)]
266+
event_type = gold_event[example_id]
267+
words = gold_words[example_id]
268+
all_role = [i[0] for i in Event2Query[event_type]]
269+
assert len(all_role) == question_num
270+
for idx, ques in enumerate(questions):
271+
answers = res[ques]
272+
for ans in answers:
273+
role = all_role[idx]
274+
span = ans['span']
275+
confidence = ans['confidence']
276+
if_reasonable = ans['if_reasonable']
277+
preds[example_id].append([role, span, ans['start_word_index'] , ans['end_word_index'], words, confidence, if_reasonable])
278+
279+
280+
invalid_arg_num = filter_invalid_answer(preds, only_head)
281+
print(invalid_arg_num)
282+
evaluate(preds, gold, only_head)
283+
284+
285+
LABEL2ID,ID2LABEL = get_vocab()
286+
if __name__ == "__main__":
287+
result_dir = './Output/EAE/Full_Testset/EAE_E+_Closed'
288+
question_path = './Code/description_queries_new.csv'
289+
gold_path = "./data/ACE05-E+/EAE_E+_gold.json"
290+
only_head = False
291+
main(result_dir, gold_path, question_path, only_head)
292+

0 commit comments

Comments
 (0)