|
30 | 30 | ) |
31 | 31 | from ...utils.torch_utils import randn_tensor |
32 | 32 | from ..pipeline_utils import DiffusionPipeline |
33 | | -from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_timesteps |
| 33 | +from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_timesteps |
34 | 34 | from .pipeline_output import FluxPipelineOutput |
35 | 35 |
|
36 | 36 |
|
|
81 | 81 |
|
82 | 82 | class FluxControlPipeline( |
83 | 83 | DiffusionPipeline, |
84 | | - FluxControlMixin, |
| 84 | + FluxMixin, |
85 | 85 | FluxLoraLoaderMixin, |
86 | 86 | FromSingleFileMixin, |
87 | 87 | TextualInversionLoaderMixin, |
@@ -235,6 +235,41 @@ def prepare_latents( |
235 | 235 |
|
236 | 236 | return latents, latent_image_ids |
237 | 237 |
|
| 238 | + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image |
| 239 | + def prepare_image( |
| 240 | + self, |
| 241 | + image, |
| 242 | + width, |
| 243 | + height, |
| 244 | + batch_size, |
| 245 | + num_images_per_prompt, |
| 246 | + device, |
| 247 | + dtype, |
| 248 | + do_classifier_free_guidance=False, |
| 249 | + guess_mode=False, |
| 250 | + ): |
| 251 | + if isinstance(image, torch.Tensor): |
| 252 | + pass |
| 253 | + else: |
| 254 | + image = self.image_processor.preprocess(image, height=height, width=width) |
| 255 | + |
| 256 | + image_batch_size = image.shape[0] |
| 257 | + |
| 258 | + if image_batch_size == 1: |
| 259 | + repeat_by = batch_size |
| 260 | + else: |
| 261 | + # image batch size is the same as prompt batch size |
| 262 | + repeat_by = num_images_per_prompt |
| 263 | + |
| 264 | + image = image.repeat_interleave(repeat_by, dim=0) |
| 265 | + |
| 266 | + image = image.to(device=device, dtype=dtype) |
| 267 | + |
| 268 | + if do_classifier_free_guidance and not guess_mode: |
| 269 | + image = torch.cat([image] * 2) |
| 270 | + |
| 271 | + return image |
| 272 | + |
238 | 273 | @torch.no_grad() |
239 | 274 | @replace_example_docstring(EXAMPLE_DOC_STRING) |
240 | 275 | def __call__( |
|
0 commit comments