Skip to content

Commit b1d813f

Browse files
author
shenwzh3
committed
v1
1 parent 59b6e16 commit b1d813f

8 files changed

Lines changed: 1571 additions & 0 deletions

File tree

dataloader.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from dataset import *
2+
import pickle
3+
from torch.utils.data.sampler import SubsetRandomSampler
4+
from torch.utils.data import DataLoader
5+
import os
6+
import argparse
7+
import numpy as np
8+
from transformers import BertTokenizer
9+
10+
def get_train_valid_sampler(trainset):
11+
size = len(trainset)
12+
idx = list(range(size))
13+
return SubsetRandomSampler(idx)
14+
15+
16+
def load_vocab(dataset_name):
17+
speaker_vocab = pickle.load(open('../data/%s/speaker_vocab.pkl' % (dataset_name), 'rb'))
18+
label_vocab = pickle.load(open('../data/%s/label_vocab.pkl' % (dataset_name), 'rb'))
19+
person_vec_dir = '../data/%s/person_vect.pkl' % (dataset_name)
20+
# if os.path.exists(person_vec_dir):
21+
# print('Load person vec from ' + person_vec_dir)
22+
# person_vec = pickle.load(open(person_vec_dir, 'rb'))
23+
# else:
24+
# print('Creating personality vectors')
25+
# person_vec = np.random.randn(len(speaker_vocab['itos']), 100)a
26+
# print('Saving personality vectors to' + person_vec_dir)
27+
# with open(person_vec_dir,'wb') as f:
28+
# pickle.dump(person_vec, f, -1)
29+
person_vec = None
30+
31+
return speaker_vocab, label_vocab, person_vec
32+
33+
34+
def get_IEMOCAP_loaders(dataset_name = 'IEMOCAP', batch_size=32, num_workers=0, pin_memory=False, args = None):
35+
print('building vocab.. ')
36+
speaker_vocab, label_vocab, person_vec = load_vocab(dataset_name)
37+
print('building datasets..')
38+
trainset = IEMOCAPDataset(dataset_name, 'train', speaker_vocab, label_vocab, args)
39+
devset = IEMOCAPDataset(dataset_name, 'dev', speaker_vocab, label_vocab, args)
40+
train_sampler = get_train_valid_sampler(trainset)
41+
valid_sampler = get_train_valid_sampler(devset)
42+
43+
train_loader = DataLoader(trainset,
44+
batch_size=batch_size,
45+
sampler=train_sampler,
46+
collate_fn=trainset.collate_fn,
47+
num_workers=num_workers,
48+
pin_memory=pin_memory)
49+
50+
valid_loader = DataLoader(devset,
51+
batch_size=batch_size,
52+
sampler=valid_sampler,
53+
collate_fn=devset.collate_fn,
54+
num_workers=num_workers,
55+
pin_memory=pin_memory)
56+
57+
testset = IEMOCAPDataset(dataset_name, 'test', speaker_vocab, label_vocab, args)
58+
test_loader = DataLoader(testset,
59+
batch_size=batch_size,
60+
collate_fn=testset.collate_fn,
61+
num_workers=num_workers,
62+
pin_memory=pin_memory)
63+
64+
return train_loader, valid_loader, test_loader, speaker_vocab, label_vocab, person_vec

dataset.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import torch
2+
from torch.utils.data import Dataset
3+
from torch.nn.utils.rnn import pad_sequence
4+
import pickle, pandas as pd
5+
import json
6+
import numpy as np
7+
import random
8+
from pandas import DataFrame
9+
10+
11+
class IEMOCAPDataset(Dataset):
12+
13+
def __init__(self, dataset_name = 'IEMOCAP', split = 'train', speaker_vocab=None, label_vocab=None, args = None, tokenizer = None):
14+
self.speaker_vocab = speaker_vocab
15+
self.label_vocab = label_vocab
16+
self.args = args
17+
self.data = self.read(dataset_name, split, tokenizer)
18+
print(len(self.data))
19+
20+
self.len = len(self.data)
21+
22+
def read(self, dataset_name, split, tokenizer):
23+
with open('../data/%s/%s_data_roberta.json.feature'%(dataset_name, split), encoding='utf-8') as f:
24+
raw_data = json.load(f)
25+
26+
# process dialogue
27+
dialogs = []
28+
# raw_data = sorted(raw_data, key=lambda x:len(x))
29+
for d in raw_data:
30+
# if len(d) < 5 or len(d) > 6:
31+
# continue
32+
utterances = []
33+
labels = []
34+
speakers = []
35+
features = []
36+
for i,u in enumerate(d):
37+
utterances.append(u['text'])
38+
labels.append(self.label_vocab['stoi'][u['label']] if 'label' in u.keys() else -1)
39+
speakers.append(self.speaker_vocab['stoi'][u['speaker']])
40+
features.append(u['cls'])
41+
dialogs.append({
42+
'utterances': utterances,
43+
'labels': labels,
44+
'speakers':speakers,
45+
'features': features
46+
})
47+
random.shuffle(dialogs)
48+
return dialogs
49+
50+
def __getitem__(self, index):
51+
'''
52+
:param index:
53+
:return:
54+
feature,
55+
label
56+
speaker
57+
length
58+
text
59+
'''
60+
return torch.FloatTensor(self.data[index]['features']), \
61+
torch.LongTensor(self.data[index]['labels']),\
62+
self.data[index]['speakers'], \
63+
len(self.data[index]['labels']), \
64+
self.data[index]['utterances']
65+
66+
def __len__(self):
67+
return self.len
68+
69+
def get_adj(self, speakers, max_dialog_len):
70+
'''
71+
get adj matrix
72+
:param speakers: (B, N)
73+
:param max_dialog_len:
74+
:return:
75+
adj: (B, N, N) adj[:,i,:] means the direct predecessors of node i
76+
'''
77+
adj = []
78+
for speaker in speakers:
79+
a = torch.zeros(max_dialog_len, max_dialog_len)
80+
for i,s in enumerate(speaker):
81+
get_local_pred = False
82+
get_global_pred = False
83+
for j in range(i - 1, -1, -1):
84+
if speaker[j] == s and not get_local_pred:
85+
get_local_pred = True
86+
a[i,j] = 1
87+
elif speaker[j] != s and not get_global_pred:
88+
get_global_pred = True
89+
a[i,j] = 1
90+
if get_global_pred and get_local_pred:
91+
break
92+
adj.append(a)
93+
return torch.stack(adj)
94+
95+
def get_adj_v1(self, speakers, max_dialog_len):
96+
'''
97+
get adj matrix
98+
:param speakers: (B, N)
99+
:param max_dialog_len:
100+
:return:
101+
adj: (B, N, N) adj[:,i,:] means the direct predecessors of node i
102+
'''
103+
adj = []
104+
for speaker in speakers:
105+
a = torch.zeros(max_dialog_len, max_dialog_len)
106+
for i,s in enumerate(speaker):
107+
cnt = 0
108+
for j in range(i - 1, -1, -1):
109+
a[i,j] = 1
110+
if speaker[j] == s:
111+
cnt += 1
112+
if cnt==self.args.windowp:
113+
break
114+
adj.append(a)
115+
return torch.stack(adj)
116+
117+
def get_s_mask(self, speakers, max_dialog_len):
118+
'''
119+
:param speakers:
120+
:param max_dialog_len:
121+
:return:
122+
s_mask: (B, N, N) s_mask[:,i,:] means the speaker informations for predecessors of node i, where 1 denotes the same speaker, 0 denotes the different speaker
123+
s_mask_onehot (B, N, N, 2) onehot emcoding of s_mask
124+
'''
125+
s_mask = []
126+
s_mask_onehot = []
127+
for speaker in speakers:
128+
s = torch.zeros(max_dialog_len, max_dialog_len, dtype = torch.long)
129+
s_onehot = torch.zeros(max_dialog_len, max_dialog_len, 2)
130+
for i in range(len(speaker)):
131+
for j in range(len(speaker)):
132+
if speaker[i] == speaker[j]:
133+
s[i,j] = 1
134+
s_onehot[i,j,1] = 1
135+
else:
136+
s_onehot[i,j,0] = 1
137+
138+
s_mask.append(s)
139+
s_mask_onehot.append(s_onehot)
140+
return torch.stack(s_mask), torch.stack(s_mask_onehot)
141+
142+
def collate_fn(self, data):
143+
'''
144+
:param data:
145+
features, labels, speakers, length, utterances
146+
:return:
147+
features: (B, N, D) padded
148+
labels: (B, N) padded
149+
adj: (B, N, N) adj[:,i,:] means the direct predecessors of node i
150+
s_mask: (B, N, N) s_mask[:,i,:] means the speaker informations for predecessors of node i, where 1 denotes the same speaker, 0 denotes the different speaker
151+
lengths: (B, )
152+
utterances: not a tensor
153+
'''
154+
max_dialog_len = max([d[3] for d in data])
155+
feaures = pad_sequence([d[0] for d in data], batch_first = True) # (B, N, D)
156+
labels = pad_sequence([d[1] for d in data], batch_first = True, padding_value = -1) # (B, N )
157+
adj = self.get_adj_v1([d[2] for d in data], max_dialog_len)
158+
s_mask, s_mask_onehot = self.get_s_mask([d[2] for d in data], max_dialog_len)
159+
lengths = torch.LongTensor([d[3] for d in data])
160+
speakers = pad_sequence([torch.LongTensor(d[2]) for d in data], batch_first = True, padding_value = -1)
161+
utterances = [d[4] for d in data]
162+
163+
return feaures, labels, adj,s_mask, s_mask_onehot,lengths, speakers, utterances

evaluate.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import os
2+
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
3+
import numpy as np, argparse, time, pickle, random
4+
import torch
5+
import torch.nn as nn
6+
import torch.optim as optim
7+
from dataloader import IEMOCAPDataset
8+
from model import *
9+
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score, classification_report, \
10+
precision_recall_fscore_support
11+
from trainer import train_or_eval_model, save_badcase
12+
from dataset import IEMOCAPDataset
13+
from dataloader import get_IEMOCAP_loaders
14+
from transformers import AdamW
15+
import copy
16+
17+
# We use seed = 100 for reproduction of the results reported in the paper.
18+
seed = 100
19+
20+
21+
def seed_everything(seed=seed):
22+
random.seed(seed)
23+
np.random.seed(seed)
24+
torch.manual_seed(seed)
25+
torch.cuda.manual_seed(seed)
26+
torch.cuda.manual_seed_all(seed)
27+
torch.backends.cudnn.benchmark = False
28+
torch.backends.cudnn.deterministic = True
29+
30+
31+
def evaluate(model, dataloader, cuda, args, speaker_vocab, label_vocab):
32+
preds, labels = [], []
33+
scores, vids = [], []
34+
dialogs = []
35+
speakers = []
36+
37+
model.eval()
38+
39+
for data in dataloader:
40+
41+
features, label, adj,s_mask, s_mask_onehot,lengths, speaker, utterances = data
42+
if cuda:
43+
features = features.cuda()
44+
label = label.cuda()
45+
adj = adj.cuda()
46+
s_mask_onehot = s_mask_onehot.cuda()
47+
s_mask = s_mask.cuda()
48+
lengths = lengths.cuda()
49+
50+
log_prob = model(features, adj,s_mask, s_mask_onehot, lengths) # (B, N, C)
51+
52+
label = label.cpu().numpy().tolist() # (B, N)
53+
pred = torch.argmax(log_prob, dim = 2).cpu().numpy().tolist() # (B, N)
54+
preds += pred
55+
labels += label
56+
dialogs += utterances
57+
speakers += speaker
58+
59+
if preds != []:
60+
new_preds = []
61+
new_labels = []
62+
for i,label in enumerate(labels):
63+
for j,l in enumerate(label):
64+
if l != -1:
65+
new_labels.append(l)
66+
new_preds.append(preds[i][j])
67+
else:
68+
return
69+
70+
avg_accuracy = round(accuracy_score(new_labels, new_preds) * 100, 2)
71+
if args.dataset_name in ['IEMOCAP', 'MELD', 'EmoryNLP']:
72+
avg_fscore = round(f1_score(new_labels, new_preds, average='weighted') * 100, 2)
73+
print('test_accuracy', avg_accuracy)
74+
print('test_f1', avg_fscore)
75+
return
76+
else:
77+
avg_micro_fscore = round(f1_score(new_labels, new_preds, average='micro', labels=list(range(1, 7))) * 100, 2)
78+
avg_macro_fscore = round(f1_score(new_labels, new_preds, average='macro') * 100, 2)
79+
print('test_accuracy', avg_accuracy)
80+
print('test_micro_f1', avg_micro_fscore)
81+
print('test_macro_f1', avg_macro_fscore)
82+
return
83+
84+
if __name__ == '__main__':
85+
86+
#path = './saved_models/'
87+
88+
parser = argparse.ArgumentParser()
89+
parser.add_argument('--bert_model_dir', type=str, default='')
90+
parser.add_argument('--bert_tokenizer_dir', type=str, default='')
91+
92+
parser.add_argument('--state_dict_file', type=str, default='')
93+
94+
parser.add_argument('--bert_dim', type = int, default=1024)
95+
parser.add_argument('--hidden_dim', type = int, default=300)
96+
parser.add_argument('--mlp_layers', type=int, default=2, help='Number of output mlp layers.')
97+
parser.add_argument('--gnn_layers', type=int, default=2, help='Number of gnn layers.')
98+
parser.add_argument('--emb_dim', type=int, default=1024, help='Feature size.')
99+
100+
parser.add_argument('--attn_type', type=str, default='rgcn', choices=['dotprod','linear','bilinear', 'rgcn'], help='Feature size.')
101+
parser.add_argument('--no_rel_attn', action='store_true', default=False, help='no relation for edges' )
102+
103+
parser.add_argument('--max_sent_len', type=int, default=200,
104+
help='max content length for each text, if set to 0, then the max length has no constrain')
105+
106+
parser.add_argument('--no_cuda', action='store_true', default=False, help='does not use GPU')
107+
108+
parser.add_argument('--dataset_name', default='IEMOCAP', type= str, help='dataset name, IEMOCAP or MELD or DailyDialog')
109+
110+
parser.add_argument('--windowp', type=int, default=1,
111+
help='context window size for constructing edges in graph model for past utterances')
112+
113+
parser.add_argument('--windowf', type=int, default=0,
114+
help='context window size for constructing edges in graph model for future utterances')
115+
116+
parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')
117+
118+
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', help='learning rate')
119+
120+
121+
parser.add_argument('--dropout', type=float, default=0, metavar='dropout', help='dropout rate')
122+
123+
parser.add_argument('--batch_size', type=int, default=8, metavar='BS', help='batch size')
124+
125+
parser.add_argument('--epochs', type=int, default=20, metavar='E', help='number of epochs')
126+
127+
parser.add_argument('--tensorboard', action='store_true', default=False, help='Enables tensorboard log')
128+
129+
parser.add_argument('--nodal_att_type', type=str, default=None, choices=['global','past'], help='type of nodal attention')
130+
131+
args = parser.parse_args()
132+
print(args)
133+
134+
seed_everything()
135+
136+
args.cuda = torch.cuda.is_available() and not args.no_cuda
137+
138+
if args.cuda:
139+
print('Running on GPU')
140+
else:
141+
print('Running on CPU')
142+
143+
if args.tensorboard:
144+
from tensorboardX import SummaryWriter
145+
146+
writer = SummaryWriter()
147+
148+
149+
cuda = args.cuda
150+
n_epochs = args.epochs
151+
batch_size = args.batch_size
152+
train_loader, valid_loader, test_loader, speaker_vocab, label_vocab, person_vec = get_IEMOCAP_loaders(dataset_name=args.dataset_name, batch_size=batch_size, num_workers=0, args = args)
153+
n_classes = len(label_vocab['itos'])
154+
155+
print('building model..')
156+
model = DAGERC_fushion(args, n_classes)
157+
158+
159+
if torch.cuda.device_count() > 1:
160+
print('Multi-GPU...........')
161+
model = nn.DataParallel(model,device_ids = range(torch.cuda.device_count()))
162+
if cuda:
163+
model.cuda()
164+
165+
state_dict = torch.load(args.state_dict_file)
166+
model.load_state_dict(state_dict)
167+
evaluate(model, test_loader, cuda, args, speaker_vocab, label_vocab)

0 commit comments

Comments
 (0)