|
26 | 26 | from ..activations import get_activation |
27 | 27 | from ..modeling_outputs import AutoencoderKLOutput |
28 | 28 | from ..modeling_utils import ModelMixin |
29 | | -from .vae import DecoderOutput, DiagonalGaussianDistribution |
| 29 | +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution |
30 | 30 |
|
31 | 31 |
|
32 | 32 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
@@ -584,7 +584,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
584 | 584 | return hidden_states |
585 | 585 |
|
586 | 586 |
|
587 | | -class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin): |
| 587 | +class AutoencoderKLHunyuanImageRefiner(ModelMixin, AutoencoderMixin, ConfigMixin): |
588 | 588 | r""" |
589 | 589 | A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for |
590 | 590 | HunyuanImage-2.1 Refiner. |
@@ -685,27 +685,6 @@ def enable_tiling( |
685 | 685 | self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width |
686 | 686 | self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor |
687 | 687 |
|
688 | | - def disable_tiling(self) -> None: |
689 | | - r""" |
690 | | - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing |
691 | | - decoding in one step. |
692 | | - """ |
693 | | - self.use_tiling = False |
694 | | - |
695 | | - def enable_slicing(self) -> None: |
696 | | - r""" |
697 | | - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to |
698 | | - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. |
699 | | - """ |
700 | | - self.use_slicing = True |
701 | | - |
702 | | - def disable_slicing(self) -> None: |
703 | | - r""" |
704 | | - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing |
705 | | - decoding in one step. |
706 | | - """ |
707 | | - self.use_slicing = False |
708 | | - |
709 | 688 | def _encode(self, x: torch.Tensor) -> torch.Tensor: |
710 | 689 | _, _, _, height, width = x.shape |
711 | 690 |
|
|
0 commit comments