Skip to content

Commit 44cca3d

Browse files
authored
feat: support safetensors export in convert mode (#1444)
1 parent 0a7ae07 commit 44cca3d

10 files changed

Lines changed: 257 additions & 120 deletions

File tree

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,10 @@ API and command-line option may change frequently.***
7777
- OpenCL
7878
- SYCL
7979
- Supported weight formats
80-
- Pytorch checkpoint (`.ckpt` or `.pth`)
80+
- Pytorch checkpoint (`.ckpt` or `.pth` or `.pt`)
8181
- Safetensors (`.safetensors`)
8282
- GGUF (`.gguf`)
83+
- Convert mode supports converting model weights to `.gguf` or `.safetensors`
8384
- Supported platforms
8485
- Linux
8586
- Mac OS

examples/cli/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ CLI Options:
1414
--metadata-format <string> metadata output format, one of [text, json] (default: text)
1515
--canny apply canny preprocessor (edge detection)
1616
--convert-name convert tensor name (for convert mode)
17+
convert mode writes `.gguf` or `.safetensors` based on the output extension.
18+
`.safetensors` export currently supports f16, bf16, f32, and i32 tensor types only.
19+
i32 is passthrough only; no f32 <-> i32 conversion is performed
1720
-v, --verbose print extra info
1821
--color colors the logging tags according to level
1922
--taesd-preview-only prevents usage of taesd for decoding the final image. (for use with --preview tae)

src/convert.cpp

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#include <cstring>
2+
#include <mutex>
3+
#include <regex>
4+
#include <vector>
5+
6+
#include "model.h"
7+
#include "model_io/gguf_io.h"
8+
#include "model_io/safetensors_io.h"
9+
#include "util.h"
10+
11+
#include "ggml-cpu.h"
12+
13+
static ggml_type get_export_tensor_type(ModelLoader& model_loader,
14+
const TensorStorage& tensor_storage,
15+
ggml_type type,
16+
const TensorTypeRules& tensor_type_rules) {
17+
const std::string& name = tensor_storage.name;
18+
ggml_type tensor_type = tensor_storage.type;
19+
ggml_type dst_type = type;
20+
21+
for (const auto& tensor_type_rule : tensor_type_rules) {
22+
std::regex pattern(tensor_type_rule.first);
23+
if (std::regex_search(name, pattern)) {
24+
dst_type = tensor_type_rule.second;
25+
break;
26+
}
27+
}
28+
29+
if (model_loader.tensor_should_be_converted(tensor_storage, dst_type)) {
30+
tensor_type = dst_type;
31+
}
32+
33+
return tensor_type;
34+
}
35+
36+
static bool load_tensors_for_export(ModelLoader& model_loader,
37+
ggml_context* ggml_ctx,
38+
ggml_type type,
39+
const TensorTypeRules& tensor_type_rules,
40+
std::vector<TensorWriteInfo>& tensors) {
41+
std::mutex tensor_mutex;
42+
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
43+
const std::string& name = tensor_storage.name;
44+
ggml_type tensor_type = get_export_tensor_type(model_loader, tensor_storage, type, tensor_type_rules);
45+
46+
std::lock_guard<std::mutex> lock(tensor_mutex);
47+
ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
48+
if (tensor == nullptr) {
49+
LOG_ERROR("ggml_new_tensor failed");
50+
return false;
51+
}
52+
ggml_set_name(tensor, name.c_str());
53+
54+
if (!tensor->data) {
55+
GGML_ASSERT(ggml_nelements(tensor) == 0);
56+
// Avoid crashing writers by setting a dummy pointer for zero-sized tensors.
57+
LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str());
58+
tensor->data = ggml_get_mem_buffer(ggml_ctx);
59+
}
60+
61+
TensorWriteInfo write_info;
62+
write_info.tensor = tensor;
63+
write_info.n_dims = tensor_storage.n_dims;
64+
for (int i = 0; i < tensor_storage.n_dims; ++i) {
65+
write_info.ne[i] = tensor_storage.ne[i];
66+
}
67+
68+
*dst_tensor = tensor;
69+
tensors.push_back(std::move(write_info));
70+
71+
return true;
72+
};
73+
74+
bool success = model_loader.load_tensors(on_new_tensor_cb);
75+
LOG_INFO("load tensors done");
76+
return success;
77+
}
78+
79+
bool convert(const char* input_path,
80+
const char* vae_path,
81+
const char* output_path,
82+
sd_type_t output_type,
83+
const char* tensor_type_rules,
84+
bool convert_name) {
85+
ModelLoader model_loader;
86+
87+
if (!model_loader.init_from_file(input_path)) {
88+
LOG_ERROR("init model loader from file failed: '%s'", input_path);
89+
return false;
90+
}
91+
92+
if (vae_path != nullptr && strlen(vae_path) > 0) {
93+
if (!model_loader.init_from_file(vae_path, "vae.")) {
94+
LOG_ERROR("init model loader from file failed: '%s'", vae_path);
95+
return false;
96+
}
97+
}
98+
if (convert_name) {
99+
model_loader.convert_tensors_name();
100+
}
101+
102+
ggml_type type = (ggml_type)output_type;
103+
bool output_is_safetensors = ends_with(output_path, ".safetensors");
104+
TensorTypeRules type_rules = parse_tensor_type_rules(tensor_type_rules);
105+
106+
auto backend = ggml_backend_cpu_init();
107+
size_t mem_size = 1 * 1024 * 1024; // for padding
108+
mem_size += model_loader.get_tensor_storage_map().size() * ggml_tensor_overhead();
109+
mem_size += model_loader.get_params_mem_size(backend, type);
110+
LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f);
111+
ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false});
112+
113+
if (ggml_ctx == nullptr) {
114+
LOG_ERROR("ggml_init failed for converter");
115+
ggml_backend_free(backend);
116+
return false;
117+
}
118+
119+
std::vector<TensorWriteInfo> tensors;
120+
bool success = load_tensors_for_export(model_loader, ggml_ctx, type, type_rules, tensors);
121+
ggml_backend_free(backend);
122+
123+
std::string error;
124+
if (success) {
125+
if (output_is_safetensors) {
126+
success = write_safetensors_file(output_path, tensors, &error);
127+
} else {
128+
success = write_gguf_file(output_path, tensors, &error);
129+
}
130+
}
131+
132+
if (!success && !error.empty()) {
133+
LOG_ERROR("%s", error.c_str());
134+
}
135+
136+
ggml_free(ggml_ctx);
137+
return success;
138+
}

src/model.cpp

Lines changed: 3 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ const char* unused_tensors[] = {
8181
"first_stage_model.bn.",
8282
};
8383

84-
bool is_unused_tensor(std::string name) {
84+
bool is_unused_tensor(const std::string& name) {
8585
for (size_t i = 0; i < sizeof(unused_tensors) / sizeof(const char*); i++) {
8686
if (starts_with(name, unused_tensors[i])) {
8787
return true;
@@ -687,8 +687,8 @@ std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
687687
return wtype_stat;
688688
}
689689

690-
static std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
691-
std::vector<std::pair<std::string, ggml_type>> result;
690+
TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules) {
691+
TensorTypeRules result;
692692
for (const auto& item : split_string(tensor_type_rules, ',')) {
693693
if (item.size() == 0)
694694
continue;
@@ -1121,91 +1121,6 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
11211121
return false;
11221122
}
11231123

1124-
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) {
1125-
auto tensor_type_rules = parse_tensor_type_rules(tensor_type_rules_str);
1126-
auto get_tensor_type = [&](const TensorStorage& tensor_storage) -> ggml_type {
1127-
const std::string& name = tensor_storage.name;
1128-
ggml_type tensor_type = tensor_storage.type;
1129-
ggml_type dst_type = type;
1130-
1131-
for (const auto& tensor_type_rule : tensor_type_rules) {
1132-
std::regex pattern(tensor_type_rule.first);
1133-
if (std::regex_search(name, pattern)) {
1134-
dst_type = tensor_type_rule.second;
1135-
break;
1136-
}
1137-
}
1138-
1139-
if (tensor_should_be_converted(tensor_storage, dst_type)) {
1140-
tensor_type = dst_type;
1141-
}
1142-
1143-
return tensor_type;
1144-
};
1145-
1146-
auto backend = ggml_backend_cpu_init();
1147-
size_t mem_size = 1 * 1024 * 1024; // for padding
1148-
mem_size += tensor_storage_map.size() * ggml_tensor_overhead();
1149-
mem_size += get_params_mem_size(backend, type);
1150-
LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f);
1151-
ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false});
1152-
1153-
if (ggml_ctx == nullptr) {
1154-
LOG_ERROR("ggml_init failed for GGUF writer");
1155-
ggml_backend_free(backend);
1156-
return false;
1157-
}
1158-
1159-
std::vector<ggml_tensor*> tensors;
1160-
std::mutex tensor_mutex;
1161-
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
1162-
const std::string& name = tensor_storage.name;
1163-
ggml_type tensor_type = get_tensor_type(tensor_storage);
1164-
1165-
std::lock_guard<std::mutex> lock(tensor_mutex);
1166-
ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
1167-
if (tensor == nullptr) {
1168-
LOG_ERROR("ggml_new_tensor failed");
1169-
return false;
1170-
}
1171-
ggml_set_name(tensor, name.c_str());
1172-
1173-
// LOG_DEBUG("%s %d %s %d[%d %d %d %d] %d[%d %d %d %d]", name.c_str(),
1174-
// ggml_nbytes(tensor), ggml_type_name(tensor_type),
1175-
// tensor_storage.n_dims,
1176-
// tensor_storage.ne[0], tensor_storage.ne[1], tensor_storage.ne[2], tensor_storage.ne[3],
1177-
// tensor->n_dims, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
1178-
1179-
if (!tensor->data) {
1180-
GGML_ASSERT(ggml_nelements(tensor) == 0);
1181-
// avoid crashing the gguf writer by setting a dummy pointer for zero-sized tensors
1182-
LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str());
1183-
tensor->data = ggml_get_mem_buffer(ggml_ctx);
1184-
}
1185-
1186-
*dst_tensor = tensor;
1187-
tensors.push_back(tensor);
1188-
1189-
return true;
1190-
};
1191-
1192-
bool success = load_tensors(on_new_tensor_cb);
1193-
ggml_backend_free(backend);
1194-
LOG_INFO("load tensors done");
1195-
1196-
std::string error;
1197-
if (success) {
1198-
success = write_gguf_file(file_path, tensors, &error);
1199-
}
1200-
1201-
if (!success && !error.empty()) {
1202-
LOG_ERROR("%s", error.c_str());
1203-
}
1204-
1205-
ggml_free(ggml_ctx);
1206-
return success;
1207-
}
1208-
12091124
int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) {
12101125
size_t alignment = 128;
12111126
if (backend != nullptr) {
@@ -1225,28 +1140,3 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
12251140

12261141
return mem_size;
12271142
}
1228-
1229-
bool convert(const char* input_path,
1230-
const char* vae_path,
1231-
const char* output_path,
1232-
sd_type_t output_type,
1233-
const char* tensor_type_rules,
1234-
bool convert_name) {
1235-
ModelLoader model_loader;
1236-
1237-
if (!model_loader.init_from_file(input_path)) {
1238-
LOG_ERROR("init model loader from file failed: '%s'", input_path);
1239-
return false;
1240-
}
1241-
1242-
if (vae_path != nullptr && strlen(vae_path) > 0) {
1243-
if (!model_loader.init_from_file(vae_path, "vae.")) {
1244-
LOG_ERROR("init model loader from file failed: '%s'", vae_path);
1245-
return false;
1246-
}
1247-
}
1248-
if (convert_name) {
1249-
model_loader.convert_tensors_name();
1250-
}
1251-
return model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules);
1252-
}

src/model.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ enum PMVersion {
189189
};
190190

191191
typedef OrderedMap<std::string, TensorStorage> String2TensorStorage;
192+
using TensorTypeRules = std::vector<std::pair<std::string, ggml_type>>;
193+
194+
TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules);
192195

193196
class ModelLoader {
194197
protected:
@@ -231,7 +234,6 @@ class ModelLoader {
231234
return names;
232235
}
233236

234-
bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules);
235237
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
236238
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
237239
~ModelLoader() = default;

src/model_io/gguf_io.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,16 @@ bool read_gguf_file(const std::string& file_path,
9595
}
9696

9797
bool write_gguf_file(const std::string& file_path,
98-
const std::vector<ggml_tensor*>& tensors,
98+
const std::vector<TensorWriteInfo>& tensors,
9999
std::string* error) {
100100
gguf_context* gguf_ctx = gguf_init_empty();
101101
if (gguf_ctx == nullptr) {
102102
set_error(error, "gguf_init_empty failed");
103103
return false;
104104
}
105105

106-
for (ggml_tensor* tensor : tensors) {
106+
for (const TensorWriteInfo& write_tensor : tensors) {
107+
ggml_tensor* tensor = write_tensor.tensor;
107108
if (tensor == nullptr) {
108109
set_error(error, "null tensor cannot be written to GGUF");
109110
gguf_free(gguf_ctx);

src/model_io/gguf_io.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ bool read_gguf_file(const std::string& file_path,
1111
std::vector<TensorStorage>& tensor_storages,
1212
std::string* error = nullptr);
1313
bool write_gguf_file(const std::string& file_path,
14-
const std::vector<ggml_tensor*>& tensors,
14+
const std::vector<TensorWriteInfo>& tensors,
1515
std::string* error = nullptr);
1616

1717
#endif // __SD_MODEL_IO_GGUF_IO_H__

0 commit comments

Comments
 (0)