Skip to content

Commit e3c70bc

Browse files
committed
Implement Sola algorithm for smoother audio transitions
Apply the Synchronized Overlap-Add (SOLA) algorithm to smooth the connection between audio segments output by the decoder, resulting in more natural-sounding transitions between segments.
1 parent 5782f89 commit e3c70bc

3 files changed

Lines changed: 315 additions & 22 deletions

File tree

projects/llm_framework/main_melotts/src/main.cpp

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "Lexicon.hpp"
1010
#include <ax_sys_api.h>
1111
#include "AudioFile.h"
12+
#include "SolaProcessor.h"
1213
#include "Lexicon.hpp"
1314

1415
#include <signal.h>
@@ -263,49 +264,71 @@ class llm_task {
263264
auto encoder_output =
264265
encoder_->Run(phones, tones, langids, g_matrix, mode_config_.noise_scale, mode_config_.noise_scale_w,
265266
mode_config_.get_length_scale(), mode_config_.sdp_ratio);
266-
float *zp_data = encoder_output.at(0).GetTensorMutableData<float>();
267-
int audio_len = encoder_output.at(2).GetTensorMutableData<int>()[0];
268-
auto zp_info = encoder_output.at(0).GetTensorTypeAndShapeInfo();
269-
auto zp_shape = zp_info.GetShape();
270-
int zp_size = decoder_->GetInputSize(0) / sizeof(float);
271-
int dec_len = zp_size / zp_shape[1];
272-
int audio_slice_len = decoder_->GetOutputSize(0) / sizeof(float);
273-
std::vector<float> decoder_output(audio_slice_len);
274-
int dec_slice_num = int(std::ceil(zp_shape[2] * 1.0 / dec_len));
267+
float *zp_data = encoder_output.at(0).GetTensorMutableData<float>();
268+
int audio_len = encoder_output.at(2).GetTensorMutableData<int>()[0];
269+
auto zp_info = encoder_output.at(0).GetTensorTypeAndShapeInfo();
270+
auto zp_shape = zp_info.GetShape();
271+
272+
// Decoder parameters setup
273+
int zp_size = decoder_->GetInputSize(0) / sizeof(float);
274+
int dec_len = zp_size / zp_shape[1];
275+
int audio_slice_len = decoder_->GetOutputSize(0) / sizeof(float);
276+
const int pad_frames = 16;
277+
const int samples_per_frame = 512;
278+
const int effective_frames = dec_len - 2 * pad_frames;
279+
int dec_slice_num =
280+
static_cast<int>(std::ceil(static_cast<double>(zp_shape[2]) / static_cast<double>(effective_frames)));
281+
SolaProcessor sola(pad_frames, samples_per_frame);
275282
std::vector<float> pcmlist;
283+
276284
for (int i = 0; i < dec_slice_num; i++) {
285+
int input_start = i * effective_frames;
286+
if (i > 0) {
287+
input_start -= pad_frames;
288+
}
289+
input_start = std::max(0, input_start);
290+
int actual_len = std::min(dec_len, static_cast<int>(zp_shape[2] - input_start));
277291
std::vector<float> zp(zp_size, 0);
278-
int actual_size = (i + 1) * dec_len < zp_shape[2] ? dec_len : zp_shape[2] - i * dec_len;
292+
279293
for (int n = 0; n < zp_shape[1]; n++) {
280-
memcpy(zp.data() + n * dec_len, zp_data + n * zp_shape[2] + i * dec_len,
281-
sizeof(float) * actual_size);
294+
int copy_size = std::min(actual_len, static_cast<int>(zp_shape[2] - input_start));
295+
if (copy_size > 0) {
296+
memcpy(zp.data() + n * dec_len, zp_data + n * zp_shape[2] + input_start,
297+
sizeof(float) * copy_size);
298+
}
282299
}
300+
// Run decoder
301+
std::vector<float> decoder_output(audio_slice_len);
283302
decoder_->SetInput(zp.data(), 0);
284303
decoder_->SetInput(g_matrix.data(), 1);
285304
if (0 != decoder_->Run()) {
286-
printf("Run decoder model failed!\n");
287305
throw std::string("decoder_ RunSync error");
288306
}
289307
decoder_->GetOutput(decoder_output.data(), 0);
290-
actual_size = (i + 1) * audio_slice_len < audio_len ? audio_slice_len : audio_len - i * audio_slice_len;
291-
if (decoder_output.size() > actual_size) {
292-
pcmlist.reserve(pcmlist.size() + actual_size);
293-
std::copy(decoder_output.begin(), decoder_output.begin() + actual_size,
294-
std::back_inserter(pcmlist));
295-
} else {
296-
pcmlist.reserve(pcmlist.size() + decoder_output.size());
297-
std::copy(decoder_output.begin(), decoder_output.end(), std::back_inserter(pcmlist));
298-
}
308+
std::vector<float> processed_output = sola.ProcessFrame(decoder_output, i, dec_slice_num, actual_len);
309+
310+
pcmlist.insert(pcmlist.end(), processed_output.begin(), processed_output.end());
299311
}
312+
300313
double src_ratio = (mode_config_.audio_rate * 1.0f) / (mode_config_.mode_rate * 1.0f);
301314
std::vector<float> tmp_pcm((pcmlist.size() * src_ratio + 1));
302315
int len;
303316
resample_audio(pcmlist.data(), pcmlist.size(), tmp_pcm.data(), &len, src_ratio);
317+
318+
// Convert to 16-bit PCM
319+
wav_pcm_data.reserve(len);
304320
std::transform(tmp_pcm.begin(), tmp_pcm.begin() + len, std::back_inserter(wav_pcm_data),
305321
[](const auto val) { return (int16_t)(val * INT16_MAX); });
322+
323+
// Call callback function with output
306324
if (out_callback_)
307325
out_callback_(std::string((char *)wav_pcm_data.data(), wav_pcm_data.size() * sizeof(int16_t)), finish);
326+
327+
} catch (const std::exception &e) {
328+
SLOGI("TTS processing exception: %s", e.what());
329+
return true;
308330
} catch (...) {
331+
SLOGI("TTS processing encountered unknown exception");
309332
return true;
310333
}
311334
return false;

projects/llm_framework/main_melotts/src/runner/Lexicon.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class Lexicon {
3232
public:
3333
Lexicon(const std::string& lexicon_filename, const std::string& tokens_filename) : max_phrase_length(0)
3434
{
35+
SLOGI("词典加载: %zu 发音表加载: %zu", tokens_filename, lexicon_filename);
3536
std::unordered_map<std::string, int> tokens;
3637
std::ifstream ifs(tokens_filename);
3738
assert(ifs.is_open());
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
#ifndef SOLA_PROCESSOR_H
2+
#define SOLA_PROCESSOR_H
3+
4+
#include <algorithm>
5+
#include <cmath>
6+
#include <functional>
7+
#include <string>
8+
#include <vector>
9+
10+
/**
11+
* SolaProcessor - Synchronous Overlap-Add method for audio frame processing
12+
*
13+
* This class provides functionality for smoothly concatenating audio frames
14+
* using the SOLA algorithm, which finds optimal alignment points between
15+
* consecutive frames and applies crossfading for smooth transitions.
16+
*/
17+
class SolaProcessor {
18+
public:
19+
/**
20+
* Constructor
21+
*
22+
* @param padFrames Number of padding frames at the beginning and end
23+
* @param samplesPerFrame Number of audio samples in each frame
24+
*/
25+
SolaProcessor(int padFrames, int samplesPerFrame)
26+
: pad_frames_(padFrames), samples_per_frame_(samplesPerFrame), first_frame_(true)
27+
{
28+
Initialize();
29+
}
30+
31+
/**
32+
* Reset the processor to its initial state
33+
*/
34+
void Reset()
35+
{
36+
first_frame_ = true;
37+
std::fill(sola_buffer_.begin(), sola_buffer_.end(), 0.0f);
38+
}
39+
40+
/**
41+
* Process a single audio frame
42+
*
43+
* @param decoder_output Raw audio data from decoder
44+
* @param frameIndex Current frame index
45+
* @param totalFrames Total number of frames
46+
* @param actualFrameLen Actual length of the frame
47+
* @return Processed audio samples
48+
*/
49+
std::vector<float> ProcessFrame(const std::vector<float>& decoder_output, int frameIndex, int totalFrames,
50+
int actualFrameLen)
51+
{
52+
std::vector<float> processed_output;
53+
54+
if (first_frame_) {
55+
// Special handling for the first frame
56+
ProcessFirstFrame(decoder_output, processed_output, actualFrameLen);
57+
first_frame_ = false;
58+
} else {
59+
// Process subsequent frames with SOLA algorithm
60+
ProcessSubsequentFrame(decoder_output, processed_output, frameIndex, totalFrames, actualFrameLen);
61+
}
62+
63+
return processed_output;
64+
}
65+
66+
private:
67+
/**
68+
* Initialize the SOLA processor parameters and buffers
69+
*/
70+
void Initialize()
71+
{
72+
// Calculate SOLA parameters
73+
sola_buffer_frame_ = pad_frames_ * samples_per_frame_;
74+
sola_search_frame_ = pad_frames_ * samples_per_frame_;
75+
effective_frames_ = 0; // Will be set during frame processing
76+
77+
// Create fade-in and fade-out windows
78+
fade_in_window_.resize(sola_buffer_frame_);
79+
fade_out_window_.resize(sola_buffer_frame_);
80+
81+
for (int i = 0; i < sola_buffer_frame_; i++) {
82+
fade_in_window_[i] = static_cast<float>(i) / sola_buffer_frame_;
83+
fade_out_window_[i] = 1.0f - fade_in_window_[i];
84+
}
85+
86+
// Initialize SOLA buffer
87+
sola_buffer_.resize(sola_buffer_frame_, 0.0f);
88+
}
89+
90+
/**
91+
* Process the first audio frame
92+
*
93+
* @param decoder_output Raw audio data from decoder
94+
* @param processed_output Output buffer for processed audio
95+
* @param actualFrameLen Actual length of the frame
96+
*/
97+
void ProcessFirstFrame(const std::vector<float>& decoder_output, std::vector<float>& processed_output,
98+
int actualFrameLen)
99+
{
100+
int audio_start = pad_frames_ * samples_per_frame_;
101+
int audio_len = (actualFrameLen - 2 * pad_frames_) * samples_per_frame_;
102+
103+
// Boundary check
104+
audio_len = std::min(audio_len, static_cast<int>(decoder_output.size() - audio_start));
105+
106+
// Add first frame data to output
107+
processed_output.insert(processed_output.end(), decoder_output.begin() + audio_start,
108+
decoder_output.begin() + audio_start + audio_len);
109+
110+
// Save the end part to SOLA buffer for next frame alignment
111+
int buffer_start = audio_start + audio_len;
112+
if (buffer_start + sola_buffer_frame_ <= decoder_output.size()) {
113+
std::copy(decoder_output.begin() + buffer_start, decoder_output.begin() + buffer_start + sola_buffer_frame_,
114+
sola_buffer_.begin());
115+
}
116+
}
117+
118+
/**
119+
* Process subsequent audio frames using SOLA algorithm
120+
*
121+
* @param decoder_output Raw audio data from decoder
122+
* @param processed_output Output buffer for processed audio
123+
* @param frameIndex Current frame index
124+
* @param totalFrames Total number of frames
125+
* @param actualFrameLen Actual length of the frame
126+
*/
127+
void ProcessSubsequentFrame(const std::vector<float>& decoder_output, std::vector<float>& processed_output,
128+
int frameIndex, int totalFrames, int actualFrameLen)
129+
{
130+
int audio_start = pad_frames_ * samples_per_frame_;
131+
132+
// 1. Prepare search window
133+
std::vector<float> search_window(sola_buffer_frame_ + sola_search_frame_);
134+
std::copy(decoder_output.begin() + audio_start, decoder_output.begin() + audio_start + search_window.size(),
135+
search_window.begin());
136+
137+
// 2. Find best alignment point (compute cross-correlation)
138+
int best_offset = FindBestOffset(search_window);
139+
140+
// 3. Apply alignment offset
141+
int aligned_start = audio_start + best_offset;
142+
143+
// 4. Create smooth transition
144+
std::vector<float> crossfade_region = CreateCrossfade(decoder_output, aligned_start);
145+
146+
// 5. Add crossfade region to output
147+
processed_output.insert(processed_output.end(), crossfade_region.begin(), crossfade_region.end());
148+
149+
// 6. Add remaining valid audio data
150+
AddRemainingAudio(decoder_output, processed_output, aligned_start, frameIndex, totalFrames, actualFrameLen);
151+
}
152+
153+
/**
154+
* Find the best alignment offset using normalized cross-correlation
155+
*
156+
* @param search_window Window of audio samples to search in
157+
* @return Optimal offset for alignment
158+
*/
159+
int FindBestOffset(const std::vector<float>& search_window)
160+
{
161+
int best_offset = 0;
162+
float best_correlation = -1.0f;
163+
164+
for (int offset = 0; offset <= sola_search_frame_; offset++) {
165+
float correlation = 0.0f;
166+
float energy = 0.0f;
167+
168+
for (int j = 0; j < sola_buffer_frame_; j++) {
169+
correlation += sola_buffer_[j] * search_window[j + offset];
170+
energy += search_window[j + offset] * search_window[j + offset];
171+
}
172+
173+
// Normalize correlation
174+
float normalized_correlation = (energy > 1e-8) ? correlation / std::sqrt(energy) : 0.0f;
175+
176+
if (normalized_correlation > best_correlation) {
177+
best_correlation = normalized_correlation;
178+
best_offset = offset;
179+
}
180+
}
181+
182+
return best_offset;
183+
}
184+
185+
/**
186+
* Create crossfade transition region
187+
*
188+
* @param decoder_output Raw audio data from decoder
189+
* @param aligned_start Starting point after alignment
190+
* @return Crossfaded audio samples
191+
*/
192+
std::vector<float> CreateCrossfade(const std::vector<float>& decoder_output, int aligned_start)
193+
{
194+
std::vector<float> crossfade_region(sola_buffer_frame_);
195+
196+
for (int j = 0; j < sola_buffer_frame_; j++) {
197+
// Apply fade-in and fade-out window functions
198+
crossfade_region[j] =
199+
decoder_output[aligned_start + j] * fade_in_window_[j] + sola_buffer_[j] * fade_out_window_[j];
200+
}
201+
202+
return crossfade_region;
203+
}
204+
205+
/**
206+
* Add remaining audio data and update buffer
207+
*
208+
* @param decoder_output Raw audio data from decoder
209+
* @param processed_output Output buffer for processed audio
210+
* @param aligned_start Starting point after alignment
211+
* @param frameIndex Current frame index
212+
* @param totalFrames Total number of frames
213+
* @param actualFrameLen Actual length of the frame
214+
*/
215+
void AddRemainingAudio(const std::vector<float>& decoder_output, std::vector<float>& processed_output,
216+
int aligned_start, int frameIndex, int totalFrames, int actualFrameLen)
217+
{
218+
int remaining_start = aligned_start + sola_buffer_frame_;
219+
int remaining_len = (actualFrameLen - 2 * pad_frames_) * samples_per_frame_ - sola_buffer_frame_;
220+
221+
// Boundary check
222+
remaining_len = std::min(remaining_len, static_cast<int>(decoder_output.size() - remaining_start));
223+
224+
if (remaining_len > 0) {
225+
processed_output.insert(processed_output.end(), decoder_output.begin() + remaining_start,
226+
decoder_output.begin() + remaining_start + remaining_len);
227+
}
228+
229+
// Update SOLA buffer
230+
UpdateSolaBuffer(decoder_output, remaining_start + remaining_len);
231+
}
232+
233+
/**
234+
* Update SOLA buffer with new audio data
235+
*
236+
* @param decoder_output Raw audio data from decoder
237+
* @param buffer_start Starting point for the new buffer data
238+
*/
239+
void UpdateSolaBuffer(const std::vector<float>& decoder_output, int buffer_start)
240+
{
241+
// Check if there's enough data for the next buffer
242+
if (buffer_start + sola_buffer_frame_ <= decoder_output.size()) {
243+
std::copy(decoder_output.begin() + buffer_start, decoder_output.begin() + buffer_start + sola_buffer_frame_,
244+
sola_buffer_.begin());
245+
} else {
246+
// Fill with zeros if not enough data
247+
int avail = static_cast<int>(decoder_output.size() - buffer_start);
248+
if (avail > 0) {
249+
std::copy(decoder_output.begin() + buffer_start, decoder_output.end(), sola_buffer_.begin());
250+
}
251+
std::fill(sola_buffer_.begin() + avail, sola_buffer_.end(), 0.0f);
252+
}
253+
}
254+
255+
private:
256+
int pad_frames_; // Number of padding frames
257+
int samples_per_frame_; // Number of samples per frame
258+
int effective_frames_; // Number of effective frames
259+
int sola_buffer_frame_; // SOLA buffer length
260+
int sola_search_frame_; // SOLA search window length
261+
262+
std::vector<float> fade_in_window_; // Fade-in window
263+
std::vector<float> fade_out_window_; // Fade-out window
264+
std::vector<float> sola_buffer_; // SOLA buffer
265+
266+
bool first_frame_; // Flag for first frame processing
267+
};
268+
269+
#endif // SOLA_PROCESSOR_H

0 commit comments

Comments
 (0)