55//
66// Copyright(C) 2026 Chris Warren-Smith
77
8- #include < cstdio>
9- #include < cstring>
10- #include < string>
118#include < vector>
12-
139#include " llama.h"
1410#include " llama-sb.h"
1511
@@ -46,11 +42,11 @@ const string Llama::build_chat_prompt(const string &user_msg) {
4642 return _chat_prompt;
4743}
4844
49- bool Llama::create (string model_path, int n_ctx, bool disable_log) {
45+ bool Llama::construct (string model_path, int n_ctx, bool disable_log) {
5046 if (disable_log) {
5147 // only print errors
5248 llama_log_set ([](enum ggml_log_level level, const char * text, void * /* user_data */ ) {
53- if (level >= GGML_LOG_LEVEL_ERROR) {
49+ if (level >= GGML_LOG_LEVEL_ERROR && text[ 0 ] != ' . ' && text[ 0 ] != ' \n ' ) {
5450 fprintf (stderr, " %s" , text);
5551 }
5652 }, nullptr );
@@ -59,7 +55,7 @@ bool Llama::create(string model_path, int n_ctx, bool disable_log) {
5955 ggml_backend_load_all ();
6056
6157 llama_model_params mparams = llama_model_default_params ();
62- mparams.n_gpu_layers = 0 ;
58+ mparams.n_gpu_layers = 99 ;
6359
6460 _model = llama_model_load_from_file (model_path.c_str (), mparams);
6561 if (!_model) {
@@ -68,81 +64,107 @@ bool Llama::create(string model_path, int n_ctx, bool disable_log) {
6864 llama_context_params cparams = llama_context_default_params ();
6965 cparams.n_ctx = n_ctx;
7066 cparams.n_batch = n_ctx;
67+ cparams.no_perf = true ;
7168
7269 _ctx = llama_init_from_model (_model, cparams);
7370 if (!_ctx) {
7471 _last_error = " failed to create context" ;
7572 } else {
7673 _vocab = llama_model_get_vocab (_model);
77- configure_sampler (0 );
7874 }
7975 }
8076 return _last_error.empty ();
8177}
8278
8379void Llama::configure_sampler (float temperature) {
8480 if (temperature != _temperature || _sampler == nullptr ) {
85- if (_sampler) {
81+ if (_sampler != nullptr ) {
8682 llama_sampler_free (_sampler);
8783 }
8884 auto sparams = llama_sampler_chain_default_params ();
85+ sparams.no_perf = false ;
8986 _sampler = llama_sampler_chain_init (sparams);
9087 _temperature = temperature;
9188
9289 // llama_sampler_chain_reset(sampler);
9390 if (temperature <= 0 .0f ) {
9491 llama_sampler_chain_add (_sampler, llama_sampler_init_greedy ());
9592 } else {
93+ llama_sampler_chain_add (_sampler, llama_sampler_init_min_p (0 .05f , 1 ));
9694 llama_sampler_chain_add (_sampler, llama_sampler_init_temp (temperature));
95+ llama_sampler_chain_add (_sampler, llama_sampler_init_dist (LLAMA_DEFAULT_SEED));
9796 }
9897 }
9998}
10099
101- static std::vector<llama_token> tokenize (const llama_vocab *vocab, const string &text) {
102- int n = -llama_tokenize (vocab, text.c_str (), text.size (), nullptr , 0 , true , true );
103- std::vector<llama_token> tokens (n);
104- llama_tokenize (vocab, text.c_str (), text.size (), tokens.data (), tokens.size (), true , true );
105- return tokens;
106- }
107-
108100string Llama::generate (const string &prompt, int max_tokens, float temperature, bool echo, bool clear_cache) {
109101 string out;
110102
111103 if (clear_cache) {
112104 // llama_kv_cache_clear(_ctx);
113105 }
114106
115- auto prompt_tokens = tokenize (_vocab, prompt);
107+ // find the number of tokens in the prompt
108+ int n_prompt = -llama_tokenize (_vocab, prompt.c_str (), prompt.size (), nullptr , 0 , true , true );
109+
110+ // allocate space for the tokens and tokenize the prompt
111+ std::vector<llama_token> prompt_tokens (n_prompt);
112+ if (llama_tokenize (_vocab, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), true , true ) < 0 ) {
113+ _last_error = " failed tokenize the prompt" ;
114+ return out;
115+ }
116+
117+ // initialize the sampler
116118 configure_sampler (temperature);
117119
120+ // prepare a batch for the prompt
118121 llama_batch batch = llama_batch_get_one (prompt_tokens.data (), prompt_tokens.size ());
122+ if (llama_model_has_encoder (_model)) {
123+ if (llama_encode (_ctx, batch)) {
124+ _last_error = " failed to eval" ;
125+ return out;
126+ }
119127
120- if (llama_decode (_ctx, batch)) {
121- _last_error = " decode failed" ;
122- return out;
128+ llama_token decoder_start_token_id = llama_model_decoder_start_token (_model);
129+ if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
130+ decoder_start_token_id = llama_vocab_bos (_vocab);
131+ }
132+
133+ batch = llama_batch_get_one (&decoder_start_token_id, 1 );
123134 }
124135
125136 if (echo) {
126137 out += prompt;
127138 }
128139
129- for (int i = 0 ; i < max_tokens; ++i ) {
130- llama_token tok = llama_sampler_sample (_sampler, _ctx, - 1 );
131-
132- if ( llama_vocab_is_eog (_vocab, tok)) {
140+ for (int n_pos = 0 ; n_pos + batch. n_tokens < n_prompt + max_tokens; ) {
141+ // evaluate the current batch with the transformer model
142+ if ( llama_decode (_ctx, batch)) {
143+ _last_error = " failed to eval " ;
133144 break ;
134145 }
135146
136- char buf[128 ];
137- int n = llama_token_to_piece (_vocab, tok, buf, sizeof (buf), 0 , true );
147+ n_pos += batch.n_tokens ;
138148
139- if (n > 0 ) {
140- out.append (buf, n);
149+ // sample the next token
150+ llama_token new_token_id = llama_sampler_sample (_sampler, _ctx, -1 );
151+
152+ // is it an end of generation?
153+ if (llama_vocab_is_eog (_vocab, new_token_id)) {
154+ break ;
141155 }
142- batch = llama_batch_get_one (&tok, 1 );
143- if (llama_decode (_ctx, batch)) {
156+
157+ char buf[128 ];
158+ int n = llama_token_to_piece (_vocab, new_token_id, buf, sizeof (buf), 0 , true );
159+ if (n < 0 ) {
160+ _last_error = " failed to convert token to piece" ;
144161 break ;
162+ } else if (n > 0 ) {
163+ out.append (buf, n);
145164 }
165+
166+ // prepare the next batch with the sampled token
167+ batch = llama_batch_get_one (&new_token_id, 1 );
146168 }
147169
148170 return out;
0 commit comments