-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathinference.py
More file actions
126 lines (106 loc) · 4.95 KB
/
inference.py
File metadata and controls
126 lines (106 loc) · 4.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse
import librosa
import torch
from data_util import audioset_classes
from helpers.decode import batched_decode_preds
from helpers.encode import ManyHotEncoder
from models.atstframe.ATSTF_wrapper import ATSTWrapper
from models.beats.BEATs_wrapper import BEATsWrapper
from models.frame_passt.fpasst_wrapper import FPaSSTWrapper
from models.m2d.M2D_wrapper import M2DWrapper
from models.asit.ASIT_wrapper import ASiTWrapper
from models.frame_mn.Frame_MN_wrapper import FrameMNWrapper
from models.prediction_wrapper import PredictionsWrapper
from models.frame_mn.utils import NAME_TO_WIDTH
def sound_event_detection(args):
"""
Running Sound Event Detection on an audio clip.
"""
device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu')
model_name = args.model_name
if model_name == "BEATs":
beats = BEATsWrapper()
model = PredictionsWrapper(beats, checkpoint="BEATs_strong_1")
elif model_name == "ATST-F":
atst = ATSTWrapper()
model = PredictionsWrapper(atst, checkpoint="ATST-F_strong_1")
elif model_name == "fpasst":
fpasst = FPaSSTWrapper()
model = PredictionsWrapper(fpasst, checkpoint="fpasst_strong_1")
elif model_name == "M2D":
m2d = M2DWrapper()
model = PredictionsWrapper(m2d, checkpoint="M2D_strong_1", embed_dim=m2d.m2d.cfg.feature_d)
elif model_name == "ASIT":
asit = ASiTWrapper()
model = PredictionsWrapper(asit, checkpoint="ASIT_strong_1")
elif model_name.startswith("frame_mn"):
width = NAME_TO_WIDTH(model_name)
frame_mn = FrameMNWrapper(width)
embed_dim = frame_mn.state_dict()['frame_mn.features.16.1.bias'].shape[0]
model = PredictionsWrapper(frame_mn, checkpoint=f"{model_name}_strong_1", embed_dim=embed_dim)
else:
raise NotImplementedError(f"Model {model_name} not (yet) implemented")
model.eval()
model.to(device)
sample_rate = 16_000 # all our models are trained on 16 kHz audio
segment_duration = 10 # all models are trained on 10-second pieces
segment_samples = segment_duration * sample_rate
# load audio
(waveform, _) = librosa.core.load(args.audio_file, sr=sample_rate, mono=True)
waveform = torch.from_numpy(waveform[None, :]).to(device)
waveform_len = waveform.shape[1]
audio_len = waveform_len / sample_rate # in seconds
print("Audio length (seconds): ", audio_len)
# encoder manages decoding of model predictions into dataframes
# containing event labels, onsets and offsets
encoder = ManyHotEncoder(audioset_classes.as_strong_train_classes, audio_len=audio_len)
# split audio file into 10-second chunks
num_chunks = waveform_len // segment_samples + (waveform_len % segment_samples != 0)
all_predictions = []
# Process each 10-second chunk
for i in range(num_chunks):
start_idx = i * segment_samples
end_idx = min((i + 1) * segment_samples, waveform_len)
waveform_chunk = waveform[:, start_idx:end_idx]
# Pad the last chunk if it's shorter than 10 seconds
if waveform_chunk.shape[1] < segment_samples:
pad_size = segment_samples - waveform_chunk.shape[1]
waveform_chunk = torch.nn.functional.pad(waveform_chunk, (0, pad_size))
# Run inference for each chunk
with torch.no_grad():
mel = model.mel_forward(waveform_chunk)
y_strong, _ = model(mel)
# Collect predictions
all_predictions.append(y_strong)
# Concatenate all predictions along the time axis
y_strong = torch.cat(all_predictions, dim=2)
# convert into probabilities
y_strong = torch.sigmoid(y_strong)
(
scores_unprocessed,
scores_postprocessed,
decoded_predictions
) = batched_decode_preds(
y_strong.float(),
[args.audio_file],
encoder,
median_filter=args.median_window,
thresholds=args.detection_thresholds,
)
for th in decoded_predictions:
print("***************************************")
print(f"Detected events using threshold {th}:")
print(decoded_predictions[th].sort_values(by="onset"))
print("***************************************")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Example of parser. ')
# model names: [BEATs, ASIT, ATST-F, fpasst, M2D]
parser.add_argument('--model_name', type=str, default='BEATs')
parser.add_argument('--audio_file', type=str,
default='test_files/752547__iscence__milan_metro_coming_in_station.wav')
parser.add_argument('--detection_thresholds', type=float, default=(0.1, 0.2, 0.5))
parser.add_argument('--median_window', type=float, default=9)
parser.add_argument('--cuda', action='store_true', default=False)
args = parser.parse_args()
assert args.model_name in ["BEATs", "ASIT", "ATST-F", "fpasst", "M2D"] or args.model_name.startswith("frame_mn")
sound_event_detection(args)