Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Refactor SOLA component code
Streamline and simplify code in the SOLA module for improved readability and maintenance
  • Loading branch information
yuyun2000 committed May 15, 2025
commit 10e4bdf828912595765ef4031b95435073054ea7
119 changes: 12 additions & 107 deletions projects/llm_framework/main_melotts/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ class llm_task {
try {
std::vector<int16_t> wav_pcm_data;
if (msg_str.empty()) {
SLOGI("empty");
if (out_callback_) {
std::string output = wav_pcm_data.empty() ? std::string()
: std::string((char *)wav_pcm_data.data(),
Expand All @@ -252,19 +251,14 @@ class llm_task {
}
return false;
}
SLOGI("Processing text: %s", msg_str.c_str());

// Convert text to phonemes and tones
std::vector<int> phones_bef, tones_bef;
lexicon_->convert(msg_str, phones_bef, tones_bef);
auto phones = intersperse(phones_bef, 0);
auto tones = intersperse(tones_bef, 0);
int phone_len = phones.size();
std::vector<int> langids(phone_len, 3);

SLOGI("Phoneme conversion completed, length: %d", phone_len);

// Run the encoder to generate hidden representations
auto encoder_output =
encoder_->Run(phones, tones, langids, g_matrix, mode_config_.noise_scale, mode_config_.noise_scale_w,
mode_config_.get_length_scale(), mode_config_.sdp_ratio);
Expand All @@ -273,33 +267,22 @@ class llm_task {
auto zp_info = encoder_output.at(0).GetTensorTypeAndShapeInfo();
auto zp_shape = zp_info.GetShape();

SLOGI("Encoder output completed, shape: [%ld, %ld, %ld], expected audio length: %d", zp_shape[0],
zp_shape[1], zp_shape[2], audio_len);

// Calculate decoder parameters
int zp_size = decoder_->GetInputSize(0) / sizeof(float);
int dec_len = zp_size / zp_shape[1];
int audio_slice_len = decoder_->GetOutputSize(0) / sizeof(float);

const int pad_frames = 16;
const int pad_frames = 24;
const int samples_per_frame = 512;

SLOGI("Decoder configuration: frame length=%d, audio slice length=%d, pad length=%d, samples per frame=%d",
dec_len, audio_slice_len, pad_frames, samples_per_frame);

const int effective_frames = dec_len - 2 * pad_frames;

int dec_slice_num =
static_cast<int>(std::ceil(static_cast<double>(zp_shape[2]) / static_cast<double>(effective_frames)));

SLOGI("Will perform %d inferences, each with effective frames: %d", dec_slice_num, effective_frames);
const int sola_buffer_frame = pad_frames * samples_per_frame;
const int sola_search_frame = pad_frames * samples_per_frame;
const int block_frame = (dec_len - 2 * pad_frames) * samples_per_frame;

// SOLA parameters setup
const int sola_buffer_frame = pad_frames * samples_per_frame; // Overlap buffer length
const int sola_search_frame = pad_frames * samples_per_frame; // Search window length
const int block_frame = (dec_len - 2 * pad_frames) * samples_per_frame; // Effective block length

// Create fade-in/fade-out windows for smooth transitions
std::vector<float> fade_in_window(sola_buffer_frame);
std::vector<float> fade_out_window(sola_buffer_frame);

Expand All @@ -308,50 +291,35 @@ class llm_task {
fade_out_window[i] = 1.0f - fade_in_window[i];
}

// Initialize SOLA buffer
std::vector<float> sola_buffer(sola_buffer_frame, 0.0f);
bool first_frame = true;

std::vector<float> pcmlist;

// Main decoding loop - process each slice
for (int i = 0; i < dec_slice_num; i++) {
// Calculate start position for current batch input
int input_start = i * effective_frames;
// Consider forward padding, but ensure non-negative
if (i > 0) {
input_start -= pad_frames;
}
input_start = std::max(0, input_start);

// Actual input length
int actual_len = std::min(dec_len, static_cast<int>(zp_shape[2] - input_start));

// Calculate effective output range (frame level)
int output_start_frame, output_end_frame;

if (i == 0) {
// First frame: skip padding at beginning
output_start_frame = 0;
output_end_frame = effective_frames - 1;
} else if (i == dec_slice_num - 1) {
// Last frame: calculate from current segment start
output_start_frame = i * effective_frames;
// Last frame extends to encoder's maximum output length
output_end_frame = static_cast<int>(zp_shape[2]) - 1;
output_end_frame = static_cast<int>(zp_shape[2]) - 1;
} else {
// Middle frames: standard calculation
output_start_frame = i * effective_frames;
output_end_frame = (i + 1) * effective_frames - 1;
}

SLOGI("Inference #%d: input frame range=[%d-%d], actual length=%d, output frame range=[%d-%d]", i + 1,
input_start, input_start + actual_len - 1, actual_len, output_start_frame, output_end_frame);

// Prepare decoder input, initialize all to zero
std::vector<float> zp(zp_size, 0);

// Copy data to decoder input
for (int n = 0; n < zp_shape[1]; n++) {
int copy_size = std::min(actual_len, static_cast<int>(zp_shape[2] - input_start));
if (copy_size > 0) {
Expand All @@ -360,76 +328,49 @@ class llm_task {
}
}

// Run decoder
std::vector<float> decoder_output(audio_slice_len);
decoder_->SetInput(zp.data(), 0);
decoder_->SetInput(g_matrix.data(), 1);

SLOGI("Inference #%d: starting decoding...", i + 1);

if (0 != decoder_->Run()) {
SLOGI("Inference #%d: decoding failed", i + 1);
throw std::string("decoder_ RunSync error");
}

decoder_->GetOutput(decoder_output.data(), 0);

// === SOLA Processing Logic ===
if (first_frame) {
// Special handling for first frame - should not skip initial content
// First frame starts directly from decoder output without skipping
int audio_start = 0; // Start from beginning, don't skip pad_frames
int audio_start = 0;
int audio_len = decoder_output.size() - sola_buffer_frame;
audio_len = std::max(0, audio_len);

// Calculate data length for first frame
// First frame should preserve complete decoder output, only reserving sola_buffer_frame at the end
// for next frame alignment
int audio_len = decoder_output.size() - sola_buffer_frame;

// Boundary check
audio_len = std::max(0, audio_len); // Ensure non-negative

// Add first frame data
if (audio_len > 0) {
pcmlist.insert(pcmlist.end(), decoder_output.begin() + audio_start,
decoder_output.begin() + audio_start + audio_len);
}

// Save sola_buffer_frame length from the end to SOLA buffer for next frame alignment
int buffer_start = audio_len;

// Ensure sufficient data is available for copying
if (buffer_start + sola_buffer_frame <= decoder_output.size()) {
std::copy(decoder_output.begin() + buffer_start,
decoder_output.begin() + buffer_start + sola_buffer_frame, sola_buffer.begin());
} else {
// Possible case: first frame data is shorter than sola_buffer_frame
int available = static_cast<int>(decoder_output.size() - buffer_start);
if (available > 0) {
std::copy(decoder_output.begin() + buffer_start, decoder_output.end(), sola_buffer.begin());
// Fill with zeros
std::fill(sola_buffer.begin() + available, sola_buffer.end(), 0.0f);
} else {
// Completely insufficient data, fill all with zeros
std::fill(sola_buffer.begin(), sola_buffer.end(), 0.0f);
}
}

first_frame = false;

SLOGI(
"Inference #%d: First frame processing, added %d samples from position %d to output, saved %d "
"samples to SOLA buffer",
i + 1, audio_len, audio_start, sola_buffer_frame);
} else {
// Non-first frame: SOLA alignment required
int audio_start = pad_frames * samples_per_frame;

// 1. Prepare search window - beginning portion of current frame
std::vector<float> search_window(sola_buffer_frame + sola_search_frame);
std::copy(decoder_output.begin() + audio_start,
decoder_output.begin() + audio_start + search_window.size(), search_window.begin());

// 2. Find best alignment point (calculate cross-correlation)
int best_offset = 0;
float best_correlation = -1.0;

Expand All @@ -442,7 +383,6 @@ class llm_task {
energy += search_window[j + offset] * search_window[j + offset];
}

// Normalize correlation (avoid division by zero)
float normalized_correlation = (energy > 1e-8) ? correlation / std::sqrt(energy) : 0.0f;

if (normalized_correlation > best_correlation) {
Expand All @@ -451,48 +391,35 @@ class llm_task {
}
}

SLOGI("Inference #%d: SOLA found best alignment offset %d with correlation coefficient %f", i + 1,
best_offset, best_correlation);

// 3. Apply alignment offset
int aligned_start = audio_start + best_offset;

// 4. Smooth transition processing (crossfade in alignment region)
std::vector<float> crossfade_region(sola_buffer_frame);

for (int j = 0; j < sola_buffer_frame; j++) {
// Apply fade-in/fade-out window functions
crossfade_region[j] =
decoder_output[aligned_start + j] * fade_in_window[j] + sola_buffer[j] * fade_out_window[j];
}

// 5. Add crossfade region to output
pcmlist.insert(pcmlist.end(), crossfade_region.begin(), crossfade_region.end());

int remaining_start = aligned_start + sola_buffer_frame;

if (i == dec_slice_num - 1) {
int total_expected_samples = audio_len * samples_per_frame / 512;

int processed_samples = static_cast<int>(pcmlist.size());

int remaining_needed = total_expected_samples - processed_samples;
remaining_needed = std::max(0, remaining_needed);
int processed_samples = static_cast<int>(pcmlist.size());
int remaining_needed = total_expected_samples - processed_samples;
remaining_needed = std::max(0, remaining_needed);

int remaining_len =
std::min(remaining_needed, static_cast<int>(decoder_output.size() - remaining_start));

SLOGI("Inference #%d (final): Expected total=%d, processed=%d, needed=%d, available=%d", i + 1,
total_expected_samples, processed_samples, remaining_needed, remaining_len);

if (remaining_len > 0) {
pcmlist.insert(pcmlist.end(), decoder_output.begin() + remaining_start,
decoder_output.begin() + remaining_start + remaining_len);
}

} else {
int remaining_len = (dec_len - 2 * pad_frames) * samples_per_frame - sola_buffer_frame;

remaining_len =
std::min(remaining_len, static_cast<int>(decoder_output.size() - remaining_start));

Expand All @@ -514,55 +441,33 @@ class llm_task {
}
std::fill(sola_buffer.begin() + avail, sola_buffer.end(), 0.0f);
}

SLOGI("Inference #%d: Added %d + %d samples to output, cumulative length: %zu", i + 1,
sola_buffer_frame, remaining_len, pcmlist.size());
}
}
}

SLOGI("All inference completed, raw generated PCM length: %zu", pcmlist.size());

if (pcmlist.size() > audio_len) {
SLOGI("Truncating output from %zu to %d samples as per encoder prediction", pcmlist.size(), audio_len);
pcmlist.resize(audio_len);
}

SLOGI("Final PCM length after truncation: %zu", pcmlist.size());

// Post-processing: resample and convert to int16
double src_ratio =
static_cast<double>(mode_config_.audio_rate) / static_cast<double>(mode_config_.mode_rate);
std::vector<float> tmp_pcm((pcmlist.size() * src_ratio + 1));
int len;

SLOGI("Starting audio resampling, source rate: %f, target rate: %f, ratio: %f",
static_cast<float>(mode_config_.mode_rate), static_cast<float>(mode_config_.audio_rate), src_ratio);

resample_audio(pcmlist.data(), pcmlist.size(), tmp_pcm.data(), &len, src_ratio);

SLOGI("Resampling completed, length after resampling: %d", len);

// Convert to 16-bit PCM
wav_pcm_data.reserve(len);
std::transform(tmp_pcm.begin(), tmp_pcm.begin() + len, std::back_inserter(wav_pcm_data),
[](const auto val) { return static_cast<int16_t>(val * INT16_MAX); });

SLOGI("Final audio length: %zu samples", wav_pcm_data.size());

// Call the output callback function with the result
if (out_callback_) {
out_callback_(
std::string(reinterpret_cast<char *>(wav_pcm_data.data()), wav_pcm_data.size() * sizeof(int16_t)),
finish);
}

SLOGI("TTS processing completed, output callback invoked");
} catch (const std::exception &e) {
SLOGI("TTS processing exception: %s", e.what());
return true;
} catch (...) {
SLOGI("TTS processing encountered an unknown exception");
return true;
}
return false;
Expand Down Expand Up @@ -975,4 +880,4 @@ int main(int argc, char *argv[])
}
llm.llm_firework_exit();
return 0;
}
}