-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathengine.py
More file actions
187 lines (157 loc) · 6.59 KB
/
engine.py
File metadata and controls
187 lines (157 loc) · 6.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# File: engine.py
# Core TTS model loading and speech generation logic.
import logging
import random
import numpy as np
import torch
from typing import Optional, Tuple
from pathlib import Path
from chatterbox.tts import ChatterboxTTS # Main TTS engine class
from chatterbox.models.s3gen.const import (
S3GEN_SR,
) # Default sample rate from the engine
# Import the singleton config_manager
from config import config_manager
logger = logging.getLogger(__name__)
# --- Global Module Variables ---
chatterbox_model: Optional[ChatterboxTTS] = None
MODEL_LOADED: bool = False
model_device: Optional[str] = (
None # Stores the resolved device string ('cuda' or 'cpu')
)
def set_seed(seed_value: int):
"""
Sets the seed for torch, random, and numpy for reproducibility.
This is called if a non-zero seed is provided for generation.
"""
torch.manual_seed(seed_value)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value) # if using multi-GPU
random.seed(seed_value)
np.random.seed(seed_value)
logger.info(f"Global seed set to: {seed_value}")
def load_model() -> bool:
"""
Loads the TTS model.
This version directly attempts to load from the Hugging Face repository (or its cache)
using `from_pretrained`, bypassing the local `paths.model_cache` directory.
Updates global variables `chatterbox_model`, `MODEL_LOADED`, and `model_device`.
Returns:
bool: True if the model was loaded successfully, False otherwise.
"""
global chatterbox_model, MODEL_LOADED, model_device
if MODEL_LOADED:
logger.info("TTS model is already loaded.")
return True
try:
# Determine processing device
device_setting = config_manager.get_string("tts_engine.device", "auto")
if device_setting == "auto":
resolved_device_str = "cuda" if torch.cuda.is_available() else "cpu"
elif device_setting in ["cuda", "cpu"]:
resolved_device_str = device_setting
else:
logger.warning(
f"Invalid device setting '{device_setting}', defaulting to auto-detection."
)
resolved_device_str = "cuda" if torch.cuda.is_available() else "cpu"
model_device = resolved_device_str
logger.info(f"Attempting to load TTS model on device: {model_device}")
# Get configured model_repo_id for logging and context,
# though from_pretrained might use its own internal default if not overridden.
model_repo_id_config = config_manager.get_string(
"model.repo_id", "ResembleAI/chatterbox"
)
logger.info(
f"Attempting to load model directly using from_pretrained (expected from Hugging Face repository: {model_repo_id_config} or library default)."
)
try:
# Directly use from_pretrained. This will utilize the standard Hugging Face cache.
# The ChatterboxTTS.from_pretrained method handles downloading if the model is not in the cache.
chatterbox_model = ChatterboxTTS.from_pretrained(device=model_device)
# The actual repo ID used by from_pretrained is often internal to the library,
# but logging the configured one provides user context.
logger.info(
f"Successfully loaded TTS model using from_pretrained (expected from '{model_repo_id_config}' or library default)."
)
except Exception as e_hf:
logger.error(
f"Failed to load model using from_pretrained (expected from '{model_repo_id_config}' or library default): {e_hf}",
exc_info=True,
)
chatterbox_model = None
MODEL_LOADED = False
return False
MODEL_LOADED = True
if chatterbox_model:
logger.info(
f"TTS Model loaded successfully. Engine sample rate: {chatterbox_model.sr} Hz."
)
else:
logger.error(
"Model loading sequence completed, but chatterbox_model is None. This indicates an unexpected issue."
)
MODEL_LOADED = False
return False
return True
except Exception as e:
logger.error(
f"An unexpected error occurred during model loading: {e}", exc_info=True
)
chatterbox_model = None
MODEL_LOADED = False
return False
def synthesize(
text: str,
audio_prompt_path: Optional[str] = None,
temperature: float = 0.8,
exaggeration: float = 0.5,
cfg_weight: float = 0.5,
seed: int = 0,
) -> Tuple[Optional[torch.Tensor], Optional[int]]:
"""
Synthesizes audio from text using the loaded TTS model.
Args:
text: The text to synthesize.
audio_prompt_path: Path to an audio file for voice cloning or predefined voice.
temperature: Controls randomness in generation.
exaggeration: Controls expressiveness.
cfg_weight: Classifier-Free Guidance weight.
seed: Random seed for generation. If 0, default randomness is used.
If non-zero, a global seed is set for reproducibility.
Returns:
A tuple containing the audio waveform (torch.Tensor) and the sample rate (int),
or (None, None) if synthesis fails.
"""
global chatterbox_model
if not MODEL_LOADED or chatterbox_model is None:
logger.error("TTS model is not loaded. Cannot synthesize audio.")
return None, None
try:
# Set seed globally if a specific seed value is provided and is non-zero.
if seed != 0:
logger.info(f"Applying user-provided seed for generation: {seed}")
set_seed(seed)
else:
logger.info(
"Using default (potentially random) generation behavior as seed is 0."
)
logger.debug(
f"Synthesizing with params: audio_prompt='{audio_prompt_path}', temp={temperature}, "
f"exag={exaggeration}, cfg_weight={cfg_weight}, seed_applied_globally_if_nonzero={seed}"
)
# Call the core model's generate method
wav_tensor = chatterbox_model.generate(
text=text,
audio_prompt_path=audio_prompt_path,
temperature=temperature,
exaggeration=exaggeration,
cfg_weight=cfg_weight,
)
# The ChatterboxTTS.generate method already returns a CPU tensor.
return wav_tensor, chatterbox_model.sr
except Exception as e:
logger.error(f"Error during TTS synthesis: {e}", exc_info=True)
return None, None
# --- End File: engine.py ---