-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy pathpykaldi-online-latgen-recogniser.py
More file actions
executable file
·121 lines (108 loc) · 4.62 KB
/
pykaldi-online-latgen-recogniser.py
File metadata and controls
executable file
·121 lines (108 loc) · 4.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Example Python script shows PyOnlineFasterRecogniser decoding.
Requieres arguments specifying AM, HCLG graph, etc ...
"""
#!/usr/bin/env python
# Copyright (c) 2013, Ondrej Platek, Ufal MFF UK <oplatek@ufal.mff.cuni.cz>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License. #
from __future__ import unicode_literals
from kaldi.utils import load_wav, wst2dict, lattice_to_nbest
from kaldi.decoders import PyOnlineLatgenRecogniser
import sys
import fst
import time
import os
if os.environ.has_key('DEBUG') and os.environ['DEBUG'] == 'true':
DEBUG = True
else:
DEBUG = False
def write_decoded(f, wav_name, word_ids, wst):
"""Stores the decoded ASR hypothesis to file"""
assert(len(word_ids) > 0)
best_weight, best_path = word_ids[0]
if wst is not None:
decoded = [wst[w] for w in best_path]
else:
decoded = [unicode(w) for w in best_path]
line = u' '.join([wav_name] + decoded + ['\n'])
if DEBUG:
print '%s best path' % (wav_name)
for i, s in enumerate(word_ids):
if i > 0:
break
print 'best path %d: %s' % (i, str(s))
f.write(line.encode('UTF-8'))
# @profile
def decode(d, pcm):
"""d - PyOnlineLatgenRecogniser
pcm - raw bytes which is interpreted as audio with sample width 2 Bytes
Performs simulated on-line decoding.
The audio is queued in small chunks and immediately decoded."""
frame_len = (2 * audio_batch_size) # 16-bit audio so 1 sample = 2 chars
i, decoded_frames, max_end = 0, 0, len(pcm)
start = time.time()
while i * frame_len < len(pcm):
i, begin, end = i + 1, i * frame_len, min(max_end, (i + 1) * frame_len)
audio_chunk = pcm[begin:end]
d.frame_in(audio_chunk)
dec_t = d.decode(max_frames=10)
while dec_t > 0:
decoded_frames += dec_t
dec_t = d.decode(max_frames=10)
if (decoded_frames % 10) == 0 and DEBUG:
startbp = time.time()
bp_result = d.get_best_path()
# print "one best path decode: %s secs" % str(time.time() - start)
# print "TODO", bp_result
print "one best path decode after frame %d: %s secs" % (decoded_frames, str(time.time() - startbp))
print "Hypothesis", bp_result
print "forward decode: %s secs" % str(time.time() - start)
start = time.time()
d.finalize_decoding()
lik, lat = d.get_lattice()
print "backward decode: %s secs" % str(time.time() - start)
d.reset(reset_pipeline=False)
return (lat, lik, decoded_frames)
def decode_wrap(argv, audio_batch_size, wav_paths, file_output, wst_path=None):
"""Prepares the setup for decoding.
After decoding also saves the results."""
wst = wst2dict(wst_path)
d = PyOnlineLatgenRecogniser()
d.setup(argv)
for wav_name, wav_path in wav_paths:
sw, sr = 2, 16000 # 16-bit audio so 1 sample_width = 2 chars
pcm = load_wav(wav_path, def_sample_width=sw, def_sample_rate=sr)
print '%s has %f sec' % (wav_name, (float(len(pcm)) / sw) / sr)
lat, lik, decoded_frames = decode(d, pcm)
lat.isyms = lat.osyms = fst.read_symbols_text(wst_path)
if DEBUG:
with open('pykaldi_%s.svg' % wav_name, 'w') as f:
f.write(lat._repr_svg_())
lat.write('%s_pykaldi.fst' % wav_name)
print "Log-likelihood per frame for utterance %s is %f over %d frames" % (
wav_name, (lik / decoded_frames), decoded_frames)
word_ids = lattice_to_nbest(lat, n=10)
write_decoded(file_output, wav_name, word_ids, wst)
if __name__ == '__main__':
audio_scp, audio_batch_size = sys.argv[1], int(sys.argv[2])
dec_hypo, wst_path = sys.argv[3], sys.argv[4]
argv = sys.argv[5:]
print >> sys.stderr, 'Python args: %s' % str(sys.argv)
# open audio_scp, decode and write to dec_hypo file
with open(audio_scp, 'rb') as r:
with open(dec_hypo, 'wb') as w:
lines = r.readlines()
scp = [tuple(line.strip().split(' ', 1)) for line in lines]
decode_wrap(argv, audio_batch_size, scp, w, wst_path)