Skip to content

Commit bbe0c2f

Browse files
author
Chris Warren-Smith
committed
LLM: plugin module - initial commit
1 parent 39189b7 commit bbe0c2f

4 files changed

Lines changed: 89 additions & 201 deletions

File tree

llama/llama-sb.cpp

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@
55
//
66
// Copyright(C) 2026 Chris Warren-Smith
77

8-
#include <cstdio>
9-
#include <cstring>
10-
#include <string>
118
#include <vector>
12-
139
#include "llama.h"
1410
#include "llama-sb.h"
1511

@@ -46,11 +42,11 @@ const string Llama::build_chat_prompt(const string &user_msg) {
4642
return _chat_prompt;
4743
}
4844

49-
bool Llama::create(string model_path, int n_ctx, bool disable_log) {
45+
bool Llama::construct(string model_path, int n_ctx, bool disable_log) {
5046
if (disable_log) {
5147
// only print errors
5248
llama_log_set([](enum ggml_log_level level, const char * text, void * /* user_data */) {
53-
if (level >= GGML_LOG_LEVEL_ERROR) {
49+
if (level >= GGML_LOG_LEVEL_ERROR && text[0] != '.' && text[0] != '\n') {
5450
fprintf(stderr, "%s", text);
5551
}
5652
}, nullptr);
@@ -59,7 +55,7 @@ bool Llama::create(string model_path, int n_ctx, bool disable_log) {
5955
ggml_backend_load_all();
6056

6157
llama_model_params mparams = llama_model_default_params();
62-
mparams.n_gpu_layers = 0;
58+
mparams.n_gpu_layers = 99;
6359

6460
_model = llama_model_load_from_file(model_path.c_str(), mparams);
6561
if (!_model) {
@@ -68,81 +64,107 @@ bool Llama::create(string model_path, int n_ctx, bool disable_log) {
6864
llama_context_params cparams = llama_context_default_params();
6965
cparams.n_ctx = n_ctx;
7066
cparams.n_batch = n_ctx;
67+
cparams.no_perf = true;
7168

7269
_ctx = llama_init_from_model(_model, cparams);
7370
if (!_ctx) {
7471
_last_error = "failed to create context";
7572
} else {
7673
_vocab = llama_model_get_vocab(_model);
77-
configure_sampler(0);
7874
}
7975
}
8076
return _last_error.empty();
8177
}
8278

8379
void Llama::configure_sampler(float temperature) {
8480
if (temperature != _temperature || _sampler == nullptr) {
85-
if (_sampler) {
81+
if (_sampler != nullptr) {
8682
llama_sampler_free(_sampler);
8783
}
8884
auto sparams = llama_sampler_chain_default_params();
85+
sparams.no_perf = false;
8986
_sampler = llama_sampler_chain_init(sparams);
9087
_temperature = temperature;
9188

9289
// llama_sampler_chain_reset(sampler);
9390
if (temperature <= 0.0f) {
9491
llama_sampler_chain_add(_sampler, llama_sampler_init_greedy());
9592
} else {
93+
llama_sampler_chain_add(_sampler, llama_sampler_init_min_p(0.05f, 1));
9694
llama_sampler_chain_add(_sampler, llama_sampler_init_temp(temperature));
95+
llama_sampler_chain_add(_sampler, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
9796
}
9897
}
9998
}
10099

101-
static std::vector<llama_token> tokenize(const llama_vocab *vocab, const string &text) {
102-
int n = -llama_tokenize(vocab, text.c_str(), text.size(), nullptr, 0, true, true);
103-
std::vector<llama_token> tokens(n);
104-
llama_tokenize(vocab, text.c_str(), text.size(), tokens.data(), tokens.size(), true, true);
105-
return tokens;
106-
}
107-
108100
string Llama::generate(const string &prompt, int max_tokens, float temperature, bool echo, bool clear_cache) {
109101
string out;
110102

111103
if (clear_cache) {
112104
// llama_kv_cache_clear(_ctx);
113105
}
114106

115-
auto prompt_tokens = tokenize(_vocab, prompt);
107+
// find the number of tokens in the prompt
108+
int n_prompt = -llama_tokenize(_vocab, prompt.c_str(), prompt.size(), nullptr, 0, true, true);
109+
110+
// allocate space for the tokens and tokenize the prompt
111+
std::vector<llama_token> prompt_tokens(n_prompt);
112+
if (llama_tokenize(_vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) {
113+
_last_error = "failed tokenize the prompt";
114+
return out;
115+
}
116+
117+
// initialize the sampler
116118
configure_sampler(temperature);
117119

120+
// prepare a batch for the prompt
118121
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
122+
if (llama_model_has_encoder(_model)) {
123+
if (llama_encode(_ctx, batch)) {
124+
_last_error = "failed to eval";
125+
return out;
126+
}
119127

120-
if (llama_decode(_ctx, batch)) {
121-
_last_error = "decode failed";
122-
return out;
128+
llama_token decoder_start_token_id = llama_model_decoder_start_token(_model);
129+
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
130+
decoder_start_token_id = llama_vocab_bos(_vocab);
131+
}
132+
133+
batch = llama_batch_get_one(&decoder_start_token_id, 1);
123134
}
124135

125136
if (echo) {
126137
out += prompt;
127138
}
128139

129-
for (int i = 0; i < max_tokens; ++i) {
130-
llama_token tok = llama_sampler_sample(_sampler, _ctx, -1);
131-
132-
if (llama_vocab_is_eog(_vocab, tok)) {
140+
for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + max_tokens;) {
141+
// evaluate the current batch with the transformer model
142+
if (llama_decode(_ctx, batch)) {
143+
_last_error = "failed to eval";
133144
break;
134145
}
135146

136-
char buf[128];
137-
int n = llama_token_to_piece(_vocab, tok, buf, sizeof(buf), 0, true);
147+
n_pos += batch.n_tokens;
138148

139-
if (n > 0) {
140-
out.append(buf, n);
149+
// sample the next token
150+
llama_token new_token_id = llama_sampler_sample(_sampler, _ctx, -1);
151+
152+
// is it an end of generation?
153+
if (llama_vocab_is_eog(_vocab, new_token_id)) {
154+
break;
141155
}
142-
batch = llama_batch_get_one(&tok, 1);
143-
if (llama_decode(_ctx, batch)) {
156+
157+
char buf[128];
158+
int n = llama_token_to_piece(_vocab, new_token_id, buf, sizeof(buf), 0, true);
159+
if (n < 0) {
160+
_last_error = "failed to convert token to piece";
144161
break;
162+
} else if (n > 0) {
163+
out.append(buf, n);
145164
}
165+
166+
// prepare the next batch with the sampled token
167+
batch = llama_batch_get_one(&new_token_id, 1);
146168
}
147169

148170
return out;

llama/llama-sb.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ struct Llama {
1818

1919
void append_response(const string &response);
2020
const string build_chat_prompt(const string &user_msg);
21-
bool create(string model_path, int n_ctx, bool disable_log);
21+
bool construct(string model_path, int n_ctx, bool disable_log);
2222
string generate(const string &prompt,
2323
int max_tokens = 128,
2424
float temperature = 0.8f,

llama/main.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ static int cmd_llama_generate(var_s *self, int argc, slib_par_t *arg, var_s *ret
117117
// run generation WITHOUT clearing cache
118118
string response = llama.generate(prompt, max_tokens, temperature, false, true);
119119
v_setstr(retval, response.c_str());
120+
result = 1;
120121
}
121122
}
122123
return result;
@@ -129,7 +130,7 @@ static int cmd_create_llama(int argc, slib_par_t *params, var_t *retval) {
129130
int disable_log = get_param_int(argc, params, 0, 1);
130131
int id = ++g_nextId;
131132
Llama &llama = g_map[id];
132-
if (llama.create(model, n_ctx, disable_log)) {
133+
if (llama.construct(model, n_ctx, disable_log)) {
133134
map_init_id(retval, id, CLASS_ID);
134135
v_create_callback(retval, "chat", cmd_llama_chat);
135136
v_create_callback(retval, "generate", cmd_llama_generate);

0 commit comments

Comments
 (0)