Skip to content

Commit ce9e076

Browse files
committed
[fix] fix stream output
1 parent 79f11b2 commit ce9e076

1 file changed

Lines changed: 22 additions & 8 deletions

File tree

  • projects/llm_framework/main_melotts/src

projects/llm_framework/main_melotts/src/main.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,20 @@ class llm_task {
233233
src_delete(src_state);
234234
}
235235

236-
bool TTS(const std::string &msg_str)
236+
bool TTS(const std::string &msg_str, bool finish)
237237
{
238238
try {
239+
std::vector<int16_t> wav_pcm_data;
240+
if (msg_str.empty()) {
241+
SLOGI("empty");
242+
if (out_callback_) {
243+
std::string output = wav_pcm_data.empty() ?
244+
std::string() :
245+
std::string((char *)wav_pcm_data.data(), wav_pcm_data.size() * sizeof(int16_t));
246+
out_callback_(output, finish);
247+
}
248+
return false;
249+
}
239250
std::vector<int> phones_bef, tones_bef;
240251
lexicon_->convert(msg_str, phones_bef, tones_bef);
241252
// Add blank between words
@@ -284,11 +295,10 @@ class llm_task {
284295
std::vector<float> tmp_pcm((pcmlist.size() * src_ratio + 1));
285296
int len;
286297
resample_audio(pcmlist.data(), pcmlist.size(), tmp_pcm.data(), &len, src_ratio);
287-
std::vector<int16_t> wav_pcm_data;
288298
std::transform(tmp_pcm.begin(), tmp_pcm.begin() + len, std::back_inserter(wav_pcm_data),
289299
[](const auto val) { return (int16_t)(val * INT16_MAX); });
290300
if (out_callback_)
291-
out_callback_(std::string((char *)wav_pcm_data.data(), wav_pcm_data.size() * sizeof(int16_t)), true);
301+
out_callback_(std::string((char *)wav_pcm_data.data(), wav_pcm_data.size() * sizeof(int16_t)), finish);
292302
} catch (...) {
293303
return true;
294304
}
@@ -371,15 +381,17 @@ class llm_tts : public StackFlow {
371381
return;
372382
}
373383
std::string base64_data;
374-
int len = encode_base64(data, base64_data);
384+
if (!data.empty()) {
385+
int len = encode_base64(data, base64_data);
386+
}
375387
if (llm_channel->enstream_) {
376388
static int count = 0;
377389
nlohmann::json data_body;
378390
data_body["index"] = count++;
379-
if (!finish)
391+
if (!data.empty())
380392
data_body["delta"] = base64_data;
381393
else
382-
data_body["delta"] = std::string("");
394+
data_body["delta"] = "";
383395
data_body["finish"] = finish;
384396
if (finish) count = 0;
385397
llm_channel->send(llm_task_obj->response_format_, data_body, LLM_NO_ERROR);
@@ -436,7 +448,7 @@ class llm_tts : public StackFlow {
436448
for (auto cutf8 : tmp_data) {
437449
if (is_breakpoint(cutf8)) {
438450
llm_task_obj->tts_string_stream_buff += cutf8;
439-
ret = llm_task_obj->TTS(llm_task_obj->tts_string_stream_buff);
451+
ret = llm_task_obj->TTS(llm_task_obj->tts_string_stream_buff, false);
440452
llm_task_obj->tts_string_stream_buff.clear();
441453
if (ret) {
442454
error_body["code"] = -11;
@@ -450,13 +462,15 @@ class llm_tts : public StackFlow {
450462
if (finish_flage) {
451463
if (!llm_task_obj->tts_string_stream_buff.empty()) {
452464
llm_task_obj->tts_string_stream_buff.push_back('.');
453-
ret = llm_task_obj->TTS(llm_task_obj->tts_string_stream_buff);
465+
ret = llm_task_obj->TTS(llm_task_obj->tts_string_stream_buff, true);
454466
llm_task_obj->tts_string_stream_buff.clear();
455467
if (ret) {
456468
error_body["code"] = -11;
457469
error_body["message"] = "Model run failed.";
458470
llm_channel->send("None", "None", error_body, llm_channel->work_id_);
459471
}
472+
} else {
473+
llm_task_obj->TTS("", true);
460474
}
461475
}
462476
}

0 commit comments

Comments
 (0)