Skip to content
This repository was archived by the owner on Sep 11, 2024. It is now read-only.

Commit 444e479

Browse files
committed
Alignment for best path can be obtained.
1 parent 30b7942 commit 444e479

File tree

4 files changed

+45
-0
lines changed

4 files changed

+45
-0
lines changed

alex_asr/decoder.pyx

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ cdef extern from "src/decoder.h" namespace "alex_asr":
2222
void FrameIn(unsigned char *frame, size_t frame_len) except +
2323
bool GetBestPath(vector[int] *v_out, float *lik) except +
2424
bool GetLattice(alex_asr.fst.libfst.LogVectorFst *fst_out, double *tot_lik) except +
25+
bool GetTimeAlignment(vector[int] *words, vector[int] *times, vector[int] *durations) except +
2526
string GetWord(int word_id) except +
2627
void InputFinished() except +
2728
bool EndpointDetected() except +
@@ -126,6 +127,24 @@ cdef class Decoder:
126127
self.utt_decoded = 0
127128
return (lik, r)
128129

130+
def get_time_alignment(self):
131+
"""get_best_path(self)
132+
Get time alignment of the current 1-best decoding hypothesis.
133+
134+
Returns:
135+
tuple: (list of word id's, list of start times, list of durations)
136+
"""
137+
138+
cdef vector[int] w
139+
cdef vector[int] t
140+
cdef vector[int] d
141+
self.thisptr.GetTimeAlignment(address(w), address(t), address(d))
142+
words = [w[i] for i in xrange(w.size()) if w[i] != 0]
143+
times = [t[i] for i in xrange(t.size()) if w[i] != 0]
144+
durations = [d[i] for i in xrange(d.size()) if w[i] != 0]
145+
146+
return (words, times, durations)
147+
129148
def get_word(self, word_id):
130149
"""get_word(self, word_id)
131150
Get word string form given word id.

src/decoder.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "src/utils.h"
33

44
#include "online2/onlinebin-util.h"
5+
#include "lat/kaldi-lattice.h"
56

67
using namespace kaldi;
78

@@ -218,6 +219,22 @@ namespace alex_asr {
218219
return ok;
219220
}
220221

222+
bool Decoder::GetTimeAlignment(std::vector<int> *words, std::vector<int> *times, std::vector<int> *lengths) {
223+
Lattice lat;
224+
CompactLattice compact_lat;
225+
CompactLattice compact_best_path;
226+
bool ok = true;
227+
228+
ok = ok && decoder_->GetRawLattice(&lat);
229+
BaseFloat lat_beam = config_->decoder_opts.lattice_beam;
230+
DeterminizeLatticePhonePrunedWrapper(*trans_model_, &lat, lat_beam, &compact_lat, config_->decoder_opts.det_opts);
231+
232+
CompactLatticeShortestPath(compact_lat, &compact_best_path);
233+
ok = ok && CompactLatticeToWordAlignment(compact_best_path, words, times, lengths);
234+
235+
return ok;
236+
}
237+
221238
string Decoder::GetWord(int word_id) {
222239
return words_->Find(word_id);
223240
}

src/decoder.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include "feat/online-feature.h"
1212
#include "matrix/matrix-lib.h"
1313
#include "util/common-utils.h"
14+
#include "lat/kaldi-lattice.h"
15+
#include "lat/lattice-functions.h"
1416
#include "nnet2/online-nnet2-decodable.h"
1517
#include "online2/online-gmm-decodable.h"
1618
#include "online2/online-endpoint.h"
@@ -29,6 +31,7 @@ namespace alex_asr {
2931
void FrameIn(VectorBase<BaseFloat> *waveform_in);
3032
bool GetBestPath(std::vector<int> *v_out, BaseFloat *prob);
3133
bool GetLattice(fst::VectorFst<fst::LogArc> * out_fst, double *tot_lik, bool end_of_utt=true);
34+
bool GetTimeAlignment(std::vector<int> *words, std::vector<int> *times, std::vector<int> *lengths);
3235
string GetWord(int word_id);
3336
void InputFinished();
3437
bool EndpointDetected();

test/test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,11 @@ def word_ids_to_str_hyp(decoder, word_ids):
4545
for arc in state.arcs:
4646
print (' %s' % decoder.get_word(arc.ilabel))
4747

48+
print ('Resulting time alignment:')
49+
words, times, durations = decoder.get_time_alignment()
50+
words = word_ids_to_str_hyp(decoder, words).split()
4851

52+
for (word, time, duration) in zip(words, times, durations):
53+
if word != "<eps>":
54+
print (word, time, duration)
4955

0 commit comments

Comments
 (0)