Skip to content

Commit 169a595

Browse files
committed
[fix] fix whisper bug
1 parent c2156b2 commit 169a595

1 file changed

Lines changed: 31 additions & 34 deletions

File tree

  • projects/llm_framework/main_whisper/src

projects/llm_framework/main_whisper/src/main.cpp

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,22 @@ typedef struct {
4848
std::string model_type;
4949
std::string language;
5050
std::string t2s;
51-
int whisper_sample_rate;
52-
int whisper_n_fft;
53-
int awake_delay;
54-
int whisper_hop_length;
55-
int whisper_chunk_size;
56-
int whisper_n_mels;
57-
int whisper_sot;
58-
int whisper_eot;
59-
int whisper_blank;
60-
int whisper_no_timestamps;
61-
int whisper_no_speech;
62-
int whisper_translate;
63-
int whisper_transcribe;
64-
int whisper_vocab_size;
65-
int whisper_n_text_ctx;
66-
float neg_inf = -std::numeric_limits<float>::infinity();
51+
int whisper_sample_rate = 16000;
52+
int whisper_n_fft = 400;
53+
int awake_delay = 1;
54+
int whisper_hop_length = 160;
55+
int whisper_chunk_size = 30;
56+
int whisper_n_mels = 400;
57+
int whisper_sot = 50258;
58+
int whisper_eot = 50257;
59+
int whisper_blank = 220;
60+
int whisper_no_timestamps = 50363;
61+
int whisper_no_speech = 50362;
62+
int whisper_translate = 50358;
63+
int whisper_transcribe = 50359;
64+
int whisper_vocab_size = 51865;
65+
int whisper_n_text_ctx = 448;
66+
float neg_inf = -std::numeric_limits<float>::infinity();
6767
} whisper_config;
6868

6969
typedef std::function<void(const std::string &data, bool finish)> task_callback_t;
@@ -149,6 +149,7 @@ class llm_task {
149149

150150
void supress_tokens(std::vector<float> &logits, bool is_initial)
151151
{
152+
mode_config_.neg_inf = -std::numeric_limits<float>::infinity();
152153
if (is_initial) {
153154
logits[mode_config_.whisper_eot] = mode_config_.neg_inf;
154155
logits[mode_config_.whisper_blank] = mode_config_.neg_inf;
@@ -179,6 +180,14 @@ class llm_task {
179180
return WHISPER_LANG_CODES[i];
180181
}
181182

183+
double get_current_time()
184+
{
185+
struct timeval tv;
186+
gettimeofday(&tv, NULL);
187+
188+
return tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0;
189+
}
190+
182191
int load_model(const nlohmann::json &config_body)
183192
{
184193
if (parse_config(config_body)) {
@@ -290,14 +299,6 @@ class llm_task {
290299
out_callback_ = out_callback;
291300
}
292301

293-
double get_current_time()
294-
{
295-
struct timeval tv;
296-
gettimeofday(&tv, NULL);
297-
298-
return tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0;
299-
}
300-
301302
void sys_pcm_on_data(const std::string &raw)
302303
{
303304
if (WHISPER_N_TEXT_STATE_MAP.find(mode_config_.model_type) == WHISPER_N_TEXT_STATE_MAP.end()) {
@@ -373,6 +374,8 @@ class llm_task {
373374
}
374375
end = get_current_time();
375376
printf("Encoder run take %.2f ms\n", (end - start));
377+
378+
// detect language
376379
SOT_SEQUENCE[1] = detect_language(mode_config_.language);
377380

378381
// decoder_main
@@ -394,17 +397,10 @@ class llm_task {
394397
std::copy(decoder_main_logits.begin() + 3 * mode_config_.whisper_vocab_size, decoder_main_logits.end(),
395398
logits.begin());
396399
supress_tokens(logits, true);
397-
for (int i = 0; i < logits.size(); i++) {
398-
printf("logits[%d] = %f\n", i, logits[i]);
399-
}
400+
400401
max_token_id = argmax(logits);
401-
FILE* fp = fopen("logits.bin", "wb");
402-
fwrite(logits.data(), sizeof(float), logits.size(), fp);
403-
fclose(fp);
404-
std::cout << "Data written successfully!" << std::endl;
405-
printf("max_token_id = %d\n", max_token_id);
406402
printf("First token: %d \t take %.2fms\n", max_token_id, (end - start));
407-
403+
mode_config_.neg_inf = -std::numeric_limits<float>::infinity();
408404
std::vector<float> mask(mode_config_.whisper_n_text_ctx);
409405
for (int n = 0; n < mode_config_.whisper_n_text_ctx - offset - 1; n++) {
410406
mask[n] = mode_config_.neg_inf;
@@ -532,7 +528,8 @@ class llm_task {
532528
}
533529
AX_ENGINE_NPU_ATTR_T npu_attr;
534530
memset(&npu_attr, 0, sizeof(npu_attr));
535-
ret = AX_ENGINE_Init(&npu_attr);
531+
npu_attr.eHardMode = AX_ENGINE_VIRTUAL_NPU_DISABLE;
532+
ret = AX_ENGINE_Init(&npu_attr);
536533
if (0 != ret) {
537534
fprintf(stderr, "Init ax-engine failed{0x%8x}.\n", ret);
538535
}

0 commit comments

Comments
 (0)