Skip to content

Commit 29b391c

Browse files
authored
Merge pull request #550 from idiap/release/0.27.5
v0.27.5
2 parents 79c4373 + f1a7e02 commit 29b391c

File tree

4 files changed

+11
-24
lines changed

4 files changed

+11
-24
lines changed

TTS/tts/layers/xtts/gpt.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,5 @@ def get_generator(self, fake_inputs, **hf_generate_kwargs):
531531
eos_token_id=self.stop_audio_token,
532532
max_length=self.max_gen_mel_tokens + fake_inputs.shape[-1],
533533
attention_mask=attention_mask,
534-
do_stream=True,
535534
**hf_generate_kwargs,
536535
)

TTS/tts/layers/xtts/gpt_inference.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from transformers import GenerationMixin, GPT2PreTrainedModel
44
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
55

6-
from TTS.tts.layers.xtts.stream_generator import StreamGenerationConfig
7-
86

97
class GPT2InferenceModel(GPT2PreTrainedModel, GenerationMixin):
108
"""Override GPT2LMHeadModel to allow for prefix conditioning."""
@@ -17,7 +15,6 @@ def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
1715
self.final_norm = norm
1816
self.lm_head = nn.Sequential(norm, linear)
1917
self.kv_cache = kv_cache
20-
self.generation_config = StreamGenerationConfig.from_model_config(config) if self.can_generate() else None
2118

2219
def store_prefix_emb(self, prefix_emb):
2320
self.cached_prefix_emb = prefix_emb

TTS/tts/layers/xtts/stream_generator.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,18 @@ def setup_seed(seed: int) -> None:
2929
torch.backends.cudnn.deterministic = True
3030

3131

32-
class StreamGenerationConfig(GenerationConfig):
33-
def __init__(self, **kwargs):
34-
super().__init__(**kwargs)
35-
self.do_stream = kwargs.pop("do_stream", False)
36-
37-
3832
class NewGenerationMixin(GenerationMixin):
3933
@torch.inference_mode()
4034
def generate( # noqa: PLR0911
4135
self,
4236
inputs: torch.Tensor | None = None,
43-
generation_config: StreamGenerationConfig | None = None,
37+
generation_config: GenerationConfig | None = None,
4438
logits_processor: LogitsProcessorList | None = None,
4539
stopping_criteria: StoppingCriteriaList | None = None,
4640
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]] | None = None,
4741
synced_gpus: bool | None = False,
4842
assistant_model: PreTrainedModel | None = None,
4943
streamer: "BaseStreamer | None" = None,
50-
use_model_defaults: bool | None = None,
5144
custom_generate: str | Callable | None = None,
5245
seed: int = 0,
5346
**kwargs,
@@ -102,11 +95,6 @@ def generate( # noqa: PLR0911
10295
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistant model
10396
is much faster than running generation with the model you're calling generate from. As such, the
10497
assistant model should be much smaller.
105-
use_model_defaults (`bool`, *optional*):
106-
When it is `True`, unset parameters in `generation_config` will be set to the model-specific default
107-
generation configuration (`model.generation_config`), as opposed to the global defaults
108-
(`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
109-
`True`.
11098
kwargs:
11199
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
112100
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
@@ -137,10 +125,16 @@ def generate( # noqa: PLR0911
137125
generation_mode_kwargs = self._extract_generation_mode_kwargs(
138126
custom_generate, kwargs, synced_gpus, assistant_model, streamer
139127
)
140-
141-
generation_config, model_kwargs = self._prepare_generation_config(
142-
generation_config, use_model_defaults, **kwargs
128+
# Check length values before updating the config with defaults.
129+
# We'll use it later to define the final min/max length (# 6)
130+
has_default_max_length = kwargs.get("max_length") is None and (
131+
generation_config is None or generation_config.max_length is None
143132
)
133+
has_default_min_length = kwargs.get("min_length") is None and (
134+
generation_config is None or generation_config.min_length is None
135+
)
136+
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
137+
144138
generation_mode = generation_config.get_generation_mode(assistant_model)
145139
self._validate_model_kwargs(model_kwargs.copy())
146140
self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs)
@@ -212,8 +206,6 @@ def generate( # noqa: PLR0911
212206

213207
# 6. Prepare `max_length` depending on other stopping criteria.
214208
input_ids_length = input_ids.shape[-1]
215-
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
216-
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
217209
generation_config = self._prepare_generated_length(
218210
generation_config=generation_config,
219211
has_default_max_length=has_default_max_length,
@@ -500,7 +492,6 @@ def init_stream_support():
500492
repetition_penalty=1.2,
501493
early_stopping=True,
502494
seed=0,
503-
do_stream=True,
504495
)
505496
stream_result = ""
506497
for x in generator:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ build-backend = "hatchling.build"
2525

2626
[project]
2727
name = "coqui-tts"
28-
version = "0.27.4"
28+
version = "0.27.5"
2929
description = "Deep learning for Text to Speech."
3030
readme = "README.md"
3131
requires-python = ">=3.10, <3.15"

0 commit comments

Comments
 (0)