Skip to content

Commit 44b32fd

Browse files
[Ready to merge] Pruned-transducer-stateless2 recipe for aidatatang_200zh (k2-fsa#375)
* add pruned-rnnt2 model for aidatatang_200zh * do some changes * change for README.md * do some changes
1 parent fe522bc commit 44b32fd

27 files changed

+3978
-0
lines changed

egs/aidatatang_200zh/ASR/README.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
Note: This recipe is trained with the codes from this PR https://github.com/k2-fsa/icefall/pull/375
2+
# Pre-trained Transducer-Stateless2 models for the Aidatatang_200zh dataset with icefall.
3+
The model was trained on full [Aidatatang_200zh](https://www.openslr.org/62) with the scripts in [icefall](https://github.com/k2-fsa/icefall) based on the latest version k2.
4+
## Training procedure
5+
The main repositories are list below, we will update the training and decoding scripts with the update of version.
6+
k2: https://github.com/k2-fsa/k2
7+
icefall: https://github.com/k2-fsa/icefall
8+
lhotse: https://github.com/lhotse-speech/lhotse
9+
* Install k2 and lhotse, k2 installation guide refers to https://k2.readthedocs.io/en/latest/installation/index.html, lhotse refers to https://lhotse.readthedocs.io/en/latest/getting-started.html#installation. I think the latest version would be ok. And please also install the requirements listed in icefall.
10+
* Clone icefall(https://github.com/k2-fsa/icefall) and check to the commit showed above.
11+
```
12+
git clone https://github.com/k2-fsa/icefall
13+
cd icefall
14+
```
15+
* Preparing data.
16+
```
17+
cd egs/aidatatang_200zh/ASR
18+
bash ./prepare.sh
19+
```
20+
* Training
21+
```
22+
export CUDA_VISIBLE_DEVICES="0,1"
23+
./pruned_transducer_stateless2/train.py \
24+
--world-size 2 \
25+
--num-epochs 30 \
26+
--start-epoch 0 \
27+
--exp-dir pruned_transducer_stateless2/exp \
28+
--lang-dir data/lang_char \
29+
--max-duration 250
30+
```
31+
## Evaluation results
32+
The decoding results (WER%) on Aidatatang_200zh(dev and test) are listed below, we got this result by averaging models from epoch 11 to 29.
33+
The WERs are
34+
| | dev | test | comment |
35+
|------------------------------------|------------|------------|------------------------------------------|
36+
| greedy search | 5.53 | 6.59 | --epoch 29, --avg 19, --max-duration 100 |
37+
| modified beam search (beam size 4) | 5.27 | 6.33 | --epoch 29, --avg 19, --max-duration 100 |
38+
| fast beam search (set as default) | 5.30 | 6.34 | --epoch 29, --avg 19, --max-duration 1500|
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
## Results
2+
3+
### Aidatatang_200zh Char training results (Pruned Transducer Stateless2)
4+
5+
#### 2022-05-16
6+
7+
Using the codes from this PR https://github.com/k2-fsa/icefall/pull/375.
8+
9+
The WERs are
10+
11+
| | dev | test | comment |
12+
|------------------------------------|------------|------------|------------------------------------------|
13+
| greedy search | 5.53 | 6.59 | --epoch 29, --avg 19, --max-duration 100 |
14+
| modified beam search (beam size 4) | 5.27 | 6.33 | --epoch 29, --avg 19, --max-duration 100 |
15+
| fast beam search (set as default) | 5.30 | 6.34 | --epoch 29, --avg 19, --max-duration 1500|
16+
17+
The training command for reproducing is given below:
18+
19+
```
20+
export CUDA_VISIBLE_DEVICES="0,1"
21+
22+
./pruned_transducer_stateless2/train.py \
23+
--world-size 2 \
24+
--num-epochs 30 \
25+
--start-epoch 0 \
26+
--exp-dir pruned_transducer_stateless2/exp \
27+
--lang-dir data/lang_char \
28+
--max-duration 250 \
29+
--save-every-n 1000
30+
31+
```
32+
33+
The tensorboard training log can be found at
34+
https://tensorboard.dev/experiment/xS7kgYf2RwyDpQAOdS8rAA/#scalars
35+
36+
The decoding command is:
37+
```
38+
epoch=29
39+
avg=19
40+
41+
## greedy search
42+
./pruned_transducer_stateless2/decode.py \
43+
--epoch $epoch \
44+
--avg $avg \
45+
--exp-dir pruned_transducer_stateless2/exp \
46+
--lang-dir ./data/lang_char \
47+
--max-duration 100
48+
49+
## modified beam search
50+
./pruned_transducer_stateless2/decode.py \
51+
--epoch $epoch \
52+
--avg $avg \
53+
--exp-dir pruned_transducer_stateless2/exp \
54+
--lang-dir ./data/lang_char \
55+
--max-duration 100 \
56+
--decoding-method modified_beam_search \
57+
--beam-size 4
58+
59+
## fast beam search
60+
./pruned_transducer_stateless2/decode.py \
61+
--epoch $epoch \
62+
--avg $avg \
63+
--exp-dir ./pruned_transducer_stateless2/exp \
64+
--lang-dir ./data/lang_char \
65+
--max-duration 1500 \
66+
--decoding-method fast_beam_search \
67+
--beam 4 \
68+
--max-contexts 4 \
69+
--max-states 8
70+
```
71+
72+
A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_aidatatang-200zh_pruned_transducer_stateless2>

egs/aidatatang_200zh/ASR/local/__init__.py

Whitespace-only changes.
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
3+
#
4+
# See ../../../../LICENSE for clarification regarding multiple authors
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
19+
"""
20+
This file computes fbank features of the aidatatang_200zh dataset.
21+
It looks for manifests in the directory data/manifests.
22+
23+
The generated fbank features are saved in data/fbank.
24+
"""
25+
26+
import argparse
27+
import logging
28+
import os
29+
from pathlib import Path
30+
31+
import torch
32+
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
33+
from lhotse.recipes.utils import read_manifests_if_cached
34+
35+
from icefall.utils import get_executor
36+
37+
# Torch's multithreaded behavior needs to be disabled or
38+
# it wastes a lot of CPU and slow things down.
39+
# Do this outside of main() in case it needs to take effect
40+
# even when we are not invoking the main (e.g. when spawning subprocesses).
41+
torch.set_num_threads(1)
42+
torch.set_num_interop_threads(1)
43+
44+
45+
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
46+
src_dir = Path("data/manifests/aidatatang_200zh")
47+
output_dir = Path("data/fbank")
48+
num_jobs = min(15, os.cpu_count())
49+
50+
dataset_parts = (
51+
"train",
52+
"dev",
53+
"test",
54+
)
55+
manifests = read_manifests_if_cached(
56+
dataset_parts=dataset_parts, output_dir=src_dir
57+
)
58+
assert manifests is not None
59+
60+
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
61+
62+
with get_executor() as ex: # Initialize the executor only once.
63+
for partition, m in manifests.items():
64+
if (output_dir / f"cuts_{partition}.json.gz").is_file():
65+
logging.info(f"{partition} already exists - skipping.")
66+
continue
67+
logging.info(f"Processing {partition}")
68+
cut_set = CutSet.from_manifests(
69+
recordings=m["recordings"],
70+
supervisions=m["supervisions"],
71+
)
72+
if "train" in partition:
73+
cut_set = (
74+
cut_set
75+
+ cut_set.perturb_speed(0.9)
76+
+ cut_set.perturb_speed(1.1)
77+
)
78+
cut_set = cut_set.compute_and_store_features(
79+
extractor=extractor,
80+
storage_path=f"{output_dir}/feats_{partition}",
81+
# when an executor is specified, make more partitions
82+
num_jobs=num_jobs if ex is None else 80,
83+
executor=ex,
84+
storage_type=LilcomHdf5Writer,
85+
)
86+
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")
87+
88+
89+
def get_args():
90+
parser = argparse.ArgumentParser()
91+
parser.add_argument(
92+
"--num-mel-bins",
93+
type=int,
94+
default=80,
95+
help="""The number of mel bins for Fbank""",
96+
)
97+
98+
return parser.parse_args()
99+
100+
101+
if __name__ == "__main__":
102+
formatter = (
103+
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
104+
)
105+
106+
logging.basicConfig(format=formatter, level=logging.INFO)
107+
108+
args = get_args()
109+
compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../librispeech/ASR/local/compute_fbank_musan.py
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
2+
# Mingshuang Luo)
3+
#
4+
# See ../../../../LICENSE for clarification regarding multiple authors
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
"""
19+
This file displays duration statistics of utterances in a manifest.
20+
You can use the displayed value to choose minimum/maximum duration
21+
to remove short and long utterances during the training.
22+
See the function `remove_short_and_long_utt()`
23+
in ../../../librispeech/ASR/transducer/train.py
24+
for usage.
25+
"""
26+
27+
28+
from lhotse import load_manifest
29+
30+
31+
def main():
32+
paths = [
33+
"./data/fbank/cuts_train.json.gz",
34+
"./data/fbank/cuts_dev.json.gz",
35+
"./data/fbank/cuts_test.json.gz",
36+
]
37+
38+
for path in paths:
39+
print(f"Starting display the statistics for {path}")
40+
cuts = load_manifest(path)
41+
cuts.describe()
42+
43+
44+
if __name__ == "__main__":
45+
main()
46+
47+
"""
48+
Starting display the statistics for ./data/fbank/cuts_train.json.gz
49+
Cuts count: 494715
50+
Total duration (hours): 422.6
51+
Speech duration (hours): 422.6 (100.0%)
52+
***
53+
Duration statistics (seconds):
54+
mean 3.1
55+
std 1.2
56+
min 1.0
57+
25% 2.3
58+
50% 2.7
59+
75% 3.5
60+
99% 7.2
61+
99.5% 8.0
62+
99.9% 9.5
63+
max 18.1
64+
Starting display the statistics for ./data/fbank/cuts_dev.json.gz
65+
Cuts count: 24216
66+
Total duration (hours): 20.2
67+
Speech duration (hours): 20.2 (100.0%)
68+
***
69+
Duration statistics (seconds):
70+
mean 3.0
71+
std 1.0
72+
min 1.2
73+
25% 2.3
74+
50% 2.7
75+
75% 3.4
76+
99% 6.7
77+
99.5% 7.3
78+
99.9% 8.8
79+
max 11.3
80+
Starting display the statistics for ./data/fbank/cuts_test.json.gz
81+
Cuts count: 48144
82+
Total duration (hours): 40.2
83+
Speech duration (hours): 40.2 (100.0%)
84+
***
85+
Duration statistics (seconds):
86+
mean 3.0
87+
std 1.1
88+
min 0.9
89+
25% 2.3
90+
50% 2.6
91+
75% 3.4
92+
99% 6.9
93+
99.5% 7.5
94+
99.9% 9.0
95+
max 21.8
96+
"""

0 commit comments

Comments
 (0)