|
10 | 10 | #include "ax_cmm_utils.hpp" |
11 | 11 | #include "cqdm.h" |
12 | 12 | #include "timer.hpp" |
| 13 | +#include "LLMPostprocess.hpp" |
13 | 14 |
|
14 | 15 | #include <ax_sys_api.h> |
15 | 16 |
|
@@ -85,41 +86,37 @@ class LLM |
85 | 86 |
|
86 | 87 | bool b_stop = false; |
87 | 88 |
|
88 | | - static int FindMax(unsigned short *p, int n, float *val = 0) |
| 89 | + static int post_process(unsigned short *p, int n, std::vector<int> &history, float *val = 0) |
89 | 90 | { |
90 | | - float max_val = -MAXFLOAT; |
91 | | - int max_index = 0; |
| 91 | + std::vector<float> logits(n); |
92 | 92 | for (int i = 0; i < n; i++) |
93 | 93 | { |
94 | 94 | unsigned int proc = p[i] << 16; |
95 | | - float tmp = *reinterpret_cast<float *>(&proc); |
96 | | - if (tmp > max_val) |
97 | | - { |
98 | | - max_val = tmp; |
99 | | - max_index = i; |
100 | | - } |
| 95 | + logits[i] = *reinterpret_cast<float *>(&proc); |
101 | 96 | } |
| 97 | + LLMPostprocess postprocess; |
| 98 | + postprocess.set_temperature(true, 0.8f); |
| 99 | + postprocess.set_repetition_penalty(true, 1.2f); |
| 100 | + // postprocess.set_top_k_sampling(true, 40); |
| 101 | + postprocess.set_top_p_sampling(true, 0.9f); |
102 | 102 |
|
103 | | - // for (int i = 0; i < n; i += 4) |
104 | | - // { |
105 | | - // uint16x4_t bf16_data = vld1_u16(&p[i]); |
106 | | - // uint32x4_t float_data = vmovl_u16(bf16_data); |
107 | | - // float32x4_t tmp_floats = vreinterpretq_f32_u32(vshlq_n_u32(float_data, 16)); |
| 103 | + return postprocess.apply(logits, history); |
108 | 104 |
|
109 | | - // for (int j = 0; j < 4; j++) |
| 105 | + // float max_val = -MAXFLOAT; |
| 106 | + // int max_index = 0; |
| 107 | + // for (int i = 0; i < n; i++) |
| 108 | + // { |
| 109 | + // unsigned int proc = p[i] << 16; |
| 110 | + // float tmp = *reinterpret_cast<float *>(&proc); |
| 111 | + // if (tmp > max_val) |
110 | 112 | // { |
111 | | - // float tmp = vgetq_lane_f32(tmp_floats, j); |
112 | | - // if (tmp > max_val) |
113 | | - // { |
114 | | - // max_val = tmp; |
115 | | - // max_index = i + j; |
116 | | - // } |
| 113 | + // max_val = tmp; |
| 114 | + // max_index = i; |
117 | 115 | // } |
118 | 116 | // } |
119 | | - |
120 | | - if (val) |
121 | | - *val = max_val; |
122 | | - return max_index; |
| 117 | + // if (val) |
| 118 | + // *val = max_val; |
| 119 | + // return max_index; |
123 | 120 | } |
124 | 121 |
|
125 | 122 | public: |
@@ -456,7 +453,7 @@ class LLM |
456 | 453 | AX_SYS_MinvalidateCache(output_post.phyAddr, output_post.pVirAddr, output_post.nSize); |
457 | 454 | unsigned short *post_out = (unsigned short *)output_post.pVirAddr; |
458 | 455 | float max_val = -MAXFLOAT; |
459 | | - max_index = FindMax(post_out, _attr.tokens_embed_num, &max_val); |
| 456 | + max_index = post_process(post_out, _attr.tokens_embed_num, token_ids, &max_val); |
460 | 457 | } |
461 | 458 | next_token = max_index; |
462 | 459 |
|
@@ -558,7 +555,7 @@ class LLM |
558 | 555 | AX_SYS_MinvalidateCache(output_post.phyAddr, output_post.pVirAddr, output_post.nSize); |
559 | 556 | unsigned short *post_out = (unsigned short *)output_post.pVirAddr; |
560 | 557 | float max_val = -MAXFLOAT; |
561 | | - max_index = FindMax(post_out, _attr.tokens_embed_num, &max_val); |
| 558 | + max_index = post_process(post_out, _attr.tokens_embed_num, token_ids, &max_val); |
562 | 559 | } |
563 | 560 | next_token = max_index; |
564 | 561 |
|
|
0 commit comments