Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 15 additions & 62 deletions projects/llm_framework/main_melotts/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ class llm_task {
try {
std::vector<int16_t> wav_pcm_data;
if (msg_str.empty()) {
SLOGI("empty");
if (out_callback_) {
std::string output = wav_pcm_data.empty() ? std::string()
: std::string((char *)wav_pcm_data.data(),
Expand All @@ -253,15 +252,13 @@ class llm_task {
return false;
}

// Convert text to phonemes and tones
std::vector<int> phones_bef, tones_bef;
lexicon_->convert(msg_str, phones_bef, tones_bef);
auto phones = intersperse(phones_bef, 0);
auto tones = intersperse(tones_bef, 0);
int phone_len = phones.size();
std::vector<int> langids(phone_len, 3);

// Run the encoder to generate hidden representations
auto encoder_output =
encoder_->Run(phones, tones, langids, g_matrix, mode_config_.noise_scale, mode_config_.noise_scale_w,
mode_config_.get_length_scale(), mode_config_.sdp_ratio);
Expand All @@ -270,7 +267,6 @@ class llm_task {
auto zp_info = encoder_output.at(0).GetTensorTypeAndShapeInfo();
auto zp_shape = zp_info.GetShape();

// Calculate decoder parameters
int zp_size = decoder_->GetInputSize(0) / sizeof(float);
int dec_len = zp_size / zp_shape[1];
int audio_slice_len = decoder_->GetOutputSize(0) / sizeof(float);
Expand All @@ -283,12 +279,10 @@ class llm_task {
int dec_slice_num =
static_cast<int>(std::ceil(static_cast<double>(zp_shape[2]) / static_cast<double>(effective_frames)));

// SOLA parameters setup
const int sola_buffer_frame = pad_frames * samples_per_frame; // Overlap buffer length
const int sola_search_frame = pad_frames * samples_per_frame; // Search window length
const int block_frame = (dec_len - 2 * pad_frames) * samples_per_frame; // Effective block length
const int sola_buffer_frame = pad_frames * samples_per_frame;
const int sola_search_frame = pad_frames * samples_per_frame;
const int block_frame = (dec_len - 2 * pad_frames) * samples_per_frame;

// Create fade-in/fade-out windows for smooth transitions
std::vector<float> fade_in_window(sola_buffer_frame);
std::vector<float> fade_out_window(sola_buffer_frame);

Expand All @@ -297,46 +291,35 @@ class llm_task {
fade_out_window[i] = 1.0f - fade_in_window[i];
}

// Initialize SOLA buffer
std::vector<float> sola_buffer(sola_buffer_frame, 0.0f);
bool first_frame = true;

std::vector<float> pcmlist;

// Main decoding loop - process each slice
for (int i = 0; i < dec_slice_num; i++) {
// Calculate start position for current batch input
int input_start = i * effective_frames;
// Consider forward padding, but ensure non-negative
if (i > 0) {
input_start -= pad_frames;
}
input_start = std::max(0, input_start);

// Actual input length
int actual_len = std::min(dec_len, static_cast<int>(zp_shape[2] - input_start));

// Calculate effective output range (frame level)
int output_start_frame, output_end_frame;

if (i == 0) {
// First frame: skip padding at beginning
output_start_frame = 0;
output_end_frame = effective_frames - 1;
} else if (i == dec_slice_num - 1) {
// Last frame: calculate from current segment start
output_start_frame = i * effective_frames;
// Last frame extends to encoder's maximum output length
output_end_frame = static_cast<int>(zp_shape[2]) - 1;
output_end_frame = static_cast<int>(zp_shape[2]) - 1;
} else {
// Middle frames: standard calculation
output_start_frame = i * effective_frames;
output_end_frame = (i + 1) * effective_frames - 1;
}
// Prepare decoder input, initialize all to zero

std::vector<float> zp(zp_size, 0);

// Copy data to decoder input
for (int n = 0; n < zp_shape[1]; n++) {
int copy_size = std::min(actual_len, static_cast<int>(zp_shape[2] - input_start));
if (copy_size > 0) {
Expand All @@ -345,70 +328,50 @@ class llm_task {
}
}

// Run decoder
std::vector<float> decoder_output(audio_slice_len);
decoder_->SetInput(zp.data(), 0);
decoder_->SetInput(g_matrix.data(), 1);

if (0 != decoder_->Run()) {
SLOGI("Inference #%d: decoding failed", i + 1);
throw std::string("decoder_ RunSync error");
}

decoder_->GetOutput(decoder_output.data(), 0);

// === SOLA Processing Logic ===
if (first_frame) {
// Special handling for first frame - should not skip initial content
// First frame starts directly from decoder output without skipping
int audio_start = 0; // Start from beginning, don't skip pad_frames

// Calculate data length for first frame
// First frame should preserve complete decoder output, only reserving sola_buffer_frame at the end
// for next frame alignment
int audio_len = decoder_output.size() - sola_buffer_frame;

// Boundary check
audio_len = std::max(0, audio_len); // Ensure non-negative
int audio_start = 0;
int audio_len = decoder_output.size() - sola_buffer_frame;
audio_len = std::max(0, audio_len);

// Add first frame data
if (audio_len > 0) {
pcmlist.insert(pcmlist.end(), decoder_output.begin() + audio_start,
decoder_output.begin() + audio_start + audio_len);
}

// Save sola_buffer_frame length from the end to SOLA buffer for next frame alignment
int buffer_start = audio_len;

// Ensure sufficient data is available for copying
if (buffer_start + sola_buffer_frame <= decoder_output.size()) {
std::copy(decoder_output.begin() + buffer_start,
decoder_output.begin() + buffer_start + sola_buffer_frame, sola_buffer.begin());
} else {
// Possible case: first frame data is shorter than sola_buffer_frame
int available = static_cast<int>(decoder_output.size() - buffer_start);
if (available > 0) {
std::copy(decoder_output.begin() + buffer_start, decoder_output.end(), sola_buffer.begin());
// Fill with zeros
std::fill(sola_buffer.begin() + available, sola_buffer.end(), 0.0f);
} else {
// Completely insufficient data, fill all with zeros
std::fill(sola_buffer.begin(), sola_buffer.end(), 0.0f);
}
}

first_frame = false;

} else {
// Non-first frame: SOLA alignment required
int audio_start = pad_frames * samples_per_frame;

// 1. Prepare search window - beginning portion of current frame
std::vector<float> search_window(sola_buffer_frame + sola_search_frame);
std::copy(decoder_output.begin() + audio_start,
decoder_output.begin() + audio_start + search_window.size(), search_window.begin());

// 2. Find best alignment point (calculate cross-correlation)
int best_offset = 0;
float best_correlation = -1.0;

Expand All @@ -421,7 +384,6 @@ class llm_task {
energy += search_window[j + offset] * search_window[j + offset];
}

// Normalize correlation (avoid division by zero)
float normalized_correlation = (energy > 1e-8) ? correlation / std::sqrt(energy) : 0.0f;

if (normalized_correlation > best_correlation) {
Expand All @@ -430,30 +392,25 @@ class llm_task {
}
}

// 3. Apply alignment offset

int aligned_start = audio_start + best_offset;

// 4. Smooth transition processing (crossfade in alignment region)
std::vector<float> crossfade_region(sola_buffer_frame);

for (int j = 0; j < sola_buffer_frame; j++) {
// Apply fade-in/fade-out window functions
crossfade_region[j] =
decoder_output[aligned_start + j] * fade_in_window[j] + sola_buffer[j] * fade_out_window[j];
}

// 5. Add crossfade region to output
pcmlist.insert(pcmlist.end(), crossfade_region.begin(), crossfade_region.end());

int remaining_start = aligned_start + sola_buffer_frame;

if (i == dec_slice_num - 1) {
int total_expected_samples = audio_len * samples_per_frame / 512;

int processed_samples = static_cast<int>(pcmlist.size());

int remaining_needed = total_expected_samples - processed_samples;
remaining_needed = std::max(0, remaining_needed);
int processed_samples = static_cast<int>(pcmlist.size());
int remaining_needed = total_expected_samples - processed_samples;
remaining_needed = std::max(0, remaining_needed);

int remaining_len =
std::min(remaining_needed, static_cast<int>(decoder_output.size() - remaining_start));
Expand All @@ -465,7 +422,6 @@ class llm_task {

} else {
int remaining_len = (dec_len - 2 * pad_frames) * samples_per_frame - sola_buffer_frame;

remaining_len =
std::min(remaining_len, static_cast<int>(decoder_output.size() - remaining_start));

Expand Down Expand Up @@ -495,31 +451,28 @@ class llm_task {
pcmlist.resize(audio_len);
}

// Post-processing: resample and convert to int16

double src_ratio =
static_cast<double>(mode_config_.audio_rate) / static_cast<double>(mode_config_.mode_rate);
std::vector<float> tmp_pcm((pcmlist.size() * src_ratio + 1));
int len;

resample_audio(pcmlist.data(), pcmlist.size(), tmp_pcm.data(), &len, src_ratio);

// Convert to 16-bit PCM

wav_pcm_data.reserve(len);
std::transform(tmp_pcm.begin(), tmp_pcm.begin() + len, std::back_inserter(wav_pcm_data),
[](const auto val) { return static_cast<int16_t>(val * INT16_MAX); });

// Call the output callback function with the result
if (out_callback_) {
out_callback_(
std::string(reinterpret_cast<char *>(wav_pcm_data.data()), wav_pcm_data.size() * sizeof(int16_t)),
finish);
}

} catch (const std::exception &e) {
SLOGI("TTS processing exception: %s", e.what());
return true;
} catch (...) {
SLOGI("TTS processing encountered an unknown exception");
return true;
}
return false;
Expand Down Expand Up @@ -932,4 +885,4 @@ int main(int argc, char *argv[])
}
llm.llm_firework_exit();
return 0;
}
}