|
8 | 8 |
|
9 | 9 | from tests import get_tests_data_path |
10 | 10 | from TTS.tts.configs.shared_configs import BaseDatasetConfig, BaseTTSConfig |
11 | | -from TTS.tts.datasets import load_tts_samples |
| 11 | +from TTS.tts.datasets import add_formatter, load_tts_samples |
12 | 12 | from TTS.tts.datasets.dataset import TTSDataset |
13 | 13 | from TTS.tts.utils.text.tokenizer import TTSTokenizer |
14 | 14 | from TTS.utils.audio import AudioProcessor |
@@ -251,3 +251,36 @@ def check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths): |
251 | 251 | # check batch zero-frame conditions (zero-frame disabled) |
252 | 252 | # assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 |
253 | 253 | # assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 |
| 254 | + |
| 255 | + |
| 256 | +def test_custom_formatted_dataset_with_loader(): |
| 257 | + def custom_formatter(path, metafile, **kwargs): |
| 258 | + with open(os.path.join(path, metafile)) as f: |
| 259 | + data = f.readlines() |
| 260 | + items = [] |
| 261 | + for line in data: |
| 262 | + file_path, text = line.split("|", 1) |
| 263 | + items.append({"text": text, "audio_file": file_path, "root_path": path, "speaker_name": "test"}) |
| 264 | + return items |
| 265 | + |
| 266 | + def custom_formatter2(x, *args, **kwargs): |
| 267 | + items = custom_formatter(x, *args, **kwargs) |
| 268 | + [item.update({"audio_file": f"{item['audio_file']}.wav"}) for item in items] |
| 269 | + return items |
| 270 | + |
| 271 | + add_formatter("custom_formatter1", custom_formatter) |
| 272 | + add_formatter("custom_formatter2", custom_formatter2) |
| 273 | + dataset1 = BaseDatasetConfig( |
| 274 | + formatter="custom_formatter1", |
| 275 | + meta_file_train="metadata.csv", |
| 276 | + path=c.data_path, |
| 277 | + ) |
| 278 | + dataset2 = BaseDatasetConfig( |
| 279 | + formatter="custom_formatter2", |
| 280 | + meta_file_train="metadata.csv", |
| 281 | + path=c.data_path, |
| 282 | + ) |
| 283 | + dataset_configs = [dataset1, dataset2] |
| 284 | + train_samples, eval_samples = load_tts_samples(dataset_configs, eval_split=True, eval_split_size=0.2) |
| 285 | + assert len(train_samples) == 14 |
| 286 | + assert len(eval_samples) == 2 |
0 commit comments