22#include " src/utils.h"
33
44#include " online2/onlinebin-util.h"
5+ #include " fst/fstlib.h"
6+ #include " fstext/table-matcher.h"
7+ #include " fstext/kaldi-fst-io.h"
58#include " lat/kaldi-lattice.h"
69#include " lat/sausages.h"
710
@@ -11,6 +14,8 @@ namespace alex_asr {
1114 Decoder::Decoder (const string model_path) :
1215 feature_pipeline_ (NULL ),
1316 hclg_ (NULL ),
17+ lm_small_ (NULL ),
18+ lm_big_ (NULL ),
1419 decoder_ (NULL ),
1520 trans_model_ (NULL ),
1621 am_nnet2_ (NULL ),
@@ -19,7 +24,6 @@ namespace alex_asr {
1924 words_ (NULL ),
2025 config_ (NULL ),
2126 decodable_ (NULL )
22-
2327 {
2428 // Change dir to model_path. Change back when leaving the scope.
2529 local_cwd cwd_to_model_path (model_path);
@@ -35,6 +39,8 @@ namespace alex_asr {
3539 Decoder::~Decoder () {
3640 delete feature_pipeline_;
3741 delete hclg_;
42+ delete lm_small_;
43+ delete lm_big_;
3844 delete decoder_;
3945 delete trans_model_;
4046 delete am_nnet2_;
@@ -109,6 +115,37 @@ namespace alex_asr {
109115 WordBoundaryInfoNewOpts word_boundary_info_opts;
110116 word_boundary_info_ = new WordBoundaryInfo (word_boundary_info_opts, config_->word_boundary_rxfilename );
111117 }
118+
119+ if (config_->rescore == true ) {
120+ LoadLM (config_->lm_small_rxfilename , &lm_small_);
121+ LoadLM (config_->lm_big_rxfilename , &lm_big_);
122+ }
123+ }
124+
125+ void Decoder::LoadLM (
126+ const string path,
127+ fst::MapFst<fst::StdArc, LatticeArc, fst::StdToLatticeMapper<BaseFloat> > **lm_fst
128+ ) {
129+ int num_states_cache = 50000 ;
130+
131+ if (FileExists (path)) {
132+ fst::VectorFst<fst::StdArc> *std_lm_fst = fst::ReadFstKaldi (path);
133+ fst::Project (std_lm_fst, fst::PROJECT_OUTPUT);
134+ if (std_lm_fst->Properties (fst::kILabelSorted , true ) == 0 ) {
135+ fst::ILabelCompare<fst::StdArc> ilabel_comp;
136+ fst::ArcSort (std_lm_fst, ilabel_comp);
137+ }
138+
139+ fst::CacheOptions cache_opts (true , num_states_cache);
140+ fst::MapFstOptions mapfst_opts (cache_opts);
141+ fst::StdToLatticeMapper<BaseFloat> mapper;
142+ *lm_fst = new fst::MapFst<fst::StdArc, LatticeArc, fst::StdToLatticeMapper<BaseFloat> >(*std_lm_fst, mapper, mapfst_opts);
143+ delete std_lm_fst;
144+
145+ KALDI_VLOG (2 ) << " LM loaded: " << path;
146+ } else {
147+ KALDI_ERR << " LM" << path << " doesn't exist." ;
148+ }
112149 }
113150
114151 void Decoder::Reset () {
@@ -213,9 +250,7 @@ namespace alex_asr {
213250 return ok;
214251 }
215252
216- bool Decoder::GetLattice (fst::VectorFst<fst::LogArc> *fst_out,
217- double *tot_lik, bool end_of_utterance) {
218- CompactLattice lat;
253+ bool Decoder::GetPrunedLattice (CompactLattice *lat) {
219254 Lattice raw_lat;
220255
221256 if (decoder_->NumFramesDecoded () == 0 )
@@ -224,27 +259,76 @@ namespace alex_asr {
224259 if (!config_->decoder_opts .determinize_lattice )
225260 KALDI_ERR << " --determinize-lattice=false option is not supported at the moment" ;
226261
227- bool ok = decoder_->GetRawLattice (&raw_lat, end_of_utterance);
262+ bool ok = decoder_->GetRawLattice (&raw_lat);
263+
228264
229265 BaseFloat lat_beam = config_->decoder_opts .lattice_beam ;
230- DeterminizeLatticePhonePrunedWrapper (
231- *trans_model_, &raw_lat, lat_beam, &lat, config_->decoder_opts .det_opts );
266+ if (!config_->rescore ) {
267+ DeterminizeLatticePhonePrunedWrapper (*trans_model_, &raw_lat, lat_beam, lat, config_->decoder_opts .det_opts );
268+ } else {
269+ CompactLattice pruned_lat;
270+
271+ DeterminizeLatticePhonePrunedWrapper (*trans_model_, &raw_lat, lat_beam, &pruned_lat, config_->decoder_opts .det_opts );
272+ ok = ok && RescoreLattice (pruned_lat, lat);
273+ }
274+
275+ return ok;
276+ }
277+
278+ bool Decoder::RescoreLattice (CompactLattice lat, CompactLattice *rescored_lattice) {
279+ CompactLattice intermidiate_lattice;
280+ bool ok = true ;
281+
282+ ok = ok && RescoreLatticeWithLM (lat, -1.0 , lm_small_, &intermidiate_lattice);
283+ ok = ok && RescoreLatticeWithLM (intermidiate_lattice, 1.0 , lm_big_, rescored_lattice);
284+
285+ return ok;
286+ }
287+
288+ bool Decoder::RescoreLatticeWithLM (
289+ CompactLattice lat,
290+ float lm_scale,
291+ fst::MapFst<fst::StdArc, LatticeArc, fst::StdToLatticeMapper<BaseFloat> > *lm_fst,
292+ CompactLattice *rescored_lattice) {
293+
294+ // Taken from https://github.com/kaldi-asr/kaldi/blob/9b9b561e2b3d2bdf64c84f8626175953f0885264/src/latbin/lattice-lmrescore.cc
295+
296+ Lattice lattice;
297+ ConvertLattice (lat, &lattice);
298+
299+ fst::ScaleLattice (fst::GraphLatticeScale (1.0 / lm_scale), &lattice);
300+ ArcSort (&lattice, fst::OLabelCompare<LatticeArc>());
301+
302+ Lattice composed_lat;
303+ fst::TableComposeOptions compose_opts (fst::TableMatcherOptions (), true , fst::SEQUENCE_FILTER, fst::MATCH_INPUT);
304+ fst::TableComposeCache<fst::Fst<LatticeArc> > lm_compose_cache (compose_opts);
305+ TableCompose (lattice, *lm_fst, &composed_lat, &lm_compose_cache);
306+ Invert (&composed_lat);
307+
308+ DeterminizeLattice (composed_lat, rescored_lattice);
309+ fst::ScaleLattice (fst::GraphLatticeScale (lm_scale), rescored_lattice);
232310
311+ return rescored_lattice->Start () != fst::kNoStateId ;
312+ }
313+
314+ bool Decoder::GetLattice (fst::VectorFst<fst::LogArc> *fst_out,
315+ double *tot_lik, bool end_of_utterance) {
316+ CompactLattice lat;
317+ bool ok = true ;
318+
319+ ok = this ->GetPrunedLattice (&lat);
233320 *tot_lik = CompactLatticeToWordsPost (lat, fst_out);
234321
235322 return ok;
236323 }
237324
238325 bool Decoder::GetTimeAlignment (std::vector<int > *words, std::vector<int > *times, std::vector<int > *lengths) {
239- Lattice lat;
240326 CompactLattice compact_lat;
241327 CompactLattice best_path;
242328 CompactLattice aligned_best_path;
243329 bool ok = true ;
244330
245- ok = ok && decoder_->GetRawLattice (&lat);
246- BaseFloat lat_beam = config_->decoder_opts .lattice_beam ;
247- DeterminizeLatticePhonePrunedWrapper (*trans_model_, &lat, lat_beam, &compact_lat, config_->decoder_opts .det_opts );
331+ ok = this ->GetPrunedLattice (&compact_lat);
248332 CompactLatticeShortestPath (compact_lat, &best_path);
249333
250334 if (config_->word_boundary_rxfilename == " " ) {
@@ -264,9 +348,7 @@ namespace alex_asr {
264348 CompactLattice aligned_best_path;
265349 bool ok = true ;
266350
267- ok = ok && decoder_->GetRawLattice (&lat);
268- BaseFloat lat_beam = config_->decoder_opts .lattice_beam ;
269- DeterminizeLatticePhonePrunedWrapper (*trans_model_, &lat, lat_beam, &compact_lat, config_->decoder_opts .det_opts );
351+ ok = this ->GetPrunedLattice (&compact_lat);
270352 CompactLatticeShortestPath (compact_lat, &best_path);
271353
272354 if (config_->word_boundary_rxfilename != " " ) {
0 commit comments