Skip to content

Commit e19220e

Browse files
committed
[update] llm_vlm add pause action & perf post process
1 parent 1f30c63 commit e19220e

3 files changed

Lines changed: 382 additions & 31 deletions

File tree

projects/llm_framework/main_vlm/src/main.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ class llm_task {
125125
CONFIG_AUTO_SET(file_body["mode_param"], b_use_mmap_load_embed);
126126
CONFIG_AUTO_SET(file_body["mode_param"], b_dynamic_load_axmodel_layer);
127127
CONFIG_AUTO_SET(file_body["mode_param"], max_token_len);
128+
CONFIG_AUTO_SET(file_body["mode_param"], temperature);
129+
CONFIG_AUTO_SET(file_body["mode_param"], top_p);
128130

129131
if (mode_config_.filename_tokenizer_model.find("http:") != std::string::npos) {
130132
std::string tokenizer_file;
@@ -171,7 +173,11 @@ class llm_task {
171173
}
172174
};
173175
lLaMa_ = std::make_unique<LLM>();
174-
if (!lLaMa_->Init(mode_config_)) return -2;
176+
if (!lLaMa_->Init(mode_config_)) {
177+
lLaMa_->Deinit();
178+
lLaMa_.reset();
179+
return -2;
180+
}
175181

176182
} catch (...) {
177183
SLOGE("config false");
@@ -293,6 +299,33 @@ class llm_llm : public StackFlow {
293299
}
294300
}
295301

302+
void task_pause(const std::weak_ptr<llm_task> llm_task_obj_weak,
303+
const std::weak_ptr<llm_channel_obj> llm_channel_weak)
304+
{
305+
auto llm_task_obj = llm_task_obj_weak.lock();
306+
auto llm_channel = llm_channel_weak.lock();
307+
if (!(llm_task_obj && llm_channel)) {
308+
return;
309+
}
310+
llm_task_obj->lLaMa_->Stop();
311+
}
312+
313+
void pause(const std::string &work_id, const std::string &object, const std::string &data) override
314+
{
315+
SLOGI("llm_asr::work:%s", data.c_str());
316+
317+
nlohmann::json error_body;
318+
int work_id_num = sample_get_work_id_num(work_id);
319+
if (llm_task_.find(work_id_num) == llm_task_.end()) {
320+
error_body["code"] = -6;
321+
error_body["message"] = "Unit Does Not Exist";
322+
send("None", "None", error_body, work_id);
323+
return;
324+
}
325+
task_pause(llm_task_[work_id_num], get_channel(work_id_num));
326+
send("None", "None", LLM_NO_ERROR, work_id);
327+
}
328+
296329
void task_user_data(const std::weak_ptr<llm_task> llm_task_obj_weak,
297330
const std::weak_ptr<llm_channel_obj> llm_channel_weak, const std::string &object,
298331
const std::string &data)

projects/llm_framework/main_vlm/src/runner/LLM.hpp

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "timer.hpp"
1313
#include "opencv2/opencv.hpp"
1414
#include "ax_sys_api.h"
15+
#include "LLMPostprocess.hpp"
1516

1617
typedef std::function<void(int*, int, const char*, float, void*)> LLMRuningCallback;
1718

@@ -26,6 +27,8 @@ struct LLMAttrType
2627

2728
std::string filename_post_axmodel = "tinyllama-int8/tinyllama_post.axmodel";
2829

30+
bool b_use_topk = false;
31+
2932
std::string filename_vpm_encoder_axmodedl = "minicpmv/vpm_resampler_version0_fp16.axmodel";
3033
std::string filename_vpm_resampler_axmodedl = "minicpmv/vpm_resampler_version0_fp16.axmodel";
3134
int vpm_width = 280;
@@ -44,13 +47,13 @@ struct LLMAttrType
4447
int kv_cache_num = 1024; // auto calc
4548
int kv_cache_size = 256; // auto calc
4649

50+
float temperature = 0.7f;
51+
float top_p = 0.9f;
4752
bool b_use_mmap_load_embed = false;
4853
bool b_dynamic_load_axmodel_layer = false;
4954

5055
bool b_use_mmap_load_layer = true;
5156

52-
bool b_use_topk = false;
53-
5457
// bool b_live_print = true;
5558
LLMRuningCallback runing_callback = nullptr;
5659
void *reserve = nullptr;
@@ -84,41 +87,37 @@ class LLM
8487

8588
bool b_stop = false;
8689

87-
static int FindMax(unsigned short *p, int n, float *val = 0)
90+
int post_process(unsigned short *p, int n, std::vector<int> &history, float *val = 0)
8891
{
89-
float max_val = -MAXFLOAT;
90-
int max_index = 0;
92+
std::vector<float> logits(n);
9193
for (int i = 0; i < n; i++)
9294
{
9395
unsigned int proc = p[i] << 16;
94-
float tmp = *reinterpret_cast<float *>(&proc);
95-
if (tmp > max_val)
96-
{
97-
max_val = tmp;
98-
max_index = i;
99-
}
96+
logits[i] = *reinterpret_cast<float *>(&proc);
10097
}
98+
LLMPostprocess postprocess;
99+
postprocess.set_temperature(true, _attr.temperature);
100+
postprocess.set_repetition_penalty(true, 1.2f);
101+
// postprocess.set_top_k_sampling(true, 40);
102+
postprocess.set_top_p_sampling(true, _attr.top_p);
101103

102-
// for (int i = 0; i < n; i += 4)
103-
// {
104-
// uint16x4_t bf16_data = vld1_u16(&p[i]);
105-
// uint32x4_t float_data = vmovl_u16(bf16_data);
106-
// float32x4_t tmp_floats = vreinterpretq_f32_u32(vshlq_n_u32(float_data, 16));
104+
return postprocess.apply(logits, history);
107105

108-
// for (int j = 0; j < 4; j++)
106+
// float max_val = -MAXFLOAT;
107+
// int max_index = 0;
108+
// for (int i = 0; i < n; i++)
109+
// {
110+
// unsigned int proc = p[i] << 16;
111+
// float tmp = *reinterpret_cast<float *>(&proc);
112+
// if (tmp > max_val)
109113
// {
110-
// float tmp = vgetq_lane_f32(tmp_floats, j);
111-
// if (tmp > max_val)
112-
// {
113-
// max_val = tmp;
114-
// max_index = i + j;
115-
// }
114+
// max_val = tmp;
115+
// max_index = i;
116116
// }
117117
// }
118-
119-
if (val)
120-
*val = max_val;
121-
return max_index;
118+
// if (val)
119+
// *val = max_val;
120+
// return max_index;
122121
}
123122

124123
public:
@@ -552,7 +551,7 @@ class LLM
552551
AX_SYS_MinvalidateCache(output_post.phyAddr, output_post.pVirAddr, output_post.nSize);
553552
unsigned short *post_out = (unsigned short *)output_post.pVirAddr;
554553
float max_val = -MAXFLOAT;
555-
max_index = FindMax(post_out, _attr.tokens_embed_num, &max_val);
554+
max_index = post_process(post_out, _attr.tokens_embed_num, token_ids, &max_val);
556555
}
557556
next_token = max_index;
558557

@@ -654,7 +653,7 @@ class LLM
654653
AX_SYS_MinvalidateCache(output_post.phyAddr, output_post.pVirAddr, output_post.nSize);
655654
unsigned short *post_out = (unsigned short *)output_post.pVirAddr;
656655
float max_val = -MAXFLOAT;
657-
max_index = FindMax(post_out, _attr.tokens_embed_num, &max_val);
656+
max_index = post_process(post_out, _attr.tokens_embed_num, token_ids, &max_val);
658657
}
659658
next_token = max_index;
660659

@@ -676,7 +675,7 @@ class LLM
676675
if (_attr.runing_callback)
677676
{
678677
cached_token.push_back(max_index);
679-
if (cached_token.size() >= 3)
678+
if (cached_token.size() >= 5)
680679
{
681680
float t_cost_ms = t_cost.cost();
682681
float token_per_sec = token_ids.size() / (t_cost_ms / 1000);

0 commit comments

Comments
 (0)