-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathtransforms.py
More file actions
195 lines (171 loc) · 8.09 KB
/
transforms.py
File metadata and controls
195 lines (171 loc) · 8.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import os
import datasets
import h5py
import numpy as np
import pandas as pd
import torch
import torchaudio
from data_util.audioset_classes import as_strong_train_classes
## Transforms with a similar style to https://github.com/descriptinc/audiotools/blob/master/audiotools/data/transforms.py
logger = datasets.logging.get_logger(__name__)
def target_transform(sample):
del sample["labels"]
del sample["label_ids"]
return sample
def strong_label_transform(sample, strong_label_encoder=None):
assert strong_label_encoder is not None
events = pd.DataFrame(sample['events'][0])
events = events[events['event_label'].isin(set(as_strong_train_classes))]
strong = strong_label_encoder.encode_strong_df(events).T
sample["strong"] = [strong]
sample["event_count"] = [strong.sum(1)]
# encode ground truth events as string - we will use this for evaluation
sample["gt_string"] = ["++".join([";;".join([str(e[0]), str(e[1]), e[2]]) for e in
zip(sample['events'][0]['onset'], sample['events'][0]['offset'],
sample['events'][0]['event_label'])])]
del sample['events']
return sample
class AddPseudoLabelsTransform:
def __init__(self, pseudo_labels_file):
self.pseudo_labels_file = pseudo_labels_file
if self.pseudo_labels_file is not None:
# fetch dict of positions for each example
self.ex2pseudo_idx = {}
f = h5py.File(self.pseudo_labels_file, "r")
for i, fname in enumerate(f["filenames"]):
self.ex2pseudo_idx[fname.decode("UTF-8")] = i
self._opened_pseudo_hdf5 = None
@property
def pseudo_hdf5_file(self):
if self._opened_pseudo_hdf5 is None:
self._opened_pseudo_hdf5 = h5py.File(self.pseudo_labels_file, "r")
return self._opened_pseudo_hdf5
def add_pseudo_label_transform(self, sample):
indices = [self.ex2pseudo_idx[fn] for fn in sample['filename']]
pseudo_strong = [torch.from_numpy(np.stack(self.pseudo_hdf5_file["strong_logits"][index])).float()
for index in indices]
pseudo_strong = [torch.sigmoid(pseudo_strong[i]) for i in range(len(pseudo_strong))]
sample['pseudo_strong'] = pseudo_strong
return sample
class SequentialTransform:
"""Apply a sequence of transforms to a batch."""
def __init__(self, transforms):
"""
Args:
transforms: list of transforms to apply
"""
self.transforms = transforms
def append(self, transform):
self.transforms.append(transform)
def __call__(self, batch):
for t in self.transforms:
batch = t(batch)
return batch
class Mp3DecodeTransform:
def __init__(
self,
mp3_bytes_key="mp3_bytes",
audio_key="audio",
sample_rate=32000,
max_length=10.0,
min_length=None,
random_sample_crop=True,
allow_resample=True,
resampling_method="sinc_interp_kaiser",
keep_mp3_bytes=False,
debug_info_key=None,
):
"""Decode mp3 bytes to audio waveform
Args:
mp3_bytes_key (str, optional): The key to mp3 bytes in the input batch. Defaults to "mp3_bytes".
audio_key (str, optional): The key to save the decoded audio in the output batch. Defaults to "audio".
sample_rate (int, optional): The expected output audio_key. Defaults to 32000.
max_length (int, float, optional): the maximum output audio length in seconds if float, otherwise in samples. Defaults to 10.
min_length (int, optional): the minimum output audio length in seconds. Defaults to max_length.
random_sample_crop (bool, optional): Randomly crop the audio to max_length if its longer otherwise return the first crop. Defaults to True.
allow_resample (bool, optional): Resample the singal if the sampling rate don't match. Defaults to True.
resampling_method (str, optional): reampling method from torchaudio.transforms.Resample . Defaults to "sinc_interp_kaiser".
keep_mp3_bytes (bool, optional): keep the original bytes in the output dict. Defaults to False.
Raises:
Exception: if minimp3py is not installed
"""
self.mp3_bytes_key = mp3_bytes_key
self.audio_key = audio_key
self.sample_rate = sample_rate
self.max_length = max_length
if min_length is None:
min_length = max_length
self.min_length = min_length
self.random_sample_crop = random_sample_crop
self.allow_resample = allow_resample
self.resampling_method = resampling_method
self.keep_mp3_bytes = keep_mp3_bytes
self.debug_info_key = debug_info_key
self.resamplers_cache = {}
try:
import minimp3py # noqa: F401
except:
raise Exception(
"minimp3py is not installed, please install it using: `CFLAGS='-O3 -march=native' pip install https://github.com/f0k/minimp3py/archive/master.zip`"
)
def __call__(self, batch):
import minimp3py
data_list = batch[self.mp3_bytes_key]
if self.debug_info_key is not None:
file_name_list = batch[self.debug_info_key]
else:
file_name_list = range(len(data_list))
audio_list = []
for data, file_name in zip(data_list, file_name_list):
try:
duration, ch, sr = minimp3py.probe(data)
if isinstance(self.max_length, float):
max_length = int(self.max_length * sr)
else:
max_length = int(self.max_length * sr // self.sample_rate)
offset = 0
if self.random_sample_crop and duration > max_length:
max_offset = max(int(duration - max_length), 0) + 1
offset = torch.randint(max_offset, (1,)).item()
waveform, _ = minimp3py.read(data, start=offset, length=max_length)
waveform = waveform[:, 0] # 0 for the first channel only
if waveform.dtype != "float32":
raise RuntimeError("Unexpected wave type")
waveform = torch.from_numpy(waveform)
if len(waveform) == 0:
logger.warning(
f"Empty waveform for {file_name}, duration {duration}, offset {offset}, max_length {max_length}, sr {sr}, ch {ch}"
)
elif sr != self.sample_rate:
assert self.allow_resample, f"Unexpected sample rate {sr} instead of {self.sample_rate} at {file_name}"
if self.resamplers_cache.get(sr) is None:
self.resamplers_cache[sr] = torchaudio.transforms.Resample(
sr,
self.sample_rate,
resampling_method=self.resampling_method,
)
waveform = self.resamplers_cache[sr](waveform)
min_length = self.min_length
if isinstance(self.min_length, float):
min_length = int(self.min_length * self.sample_rate)
if min_length is not None and len(waveform) < min_length:
waveform = torch.concatenate(
(
waveform,
torch.zeros(
min_length - len(waveform),
dtype=waveform.dtype,
device=waveform.device,
),
),
dim=0,
)
audio_list.append(waveform)
except Exception as e:
print(f"Error decoding {file_name}: {e}")
raise e
batch[self.audio_key] = audio_list
batch["sampling_rate"] = [self.sample_rate] * len(audio_list)
if not self.keep_mp3_bytes:
del batch[self.mp3_bytes_key]
return batch