Skip to content

Commit e51cee7

Browse files
committed
[update] Optimize llm post-processing function
1 parent c66bc50 commit e51cee7

2 files changed

Lines changed: 343 additions & 27 deletions

File tree

projects/llm_framework/main_llm/src/runner/LLM.hpp

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "ax_cmm_utils.hpp"
1111
#include "cqdm.h"
1212
#include "timer.hpp"
13+
#include "LLMPostprocess.hpp"
1314

1415
#include <ax_sys_api.h>
1516

@@ -85,41 +86,37 @@ class LLM
8586

8687
bool b_stop = false;
8788

88-
static int FindMax(unsigned short *p, int n, float *val = 0)
89+
static int post_process(unsigned short *p, int n, std::vector<int> &history, float *val = 0)
8990
{
90-
float max_val = -MAXFLOAT;
91-
int max_index = 0;
91+
std::vector<float> logits(n);
9292
for (int i = 0; i < n; i++)
9393
{
9494
unsigned int proc = p[i] << 16;
95-
float tmp = *reinterpret_cast<float *>(&proc);
96-
if (tmp > max_val)
97-
{
98-
max_val = tmp;
99-
max_index = i;
100-
}
95+
logits[i] = *reinterpret_cast<float *>(&proc);
10196
}
97+
LLMPostprocess postprocess;
98+
postprocess.set_temperature(true, 0.8f);
99+
postprocess.set_repetition_penalty(true, 1.2f);
100+
// postprocess.set_top_k_sampling(true, 40);
101+
postprocess.set_top_p_sampling(true, 0.9f);
102102

103-
// for (int i = 0; i < n; i += 4)
104-
// {
105-
// uint16x4_t bf16_data = vld1_u16(&p[i]);
106-
// uint32x4_t float_data = vmovl_u16(bf16_data);
107-
// float32x4_t tmp_floats = vreinterpretq_f32_u32(vshlq_n_u32(float_data, 16));
103+
return postprocess.apply(logits, history);
108104

109-
// for (int j = 0; j < 4; j++)
105+
// float max_val = -MAXFLOAT;
106+
// int max_index = 0;
107+
// for (int i = 0; i < n; i++)
108+
// {
109+
// unsigned int proc = p[i] << 16;
110+
// float tmp = *reinterpret_cast<float *>(&proc);
111+
// if (tmp > max_val)
110112
// {
111-
// float tmp = vgetq_lane_f32(tmp_floats, j);
112-
// if (tmp > max_val)
113-
// {
114-
// max_val = tmp;
115-
// max_index = i + j;
116-
// }
113+
// max_val = tmp;
114+
// max_index = i;
117115
// }
118116
// }
119-
120-
if (val)
121-
*val = max_val;
122-
return max_index;
117+
// if (val)
118+
// *val = max_val;
119+
// return max_index;
123120
}
124121

125122
public:
@@ -456,7 +453,7 @@ class LLM
456453
AX_SYS_MinvalidateCache(output_post.phyAddr, output_post.pVirAddr, output_post.nSize);
457454
unsigned short *post_out = (unsigned short *)output_post.pVirAddr;
458455
float max_val = -MAXFLOAT;
459-
max_index = FindMax(post_out, _attr.tokens_embed_num, &max_val);
456+
max_index = post_process(post_out, _attr.tokens_embed_num, token_ids, &max_val);
460457
}
461458
next_token = max_index;
462459

@@ -558,7 +555,7 @@ class LLM
558555
AX_SYS_MinvalidateCache(output_post.phyAddr, output_post.pVirAddr, output_post.nSize);
559556
unsigned short *post_out = (unsigned short *)output_post.pVirAddr;
560557
float max_val = -MAXFLOAT;
561-
max_index = FindMax(post_out, _attr.tokens_embed_num, &max_val);
558+
max_index = post_process(post_out, _attr.tokens_embed_num, token_ids, &max_val);
562559
}
563560
next_token = max_index;
564561

0 commit comments

Comments
 (0)