1212#include " timer.hpp"
1313#include " opencv2/opencv.hpp"
1414#include " ax_sys_api.h"
15+ #include " LLMPostprocess.hpp"
1516
1617typedef std::function<void (int *, int , const char *, float , void *)> LLMRuningCallback;
1718
@@ -26,6 +27,8 @@ struct LLMAttrType
2627
2728 std::string filename_post_axmodel = " tinyllama-int8/tinyllama_post.axmodel" ;
2829
30+ bool b_use_topk = false ;
31+
2932 std::string filename_vpm_encoder_axmodedl = " minicpmv/vpm_resampler_version0_fp16.axmodel" ;
3033 std::string filename_vpm_resampler_axmodedl = " minicpmv/vpm_resampler_version0_fp16.axmodel" ;
3134 int vpm_width = 280 ;
@@ -44,13 +47,13 @@ struct LLMAttrType
4447 int kv_cache_num = 1024 ; // auto calc
4548 int kv_cache_size = 256 ; // auto calc
4649
50+ float temperature = 0 .7f ;
51+ float top_p = 0 .9f ;
4752 bool b_use_mmap_load_embed = false ;
4853 bool b_dynamic_load_axmodel_layer = false ;
4954
5055 bool b_use_mmap_load_layer = true ;
5156
52- bool b_use_topk = false ;
53-
5457 // bool b_live_print = true;
5558 LLMRuningCallback runing_callback = nullptr ;
5659 void *reserve = nullptr ;
@@ -84,41 +87,37 @@ class LLM
8487
8588 bool b_stop = false ;
8689
87- static int FindMax (unsigned short *p, int n, float *val = 0 )
90+ int post_process (unsigned short *p, int n, std::vector< int > &history , float *val = 0 )
8891 {
89- float max_val = -MAXFLOAT;
90- int max_index = 0 ;
92+ std::vector<float > logits (n);
9193 for (int i = 0 ; i < n; i++)
9294 {
9395 unsigned int proc = p[i] << 16 ;
94- float tmp = *reinterpret_cast <float *>(&proc);
95- if (tmp > max_val)
96- {
97- max_val = tmp;
98- max_index = i;
99- }
96+ logits[i] = *reinterpret_cast <float *>(&proc);
10097 }
98+ LLMPostprocess postprocess;
99+ postprocess.set_temperature (true , _attr.temperature );
100+ postprocess.set_repetition_penalty (true , 1 .2f );
101+ // postprocess.set_top_k_sampling(true, 40);
102+ postprocess.set_top_p_sampling (true , _attr.top_p );
101103
102- // for (int i = 0; i < n; i += 4)
103- // {
104- // uint16x4_t bf16_data = vld1_u16(&p[i]);
105- // uint32x4_t float_data = vmovl_u16(bf16_data);
106- // float32x4_t tmp_floats = vreinterpretq_f32_u32(vshlq_n_u32(float_data, 16));
104+ return postprocess.apply (logits, history);
107105
108- // for (int j = 0; j < 4; j++)
106+ // float max_val = -MAXFLOAT;
107+ // int max_index = 0;
108+ // for (int i = 0; i < n; i++)
109+ // {
110+ // unsigned int proc = p[i] << 16;
111+ // float tmp = *reinterpret_cast<float *>(&proc);
112+ // if (tmp > max_val)
109113 // {
110- // float tmp = vgetq_lane_f32(tmp_floats, j);
111- // if (tmp > max_val)
112- // {
113- // max_val = tmp;
114- // max_index = i + j;
115- // }
114+ // max_val = tmp;
115+ // max_index = i;
116116 // }
117117 // }
118-
119- if (val)
120- *val = max_val;
121- return max_index;
118+ // if (val)
119+ // *val = max_val;
120+ // return max_index;
122121 }
123122
124123public:
@@ -552,7 +551,7 @@ class LLM
552551 AX_SYS_MinvalidateCache (output_post.phyAddr , output_post.pVirAddr , output_post.nSize );
553552 unsigned short *post_out = (unsigned short *)output_post.pVirAddr ;
554553 float max_val = -MAXFLOAT;
555- max_index = FindMax (post_out, _attr.tokens_embed_num , &max_val);
554+ max_index = post_process (post_out, _attr.tokens_embed_num , token_ids , &max_val);
556555 }
557556 next_token = max_index;
558557
@@ -654,7 +653,7 @@ class LLM
654653 AX_SYS_MinvalidateCache (output_post.phyAddr , output_post.pVirAddr , output_post.nSize );
655654 unsigned short *post_out = (unsigned short *)output_post.pVirAddr ;
656655 float max_val = -MAXFLOAT;
657- max_index = FindMax (post_out, _attr.tokens_embed_num , &max_val);
656+ max_index = post_process (post_out, _attr.tokens_embed_num , token_ids , &max_val);
658657 }
659658 next_token = max_index;
660659
@@ -676,7 +675,7 @@ class LLM
676675 if (_attr.runing_callback )
677676 {
678677 cached_token.push_back (max_index);
679- if (cached_token.size () >= 3 )
678+ if (cached_token.size () >= 5 )
680679 {
681680 float t_cost_ms = t_cost.cost ();
682681 float token_per_sec = token_ids.size () / (t_cost_ms / 1000 );
0 commit comments