Skip to content

Commit 4e1a030

Browse files
committed
implement leejet#646
1 parent c0576f7 commit 4e1a030

1 file changed

Lines changed: 20 additions & 18 deletions

File tree

ggml_extend.hpp

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,15 +1907,7 @@ struct GGMLRunner {
19071907
return gf;
19081908
}
19091909

1910-
bool alloc_compute_buffer(get_graph_cb_t get_graph) {
1911-
if (compute_allocr != nullptr) {
1912-
return true;
1913-
}
1914-
reset_compute_ctx();
1915-
struct ggml_cgraph* gf = get_compute_graph(get_graph);
1916-
backend_tensor_data_map.clear();
1917-
compute_allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(runtime_backend));
1918-
1910+
bool alloc_compute_buffer(struct ggml_cgraph* gf) {
19191911
if (!ggml_gallocr_reserve(compute_allocr, gf)) {
19201912
// failed to allocate the compute buffer
19211913
LOG_ERROR("%s: failed to allocate the compute buffer\n", get_desc().c_str());
@@ -2180,25 +2172,35 @@ struct GGMLRunner {
21802172
return ggml_get_tensor(cache_ctx, name.c_str());
21812173
}
21822174

2183-
bool compute(get_graph_cb_t get_graph,
2184-
int n_threads,
2185-
bool free_compute_buffer_immediately = true,
2186-
struct ggml_tensor** output = nullptr,
2187-
struct ggml_context* output_ctx = nullptr) {
2175+
bool compute(
2176+
get_graph_cb_t get_graph,
2177+
int n_threads,
2178+
bool free_compute_buffer_immediately = true,
2179+
struct ggml_tensor** output = nullptr,
2180+
struct ggml_context* output_ctx = nullptr
2181+
) {
21882182
if (!offload_params_to_runtime_backend()) {
21892183
LOG_ERROR("%s offload params to runtime backend failed", get_desc().c_str());
21902184
return false;
21912185
}
2192-
if (!alloc_compute_buffer(get_graph)) {
2193-
LOG_ERROR("%s alloc compute buffer failed", get_desc().c_str());
2194-
return false;
2186+
2187+
bool buffer_initialized = compute_allocr == nullptr;
2188+
if (buffer_initialized) {
2189+
reset_compute_ctx();
2190+
2191+
compute_allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(runtime_backend));
2192+
backend_tensor_data_map.clear();
21952193
}
2196-
reset_compute_ctx();
2194+
21972195
struct ggml_cgraph* gf = get_compute_graph(get_graph);
2196+
2197+
if (buffer_initialized && !alloc_compute_buffer(gf)) return false;
2198+
21982199
if (!ggml_gallocr_alloc_graph(compute_allocr, gf)) {
21992200
LOG_ERROR("%s alloc compute graph failed", get_desc().c_str());
22002201
return false;
22012202
}
2203+
22022204
copy_data_to_backend_tensor();
22032205
if (ggml_backend_is_cpu(runtime_backend)) {
22042206
ggml_backend_cpu_set_n_threads(runtime_backend, n_threads);

0 commit comments

Comments
 (0)