Skip to content

Commit e8323ca

Browse files
authored
feat: add flux2 small decoder support (#1402)
1 parent dd75372 commit e8323ca

2 files changed

Lines changed: 37 additions & 6 deletions

File tree

docs/flux2.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
- gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main
99
- Download vae
1010
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
11+
- Download FLUX.2-small-decoder (full_encoder_small_decoder.safetensors) as an alternative VAE option
12+
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-small-decoder/tree/main
1113
- Download Mistral-Small-3.2-24B-Instruct-2506-GGUF
1214
- gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main
1315

@@ -31,6 +33,8 @@
3133
- gguf: https://huggingface.co/leejet/FLUX.2-klein-base-4B-GGUF/tree/main
3234
- Download vae
3335
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
36+
- Download FLUX.2-small-decoder (full_encoder_small_decoder.safetensors) as an alternative VAE option
37+
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-small-decoder/tree/main
3438
- Download Qwen3 4b
3539
- safetensors: https://huggingface.co/Comfy-Org/flux2-klein-4B/tree/main/split_files/text_encoders
3640
- gguf: https://huggingface.co/unsloth/Qwen3-4B-GGUF/tree/main

src/auto_encoder_kl.hpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -501,11 +501,36 @@ class AutoEncoderKLModel : public GGMLBlock {
501501
bool double_z = true;
502502
} dd_config;
503503

504+
static std::string get_tensor_name(const std::string& prefix, const std::string& name) {
505+
return prefix.empty() ? name : prefix + "." + name;
506+
}
507+
508+
void detect_decoder_ch(const String2TensorStorage& tensor_storage_map,
509+
const std::string& prefix,
510+
int& decoder_ch) {
511+
auto conv_in_iter = tensor_storage_map.find(get_tensor_name(prefix, "decoder.conv_in.weight"));
512+
if (conv_in_iter != tensor_storage_map.end() && conv_in_iter->second.n_dims >= 4 && conv_in_iter->second.ne[3] > 0) {
513+
int last_ch_mult = dd_config.ch_mult.back();
514+
int64_t conv_in_out_channels = conv_in_iter->second.ne[3];
515+
if (last_ch_mult > 0 && conv_in_out_channels % last_ch_mult == 0) {
516+
decoder_ch = static_cast<int>(conv_in_out_channels / last_ch_mult);
517+
LOG_INFO("vae decoder: ch = %d", decoder_ch);
518+
} else {
519+
LOG_WARN("vae decoder: failed to infer ch from %s (%" PRId64 " / %d)",
520+
get_tensor_name(prefix, "decoder.conv_in.weight").c_str(),
521+
conv_in_out_channels,
522+
last_ch_mult);
523+
}
524+
}
525+
}
526+
504527
public:
505-
AutoEncoderKLModel(SDVersion version = VERSION_SD1,
506-
bool decode_only = true,
507-
bool use_linear_projection = false,
508-
bool use_video_decoder = false)
528+
AutoEncoderKLModel(SDVersion version = VERSION_SD1,
529+
bool decode_only = true,
530+
bool use_linear_projection = false,
531+
bool use_video_decoder = false,
532+
const String2TensorStorage& tensor_storage_map = {},
533+
const std::string& prefix = "")
509534
: version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) {
510535
if (sd_version_is_dit(version)) {
511536
if (sd_version_is_flux2(version)) {
@@ -519,7 +544,9 @@ class AutoEncoderKLModel : public GGMLBlock {
519544
if (use_video_decoder) {
520545
use_quant = false;
521546
}
522-
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(dd_config.ch,
547+
int decoder_ch = dd_config.ch;
548+
detect_decoder_ch(tensor_storage_map, prefix, decoder_ch);
549+
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(decoder_ch,
523550
dd_config.out_ch,
524551
dd_config.ch_mult,
525552
dd_config.num_res_blocks,
@@ -662,7 +689,7 @@ struct AutoEncoderKL : public VAE {
662689
break;
663690
}
664691
}
665-
ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder);
692+
ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder, tensor_storage_map, prefix);
666693
ae.init(params_ctx, tensor_storage_map, prefix);
667694
}
668695

0 commit comments

Comments
 (0)