Skip to content

Commit 39189b7

Browse files
author
Chris Warren-Smith
committed
LLM: plugin module - initial commit
1 parent 50c5c2a commit 39189b7

3 files changed

Lines changed: 105 additions & 143 deletions

File tree

llama/llama-sb.cpp

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,30 @@ Llama::Llama() :
2222
_n_ctx(0) {
2323
}
2424

25+
Llama::~Llama() {
26+
if (_sampler) {
27+
llama_sampler_free(_sampler);
28+
}
29+
if (_ctx) {
30+
llama_free(_ctx);
31+
}
32+
if (_model) {
33+
llama_model_free(_model);
34+
}
35+
}
36+
37+
void Llama::append_response(const string &response) {
38+
_chat_prompt += response;
39+
_chat_prompt += "\n";
40+
}
41+
42+
const string Llama::build_chat_prompt(const string &user_msg) {
43+
_chat_prompt += "User: ";
44+
_chat_prompt += user_msg;
45+
_chat_prompt += "\nAssistant: ";
46+
return _chat_prompt;
47+
}
48+
2549
bool Llama::create(string model_path, int n_ctx, bool disable_log) {
2650
if (disable_log) {
2751
// only print errors
@@ -56,25 +80,6 @@ bool Llama::create(string model_path, int n_ctx, bool disable_log) {
5680
return _last_error.empty();
5781
}
5882

59-
Llama::~Llama() {
60-
if (_sampler) {
61-
llama_sampler_free(_sampler);
62-
}
63-
if (_ctx) {
64-
llama_free(_ctx);
65-
}
66-
if (_model) {
67-
llama_model_free(_model);
68-
}
69-
}
70-
71-
string Llama::build_chat_prompt(const string &user_msg) {
72-
_chat_prompt += "User: ";
73-
_chat_prompt += user_msg;
74-
_chat_prompt += "\nAssistant: ";
75-
return _chat_prompt;
76-
}
77-
7883
void Llama::configure_sampler(float temperature) {
7984
if (temperature != _temperature || _sampler == nullptr) {
8085
if (_sampler) {
@@ -143,3 +148,7 @@ string Llama::generate(const string &prompt, int max_tokens, float temperature,
143148
return out;
144149
}
145150

151+
void Llama::reset() {
152+
// llama_kv_cache_clear(it->second->ctx);
153+
_chat_prompt.clear();
154+
}

llama/llama-sb.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,18 @@ struct Llama {
1616
explicit Llama();
1717
~Llama();
1818

19+
void append_response(const string &response);
20+
const string build_chat_prompt(const string &user_msg);
1921
bool create(string model_path, int n_ctx, bool disable_log);
2022
string generate(const string &prompt,
2123
int max_tokens = 128,
2224
float temperature = 0.8f,
2325
bool echo = true,
2426
bool clear_cache = true);
2527
const char *last_error() { return _last_error.c_str(); }
28+
void reset();
2629

2730
private:
28-
string build_chat_prompt(const string &user_msg);
2931
void configure_sampler(float temperature);
3032

3133
llama_model *_model;

llama/main.cpp

Lines changed: 74 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
int g_nextId = 1;
2121
robin_hood::unordered_map<int, Llama> g_map;
2222

23-
static int get_id(var_s *map, var_s *retval) {
23+
static int get_class_id(var_s *map, var_s *retval) {
2424
int result = -1;
2525
if (is_map(map)) {
2626
int id = map->v.m.id;
@@ -34,130 +34,90 @@ static int get_id(var_s *map, var_s *retval) {
3434
return result;
3535
}
3636

37-
const char *llm_llama_chat(int id,
38-
const char * user_message,
39-
int max_tokens,
40-
float temperature) {
41-
static thread_local string result;
42-
43-
// auto it = g_llamas.find(h);
44-
// if (it == g_llamas.end()) {
45-
// result = "[invalid llama handle]";
46-
// return result.c_str();
47-
// }
37+
static string expand_path(const char *path) {
38+
string result;
39+
if (path && path[0] == '~') {
40+
const char *home = getenv("HOME");
41+
if (home != nullptr) {
42+
result.append(home);
43+
result.append(path + 1);
44+
} else {
45+
result = path;
46+
}
47+
} else {
48+
result = path;
49+
}
50+
return result;
51+
}
4852

49-
// Llama * llm = it->second.get();
53+
//
54+
// print llama.chat("Hello")
55+
//
56+
static int cmd_llama_chat(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
57+
int result = 0;
58+
if (argc < 1) {
59+
error(retval, "llama.chat", 1, 3);
60+
} else {
61+
int id = get_class_id(self, retval);
62+
if (id != -1) {
63+
Llama &llama = g_map.at(id);
64+
auto prompt = get_param_str(argc, arg, 0, "");
65+
int max_tokens = get_param_int(argc, arg, 0, 512);
66+
var_num_t temperature = get_param_num(argc, arg, 0, 0);
5067

51-
// // build accumulated prompt
52-
// string prompt = build_chat_prompt(llm, user_message);
68+
// build accumulated prompt
69+
string updated_prompt = llama.build_chat_prompt(prompt);
5370

54-
// // run generation WITHOUT clearing cache
55-
// result = llm->generate(prompt,
56-
// max_tokens,
57-
// temperature,
58-
// false); // echo = false
71+
// run generation WITHOUT clearing cache
72+
string response = llama.generate(updated_prompt, max_tokens, temperature, false, false);
5973

60-
// // append assistant reply to history
61-
// llm->chat_prompt += result;
62-
// llm->chat_prompt += "\n";
74+
// append assistant reply to history
75+
llama.append_response(response);
6376

64-
return result.c_str();
77+
v_setstr(retval, response.c_str());
78+
result = 1;
79+
}
80+
}
81+
return result;
6582
}
6683

6784
//
68-
// make the model forget everything
85+
// llama.reset() - make the model forget everything
6986
//
70-
void llm_llama_reset() {
71-
// std::lock_guard<std::mutex> lock(g_mutex);
72-
73-
// auto it = g_llamas.find(h);
74-
// if (it == g_llamas.end()) return;
75-
76-
// llama_kv_cache_clear(it->second->ctx);
77-
// it->second->chat_prompt.clear();
87+
static int cmd_llama_reset(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
88+
int result = 0;
89+
if (argc != 0) {
90+
error(retval, "llama.reset", 0, 0);
91+
} else {
92+
int id = get_class_id(self, retval);
93+
if (id != -1) {
94+
Llama &llama = g_map.at(id);
95+
llama.reset();
96+
result = 1;
97+
}
98+
}
99+
return result;
78100
}
79101

80102
//
81-
// string generate(prompt, max_tokens, temperature)
103+
// print llama.generate("please generate as simple program in BASIC to draw a cat", 1024, 0.8)
82104
//
83-
const char *llm_llama_generate(const char * prompt,
84-
int max_tokens,
85-
float temperature) {
86-
// static thread_local string result;
87-
88-
// std::lock_guard<std::mutex> lock(g_mutex);
89-
90-
// auto it = g_llamas.find(h);
91-
// if (it == g_llamas.end()) {
92-
// result = "[invalid llama handle]";
93-
// return result.c_str();
94-
// }
95-
96-
// try {
97-
// result = it->second->generate(prompt,
98-
// max_tokens,
99-
// temperature);
100-
// } catch (const std::exception & e) {
101-
// result = e.what();
102-
// }
103-
104-
// return result.c_str();
105-
return nullptr;
106-
}
107-
108-
109-
static int llm_llama_create(const char *model_path, int n_ctx) {
110-
// std::lock_guard<std::mutex> lock(g_mutex);
111-
112-
// llama_handle id = g_next_id++;
113-
114-
// try {
115-
// g_llamas[id] = std::make_unique<Llama>(model_path, n_ctx);
116-
// } catch (...) {
117-
// return 0;
118-
// }
119-
120-
return 0;
121-
}
122-
123-
void llm_llama_destroy() {
124-
// std::lock_guard<std::mutex> lock(g_mutex);
125-
// g_llamas.erase(h);
126-
}
127-
128-
// const char *llm_llama_generate(llama_handle h,
129-
// const char *prompt,
130-
// int max_tokens,
131-
// float temperature) {
132-
// static thread_local string result;
133-
134-
// auto it = g_llamas.find(h);
135-
// if (it == g_llamas.end()) {
136-
// result = "[invalid llama handle]";
137-
// return result.c_str();
138-
// }
139-
140-
// try {
141-
// result = it->second->generate(prompt, max_tokens, temperature);
142-
// } catch (const std::exception &e) {
143-
// result = e.what();
144-
// }
145-
146-
// return result.c_str();
147-
// }
148-
149-
string expand_path(const char *path) {
150-
string result;
151-
if (path && path[0] == '~') {
152-
const char *home = getenv("HOME");
153-
if (home != nullptr) {
154-
result.append(home);
155-
result.append(path + 1);
156-
} else {
157-
result = path;
158-
}
105+
static int cmd_llama_generate(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
106+
int result = 0;
107+
if (argc < 1) {
108+
error(retval, "llama.generate", 1, 3);
159109
} else {
160-
result = path;
110+
int id = get_class_id(self, retval);
111+
if (id != -1) {
112+
Llama &llama = g_map.at(id);
113+
auto prompt = get_param_str(argc, arg, 0, "");
114+
int max_tokens = get_param_int(argc, arg, 0, 512);
115+
var_num_t temperature = get_param_num(argc, arg, 0, 0);
116+
117+
// run generation WITHOUT clearing cache
118+
string response = llama.generate(prompt, max_tokens, temperature, false, true);
119+
v_setstr(retval, response.c_str());
120+
}
161121
}
162122
return result;
163123
}
@@ -171,18 +131,9 @@ static int cmd_create_llama(int argc, slib_par_t *params, var_t *retval) {
171131
Llama &llama = g_map[id];
172132
if (llama.create(model, n_ctx, disable_log)) {
173133
map_init_id(retval, id, CLASS_ID);
174-
175-
// v_create_callback(map, "getVoltage", cmd_analoginput_getvoltage);
176-
// v_create_callback(map, "getVoltageSync", cmd_analoginput_getvoltagesync);
177-
// v_create_callback(map, "getReference", cmd_analoginput_getreference);
178-
// v_create_callback(map, "read", cmd_analoginput_read);
179-
// v_create_callback(map, "readSync", cmd_analoginput_readsync);
180-
// v_create_callback(map, "getOverflowCount", cmd_analoginput_getoverflowcount);
181-
// v_create_callback(map, "available", cmd_analoginput_available);
182-
// v_create_callback(map, "readBuffered", cmd_analoginput_readbuffered);
183-
// v_create_callback(map, "getVoltageBuffered", cmd_analoginput_getvoltagebuffered);
184-
// v_create_callback(map, "getSampleRate", cmd_analoginput_getsamplerate);
185-
134+
v_create_callback(retval, "chat", cmd_llama_chat);
135+
v_create_callback(retval, "generate", cmd_llama_generate);
136+
v_create_callback(retval, "reset", cmd_llama_reset);
186137
result = 1;
187138
} else {
188139
error(retval, llama.last_error());

0 commit comments

Comments
 (0)