Skip to content

Commit 9e2c60e

Browse files
author
Chris Warren-Smith
committed
LLM: plugin module - initial commit
1 parent bbe0c2f commit 9e2c60e

4 files changed

Lines changed: 39 additions & 56 deletions

File tree

llama/llama-sb.cpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,9 @@ void Llama::configure_sampler(float temperature) {
9797
}
9898
}
9999

100-
string Llama::generate(const string &prompt, int max_tokens, float temperature, bool echo, bool clear_cache) {
100+
string Llama::generate(const string &prompt, int max_tokens, float temperature) {
101101
string out;
102102

103-
if (clear_cache) {
104-
// llama_kv_cache_clear(_ctx);
105-
}
106-
107103
// find the number of tokens in the prompt
108104
int n_prompt = -llama_tokenize(_vocab, prompt.c_str(), prompt.size(), nullptr, 0, true, true);
109105

@@ -133,10 +129,6 @@ string Llama::generate(const string &prompt, int max_tokens, float temperature,
133129
batch = llama_batch_get_one(&decoder_start_token_id, 1);
134130
}
135131

136-
if (echo) {
137-
out += prompt;
138-
}
139-
140132
for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + max_tokens;) {
141133
// evaluate the current batch with the transformer model
142134
if (llama_decode(_ctx, batch)) {

llama/llama-sb.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,7 @@ struct Llama {
1919
void append_response(const string &response);
2020
const string build_chat_prompt(const string &user_msg);
2121
bool construct(string model_path, int n_ctx, bool disable_log);
22-
string generate(const string &prompt,
23-
int max_tokens = 128,
24-
float temperature = 0.8f,
25-
bool echo = true,
26-
bool clear_cache = true);
22+
string generate(const string &prompt, int max_tokens, float temperature);
2723
const char *last_error() { return _last_error.c_str(); }
2824
void reset();
2925

llama/main.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ static int cmd_llama_chat(var_s *self, int argc, slib_par_t *arg, var_s *retval)
6262
if (id != -1) {
6363
Llama &llama = g_map.at(id);
6464
auto prompt = get_param_str(argc, arg, 0, "");
65-
int max_tokens = get_param_int(argc, arg, 0, 512);
66-
var_num_t temperature = get_param_num(argc, arg, 0, 0);
65+
int max_tokens = get_param_int(argc, arg, 1, 32);
66+
var_num_t temperature = get_param_num(argc, arg, 2, 0.8f);
6767

6868
// build accumulated prompt
6969
string updated_prompt = llama.build_chat_prompt(prompt);
7070

7171
// run generation WITHOUT clearing cache
72-
string response = llama.generate(updated_prompt, max_tokens, temperature, false, false);
72+
string response = llama.generate(updated_prompt, max_tokens, temperature);
7373

7474
// append assistant reply to history
7575
llama.append_response(response);
@@ -111,11 +111,9 @@ static int cmd_llama_generate(var_s *self, int argc, slib_par_t *arg, var_s *ret
111111
if (id != -1) {
112112
Llama &llama = g_map.at(id);
113113
auto prompt = get_param_str(argc, arg, 0, "");
114-
int max_tokens = get_param_int(argc, arg, 0, 512);
115-
var_num_t temperature = get_param_num(argc, arg, 0, 0);
116-
117-
// run generation WITHOUT clearing cache
118-
string response = llama.generate(prompt, max_tokens, temperature, false, true);
114+
int max_tokens = get_param_int(argc, arg, 1, 32);
115+
var_num_t temperature = get_param_num(argc, arg, 2, 0.8f);
116+
string response = llama.generate(prompt, max_tokens, temperature);
119117
v_setstr(retval, response.c_str());
120118
result = 1;
121119
}
@@ -127,7 +125,7 @@ static int cmd_create_llama(int argc, slib_par_t *params, var_t *retval) {
127125
int result;
128126
auto model = expand_path(get_param_str(argc, params, 0, ""));
129127
int n_ctx = get_param_int(argc, params, 0, 2048);
130-
int disable_log = get_param_int(argc, params, 0, 1);
128+
int disable_log = get_param_int(argc, params, 1, 1);
131129
int id = ++g_nextId;
132130
Llama &llama = g_map[id];
133131
if (llama.construct(model, n_ctx, disable_log)) {

llama/test_main.cpp

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "llama-sb.h"
2-
32
#include <cstdio>
43
#include <cstring>
54

@@ -18,49 +17,47 @@ int main(int argc, char ** argv) {
1817
int n_predict = 32;
1918

2019
// parse command line arguments
21-
{
22-
int i = 1;
23-
for (; i < argc; i++) {
24-
if (strcmp(argv[i], "-m") == 0) {
25-
if (i + 1 < argc) {
26-
model_path = argv[++i];
27-
} else {
28-
print_usage(argc, argv);
29-
return 1;
30-
}
31-
} else if (strcmp(argv[i], "-n") == 0) {
32-
if (i + 1 < argc) {
33-
try {
34-
n_predict = std::stoi(argv[++i]);
35-
} catch (...) {
36-
print_usage(argc, argv);
37-
return 1;
38-
}
39-
} else {
20+
int i = 1;
21+
for (; i < argc; i++) {
22+
if (strcmp(argv[i], "-m") == 0) {
23+
if (i + 1 < argc) {
24+
model_path = argv[++i];
25+
} else {
26+
print_usage(argc, argv);
27+
return 1;
28+
}
29+
} else if (strcmp(argv[i], "-n") == 0) {
30+
if (i + 1 < argc) {
31+
try {
32+
n_predict = std::stoi(argv[++i]);
33+
} catch (...) {
4034
print_usage(argc, argv);
4135
return 1;
4236
}
4337
} else {
44-
// prompt starts here
45-
break;
38+
print_usage(argc, argv);
39+
return 1;
4640
}
41+
} else {
42+
// prompt starts here
43+
break;
4744
}
48-
if (model_path.empty()) {
49-
print_usage(argc, argv);
50-
return 1;
51-
}
52-
if (i < argc) {
53-
prompt = argv[i++];
54-
for (; i < argc; i++) {
55-
prompt += " ";
56-
prompt += argv[i];
57-
}
45+
}
46+
if (model_path.empty()) {
47+
print_usage(argc, argv);
48+
return 1;
49+
}
50+
if (i < argc) {
51+
prompt = argv[i++];
52+
for (; i < argc; i++) {
53+
prompt += " ";
54+
prompt += argv[i];
5855
}
5956
}
6057

6158
Llama llama;
6259
if (llama.construct(model_path, 1024, true)) {
63-
string out = llama. generate(prompt, n_predict, 0.8f, true, true);
60+
string out = llama. generate(prompt, n_predict, 0.8f);
6461
printf("\033[33m");
6562
printf(out.c_str());
6663
printf("\n\033[0m");

0 commit comments

Comments
 (0)