You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
adj: (B, N, N) adj[:,i,:] means the direct predecessors of node i
76
+
'''
77
+
adj= []
78
+
forspeakerinspeakers:
79
+
a=torch.zeros(max_dialog_len, max_dialog_len)
80
+
fori,sinenumerate(speaker):
81
+
get_local_pred=False
82
+
get_global_pred=False
83
+
forjinrange(i-1, -1, -1):
84
+
ifspeaker[j] ==sandnotget_local_pred:
85
+
get_local_pred=True
86
+
a[i,j] =1
87
+
elifspeaker[j] !=sandnotget_global_pred:
88
+
get_global_pred=True
89
+
a[i,j] =1
90
+
ifget_global_predandget_local_pred:
91
+
break
92
+
adj.append(a)
93
+
returntorch.stack(adj)
94
+
95
+
defget_adj_v1(self, speakers, max_dialog_len):
96
+
'''
97
+
get adj matrix
98
+
:param speakers: (B, N)
99
+
:param max_dialog_len:
100
+
:return:
101
+
adj: (B, N, N) adj[:,i,:] means the direct predecessors of node i
102
+
'''
103
+
adj= []
104
+
forspeakerinspeakers:
105
+
a=torch.zeros(max_dialog_len, max_dialog_len)
106
+
fori,sinenumerate(speaker):
107
+
cnt=0
108
+
forjinrange(i-1, -1, -1):
109
+
a[i,j] =1
110
+
ifspeaker[j] ==s:
111
+
cnt+=1
112
+
ifcnt==self.args.windowp:
113
+
break
114
+
adj.append(a)
115
+
returntorch.stack(adj)
116
+
117
+
defget_s_mask(self, speakers, max_dialog_len):
118
+
'''
119
+
:param speakers:
120
+
:param max_dialog_len:
121
+
:return:
122
+
s_mask: (B, N, N) s_mask[:,i,:] means the speaker informations for predecessors of node i, where 1 denotes the same speaker, 0 denotes the different speaker
123
+
s_mask_onehot (B, N, N, 2) onehot emcoding of s_mask
adj: (B, N, N) adj[:,i,:] means the direct predecessors of node i
150
+
s_mask: (B, N, N) s_mask[:,i,:] means the speaker informations for predecessors of node i, where 1 denotes the same speaker, 0 denotes the different speaker
151
+
lengths: (B, )
152
+
utterances: not a tensor
153
+
'''
154
+
max_dialog_len=max([d[3] fordindata])
155
+
feaures=pad_sequence([d[0] fordindata], batch_first=True) # (B, N, D)
156
+
labels=pad_sequence([d[1] fordindata], batch_first=True, padding_value=-1) # (B, N )
0 commit comments