@@ -501,11 +501,36 @@ class AutoEncoderKLModel : public GGMLBlock {
501501 bool double_z = true ;
502502 } dd_config;
503503
504+ static std::string get_tensor_name (const std::string& prefix, const std::string& name) {
505+ return prefix.empty () ? name : prefix + " ." + name;
506+ }
507+
508+ void detect_decoder_ch (const String2TensorStorage& tensor_storage_map,
509+ const std::string& prefix,
510+ int & decoder_ch) {
511+ auto conv_in_iter = tensor_storage_map.find (get_tensor_name (prefix, " decoder.conv_in.weight" ));
512+ if (conv_in_iter != tensor_storage_map.end () && conv_in_iter->second .n_dims >= 4 && conv_in_iter->second .ne [3 ] > 0 ) {
513+ int last_ch_mult = dd_config.ch_mult .back ();
514+ int64_t conv_in_out_channels = conv_in_iter->second .ne [3 ];
515+ if (last_ch_mult > 0 && conv_in_out_channels % last_ch_mult == 0 ) {
516+ decoder_ch = static_cast <int >(conv_in_out_channels / last_ch_mult);
517+ LOG_INFO (" vae decoder: ch = %d" , decoder_ch);
518+ } else {
519+ LOG_WARN (" vae decoder: failed to infer ch from %s (%" PRId64 " / %d)" ,
520+ get_tensor_name (prefix, " decoder.conv_in.weight" ).c_str (),
521+ conv_in_out_channels,
522+ last_ch_mult);
523+ }
524+ }
525+ }
526+
504527public:
505- AutoEncoderKLModel (SDVersion version = VERSION_SD1,
506- bool decode_only = true ,
507- bool use_linear_projection = false ,
508- bool use_video_decoder = false )
528+ AutoEncoderKLModel (SDVersion version = VERSION_SD1,
529+ bool decode_only = true ,
530+ bool use_linear_projection = false ,
531+ bool use_video_decoder = false ,
532+ const String2TensorStorage& tensor_storage_map = {},
533+ const std::string& prefix = " " )
509534 : version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) {
510535 if (sd_version_is_dit (version)) {
511536 if (sd_version_is_flux2 (version)) {
@@ -519,7 +544,9 @@ class AutoEncoderKLModel : public GGMLBlock {
519544 if (use_video_decoder) {
520545 use_quant = false ;
521546 }
522- blocks[" decoder" ] = std::shared_ptr<GGMLBlock>(new Decoder (dd_config.ch ,
547+ int decoder_ch = dd_config.ch ;
548+ detect_decoder_ch (tensor_storage_map, prefix, decoder_ch);
549+ blocks[" decoder" ] = std::shared_ptr<GGMLBlock>(new Decoder (decoder_ch,
523550 dd_config.out_ch ,
524551 dd_config.ch_mult ,
525552 dd_config.num_res_blocks ,
@@ -662,7 +689,7 @@ struct AutoEncoderKL : public VAE {
662689 break ;
663690 }
664691 }
665- ae = AutoEncoderKLModel (version, decode_only, use_linear_projection, use_video_decoder);
692+ ae = AutoEncoderKLModel (version, decode_only, use_linear_projection, use_video_decoder, tensor_storage_map, prefix );
666693 ae.init (params_ctx, tensor_storage_map, prefix);
667694 }
668695
0 commit comments