Skip to content

Commit 6cac242

Browse files
authored
Merge pull request #489 from idiap/dev
v0.27.1
2 parents 96472c8 + ccc75ae commit 6cac242

File tree

57 files changed

+4176
-3754
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+4176
-3754
lines changed

.github/workflows/docker.yaml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
steps:
3838
- uses: actions/checkout@v4
3939
- name: Log in to the Container registry
40-
uses: docker/login-action@v1
40+
uses: docker/login-action@v3
4141
with:
4242
registry: ghcr.io
4343
username: ${{ github.actor }}
@@ -66,14 +66,14 @@ jobs:
6666
fi
6767
tags="${base}:${VERSION},${base}:latest,${base}:${{ github.sha }}"
6868
fi
69-
echo "::set-output name=tags::${tags}"
69+
echo "tags=${tags}" >> $GITHUB_OUTPUT
7070
- name: Set up QEMU
71-
uses: docker/setup-qemu-action@v1
71+
uses: docker/setup-qemu-action@v3
7272
- name: Set up Docker Buildx
7373
id: buildx
74-
uses: docker/setup-buildx-action@v1
74+
uses: docker/setup-buildx-action@v3
7575
- name: Build and push
76-
uses: docker/build-push-action@v2
76+
uses: docker/build-push-action@v6
7777
with:
7878
context: .
7979
platforms: linux/${{ matrix.arch }}
@@ -91,7 +91,7 @@ jobs:
9191
steps:
9292
- uses: actions/checkout@v4
9393
- name: Log in to the Container registry
94-
uses: docker/login-action@v1
94+
uses: docker/login-action@v3
9595
with:
9696
registry: ghcr.io
9797
username: ${{ github.actor }}
@@ -120,14 +120,14 @@ jobs:
120120
fi
121121
tags="${base}:${VERSION},${base}:latest,${base}:${{ github.sha }}"
122122
fi
123-
echo "::set-output name=tags::${tags}"
123+
echo "tags=${tags}" >> $GITHUB_OUTPUT
124124
- name: Set up QEMU
125-
uses: docker/setup-qemu-action@v1
125+
uses: docker/setup-qemu-action@v3
126126
- name: Set up Docker Buildx
127127
id: buildx
128-
uses: docker/setup-buildx-action@v1
128+
uses: docker/setup-buildx-action@v3
129129
- name: Build and push
130-
uses: docker/build-push-action@v2
130+
uses: docker/build-push-action@v6
131131
with:
132132
context: .
133133
file: dockerfiles/Dockerfile.dev

TTS/tts/datasets/__init__.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010

1111
from TTS.tts.datasets.dataset import *
12-
from TTS.tts.datasets.formatters import *
12+
from TTS.tts.datasets.formatters import _FORMATTER_REGISTRY, Formatter, register_formatter
1313

1414
logger = logging.getLogger(__name__)
1515

@@ -162,15 +162,12 @@ def load_attention_mask_meta_data(metafile_path):
162162
return meta_data
163163

164164

165-
def _get_formatter_by_name(name):
165+
def _get_formatter_by_name(name: str) -> Formatter:
166166
"""Returns the respective preprocessing function."""
167-
thismodule = sys.modules[__name__]
168-
if not hasattr(thismodule, name.lower()):
169-
msg = (
170-
f"{name} formatter not found. If it is a custom formatter, pass the function to load_tts_samples() instead."
171-
)
167+
if name.lower() not in _FORMATTER_REGISTRY:
168+
msg = f"{name} formatter not found. If it is a custom formatter, make sure to call register_formatter() first."
172169
raise ValueError(msg)
173-
return getattr(thismodule, name.lower())
170+
return _FORMATTER_REGISTRY[name.lower()]
174171

175172

176173
def find_unique_chars(data_samples):

TTS/tts/datasets/formatters.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,39 @@
55
import xml.etree.ElementTree as ET
66
from glob import glob
77
from pathlib import Path
8+
from typing import Any, Protocol
89

910
from tqdm import tqdm
1011

1112
logger = logging.getLogger(__name__)
1213

14+
15+
class Formatter(Protocol):
16+
def __call__(
17+
self,
18+
root_path: str | os.PathLike[Any],
19+
meta_file: str | os.PathLike[Any],
20+
ignored_speakers: list[str] | None,
21+
**kwargs,
22+
) -> list[dict[str, Any]]: ...
23+
24+
25+
_FORMATTER_REGISTRY: dict[str, Formatter] = {}
26+
27+
28+
def register_formatter(name: str, formatter: Formatter) -> None:
29+
"""Add a formatter function to the registry.
30+
31+
Args:
32+
name: Name of the formatter.
33+
formatter: Formatter function.
34+
"""
35+
if name.lower() in _FORMATTER_REGISTRY:
36+
msg = f"Formatter {name} already exists."
37+
raise ValueError(msg)
38+
_FORMATTER_REGISTRY[name.lower()] = formatter
39+
40+
1341
########################
1442
# DATASETS
1543
########################
@@ -659,3 +687,35 @@ def bel_tts_formatter(root_path, meta_file, **kwargs): # pylint: disable=unused
659687
text = cols[1]
660688
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
661689
return items
690+
691+
692+
### Registrations
693+
register_formatter("cml_tts", cml_tts)
694+
register_formatter("coqui", coqui)
695+
register_formatter("tweb", tweb)
696+
register_formatter("mozilla", mozilla)
697+
register_formatter("mozilla_de", mozilla_de)
698+
register_formatter("mailabs", mailabs)
699+
register_formatter("ljspeech", ljspeech)
700+
register_formatter("ljspeech_test", ljspeech_test)
701+
register_formatter("thorsten", thorsten)
702+
register_formatter("sam_accenture", sam_accenture)
703+
register_formatter("ruslan", ruslan)
704+
register_formatter("css10", css10)
705+
register_formatter("nancy", nancy)
706+
register_formatter("common_voice", common_voice)
707+
register_formatter("libri_tts", libri_tts)
708+
register_formatter("custom_turkish", custom_turkish)
709+
register_formatter("brspeech", brspeech)
710+
register_formatter("vctk", vctk)
711+
register_formatter("vctk_old", vctk_old)
712+
register_formatter("synpaflex", synpaflex)
713+
register_formatter("open_bible", open_bible)
714+
register_formatter("mls", mls)
715+
register_formatter("voxceleb2", voxceleb2)
716+
register_formatter("voxceleb1", voxceleb1)
717+
register_formatter("emotion", emotion)
718+
register_formatter("baker", baker)
719+
register_formatter("kokoro", kokoro)
720+
register_formatter("kss", kss)
721+
register_formatter("bel_tts_formatter", bel_tts_formatter)

TTS/tts/layers/bark/inference_funcs.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ def generate_text_semantic(
7676
)
7777
else:
7878
semantic_history = None
79-
encoded_text = torch.LongTensor(_tokenize(model.tokenizer, text)) + model.config.TEXT_ENCODING_OFFSET
79+
encoded_text = (
80+
torch.tensor(_tokenize(model.tokenizer, text), device=model.device, dtype=torch.long)
81+
+ model.config.TEXT_ENCODING_OFFSET
82+
)
8083
if len(encoded_text) > 256:
8184
p = (len(encoded_text) - 256) / len(encoded_text) * 100
8285
logger.warning("warning, text too long, lopping of last %.1f%%", p)
@@ -99,11 +102,9 @@ def generate_text_semantic(
99102
)
100103
else:
101104
semantic_history = torch.full((256,), model.config.SEMANTIC_PAD_TOKEN, dtype=torch.int64)
102-
x = (
103-
torch.cat([encoded_text, semantic_history, torch.tensor([model.config.SEMANTIC_INFER_TOKEN])])
104-
.unsqueeze(0)
105-
.to(model.device)
106-
)
105+
x = torch.cat(
106+
[encoded_text, semantic_history, torch.tensor([model.config.SEMANTIC_INFER_TOKEN], device=model.device)]
107+
).unsqueeze(0)
107108
assert x.shape[1] == 256 + 256 + 1
108109

109110
n_tot_steps = 768

TTS/tts/models/bark.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import logging
12
import os
23
import warnings
34
from dataclasses import dataclass
45
from pathlib import Path
56
from typing import Any
67

8+
import numpy as np
79
import torch
810
import torchaudio
911
from coqpit import Coqpit
@@ -25,7 +27,14 @@
2527
from TTS.tts.layers.bark.model import GPT
2628
from TTS.tts.layers.bark.model_fine import FineGPT
2729
from TTS.tts.models.base_tts import BaseTTS
28-
from TTS.utils.generic_utils import warn_synthesize_config_deprecated, warn_synthesize_speaker_id_deprecated
30+
from TTS.utils.generic_utils import (
31+
is_pytorch_at_least_2_4,
32+
slugify,
33+
warn_synthesize_config_deprecated,
34+
warn_synthesize_speaker_id_deprecated,
35+
)
36+
37+
logger = logging.getLogger(__name__)
2938

3039

3140
@dataclass
@@ -209,6 +218,46 @@ def _clone_voice(
209218
metadata = {"name": self.config["model"]}
210219
return voice, metadata
211220

221+
def get_voices(self, voice_dir: str | os.PathLike[Any]) -> dict[str, Path]:
222+
"""Return all available voices in the given directory.
223+
224+
Args:
225+
voice_dir: Directory to search for voices.
226+
227+
Returns:
228+
Dictionary mapping a speaker ID to its voice file.
229+
"""
230+
# For Bark we overwrite the base method to also allow loading the npz
231+
# files included with the original model.
232+
return {path.stem: path for path in Path(voice_dir).iterdir() if path.suffix in (".npz", ".pth")}
233+
234+
def load_voice_file(
235+
self,
236+
speaker_id: str,
237+
voice_dir: str | os.PathLike[Any],
238+
) -> dict[str, Any]:
239+
"""Load the voice for the given speaker.
240+
241+
Args:
242+
speaker_id:
243+
Speaker ID to load.
244+
voice_dir:
245+
Directory where to look for the voice.
246+
"""
247+
# For Bark we overwrite the base method to also allow loading the npz
248+
# files included with the original model.
249+
voices = self.get_voices(voice_dir)
250+
if speaker_id not in voices:
251+
msg = f"Voice file `{slugify(speaker_id)}.pth` or .npz for speaker `{speaker_id}` not found in: {voice_dir}"
252+
raise FileNotFoundError(msg)
253+
if voices[speaker_id].suffix == ".npz":
254+
np_voice = np.load(voices[speaker_id])
255+
voice = {key: torch.tensor(np_voice[key]) for key in np_voice.keys()}
256+
else:
257+
voice = torch.load(voices[speaker_id], map_location="cpu", weights_only=is_pytorch_at_least_2_4())
258+
logger.info("Loaded voice `%s` from: %s", speaker_id, voices[speaker_id])
259+
return voice
260+
212261
def synthesize(
213262
self,
214263
text: str,

0 commit comments

Comments
 (0)