Skip to content

Commit 0c4836c

Browse files
committed
add unittest for the custom formatters
1 parent 6a315a4 commit 0c4836c

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

tests/data_tests/test_dataset_formatters.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,14 @@ def test_common_voice_preprocessor(self): # pylint: disable=no-self-use
1515

1616
assert items[-1]["text"] == "Competition for limited resources has also resulted in some local conflicts."
1717
assert items[-1]["audio_file"] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_19737074.wav")
18+
19+
def test_custom_formatter_with_existing_name(self):
20+
from TTS.tts.datasets import add_formatter
21+
22+
def custom_formatter(root_path, meta_file, ignored_speakers=None):
23+
return []
24+
25+
add_formatter("custom_formatter", custom_formatter)
26+
27+
with self.assertRaises(ValueError):
28+
add_formatter("custom_formatter", custom_formatter)

tests/data_tests/test_loader.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from tests import get_tests_data_path
1010
from TTS.tts.configs.shared_configs import BaseDatasetConfig, BaseTTSConfig
11-
from TTS.tts.datasets import load_tts_samples
11+
from TTS.tts.datasets import add_formatter, load_tts_samples
1212
from TTS.tts.datasets.dataset import TTSDataset
1313
from TTS.tts.utils.text.tokenizer import TTSTokenizer
1414
from TTS.utils.audio import AudioProcessor
@@ -251,3 +251,36 @@ def check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths):
251251
# check batch zero-frame conditions (zero-frame disabled)
252252
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
253253
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
254+
255+
256+
def test_custom_formatted_dataset_with_loader():
257+
def custom_formatter(path, metafile, **kwargs):
258+
with open(os.path.join(path, metafile)) as f:
259+
data = f.readlines()
260+
items = []
261+
for line in data:
262+
file_path, text = line.split("|", 1)
263+
items.append({"text": text, "audio_file": file_path, "root_path": path, "speaker_name": "test"})
264+
return items
265+
266+
def custom_formatter2(x, *args, **kwargs):
267+
items = custom_formatter(x, *args, **kwargs)
268+
[item.update({"audio_file": f"{item['audio_file']}.wav"}) for item in items]
269+
return items
270+
271+
add_formatter("custom_formatter1", custom_formatter)
272+
add_formatter("custom_formatter2", custom_formatter2)
273+
dataset1 = BaseDatasetConfig(
274+
formatter="custom_formatter1",
275+
meta_file_train="metadata.csv",
276+
path=c.data_path,
277+
)
278+
dataset2 = BaseDatasetConfig(
279+
formatter="custom_formatter2",
280+
meta_file_train="metadata.csv",
281+
path=c.data_path,
282+
)
283+
dataset_configs = [dataset1, dataset2]
284+
train_samples, eval_samples = load_tts_samples(dataset_configs, eval_split=True, eval_split_size=0.2)
285+
assert len(train_samples) == 14
286+
assert len(eval_samples) == 2

0 commit comments

Comments
 (0)