Skip to content

Commit e8e4cc9

Browse files
committed
[update] update vad & whisper
1 parent 169a595 commit e8e4cc9

4 files changed

Lines changed: 96 additions & 87 deletions

File tree

projects/llm_framework/main_vad/mode_silero-vad-model.json renamed to projects/llm_framework/main_vad/mode_silero-vad.json

File renamed without changes.

projects/llm_framework/main_vad/src/main.cpp

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ static void __sigint(int iSigNo)
3232
static std::string base_model_path_;
3333
static std::string base_model_config_path_;
3434

35+
typedef std::function<void(const bool &data)> task_callback_t;
36+
3537
#define CONFIG_AUTO_SET(obj, key) \
3638
if (config_body.contains(#key)) \
3739
mode_config_.key = config_body[#key]; \
@@ -50,16 +52,14 @@ class llm_task {
5052
bool enoutput_;
5153
bool enstream_;
5254
bool printed = false;
55+
task_callback_t out_callback_;
5356
std::atomic_bool audio_flage_;
5457
int delay_audio_frame_ = 100;
5558
buffer_t *pcmdata;
5659
std::string wake_wav_file_;
5760

58-
std::function<void(const std::string &)> out_callback_;
59-
6061
bool parse_config(const nlohmann::json &config_body)
6162
{
62-
fprintf(stderr, "%s\n", mode_config_.ToString().c_str());
6363
try {
6464
model_ = config_body.at("model");
6565
response_format_ = config_body.at("response_format");
@@ -136,15 +136,14 @@ class llm_task {
136136
return 0;
137137
}
138138

139-
void set_output(std::function<void(const std::string &)> out_callback)
139+
void set_output(task_callback_t out_callback)
140140
{
141141
out_callback_ = out_callback;
142142
}
143143

144144
void sys_pcm_on_data(const std::string &raw)
145145
{
146146
static int count = 0;
147-
int32_t k = 0;
148147
if (count < delay_audio_frame_) {
149148
buffer_write_char(pcmdata, raw.c_str(), raw.length());
150149
count++;
@@ -167,6 +166,9 @@ class llm_task {
167166
if (vad_->IsSpeechDetected() && !printed) {
168167
printed = true;
169168
SLOGI("Detected speech!");
169+
if (out_callback_) {
170+
out_callback_(true);
171+
}
170172
}
171173
if (!vad_->IsSpeechDetected()) {
172174
printed = false;
@@ -177,8 +179,11 @@ class llm_task {
177179
const auto &segment = vad_->Front();
178180
float duration = segment.samples.size() / static_cast<float>(sample_rate);
179181
SLOGI("Duration: %.3f seconds", duration);
180-
k += 1;
182+
// k += 1;
181183
vad_->Pop();
184+
if (out_callback_) {
185+
out_callback_(false);
186+
}
182187
}
183188
}
184189

@@ -203,18 +208,31 @@ class llm_task {
203208
};
204209
#undef CONFIG_AUTO_SET
205210

206-
class llm_kws : public StackFlow {
211+
class llm_vad : public StackFlow {
207212
private:
208213
int task_count_;
209214
std::string audio_url_;
210215
std::unordered_map<int, std::shared_ptr<llm_task>> llm_task_;
211216

212217
public:
213-
llm_kws() : StackFlow("vad")
218+
llm_vad() : StackFlow("vad")
214219
{
215220
task_count_ = 1;
216221
}
217222

223+
void task_output(const std::weak_ptr<llm_task> llm_task_obj_weak,
224+
const std::weak_ptr<llm_channel_obj> llm_channel_weak, const bool &data)
225+
{
226+
auto llm_task_obj = llm_task_obj_weak.lock();
227+
auto llm_channel = llm_channel_weak.lock();
228+
if (!(llm_task_obj && llm_channel)) {
229+
return;
230+
}
231+
std::string tmp_msg1;
232+
const bool *next_data = &data;
233+
llm_channel->send(llm_task_obj->response_format_, (*next_data), LLM_NO_ERROR);
234+
}
235+
218236
void task_pause(const std::weak_ptr<llm_task> llm_task_obj_weak,
219237
const std::weak_ptr<llm_channel_obj> llm_channel_weak)
220238
{
@@ -350,9 +368,8 @@ class llm_kws : public StackFlow {
350368
if (ret == 0) {
351369
llm_channel->set_output(llm_task_obj->enoutput_);
352370
llm_channel->set_stream(llm_task_obj->enstream_);
353-
llm_task_obj->set_output([llm_task_obj, llm_channel](const std::string &data) {
354-
llm_channel->send(llm_task_obj->response_format_, true, LLM_NO_ERROR);
355-
});
371+
llm_task_obj->set_output(std::bind(&llm_vad::task_output, this, std::weak_ptr<llm_task>(llm_task_obj),
372+
std::weak_ptr<llm_channel_obj>(llm_channel), std::placeholders::_1));
356373

357374
for (const auto input : llm_task_obj->inputs_) {
358375
if (input.find("sys") != std::string::npos) {
@@ -364,7 +381,7 @@ class llm_kws : public StackFlow {
364381
llm_task_obj->audio_flage_ = true;
365382
} else if (input.find("vad") != std::string::npos) {
366383
llm_channel->subscriber_work_id(
367-
"", std::bind(&llm_kws::task_user_data, this, std::weak_ptr<llm_task>(llm_task_obj),
384+
"", std::bind(&llm_vad::task_user_data, this, std::weak_ptr<llm_task>(llm_task_obj),
368385
std::weak_ptr<llm_channel_obj>(llm_channel), std::placeholders::_1,
369386
std::placeholders::_2));
370387
}
@@ -430,7 +447,7 @@ class llm_kws : public StackFlow {
430447
return 0;
431448
}
432449

433-
~llm_kws()
450+
~llm_vad()
434451
{
435452
while (1) {
436453
auto iteam = llm_task_.begin();
@@ -452,7 +469,7 @@ int main(int argc, char *argv[])
452469
signal(SIGTERM, __sigint);
453470
signal(SIGINT, __sigint);
454471
mkdir("/tmp/llm", 0777);
455-
llm_kws llm;
472+
llm_vad llm;
456473
while (!main_exit_flage) {
457474
sleep(1);
458475
}

projects/llm_framework/main_whisper/mode_whisper-tiny.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
],
1515
"mode_param": {
1616
"model_type": "tiny",
17-
"language": "zh",
17+
"language": "ja",
1818
"encoder": "tiny-encoder.axmodel",
1919
"decoder_main": "tiny-decoder-main.axmodel",
2020
"decoder_loop": "tiny-decoder-loop.axmodel",

projects/llm_framework/main_whisper/src/main.cpp

Lines changed: 64 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,15 @@ class llm_task {
8888
bool enoutput_;
8989
bool enstream_;
9090
bool ensleep_;
91-
bool endpoint_;
9291
std::atomic_bool superior_flage_;
9392
std::atomic_bool audio_flage_;
9493
std::atomic_bool awake_flage_;
94+
std::atomic_bool vad_endpoint_;
9595
std::string superior_id_;
9696
static int ax_init_flage_;
9797
task_callback_t out_callback_;
9898
int awake_delay_ = 50;
99-
int delay_audio_frame_ = 100;
99+
int delay_audio_frame_ = 1000;
100100
buffer_t *pcmdata;
101101

102102
std::function<void(void)> pause;
@@ -301,23 +301,34 @@ class llm_task {
301301

302302
void sys_pcm_on_data(const std::string &raw)
303303
{
304+
static int count = 0;
305+
if (count < delay_audio_frame_) {
306+
buffer_write_char(pcmdata, raw.c_str(), raw.length());
307+
count++;
308+
return;
309+
}
310+
buffer_write_char(pcmdata, raw.c_str(), raw.length());
311+
buffer_position_set(pcmdata, 0);
312+
count = 0;
313+
std::vector<float> floatSamples;
314+
{
315+
int16_t audio_val;
316+
while (buffer_read_u16(pcmdata, (unsigned short *)&audio_val, 1)) {
317+
float normalizedSample = (float)audio_val / INT16_MAX;
318+
floatSamples.push_back(normalizedSample);
319+
}
320+
}
321+
buffer_position_set(pcmdata, 0);
322+
304323
if (WHISPER_N_TEXT_STATE_MAP.find(mode_config_.model_type) == WHISPER_N_TEXT_STATE_MAP.end()) {
305324
fprintf(stderr, "Can NOT find n_text_state for model_type: %s\n", mode_config_.model_type.c_str());
306325
return;
307326
}
308327

309328
int WHISPER_N_TEXT_STATE = WHISPER_N_TEXT_STATE_MAP[mode_config_.model_type];
310329

311-
AudioFile<float> audio_file;
312-
if (!audio_file.load("demo.wav")) {
313-
printf("load wav failed!\n");
314-
return;
315-
}
316-
awake_flage_ = false;
317-
auto &samples = audio_file.samples[0];
318-
319330
auto mel = librosa::Feature::melspectrogram(
320-
samples, mode_config_.whisper_sample_rate, mode_config_.whisper_n_fft, mode_config_.whisper_hop_length,
331+
floatSamples, mode_config_.whisper_sample_rate, mode_config_.whisper_n_fft, mode_config_.whisper_hop_length,
321332
"hann", true, "reflect", 2.0f, mode_config_.whisper_n_mels, 0.0f, mode_config_.whisper_sample_rate / 2.0f);
322333
int n_mel = mel.size();
323334
int n_len = mel[0].size();
@@ -457,57 +468,18 @@ class llm_task {
457468
s += str;
458469
}
459470

460-
if (mode_config_.language == "en")
471+
if (mode_config_.language == "en" || mode_config_.language == "ja") {
461472
printf("Result: %s\n", s.c_str());
462-
else {
473+
if (out_callback_) out_callback_(s, true);
474+
} else {
463475
const opencc::SimpleConverter converter(mode_config_.t2s.c_str());
464476
std::string simple_str = converter.Convert(s);
465477
printf("Result: %s\n", simple_str.c_str());
478+
if ((!simple_str.empty()) && out_callback_) {
479+
out_callback_(simple_str, true);
480+
}
466481
}
467-
/////////////////////////////////////////////////////////////////////
468-
// static int count = 0;
469-
// if (count < delay_audio_frame_) {
470-
// buffer_write_char(pcmdata, raw.c_str(), raw.length());
471-
// count++;
472-
// return;
473-
// }
474-
// buffer_write_char(pcmdata, raw.c_str(), raw.length());
475-
// buffer_position_set(pcmdata, 0);
476-
// count = 0;
477-
// std::vector<float> floatSamples;
478-
// {
479-
// int16_t audio_val;
480-
// while (buffer_read_u16(pcmdata, (unsigned short *)&audio_val, 1)) {
481-
// float normalizedSample = (float)audio_val / INT16_MAX;
482-
// floatSamples.push_back(normalizedSample);
483-
// }
484-
// }
485-
// buffer_position_set(pcmdata, 0);
486-
// if (awake_flage_ && recognizer_stream_) {
487-
// recognizer_stream_.reset();
488-
// awake_flage_ = false;
489-
// }
490-
// if (!recognizer_stream_) {
491-
// recognizer_stream_ = recognizer_->CreateStream();
492-
// }
493-
// recognizer_stream_->AcceptWaveform(mode_config_.feat_config.sampling_rate, floatSamples.data(),
494-
// floatSamples.size());
495-
// while (recognizer_->IsReady(recognizer_stream_.get())) {
496-
// recognizer_->DecodeStream(recognizer_stream_.get());
497-
// }
498-
// std::string text = recognizer_->GetResult(recognizer_stream_.get()).text;
499-
// std::string lower_text;
500-
// lower_text.resize(text.size());
501-
// std::transform(text.begin(), text.end(), lower_text.begin(), [](const char c) { return std::tolower(c); });
502-
// if ((!lower_text.empty()) && out_callback_) out_callback_(lower_text, false);
503-
// bool is_endpoint = recognizer_->IsEndpoint(recognizer_stream_.get());
504-
// if (is_endpoint) {
505-
// std::cout << "asr have a is_endpoint \n";
506-
// recognizer_stream_->Finalize();
507-
// if ((!lower_text.empty()) && out_callback_) {
508-
// out_callback_(lower_text, true);
509-
// }
510-
// recognizer_stream_.reset();
482+
511483
if (ensleep_) {
512484
if (pause) pause();
513485
}
@@ -588,24 +560,24 @@ class llm_whisper : public StackFlow {
588560
if (!(llm_task_obj && llm_channel)) {
589561
return;
590562
}
591-
std::string base64_data;
592-
int len = encode_base64(data, base64_data);
563+
std::string tmp_msg1;
564+
const std::string *next_data = &data;
565+
if (finish) {
566+
tmp_msg1 = data + ".";
567+
next_data = &tmp_msg1;
568+
}
593569
if (llm_channel->enstream_) {
594570
static int count = 0;
595571
nlohmann::json data_body;
596-
data_body["index"] = count++;
597-
if (!finish)
598-
data_body["delta"] = base64_data;
599-
else
600-
data_body["delta"] = std::string("");
572+
data_body["index"] = count++;
573+
data_body["delta"] = (*next_data);
601574
data_body["finish"] = finish;
602575
if (finish) count = 0;
576+
SLOGI("send stream:%s", next_data->c_str());
603577
llm_channel->send(llm_task_obj->response_format_, data_body, LLM_NO_ERROR);
604578
} else if (finish) {
605-
llm_channel->send(llm_task_obj->response_format_, base64_data, LLM_NO_ERROR);
606-
}
607-
if (llm_task_obj->response_format_.find("sys") != std::string::npos) {
608-
unit_call("audio", "queue_play", data);
579+
SLOGI("send utf-8:%s", next_data->c_str());
580+
llm_channel->send(llm_task_obj->response_format_, (*next_data), LLM_NO_ERROR);
609581
}
610582
}
611583

@@ -744,6 +716,20 @@ class llm_whisper : public StackFlow {
744716
task_work(llm_task_obj, llm_channel);
745717
}
746718

719+
void vad_endpoint(const std::weak_ptr<llm_task> llm_task_obj_weak,
720+
const std::weak_ptr<llm_channel_obj> llm_channel_weak, const std::string &object,
721+
const std::string &data)
722+
{
723+
auto llm_task_obj = llm_task_obj_weak.lock();
724+
auto llm_channel = llm_channel_weak.lock();
725+
if (!(llm_task_obj && llm_channel)) {
726+
return;
727+
}
728+
if (data == "true" || data == "false") {
729+
llm_task_obj->vad_endpoint_ = (data == "true");
730+
}
731+
}
732+
747733
void work(const std::string &work_id, const std::string &object, const std::string &data) override
748734
{
749735
SLOGI("llm_asr::work:%s", data.c_str());
@@ -814,7 +800,7 @@ class llm_whisper : public StackFlow {
814800
audio_url_ = unit_call("audio", "cap", input);
815801
std::weak_ptr<llm_task> _llm_task_obj = llm_task_obj;
816802
llm_channel->subscriber(audio_url_, [_llm_task_obj](pzmq *_pzmq, const std::string &raw) {
817-
// _llm_task_obj.lock()->sys_pcm_on_data(raw);
803+
_llm_task_obj.lock()->sys_pcm_on_data(raw);
818804
});
819805
llm_task_obj->audio_flage_ = true;
820806
} else if (input.find("asr") != std::string::npos) {
@@ -830,10 +816,10 @@ class llm_whisper : public StackFlow {
830816
std::weak_ptr<llm_channel_obj>(llm_channel), std::placeholders::_1,
831817
std::placeholders::_2));
832818
} else if (input.find("vad") != std::string::npos) {
833-
llm_task_obj->endpoint_ = true;
834-
task_pause(work_id, "");
819+
llm_task_obj->vad_endpoint_ = true;
820+
// task_pause(work_id, "");
835821
llm_channel->subscriber_work_id(
836-
input, std::bind(&llm_whisper::kws_awake, this, std::weak_ptr<llm_task>(llm_task_obj),
822+
input, std::bind(&llm_whisper::vad_endpoint, this, std::weak_ptr<llm_task>(llm_task_obj),
837823
std::weak_ptr<llm_channel_obj>(llm_channel), std::placeholders::_1,
838824
std::placeholders::_2));
839825
}
@@ -880,6 +866,12 @@ class llm_whisper : public StackFlow {
880866
std::bind(&llm_whisper::kws_awake, this, std::weak_ptr<llm_task>(llm_task_obj),
881867
std::weak_ptr<llm_channel_obj>(llm_channel), std::placeholders::_1, std::placeholders::_2));
882868
llm_task_obj->inputs_.push_back(data);
869+
} else if (data.find("vad") != std::string::npos) {
870+
llm_task_obj->vad_endpoint_ = true;
871+
ret = llm_channel->subscriber_work_id(
872+
data,
873+
std::bind(&llm_whisper::vad_endpoint, this, std::weak_ptr<llm_task>(llm_task_obj),
874+
std::weak_ptr<llm_channel_obj>(llm_channel), std::placeholders::_1, std::placeholders::_2));
883875
}
884876
if (ret) {
885877
error_body["code"] = -20;

0 commit comments

Comments
 (0)