-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathex_dcase2016task2.py
More file actions
517 lines (442 loc) · 20.4 KB
/
ex_dcase2016task2.py
File metadata and controls
517 lines (442 loc) · 20.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
import argparse
import random
from pathlib import Path
from typing import Dict
import pytorch_lightning as pl
import torch
import torch.nn as nn
import transformers
from einops import rearrange
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader
import wandb
from data_util.dcase2016task2 import (get_training_dataset, get_validation_dataset, get_test_dataset,
label_vocab_nlabels, label_vocab_as_dict)
from helpers.augment import frame_shift, time_mask, mixup, filter_augmentation, mixstyle, RandomResizeCrop
from helpers.score import get_events_for_all_files, combine_target_events, EventBasedScore, SegmentBasedScore
from helpers.utils import worker_init_fn
from models.asit.ASIT_wrapper import ASiTWrapper
from models.atstframe.ATSTF_wrapper import ATSTWrapper
from models.beats.BEATs_wrapper import BEATsWrapper
from models.frame_passt.fpasst_wrapper import FPaSSTWrapper
from models.m2d.M2D_wrapper import M2DWrapper
from models.prediction_wrapper import PredictionsWrapper
class PLModule(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
if config.pretrained == "scratch":
checkpoint = None
elif config.pretrained == "ssl":
checkpoint = "ssl"
elif config.pretrained == "weak":
checkpoint = "weak"
elif config.pretrained == "strong":
checkpoint = "strong_1"
else:
raise ValueError(f"Unknown pretrained checkpoint: {config.pretrained}")
# load transformer model
if config.model_name == "BEATs":
beats = BEATsWrapper()
model = PredictionsWrapper(beats, checkpoint=f"BEATs_{checkpoint}" if checkpoint else None,
seq_model_type=config.seq_model_type,
n_classes_strong=self.config.n_classes)
elif config.model_name == "ATST-F":
atst = ATSTWrapper()
model = PredictionsWrapper(atst, checkpoint=f"ATST-F_{checkpoint}" if checkpoint else None,
seq_model_type=config.seq_model_type,
n_classes_strong=self.config.n_classes)
elif config.model_name == "fpasst":
fpasst = FPaSSTWrapper()
model = PredictionsWrapper(fpasst, checkpoint=f"fpasst_{checkpoint}" if checkpoint else None,
seq_model_type=config.seq_model_type,
n_classes_strong=self.config.n_classes)
elif config.model_name == "M2D":
m2d = M2DWrapper()
model = PredictionsWrapper(m2d, checkpoint=f"M2D_{checkpoint}" if checkpoint else None,
seq_model_type=config.seq_model_type,
n_classes_strong=self.config.n_classes,
embed_dim=m2d.m2d.cfg.feature_d)
elif config.model_name == "ASIT":
asit = ASiTWrapper()
model = PredictionsWrapper(asit, checkpoint=f"ASIT_{checkpoint}" if checkpoint else None,
seq_model_type=config.seq_model_type,
n_classes_strong=self.config.n_classes)
else:
raise NotImplementedError(f"Model {config.model_name} not (yet) implemented")
self.model = model
self.strong_loss = nn.BCEWithLogitsLoss()
self.freq_warp = RandomResizeCrop((1, 1.0), time_scale=(1.0, 1.0))
task_path = Path(self.config.task_path)
label_vocab, nlabels = label_vocab_nlabels(task_path)
self.label_to_idx = label_vocab_as_dict(label_vocab, key="label", value="idx")
self.idx_to_label: Dict[int, str] = {
idx: label for (label, idx) in self.label_to_idx.items()
}
self.event_onset_200ms_fms = EventBasedScore(
label_to_idx=self.label_to_idx,
name="event_onset_200ms_fms",
scores=("f_measure", "precision", "recall"),
params={"evaluate_onset": True, "evaluate_offset": False, "t_collar": 0.2}
)
self.event_onset_50ms_fms = EventBasedScore(
label_to_idx=self.label_to_idx,
name="event_onset_50ms_fms",
scores=("f_measure", "precision", "recall"),
params={"evaluate_onset": True, "evaluate_offset": False, "t_collar": 0.05}
)
self.segment_1s_er = SegmentBasedScore(
label_to_idx=self.label_to_idx,
name="segment_1s_er",
scores=("error_rate",),
params={"time_resolution": 1.0},
maximize=False,
)
self.postprocessing_grid = {
"median_filter_ms": [
250
],
"min_duration": [
125
]
}
self.preds, self.tgts, self.fnames, self.timestamps = [], [], [], []
def forward(self, audio):
mel = self.model.mel_forward(audio)
y_strong, _ = self.model(mel)
return y_strong
def separate_params(self):
pt_params = []
seq_params = []
head_params = []
for name, p in self.named_parameters():
name = name[len("model."):]
if name.startswith('model'):
# the transformer
pt_params.append(p)
elif name.startswith('seq_model'):
# the optional sequence model
seq_params.append(p)
elif name.startswith('strong_head') or name.startswith('weak_head'):
# the prediction head
head_params.append(p)
else:
raise ValueError(f"Unexpected key in model: {name}")
if self.model.has_separate_params():
# split parameters into groups according to their depth in the network
# based on this, we can apply layer-wise learning rate decay
pt_params = self.model.separate_params()
else:
if self.config.lr_decay != 1.0:
raise ValueError(f"Model has no separate_params function. Can't apply layer-wise lr decay, but "
f"learning rate decay is set to {self.config.lr_decay}.")
return pt_params, seq_params, head_params
def get_optimizer(
self,
lr,
lr_decay=1.0,
transformer_lr=None,
transformer_frozen=False,
adamw=False,
weight_decay=0.01,
betas=(0.9, 0.999)
):
pt_params, seq_params, head_params = self.separate_params()
param_groups = [
{'params': head_params, 'lr': lr}, # model head (besides base model and seq model)
]
if transformer_frozen:
for p in pt_params + seq_params:
if isinstance(p, list):
for p_i in p:
p_i.detach_()
else:
p.detach_()
else:
if transformer_lr is None:
transformer_lr = lr
if isinstance(pt_params, list) and isinstance(pt_params[0], list):
# apply lr decay
scale_lrs = [transformer_lr * (lr_decay ** i) for i in range(1, len(pt_params) + 1)]
param_groups = param_groups + [{"params": pt_params[i], "lr": scale_lrs[i]} for i in
range(len(pt_params))]
else:
param_groups.append(
{'params': pt_params, 'lr': transformer_lr}, # pretrained model
)
param_groups.append(
{'params': seq_params, 'lr': lr}, # pretrained model
)
# do not apply weight decay to biases and batch norms
param_groups_split = []
for param_group in param_groups:
params_1D, params_2D = [], []
lr = param_group['lr']
for param in param_group['params']:
if param.ndimension() >= 2:
params_2D.append(param)
elif param.ndimension() <= 1:
params_1D.append(param)
param_groups_split += [{'params': params_2D, 'lr': lr, 'weight_decay': weight_decay},
{'params': params_1D, 'lr': lr}]
if weight_decay > 0:
assert adamw
if adamw:
print(f"\nUsing adamw weight_decay={weight_decay}!\n")
return torch.optim.AdamW(param_groups_split, lr=lr, weight_decay=weight_decay, betas=betas)
return torch.optim.Adam(param_groups_split, lr=lr, betas=betas)
def get_lr_scheduler(
self,
optimizer,
num_training_steps,
schedule_mode="cos",
gamma: float = 0.999996,
num_warmup_steps=4000,
lr_end=1e-7,
):
if schedule_mode in {"exp"}:
return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
if schedule_mode in {"cosine", "cos"}:
return transformers.get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
)
if schedule_mode in {"linear"}:
print("Linear schedule!")
return transformers.get_polynomial_decay_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
power=1.0,
lr_end=lr_end,
)
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")
def configure_optimizers(self):
"""
This is the way pytorch lightening requires optimizers and learning rate schedulers to be defined.
The specified items are used automatically in the optimization loop (no need to call optimizer.step() yourself).
:return: dict containing optimizer and learning rate scheduler
"""
optimizer = self.get_optimizer(self.config.max_lr,
lr_decay=self.config.lr_decay,
transformer_lr=self.config.transformer_lr,
transformer_frozen=self.config.transformer_frozen,
adamw=False if self.config.no_adamw else True,
weight_decay=self.config.weight_decay)
num_training_steps = self.trainer.estimated_stepping_batches
scheduler = self.get_lr_scheduler(optimizer, num_training_steps,
schedule_mode=self.config.schedule_mode,
lr_end=self.config.lr_end)
lr_scheduler_config = {
"scheduler": scheduler,
"interval": "step",
"frequency": 1
}
return [optimizer], [lr_scheduler_config]
def training_step(self, train_batch, batch_idx):
"""
:param train_batch: contains one batch from train dataloader
:param batch_idx
:return: a dict containing at least loss that is used to update model parameters, can also contain
other items that can be processed in 'training_epoch_end' to log other metrics than loss
"""
audios, labels, fnames, timestamps = train_batch
if self.config.transformer_frozen:
self.model.model.eval()
self.model.seq_model.eval()
mel = self.model.mel_forward(audios)
# time rolling
if self.config.frame_shift_range > 0:
mel, labels = frame_shift(
mel,
labels,
shift_range=self.config.frame_shift_range
)
# mixup
if self.config.mixup_p > random.random():
mel, labels = mixup(
mel,
targets=labels
)
# mixstyle
if self.config.mixstyle_p > random.random():
mel = mixstyle(
mel
)
# time masking
if self.config.max_time_mask_size > 0:
mel, labels, pseudo_labels = time_mask(
mel,
labels,
max_mask_ratio=self.config.max_time_mask_size
)
# frequency masking
if self.config.filter_augment_p > random.random():
mel, _ = filter_augmentation(
mel
)
# frequency warping
if self.config.freq_warp_p > random.random():
mel = mel.squeeze(1)
mel = self.freq_warp(mel)
mel = mel.unsqueeze(1)
# forward through network; use strong head
y_hat_strong, _ = self.model(mel)
loss = self.strong_loss(y_hat_strong, labels)
# logging
self.log('epoch', self.current_epoch)
for i, param_group in enumerate(self.trainer.optimizers[0].param_groups):
self.log(f'trainer/lr_optimizer_{i}', param_group['lr'])
self.log("train/loss", loss.detach().cpu(), prog_bar=True)
return loss
def _score_step(self, batch):
audios, labels, fnames, timestamps = batch
strong_preds = self.forward(audios)
self.preds.append(strong_preds)
self.tgts.append(labels)
self.fnames.append(fnames)
self.timestamps.append(timestamps)
def _score_epoch_end(self, name="val"):
preds = torch.cat(self.preds)
tgts = torch.cat(self.tgts)
fnames = [item for sublist in self.fnames for item in sublist]
timestamps = torch.cat(self.timestamps)
val_loss = self.strong_loss(preds, tgts)
self.log(f"{name}/loss", val_loss, prog_bar=True)
# the following function expects one prediction per timestamp (sequence dimension must be flattened)
seq_len = preds.size(-1)
preds = rearrange(preds, 'bs c t -> (bs t) c').float()
timestamps = rearrange(timestamps, 'bs t -> (bs t)').float()
fnames = [fname for fname in fnames for _ in range(seq_len)]
predicted_events_by_postprocessing = get_events_for_all_files(
preds,
fnames,
timestamps,
self.idx_to_label,
self.postprocessing_grid
)
# we only have one postprocessing configurations (aligned with HEAR challenge)
key = list(predicted_events_by_postprocessing.keys())[0]
predicted_events = predicted_events_by_postprocessing[key]
# load ground truth for test fold
task_path = Path(self.config.task_path)
test_target_events = combine_target_events(["valid" if name == "val" else "test"], task_path)
onset_fms = self.event_onset_200ms_fms(predicted_events, test_target_events)
onset_fms_50 = self.event_onset_50ms_fms(predicted_events, test_target_events)
segment_1s_er = self.segment_1s_er(predicted_events, test_target_events)
self.log(f"{name}/onset_fms", onset_fms[0][1])
self.log(f"{name}/onset_fms_50", onset_fms_50[0][1])
self.log(f"{name}/segment_1s_er", segment_1s_er[0][1])
# free buffers
self.preds, self.tgts, self.fnames, self.timestamps = [], [], [], []
def validation_step(self, batch, batch_idx):
self._score_step(batch)
def on_validation_epoch_end(self):
self._score_epoch_end(name="val")
def test_step(self, batch, batch_idx):
self._score_step(batch)
def on_test_epoch_end(self):
self._score_epoch_end(name="test")
def train(config):
# Example for fine-tuning pre-trained transformers on a downstream task.
# logging is done using wandb
wandb_logger = WandbLogger(
project="PTSED",
notes="Downstream Training on office sound event detection.",
tags=["DCASE 2016 Task 2", "Sound Event Detection"],
config=config,
name=config.experiment_name
)
train_set = get_training_dataset(config.task_path, wavmix_p=config.wavmix_p)
val_ds = get_validation_dataset(config.task_path)
test_ds = get_test_dataset(config.task_path)
# train dataloader
train_dl = DataLoader(dataset=train_set,
worker_init_fn=worker_init_fn,
num_workers=config.num_workers,
batch_size=config.batch_size,
shuffle=True)
# validation dataloader
valid_dl = DataLoader(dataset=val_ds,
worker_init_fn=worker_init_fn,
num_workers=config.num_workers,
batch_size=config.batch_size,
shuffle=False,
drop_last=False)
# test dataloader
test_dl = DataLoader(dataset=test_ds,
worker_init_fn=worker_init_fn,
num_workers=config.num_workers,
batch_size=config.batch_size,
shuffle=False,
drop_last=False)
# create pytorch lightening module
pl_module = PLModule(config)
# create the pytorch lightening trainer by specifying the number of epochs to train, the logger,
# on which kind of device(s) to train and possible callbacks
trainer = pl.Trainer(max_epochs=config.n_epochs,
logger=wandb_logger,
accelerator='auto',
devices=config.num_devices,
precision=config.precision,
num_sanity_val_steps=0,
check_val_every_n_epoch=config.check_val_every_n_epoch
)
# start training and validation for the specified number of epochs
trainer.fit(
pl_module,
train_dataloaders=train_dl,
val_dataloaders=valid_dl,
)
test_results = trainer.test(pl_module, dataloaders=test_dl)
print(test_results)
wandb.finish()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Configuration Parser. ')
# general
parser.add_argument('--task_path', type=str, required=True)
parser.add_argument('--experiment_name', type=str, default="DCASE2016Task2")
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--num_workers', type=int, default=16)
parser.add_argument('--num_devices', type=int, default=1)
parser.add_argument('--precision', type=int, default=16)
parser.add_argument('--check_val_every_n_epoch', type=int, default=10)
# model
parser.add_argument('--model_name', type=str,
choices=["ATST-F", "BEATs", "fpasst", "M2D", "ASIT"],
default="ATST-F") # used also for training
# "scratch" = no pretraining
# "ssl" = SSL pre-trained
# "weak" = AudioSet Weak pre-trained
# "strong" = AudioSet Strong pre-trained
parser.add_argument('--pretrained', type=str, choices=["scratch", "ssl", "weak", "strong"],
default="strong")
parser.add_argument('--seq_model_type', type=str, choices=["rnn"],
default=None)
parser.add_argument('--n_classes', type=int, default=11)
# training
parser.add_argument('--n_epochs', type=int, default=300)
# augmentation
parser.add_argument('--wavmix_p', type=float, default=0.5)
parser.add_argument('--freq_warp_p', type=float, default=0.0)
parser.add_argument('--filter_augment_p', type=float, default=0.0)
parser.add_argument('--frame_shift_range', type=float, default=0.0) # in seconds
parser.add_argument('--mixup_p', type=float, default=0.5)
parser.add_argument('--mixstyle_p', type=float, default=0.0)
parser.add_argument('--max_time_mask_size', type=float, default=0.0)
# optimizer
parser.add_argument('--no_adamw', action='store_true', default=False)
parser.add_argument('--weight_decay', type=float, default=0.001)
parser.add_argument('--transformer_frozen', action='store_true', dest='transformer_frozen',
default=False,
help='Disable training for the transformer.')
# lr schedule
parser.add_argument('--schedule_mode', type=str, default="cos")
parser.add_argument('--max_lr', type=float, default=1.06e-4)
parser.add_argument('--transformer_lr', type=float, default=None)
parser.add_argument('--lr_decay', type=float, default=1.0)
parser.add_argument('--lr_end', type=float, default=1e-7)
parser.add_argument('--warmup_steps', type=int, default=100)
args = parser.parse_args()
train(args)