Skip to content

Commit 5e40c3f

Browse files
author
Yuan Gong
committed
add HF space
1 parent 4398914 commit 5e40c3f

File tree

2 files changed

+38
-26
lines changed

2 files changed

+38
-26
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
[**[Paper]**](https://arxiv.org/pdf/2307.03183.pdf)
1414

15+
[**[HuggingFace Space]**](https://huggingface.co/spaces/yuangongfdu/whisper-at) (Try Whisper-AT without Coding!)
16+
1517
[**[Colab Demo]**](https://colab.research.google.com/drive/1BbOGWCMjkOlOY5PbEMGk5RomRSqMcy_Q?usp=sharing)
1618

1719
[**[Local Notebook Demo]**(for user without Colab access)](https://github.com/YuanGongND/whisper-at/blob/main/sample/whisper_at_demo.ipynb)

app.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,48 @@
66
paper_link = "https://arxiv.org/pdf/2307.03183.pdf"
77
paper_text = "[Paper]"
88

9-
model = whisper.load_model("large-v1")
10-
print('model loaded')
9+
model_large = whisper.load_model("large-v1")
10+
model_tiny = whisper.load_model("tiny")
11+
model_tiny_en = whisper.load_model("tiny.en")
12+
model_small = whisper.load_model("small")
1113

12-
def predict(audio_path, time_resolution):
13-
def round_time_resolution(time_resolution):
14-
multiple = float(time_resolution) / 0.4
15-
rounded_multiple = round(multiple)
16-
rounded_time_resolution = rounded_multiple * 0.4
17-
return rounded_time_resolution
18-
audio_tagging_time_resolution = round_time_resolution(time_resolution)
19-
result = model.transcribe(audio_path, at_time_res=audio_tagging_time_resolution)
20-
# ASR Results
21-
print(result["text"])
22-
# Audio Tagging Results
23-
audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=5, p_threshold=-1, include_class_list=list(range(527)))
24-
print(audio_tag_result)
14+
mdl_dict = {"tiny": model_tiny, "tiny.en": model_tiny_en, "small": model_small, "large": model_large}
2515

26-
asr_output = ""
27-
for segment in result['segments']:
28-
asr_output = asr_output + str(segment['start']).zfill(1) + 's-' + str(segment['end']).zfill(1) + 's: ' + segment['text'] + '\n'
29-
at_output = ""
30-
for segment in audio_tag_result:
31-
print(segment)
32-
at_output = at_output + str(segment['time']['start']).zfill(1) + 's-' + str(segment['time']['end']).zfill(1) + 's: ' + ' ,'.join([x[0] for x in segment['audio tags']]) + '\n'
33-
print(at_output)
34-
return asr_output, at_output
16+
def round_time_resolution(time_resolution):
17+
multiple = float(time_resolution) / 0.4
18+
rounded_multiple = round(multiple)
19+
rounded_time_resolution = rounded_multiple * 0.4
20+
return rounded_time_resolution
21+
22+
def predict(audio_path_m, audio_path_t, model_size, time_resolution):
23+
# print(audio_path_m, audio_path_t)
24+
# print(type(audio_path_m), type(audio_path_t))
25+
#return audio_path_m, audio_path_t
26+
if ((audio_path_m is None) != (audio_path_t is None)) == False:
27+
return "Please upload and only upload one recording, either upload the audio file or record using microphone.", "Please upload and only upload one recording, either upload the audio file or record using microphone."
28+
else:
29+
audio_path = audio_path_m or audio_path_t
30+
audio_tagging_time_resolution = round_time_resolution(time_resolution)
31+
model = mdl_dict[model_size]
32+
result = model.transcribe(audio_path, at_time_res=audio_tagging_time_resolution)
33+
audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=5, p_threshold=-1, include_class_list=list(range(527)))
34+
asr_output = ""
35+
for segment in result['segments']:
36+
asr_output = asr_output + format(segment['start'], ".1f") + 's-' + format(segment['end'], ".1f") + 's: ' + segment['text'] + '\n'
37+
at_output = ""
38+
for segment in audio_tag_result:
39+
print(segment)
40+
at_output = at_output + format(segment['time']['start'], ".1f") + 's-' + format(segment['time']['end'], ".1f") + 's: ' + ', '.join([x[0] for x in segment['audio tags']]) + '\n'
41+
print(at_output)
42+
return asr_output, at_output
3543

3644
iface = gr.Interface(fn=predict,
37-
inputs=[gr.Audio(type="filepath", source='microphone'), gr.Textbox(value='10', label='Time Resolution in Seconds (Must be must be an integer multiple of 0.4, e.g., 0.4, 2, 10)')],
45+
inputs=[gr.Audio(type="filepath", source='microphone', label='Please either upload an audio file or record using the microphone.', show_label=True), gr.Audio(type="filepath"),
46+
gr.Radio(["tiny", "tiny.en", "small", "large"], value='large', label="Model size", info="The larger the model, the better the performance and the slower the speed."),
47+
gr.Textbox(value='10', label='Time Resolution in Seconds (Must be must be an integer multiple of 0.4, e.g., 0.4, 2, 10)')],
3848
outputs=[gr.Textbox(label="Speech Output"), gr.Textbox(label="Audio Tag Output")],
3949
cache_examples=True,
4050
title="Quick Demo of Whisper-AT",
4151
description="We are glad to introduce Whisper-AT - A new joint audio tagging and speech recognition model. It outputs background sound labels in addition to text." + f"<a href='{paper_link}'>{paper_text}</a> " + f"<a href='{link}'>{text}</a> <br>" +
42-
"Whisper-AT is authored by Yuan Gong, Sameer Khurana, Leonid Karlinsky, and James Glass (MIT & MIT-IBM Watson AI Lab).")
52+
"Whisper-AT is authored by Yuan Gong, Sameer Khurana, Leonid Karlinsky, and James Glass (MIT & MIT-IBM Watson AI Lab). It is an Interspeech 2023 paper.")
4353
iface.launch(debug=True, share=True)

0 commit comments

Comments
 (0)