-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain-utterance-level-similarity-model.py
More file actions
156 lines (133 loc) · 5.35 KB
/
train-utterance-level-similarity-model.py
File metadata and controls
156 lines (133 loc) · 5.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import argparse
import os
import random
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import pandas as pd
import IPython
from dataset.dataset import *
from model.customized_similarity_model import UtteranceLevelModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args):
random.seed()
TOP_K = args.top_k
assert (
args.utterance_list is not None
), "Require csv file of utterance-level similarity. Please run predefined utterance-level MIA first."
df = pd.read_csv(args.utterance_list, index_col=False)
# Select the top k utterances from the csv file
utterances = [x for x in df["Unseen_utterance"].values if str(x) != "nan"]
similarity = [x for x in df["Unseen_utterance_sim"].values if str(x) != "nan"]
sorted_similarity, sorted_utterances = zip(*sorted(zip(similarity, utterances)))
sorted_similarity = list(sorted_similarity)
sorted_utterances = list(sorted_utterances)
negative_utterances = sorted_utterances[:TOP_K]
positive_utterances = sorted_utterances[-TOP_K:]
train_dataset = CertainUtteranceDataset(
args.base_path, positive_utterances, negative_utterances, args.model
)
eval_negative_utterances = sorted_utterances[TOP_K : 2 * TOP_K]
eval_positive_utterances = sorted_utterances[-2 * TOP_K : -TOP_K]
eval_dataset = CertainUtteranceDataset(
args.base_path, eval_positive_utterances, eval_negative_utterances, args.model,
)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
num_workers=args.num_workers,
collate_fn=train_dataset.collate_fn,
)
eval_dataloader = DataLoader(
eval_dataset,
batch_size=args.eval_batch_size,
shuffle=True,
num_workers=args.num_workers,
collate_fn=eval_dataset.collate_fn,
)
# Build similarity model
feature, _ = train_dataset[0]
input_dim = feature.shape[-1]
model = UtteranceLevelModel(input_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
criterion = torch.nn.BCEWithLogitsLoss(reduction="none")
min_loss = 1000
early_stopping = 0
epoch = 0
while epoch < args.n_epochs:
# Train the model
model.train()
for batch_id, (features, labels) in enumerate(
tqdm(train_dataloader, dynamic_ncols=True, desc=f"Train | Epoch {epoch+1}")
):
features = [torch.FloatTensor(feature).to(device) for feature in features]
labels = torch.FloatTensor([label for label in labels]).to(device)
pred = model(features)
loss = torch.mean(criterion(pred, labels))
loss.backward()
optimizer.step()
# Eval the model
model.eval()
total_loss = []
for batch_id, (features, labels) in enumerate(
tqdm(eval_dataloader, dynamic_ncols=True, desc="Eval")
):
features = [torch.FloatTensor(feature).to(device) for feature in features]
labels = torch.FloatTensor([label for label in labels]).to(device)
with torch.no_grad():
pred = model(features)
loss = criterion(pred, labels)
total_loss += loss.detach().cpu().tolist()
total_loss = np.mean(total_loss)
# Check whether to save the model
if total_loss < min_loss:
min_loss = total_loss
print(f"Saving model (epoch = {(epoch + 1):4d}, loss = {min_loss:.4f})")
torch.save(
model.state_dict(),
os.path.join(
args.output_path,
f"customized-utterance-similarity-model-{args.model}.pt",
),
)
early_stopping = 0
else:
print(
f"Not saving model (epoch = {(epoch + 1):4d}, loss = {total_loss:.4f})"
)
early_stopping = early_stopping + 1
# Check whether early stopping the training
if early_stopping < 5:
epoch = epoch + 1
else:
epoch = args.n_epochs
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--base_path", help="directory of feature of LibriSpeech dataset"
)
parser.add_argument("--output_path", help="directory to save the model")
parser.add_argument(
"--model", help="which self-supervised model you used to extract features"
)
parser.add_argument("--seed", type=int, default=57, help="random seed")
parser.add_argument(
"--top_k", type=int, default=500, help="how many utterance to pick"
)
parser.add_argument(
"--train_batch_size", type=int, default=32, help="training batch size"
)
parser.add_argument(
"--eval_batch_size", type=int, default=32, help="evaluation batch size"
)
parser.add_argument(
"--utterance_list", type=str, default=None, help="certain utterance list"
)
parser.add_argument("--n_epochs", type=int, default=30, help="training epoch")
parser.add_argument("--num_workers", type=int, default=2, help="number of workers")
parser.add_argument("--lr", type=float, default=1e-5, help="learning rate")
parser.add_argument("--seed", type=int, default=None, help="random seed")
args = parser.parse_args()
main(args)