Skip to content

Commit ed0bfcc

Browse files
committed
refactor(datasets): register formatters in dict
1 parent 90a2bd0 commit ed0bfcc

File tree

5 files changed

+73
-35
lines changed

5 files changed

+73
-35
lines changed

TTS/tts/datasets/__init__.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010

1111
from TTS.tts.datasets.dataset import *
12-
from TTS.tts.datasets.formatters import *
12+
from TTS.tts.datasets.formatters import _FORMATTER_REGISTRY, Formatter, register_formatter
1313

1414
logger = logging.getLogger(__name__)
1515

@@ -162,32 +162,12 @@ def load_attention_mask_meta_data(metafile_path):
162162
return meta_data
163163

164164

165-
def add_formatter(name: str, formatter: Callable[[str, str, list[str] | None], list[dict]]):
166-
"""Add a formatter to the datasets module. If the formatter already exists, raise an error.
167-
Args:
168-
name (str): The name of the formatter.
169-
formatter (Callable): The formatter function.
170-
Raises:
171-
ValueError: If the formatter already exists.
172-
Returns:
173-
None
174-
"""
175-
thismodule = sys.modules[__name__]
176-
if not hasattr(thismodule, name.lower()):
177-
setattr(thismodule, name.lower(), formatter)
178-
else:
179-
raise ValueError(f"Formatter {name} already exists.")
180-
181-
182-
def _get_formatter_by_name(name):
165+
def _get_formatter_by_name(name: str) -> Formatter:
183166
"""Returns the respective preprocessing function."""
184-
thismodule = sys.modules[__name__]
185-
if not hasattr(thismodule, name.lower()):
186-
msg = (
187-
f"{name} formatter not found. If it is a custom formatter, pass the function to load_tts_samples() instead."
188-
)
167+
if name.lower() not in _FORMATTER_REGISTRY:
168+
msg = f"{name} formatter not found. If it is a custom formatter, make sure to call register_formatter() first."
189169
raise ValueError(msg)
190-
return getattr(thismodule, name.lower())
170+
return _FORMATTER_REGISTRY[name.lower()]
191171

192172

193173
def find_unique_chars(data_samples):

TTS/tts/datasets/formatters.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,39 @@
55
import xml.etree.ElementTree as ET
66
from glob import glob
77
from pathlib import Path
8+
from typing import Any, Protocol
89

910
from tqdm import tqdm
1011

1112
logger = logging.getLogger(__name__)
1213

14+
15+
class Formatter(Protocol):
16+
def __call__(
17+
self,
18+
root_path: str | os.PathLike[Any],
19+
meta_file: str | os.PathLike[Any],
20+
ignored_speakers: list[str] | None,
21+
**kwargs,
22+
) -> list[dict[str, Any]]: ...
23+
24+
25+
_FORMATTER_REGISTRY: dict[str, Formatter] = {}
26+
27+
28+
def register_formatter(name: str, formatter: Formatter) -> None:
29+
"""Add a formatter function to the registry.
30+
31+
Args:
32+
name: Name of the formatter.
33+
formatter: Formatter function.
34+
"""
35+
if name.lower() in _FORMATTER_REGISTRY:
36+
msg = f"Formatter {name} already exists."
37+
raise ValueError(msg)
38+
_FORMATTER_REGISTRY[name.lower()] = formatter
39+
40+
1341
########################
1442
# DATASETS
1543
########################
@@ -659,3 +687,35 @@ def bel_tts_formatter(root_path, meta_file, **kwargs): # pylint: disable=unused
659687
text = cols[1]
660688
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
661689
return items
690+
691+
692+
### Registrations
693+
register_formatter("cml_tts", cml_tts)
694+
register_formatter("coqui", coqui)
695+
register_formatter("tweb", tweb)
696+
register_formatter("mozilla", mozilla)
697+
register_formatter("mozilla_de", mozilla_de)
698+
register_formatter("mailabs", mailabs)
699+
register_formatter("ljspeech", ljspeech)
700+
register_formatter("ljspeech_test", ljspeech_test)
701+
register_formatter("thorsten", thorsten)
702+
register_formatter("sam_accenture", sam_accenture)
703+
register_formatter("ruslan", ruslan)
704+
register_formatter("css10", css10)
705+
register_formatter("nancy", nancy)
706+
register_formatter("common_voice", common_voice)
707+
register_formatter("libri_tts", libri_tts)
708+
register_formatter("custom_turkish", custom_turkish)
709+
register_formatter("brspeech", brspeech)
710+
register_formatter("vctk", vctk)
711+
register_formatter("vctk_old", vctk_old)
712+
register_formatter("synpaflex", synpaflex)
713+
register_formatter("open_bible", open_bible)
714+
register_formatter("mls", mls)
715+
register_formatter("voxceleb2", voxceleb2)
716+
register_formatter("voxceleb1", voxceleb1)
717+
register_formatter("emotion", emotion)
718+
register_formatter("baker", baker)
719+
register_formatter("kokoro", kokoro)
720+
register_formatter("kss", kss)
721+
register_formatter("bel_tts_formatter", bel_tts_formatter)

docs/source/datasets/formatting_your_dataset.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
101101
Load a custom dataset with a custom formatter.
102102

103103
```python
104-
from TTS.tts.datasets import load_tts_samples, add_formatter
104+
from TTS.tts.datasets import load_tts_samples, register_formatter
105105

106106

107107
# custom formatter implementation
@@ -119,7 +119,7 @@ def formatter(root_path, manifest_file, **kwargs): # pylint: disable=unused-arg
119119
items.append({"text":text, "audio_file":wav_file, "speaker_name":speaker_name, "root_path": root_path})
120120
return items
121121

122-
add_formatter("custom_formatter_name", formatter) # Use the custom formatter name in the dataset config
122+
register_formatter("custom_formatter_name", formatter) # Use the custom formatter name in the dataset config
123123
dataset_config = BaseDatasetConfig(
124124
formatter="custom_formatter_name", meta_file_train="", language="en-us", path="dataset-path")
125125
)

tests/data_tests/test_dataset_formatters.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import unittest
33

44
from tests import get_tests_input_path
5-
from TTS.tts.datasets.formatters import common_voice
5+
from TTS.tts.datasets.formatters import common_voice, register_formatter
66

77

88
class TestTTSFormatters(unittest.TestCase):
@@ -17,12 +17,10 @@ def test_common_voice_preprocessor(self): # pylint: disable=no-self-use
1717
assert items[-1]["audio_file"] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_19737074.wav")
1818

1919
def test_custom_formatter_with_existing_name(self):
20-
from TTS.tts.datasets import add_formatter
21-
2220
def custom_formatter(root_path, meta_file, ignored_speakers=None):
2321
return []
2422

25-
add_formatter("custom_formatter", custom_formatter)
23+
register_formatter("custom_formatter", custom_formatter)
2624

2725
with self.assertRaises(ValueError):
28-
add_formatter("custom_formatter", custom_formatter)
26+
register_formatter("custom_formatter", custom_formatter)

tests/data_tests/test_loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from tests import get_tests_data_path
1010
from TTS.tts.configs.shared_configs import BaseDatasetConfig, BaseTTSConfig
11-
from TTS.tts.datasets import add_formatter, load_tts_samples
11+
from TTS.tts.datasets import load_tts_samples, register_formatter
1212
from TTS.tts.datasets.dataset import TTSDataset
1313
from TTS.tts.utils.text.tokenizer import TTSTokenizer
1414
from TTS.utils.audio import AudioProcessor
@@ -268,8 +268,8 @@ def custom_formatter2(x, *args, **kwargs):
268268
[item.update({"audio_file": f"{item['audio_file']}.wav"}) for item in items]
269269
return items
270270

271-
add_formatter("custom_formatter1", custom_formatter)
272-
add_formatter("custom_formatter2", custom_formatter2)
271+
register_formatter("custom_formatter1", custom_formatter)
272+
register_formatter("custom_formatter2", custom_formatter2)
273273
dataset1 = BaseDatasetConfig(
274274
formatter="custom_formatter1",
275275
meta_file_train="metadata.csv",

0 commit comments

Comments
 (0)