Skip to content

Commit b1df925

Browse files
committed
Fix SOLA detail issue causing first frame problems
Resolved an issue in the SOLA (Synchronized Overlap-Add) implementation where specific details were causing problems with the first frame of audio processing.
1 parent a151aff commit b1df925

2 files changed

Lines changed: 228 additions & 22 deletions

File tree

projects/llm_framework/main_melotts/src/main.cpp

Lines changed: 227 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -253,14 +253,19 @@ class llm_task {
253253
}
254254
return false;
255255
}
256+
SLOGI("开始处理文本: %s", msg_str.c_str());
257+
258+
// 文本转音素处理部分保持不变
256259
std::vector<int> phones_bef, tones_bef;
257260
lexicon_->convert(msg_str, phones_bef, tones_bef);
258-
// Add blank between words
259-
auto phones = intersperse(phones_bef, 0);
260-
auto tones = intersperse(tones_bef, 0);
261-
int phone_len = phones.size();
262-
int MELOTTS_LANG_IDS = MELOTTS_LANG_IDS_MAP[mode_config_.mode];
263-
std::vector<int> langids(phone_len, MELOTTS_LANG_IDS);
261+
auto phones = intersperse(phones_bef, 0);
262+
auto tones = intersperse(tones_bef, 0);
263+
int phone_len = phones.size();
264+
std::vector<int> langids(phone_len, 3);
265+
266+
SLOGI("音素转换完成,长度: %d", phone_len);
267+
268+
// 运行encoder获取latent representation
264269
auto encoder_output =
265270
encoder_->Run(phones, tones, langids, g_matrix, mode_config_.noise_scale, mode_config_.noise_scale_w,
266271
mode_config_.get_length_scale(), mode_config_.sdp_ratio);
@@ -269,66 +274,267 @@ class llm_task {
269274
auto zp_info = encoder_output.at(0).GetTensorTypeAndShapeInfo();
270275
auto zp_shape = zp_info.GetShape();
271276

272-
// Decoder parameters setup
273-
int zp_size = decoder_->GetInputSize(0) / sizeof(float);
274-
int dec_len = zp_size / zp_shape[1];
275-
int audio_slice_len = decoder_->GetOutputSize(0) / sizeof(float);
276-
const int pad_frames = 16;
277+
SLOGI("Encoder输出完成,形状: [%ld, %ld, %ld],预期音频长度: %d", zp_shape[0], zp_shape[1], zp_shape[2],
278+
audio_len);
279+
280+
// 解码器参数设置
281+
int zp_size = decoder_->GetInputSize(0) / sizeof(float);
282+
int dec_len = zp_size / zp_shape[1];
283+
int audio_slice_len = decoder_->GetOutputSize(0) / sizeof(float);
284+
285+
// 定义pad长度(每侧填充帧数)
286+
const int pad_frames = 16;
287+
// 每个音频帧的采样点数量
277288
const int samples_per_frame = 512;
278-
const int effective_frames = dec_len - 2 * pad_frames;
289+
290+
SLOGI("解码器配置:帧长度=%d, 音频切片长度=%d, pad长度=%d, 每帧采样点=%d", dec_len, audio_slice_len,
291+
pad_frames, samples_per_frame);
292+
293+
// 每次有效处理的帧数量
294+
const int effective_frames = dec_len - 2 * pad_frames;
295+
296+
// 计算需要的解码次数 - 确保所有类型一致
279297
int dec_slice_num =
280298
static_cast<int>(std::ceil(static_cast<double>(zp_shape[2]) / static_cast<double>(effective_frames)));
281-
SolaProcessor sola(pad_frames, samples_per_frame);
299+
300+
SLOGI("将进行 %d 次推理,每次有效帧数: %d", dec_slice_num, effective_frames);
301+
302+
// === SOLA算法参数设置 ===
303+
const int sola_buffer_frame = pad_frames * samples_per_frame; // 重叠缓冲区长度
304+
const int sola_search_frame = pad_frames * samples_per_frame; // 搜索窗口长度
305+
const int block_frame = (dec_len - 2 * pad_frames) * samples_per_frame; // 有效块长度
306+
307+
// 创建淡入淡出窗口
308+
std::vector<float> fade_in_window(sola_buffer_frame);
309+
std::vector<float> fade_out_window(sola_buffer_frame);
310+
311+
for (int i = 0; i < sola_buffer_frame; i++) {
312+
fade_in_window[i] = static_cast<float>(i) / sola_buffer_frame;
313+
fade_out_window[i] = 1.0f - fade_in_window[i];
314+
}
315+
316+
// 初始化SOLA缓冲区
317+
std::vector<float> sola_buffer(sola_buffer_frame, 0.0f);
318+
bool first_frame = true;
319+
282320
std::vector<float> pcmlist;
283321

284322
for (int i = 0; i < dec_slice_num; i++) {
323+
// 计算当前批次的输入起始位置
285324
int input_start = i * effective_frames;
325+
// 考虑前向pad,但确保不为负
286326
if (i > 0) {
287327
input_start -= pad_frames;
288328
}
289-
input_start = std::max(0, input_start);
329+
input_start = std::max(0, input_start);
330+
331+
// 实际输入长度
290332
int actual_len = std::min(dec_len, static_cast<int>(zp_shape[2] - input_start));
333+
334+
// 计算输出的有效范围(帧级别)
335+
int output_start_frame, output_end_frame;
336+
337+
if (i == 0) {
338+
// 第一帧:跳过前面的pad部分
339+
output_start_frame = 0;
340+
output_end_frame = effective_frames - 1;
341+
} else if (i == dec_slice_num - 1) {
342+
// 最后一帧:从当前段起始计算
343+
output_start_frame = i * effective_frames;
344+
// 最后到编码器输出的最大长度
345+
output_end_frame = static_cast<int>(zp_shape[2]) - 1;
346+
} else {
347+
// 中间帧:标准计算
348+
output_start_frame = i * effective_frames;
349+
output_end_frame = (i + 1) * effective_frames - 1;
350+
}
351+
352+
SLOGI("第 %d 次推理: 输入帧范围=[%d-%d],实际长度=%d,输出帧范围=[%d-%d]", i + 1, input_start,
353+
input_start + actual_len - 1, actual_len, output_start_frame, output_end_frame);
354+
355+
// 准备decoder输入,全部初始化为0
291356
std::vector<float> zp(zp_size, 0);
292357

358+
// 复制数据到decoder输入
293359
for (int n = 0; n < zp_shape[1]; n++) {
294360
int copy_size = std::min(actual_len, static_cast<int>(zp_shape[2] - input_start));
295361
if (copy_size > 0) {
296362
memcpy(zp.data() + n * dec_len, zp_data + n * zp_shape[2] + input_start,
297363
sizeof(float) * copy_size);
298364
}
299365
}
300-
// Run decoder
366+
367+
// 运行decoder
301368
std::vector<float> decoder_output(audio_slice_len);
302369
decoder_->SetInput(zp.data(), 0);
303370
decoder_->SetInput(g_matrix.data(), 1);
371+
372+
SLOGI("第 %d 次推理:开始解码...", i + 1);
373+
304374
if (0 != decoder_->Run()) {
375+
SLOGI("第 %d 次推理:解码失败", i + 1);
305376
throw std::string("decoder_ RunSync error");
306377
}
378+
307379
decoder_->GetOutput(decoder_output.data(), 0);
308-
std::vector<float> processed_output = sola.ProcessFrame(decoder_output, i, dec_slice_num, actual_len);
309380

310-
pcmlist.insert(pcmlist.end(), processed_output.begin(), processed_output.end());
381+
// === SOLA处理流程 ===
382+
if (first_frame) {
383+
// 首帧特殊处理 - 不应跳过前面的内容
384+
// 首帧直接从解码器输出开始,不跳过任何内容
385+
int audio_start = 0; // 从头开始,不跳过pad_frames
386+
387+
// 计算首帧应该添加的数据长度
388+
// 首帧应该保留完整解码输出,只留出末尾的sola_buffer_frame用于下一帧衔接
389+
int audio_len = decoder_output.size() - sola_buffer_frame;
390+
391+
// 边界检查
392+
audio_len = std::max(0, audio_len); // 确保不为负
393+
394+
// 添加首帧数据
395+
if (audio_len > 0) {
396+
pcmlist.insert(pcmlist.end(), decoder_output.begin() + audio_start,
397+
decoder_output.begin() + audio_start + audio_len);
398+
}
399+
400+
// 保存末尾的sola_buffer_frame长度到SOLA缓冲区,用于下一帧对齐
401+
int buffer_start = audio_len;
402+
403+
// 确保有足够数据可供复制
404+
if (buffer_start + sola_buffer_frame <= decoder_output.size()) {
405+
std::copy(decoder_output.begin() + buffer_start,
406+
decoder_output.begin() + buffer_start + sola_buffer_frame, sola_buffer.begin());
407+
} else {
408+
// 可能的情况:首帧数据总长度不足sola_buffer_frame
409+
int available = static_cast<int>(decoder_output.size() - buffer_start);
410+
if (available > 0) {
411+
std::copy(decoder_output.begin() + buffer_start, decoder_output.end(), sola_buffer.begin());
412+
// 填充零
413+
std::fill(sola_buffer.begin() + available, sola_buffer.end(), 0.0f);
414+
} else {
415+
// 完全没有足够数据,全部填零
416+
std::fill(sola_buffer.begin(), sola_buffer.end(), 0.0f);
417+
}
418+
}
419+
420+
first_frame = false;
421+
422+
SLOGI("第 %d 次推理: 首帧处理,从位置%d开始添加%d采样点到输出,保存%d样本到SOLA缓冲区", i + 1,
423+
audio_start, audio_len, sola_buffer_frame);
424+
} else {
425+
// 非首帧:需要执行SOLA对齐
426+
int audio_start = pad_frames * samples_per_frame;
427+
428+
// 1. 准备搜索窗口 - 当前帧的开头部分
429+
std::vector<float> search_window(sola_buffer_frame + sola_search_frame);
430+
std::copy(decoder_output.begin() + audio_start,
431+
decoder_output.begin() + audio_start + search_window.size(), search_window.begin());
432+
433+
// 2. 寻找最佳对齐点(计算互相关)
434+
int best_offset = 0;
435+
float best_correlation = -1.0;
436+
437+
for (int offset = 0; offset <= sola_search_frame; offset++) {
438+
float correlation = 0.0;
439+
float energy = 0.0;
440+
441+
for (int j = 0; j < sola_buffer_frame; j++) {
442+
correlation += sola_buffer[j] * search_window[j + offset];
443+
energy += search_window[j + offset] * search_window[j + offset];
444+
}
445+
446+
// 归一化相关性(避免除零)
447+
float normalized_correlation = (energy > 1e-8) ? correlation / std::sqrt(energy) : 0.0f;
448+
449+
if (normalized_correlation > best_correlation) {
450+
best_correlation = normalized_correlation;
451+
best_offset = offset;
452+
}
453+
}
454+
455+
SLOGI("第 %d 次推理: SOLA找到最佳对齐偏移量 %d,相关系数 %f", i + 1, best_offset, best_correlation);
456+
457+
// 3. 应用对齐偏移
458+
int aligned_start = audio_start + best_offset;
459+
460+
// 4. 平滑过渡处理(对齐区域的crossfade)
461+
std::vector<float> crossfade_region(sola_buffer_frame);
462+
463+
for (int j = 0; j < sola_buffer_frame; j++) {
464+
// 应用淡入淡出窗口函数
465+
crossfade_region[j] =
466+
decoder_output[aligned_start + j] * fade_in_window[j] + sola_buffer[j] * fade_out_window[j];
467+
}
468+
469+
// 5. 添加crossfade区域到输出
470+
pcmlist.insert(pcmlist.end(), crossfade_region.begin(), crossfade_region.end());
471+
472+
// 6. 添加剩余有效音频数据
473+
int remaining_start = aligned_start + sola_buffer_frame;
474+
int remaining_len = (i == dec_slice_num - 1)
475+
? (actual_len - 2 * pad_frames) * samples_per_frame - sola_buffer_frame
476+
: (dec_len - 2 * pad_frames) * samples_per_frame - sola_buffer_frame;
477+
478+
// 边界检查
479+
remaining_len = std::min(remaining_len, static_cast<int>(decoder_output.size() - remaining_start));
480+
481+
if (remaining_len > 0) {
482+
pcmlist.insert(pcmlist.end(), decoder_output.begin() + remaining_start,
483+
decoder_output.begin() + remaining_start + remaining_len);
484+
}
485+
486+
// 7. 更新SOLA缓冲区,为下一帧准备
487+
int buffer_start = remaining_start + remaining_len;
488+
489+
// 检查是否还有足够的数据用于下一个缓冲区
490+
if (buffer_start + sola_buffer_frame <= decoder_output.size()) {
491+
std::copy(decoder_output.begin() + buffer_start,
492+
decoder_output.begin() + buffer_start + sola_buffer_frame, sola_buffer.begin());
493+
} else {
494+
// 如果不足,就用零填充
495+
int avail = static_cast<int>(decoder_output.size() - buffer_start);
496+
if (avail > 0) {
497+
std::copy(decoder_output.begin() + buffer_start, decoder_output.end(), sola_buffer.begin());
498+
}
499+
std::fill(sola_buffer.begin() + avail, sola_buffer.end(), 0.0f);
500+
}
501+
502+
SLOGI("第 %d 次推理: 添加 %d + %d 采样点到输出,累计长度: %zu", i + 1, sola_buffer_frame,
503+
remaining_len, pcmlist.size());
504+
}
311505
}
312506

507+
SLOGI("所有推理完成,生成PCM长度: %zu", pcmlist.size());
508+
509+
// 后续处理:重采样和转换为int16
313510
double src_ratio = (mode_config_.audio_rate * 1.0f) / (mode_config_.mode_rate * 1.0f);
314511
std::vector<float> tmp_pcm((pcmlist.size() * src_ratio + 1));
315512
int len;
513+
514+
SLOGI("开始音频重采样,源采样率: %f,目标采样率: %f,比率: %f", mode_config_.mode_rate * 1.0f,
515+
mode_config_.audio_rate * 1.0f, src_ratio);
516+
316517
resample_audio(pcmlist.data(), pcmlist.size(), tmp_pcm.data(), &len, src_ratio);
317518

318-
// Convert to 16-bit PCM
519+
SLOGI("重采样完成,重采样后长度: %d", len);
520+
521+
// 转换为16位PCM
319522
wav_pcm_data.reserve(len);
320523
std::transform(tmp_pcm.begin(), tmp_pcm.begin() + len, std::back_inserter(wav_pcm_data),
321524
[](const auto val) { return (int16_t)(val * INT16_MAX); });
322525

323-
// Call callback function with output
526+
SLOGI("最终生成音频长度: %zu 个采样点", wav_pcm_data.size());
527+
528+
// 调用回调函数输出结果
324529
if (out_callback_)
325530
out_callback_(std::string((char *)wav_pcm_data.data(), wav_pcm_data.size() * sizeof(int16_t)), finish);
326531

532+
SLOGI("TTS处理完成,输出回调已调用");
327533
} catch (const std::exception &e) {
328-
SLOGI("TTS processing exception: %s", e.what());
534+
SLOGI("TTS处理异常: %s", e.what());
329535
return true;
330536
} catch (...) {
331-
SLOGI("TTS processing encountered unknown exception");
537+
SLOGI("TTS处理发生未知异常");
332538
return true;
333539
}
334540
return false;

projects/llm_framework/main_melotts/src/runner/Lexicon.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <iostream>
1010
#include "../../../../../SDK/components/utilities/include/sample_log.h"
1111
// Debug logging switch - set to true to enable debug logs
12-
static bool DEBUG_LOGGING = false;
12+
static bool DEBUG_LOGGING = true;
1313
// Macro for debug logging
1414
#define DEBUG_LOG(fmt, ...) \
1515
do { \

0 commit comments

Comments
 (0)