Skip to content
This repository was archived by the owner on Sep 11, 2024. It is now read-only.

Commit 40b7bad

Browse files
committed
Final lattice can be rescored with a bigger LM.
1 parent adaf35a commit 40b7bad

File tree

8 files changed

+125
-17
lines changed

8 files changed

+125
-17
lines changed

Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ ADDLIBS = $(FSTROOT)/src/lib/.libs/libfst.a \
4747
$(KALDI_DIR)/src/tree/kaldi-tree.a \
4848
$(KALDI_DIR)/src/matrix/kaldi-matrix.a \
4949
$(KALDI_DIR)/src/util/kaldi-util.a \
50-
$(KALDI_DIR)/src/base/kaldi-base.a
50+
$(KALDI_DIR)/src/base/kaldi-base.a \
51+
$(KALDI_DIR)/src/fstext/kaldi-fstext.a
5152

5253
LDFLAGS = $(ADDLIBS) -llapack_atlas -lcblas -latlas -lf77blas -lm -lpthread -ldl
5354

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ Example of `alex_asr.conf` that should reside in `model_dir`:
7070
# with configuration of the pitch extractor.
7171
--bits_per_sample=16 # 8/16; How many bits per sample frame?
7272
73+
--rescore=True # Rescore lattice with a bigger LM?
74+
--lm_small=G_small.fst # Original G.fst
75+
--lm_big=G_big.fst # G.fst corresponding to the bigger LM
76+
7377
# These parameters specify filenames of configuration of the particular parts of the decoder. Detailed below.
7478
--cfg_decoder=decoder.cfg
7579
--cfg_decodable=decodable.cfg

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def run(self):
5858
packages=find_packages(exclude=["alex_asr/decoder.cpp"]),
5959
include_package_data=True,
6060
cmdclass={'build_ext': build_ext_with_make},
61-
version='1.0.4',
61+
version='1.0.5',
6262
install_requires=install_requires,
6363
setup_requires=['cython>=0.19.1', 'nose>=1.0'],
6464
ext_modules=[

src/decoder.cc

Lines changed: 96 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
#include "src/utils.h"
33

44
#include "online2/onlinebin-util.h"
5+
#include "fst/fstlib.h"
6+
#include "fstext/table-matcher.h"
7+
#include "fstext/kaldi-fst-io.h"
58
#include "lat/kaldi-lattice.h"
69
#include "lat/sausages.h"
710

@@ -11,6 +14,8 @@ namespace alex_asr {
1114
Decoder::Decoder(const string model_path) :
1215
feature_pipeline_(NULL),
1316
hclg_(NULL),
17+
lm_small_(NULL),
18+
lm_big_(NULL),
1419
decoder_(NULL),
1520
trans_model_(NULL),
1621
am_nnet2_(NULL),
@@ -19,7 +24,6 @@ namespace alex_asr {
1924
words_(NULL),
2025
config_(NULL),
2126
decodable_(NULL)
22-
2327
{
2428
// Change dir to model_path. Change back when leaving the scope.
2529
local_cwd cwd_to_model_path(model_path);
@@ -35,6 +39,8 @@ namespace alex_asr {
3539
Decoder::~Decoder() {
3640
delete feature_pipeline_;
3741
delete hclg_;
42+
delete lm_small_;
43+
delete lm_big_;
3844
delete decoder_;
3945
delete trans_model_;
4046
delete am_nnet2_;
@@ -109,6 +115,37 @@ namespace alex_asr {
109115
WordBoundaryInfoNewOpts word_boundary_info_opts;
110116
word_boundary_info_ = new WordBoundaryInfo(word_boundary_info_opts, config_->word_boundary_rxfilename);
111117
}
118+
119+
if(config_->rescore == true) {
120+
LoadLM(config_->lm_small_rxfilename, &lm_small_);
121+
LoadLM(config_->lm_big_rxfilename, &lm_big_);
122+
}
123+
}
124+
125+
void Decoder::LoadLM(
126+
const string path,
127+
fst::MapFst<fst::StdArc, LatticeArc, fst::StdToLatticeMapper<BaseFloat> > **lm_fst
128+
) {
129+
int num_states_cache = 50000;
130+
131+
if (FileExists(path)) {
132+
fst::VectorFst<fst::StdArc> *std_lm_fst = fst::ReadFstKaldi(path);
133+
fst::Project(std_lm_fst, fst::PROJECT_OUTPUT);
134+
if (std_lm_fst->Properties(fst::kILabelSorted, true) == 0) {
135+
fst::ILabelCompare<fst::StdArc> ilabel_comp;
136+
fst::ArcSort(std_lm_fst, ilabel_comp);
137+
}
138+
139+
fst::CacheOptions cache_opts(true, num_states_cache);
140+
fst::MapFstOptions mapfst_opts(cache_opts);
141+
fst::StdToLatticeMapper<BaseFloat> mapper;
142+
*lm_fst = new fst::MapFst<fst::StdArc, LatticeArc, fst::StdToLatticeMapper<BaseFloat> >(*std_lm_fst, mapper, mapfst_opts);
143+
delete std_lm_fst;
144+
145+
KALDI_VLOG(2) << "LM loaded: " << path;
146+
} else {
147+
KALDI_ERR << "LM" << path << "doesn't exist.";
148+
}
112149
}
113150

114151
void Decoder::Reset() {
@@ -213,9 +250,7 @@ namespace alex_asr {
213250
return ok;
214251
}
215252

216-
bool Decoder::GetLattice(fst::VectorFst<fst::LogArc> *fst_out,
217-
double *tot_lik, bool end_of_utterance) {
218-
CompactLattice lat;
253+
bool Decoder::GetPrunedLattice(CompactLattice *lat) {
219254
Lattice raw_lat;
220255

221256
if (decoder_->NumFramesDecoded() == 0)
@@ -224,27 +259,76 @@ namespace alex_asr {
224259
if (!config_->decoder_opts.determinize_lattice)
225260
KALDI_ERR << "--determinize-lattice=false option is not supported at the moment";
226261

227-
bool ok = decoder_->GetRawLattice(&raw_lat, end_of_utterance);
262+
bool ok = decoder_->GetRawLattice(&raw_lat);
263+
228264

229265
BaseFloat lat_beam = config_->decoder_opts.lattice_beam;
230-
DeterminizeLatticePhonePrunedWrapper(
231-
*trans_model_, &raw_lat, lat_beam, &lat, config_->decoder_opts.det_opts);
266+
if(!config_->rescore) {
267+
DeterminizeLatticePhonePrunedWrapper(*trans_model_, &raw_lat, lat_beam, lat, config_->decoder_opts.det_opts);
268+
} else {
269+
CompactLattice pruned_lat;
270+
271+
DeterminizeLatticePhonePrunedWrapper(*trans_model_, &raw_lat, lat_beam, &pruned_lat, config_->decoder_opts.det_opts);
272+
ok = ok && RescoreLattice(pruned_lat, lat);
273+
}
274+
275+
return ok;
276+
}
277+
278+
bool Decoder::RescoreLattice(CompactLattice lat, CompactLattice *rescored_lattice) {
279+
CompactLattice intermidiate_lattice;
280+
bool ok = true;
281+
282+
ok = ok && RescoreLatticeWithLM(lat, -1.0, lm_small_, &intermidiate_lattice);
283+
ok = ok && RescoreLatticeWithLM(intermidiate_lattice, 1.0, lm_big_, rescored_lattice);
284+
285+
return ok;
286+
}
287+
288+
bool Decoder::RescoreLatticeWithLM(
289+
CompactLattice lat,
290+
float lm_scale,
291+
fst::MapFst<fst::StdArc, LatticeArc, fst::StdToLatticeMapper<BaseFloat> > *lm_fst,
292+
CompactLattice *rescored_lattice) {
293+
294+
// Taken from https://github.com/kaldi-asr/kaldi/blob/9b9b561e2b3d2bdf64c84f8626175953f0885264/src/latbin/lattice-lmrescore.cc
295+
296+
Lattice lattice;
297+
ConvertLattice(lat, &lattice);
298+
299+
fst::ScaleLattice(fst::GraphLatticeScale(1.0 / lm_scale), &lattice);
300+
ArcSort(&lattice, fst::OLabelCompare<LatticeArc>());
301+
302+
Lattice composed_lat;
303+
fst::TableComposeOptions compose_opts(fst::TableMatcherOptions(), true, fst::SEQUENCE_FILTER, fst::MATCH_INPUT);
304+
fst::TableComposeCache<fst::Fst<LatticeArc> > lm_compose_cache(compose_opts);
305+
TableCompose(lattice, *lm_fst, &composed_lat, &lm_compose_cache);
306+
Invert(&composed_lat);
307+
308+
DeterminizeLattice(composed_lat, rescored_lattice);
309+
fst::ScaleLattice(fst::GraphLatticeScale(lm_scale), rescored_lattice);
232310

311+
return rescored_lattice->Start() != fst::kNoStateId;
312+
}
313+
314+
bool Decoder::GetLattice(fst::VectorFst<fst::LogArc> *fst_out,
315+
double *tot_lik, bool end_of_utterance) {
316+
CompactLattice lat;
317+
bool ok = true;
318+
319+
ok = this->GetPrunedLattice(&lat);
233320
*tot_lik = CompactLatticeToWordsPost(lat, fst_out);
234321

235322
return ok;
236323
}
237324

238325
bool Decoder::GetTimeAlignment(std::vector<int> *words, std::vector<int> *times, std::vector<int> *lengths) {
239-
Lattice lat;
240326
CompactLattice compact_lat;
241327
CompactLattice best_path;
242328
CompactLattice aligned_best_path;
243329
bool ok = true;
244330

245-
ok = ok && decoder_->GetRawLattice(&lat);
246-
BaseFloat lat_beam = config_->decoder_opts.lattice_beam;
247-
DeterminizeLatticePhonePrunedWrapper(*trans_model_, &lat, lat_beam, &compact_lat, config_->decoder_opts.det_opts);
331+
ok = this->GetPrunedLattice(&compact_lat);
248332
CompactLatticeShortestPath(compact_lat, &best_path);
249333

250334
if(config_->word_boundary_rxfilename == "") {
@@ -264,9 +348,7 @@ namespace alex_asr {
264348
CompactLattice aligned_best_path;
265349
bool ok = true;
266350

267-
ok = ok && decoder_->GetRawLattice(&lat);
268-
BaseFloat lat_beam = config_->decoder_opts.lattice_beam;
269-
DeterminizeLatticePhonePrunedWrapper(*trans_model_, &lat, lat_beam, &compact_lat, config_->decoder_opts.det_opts);
351+
ok = this->GetPrunedLattice(&compact_lat);
270352
CompactLatticeShortestPath(compact_lat, &best_path);
271353

272354
if(config_->word_boundary_rxfilename != "") {

src/decoder.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace alex_asr {
3232
void FrameIn(unsigned char *buffer, int32 buffer_length);
3333
void FrameIn(VectorBase<BaseFloat> *waveform_in);
3434
bool GetBestPath(std::vector<int> *v_out, BaseFloat *prob);
35-
bool GetLattice(fst::VectorFst<fst::LogArc> * out_fst, double *tot_lik, bool end_of_utt=true);
35+
bool GetLattice(fst::VectorFst<fst::LogArc> *out_fst, double *tot_lik, bool end_of_utt=true);
3636
bool GetTimeAlignment(std::vector<int> *words, std::vector<int> *times, std::vector<int> *lengths);
3737
bool GetTimeAlignmentWithWordConfidence(std::vector<int> *words, std::vector<int> *times, std::vector<int> *lengths, std::vector<float> *confs);
3838
string GetWord(int word_id);
@@ -51,6 +51,8 @@ namespace alex_asr {
5151
FeaturePipeline *feature_pipeline_;
5252

5353
fst::StdFst *hclg_;
54+
fst::MapFst<fst::StdArc, LatticeArc, fst::StdToLatticeMapper<BaseFloat> > *lm_small_;
55+
fst::MapFst<fst::StdArc, LatticeArc, fst::StdToLatticeMapper<BaseFloat> > *lm_big_;
5456
LatticeFasterOnlineDecoder *decoder_;
5557
TransitionModel *trans_model_;
5658
nnet2::AmNnet *am_nnet2_;
@@ -63,8 +65,19 @@ namespace alex_asr {
6365

6466
void InitTransformMatrices();
6567
void LoadDecoder();
68+
void LoadLM(
69+
const string path,
70+
fst::MapFst<fst::StdArc, LatticeArc, fst::StdToLatticeMapper<BaseFloat> > **lm_fst
71+
);
6672
void ParseConfig();
6773
void Deallocate();
74+
bool GetPrunedLattice(CompactLattice *lat);
75+
bool RescoreLattice(CompactLattice lat, CompactLattice *rescored_lattice);
76+
bool RescoreLatticeWithLM(
77+
CompactLattice lat,
78+
float lm_scale,
79+
fst::MapFst<fst::StdArc, LatticeArc, fst::StdToLatticeMapper<BaseFloat> > *lm_fst,
80+
CompactLattice *rescored_lattice);
6881
bool FileExists(const std::string& name);
6982
};
7083

src/decoder_config.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ namespace alex_asr {
1111
use_ivectors(false),
1212
use_cmvn(false),
1313
use_pitch(false),
14+
rescore(false),
1415
cfg_decoder(""),
1516
cfg_decodable(""),
1617
cfg_mfcc(""),
@@ -48,6 +49,9 @@ namespace alex_asr {
4849
po->Register("use_cmvn", &use_cmvn, "Are we using cmvn transform?");
4950
po->Register("use_pitch", &use_pitch, "Are we using pitch feature?");
5051
po->Register("bits_per_sample", &bits_per_sample, "Bits per sample for input.");
52+
po->Register("rescore", &rescore, "Rescore lattice with bigger LM?");
53+
po->Register("lm_small", &lm_small_rxfilename, "G.fst of the LM used in HCLG.");
54+
po->Register("lm_big", &lm_big_rxfilename, "G.fst of the LM to rescore with.");
5155

5256
po->Register("cfg_decoder", &cfg_decoder, "");
5357
po->Register("cfg_decodable", &cfg_decodable, "");

src/decoder_config.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ namespace alex_asr {
5858
bool use_ivectors;
5959
bool use_cmvn;
6060
bool use_pitch;
61+
bool rescore;
6162

6263
std::string cfg_decoder;
6364
std::string cfg_decodable;
@@ -72,6 +73,8 @@ namespace alex_asr {
7273

7374
std::string model_rxfilename;
7475
std::string fst_rxfilename;
76+
std::string lm_small_rxfilename;
77+
std::string lm_big_rxfilename;
7578
std::string words_rxfilename;
7679
std::string word_boundary_rxfilename;
7780
std::string lda_mat_rspecifier;

test/asr_model_digits/alex_asr.conf

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
--cfg_decoder=decoder.conf
99
--cfg_decodable=decodable.conf
1010
--cfg_endpoint=endpoint.conf
11+
--rescore=false

0 commit comments

Comments
 (0)