|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -import inspect |
16 | 15 | from typing import Any, Callable, Dict, List, Optional, Union |
17 | 16 |
|
18 | 17 | import numpy as np |
|
24 | 23 | from ...models.autoencoders import AutoencoderKL |
25 | 24 | from ...models.transformers import FluxTransformer2DModel |
26 | 25 | from ...schedulers import FlowMatchEulerDiscreteScheduler |
27 | | -from ...utils import ( |
28 | | - is_torch_xla_available, |
29 | | - logging, |
30 | | - replace_example_docstring, |
31 | | -) |
| 26 | +from ...utils import is_torch_xla_available, logging, replace_example_docstring |
32 | 27 | from ...utils.torch_utils import randn_tensor |
33 | 28 | from ..pipeline_utils import DiffusionPipeline |
34 | | -from .pipeline_flux_utils import FluxControlMixin |
| 29 | +from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_latents, retrieve_timesteps |
35 | 30 | from .pipeline_output import FluxPipelineOutput |
36 | 31 |
|
37 | 32 |
|
|
85 | 80 | """ |
86 | 81 |
|
87 | 82 |
|
88 | | -# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift |
89 | | -def calculate_shift( |
90 | | - image_seq_len, |
91 | | - base_seq_len: int = 256, |
92 | | - max_seq_len: int = 4096, |
93 | | - base_shift: float = 0.5, |
94 | | - max_shift: float = 1.15, |
95 | | -): |
96 | | - m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
97 | | - b = base_shift - m * base_seq_len |
98 | | - mu = image_seq_len * m + b |
99 | | - return mu |
100 | | - |
101 | | - |
102 | | -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents |
103 | | -def retrieve_latents( |
104 | | - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" |
105 | | -): |
106 | | - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": |
107 | | - return encoder_output.latent_dist.sample(generator) |
108 | | - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": |
109 | | - return encoder_output.latent_dist.mode() |
110 | | - elif hasattr(encoder_output, "latents"): |
111 | | - return encoder_output.latents |
112 | | - else: |
113 | | - raise AttributeError("Could not access latents of provided encoder_output") |
114 | | - |
115 | | - |
116 | | -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps |
117 | | -def retrieve_timesteps( |
118 | | - scheduler, |
119 | | - num_inference_steps: Optional[int] = None, |
120 | | - device: Optional[Union[str, torch.device]] = None, |
121 | | - timesteps: Optional[List[int]] = None, |
122 | | - sigmas: Optional[List[float]] = None, |
123 | | - **kwargs, |
124 | | -): |
125 | | - r""" |
126 | | - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
127 | | - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
128 | | -
|
129 | | - Args: |
130 | | - scheduler (`SchedulerMixin`): |
131 | | - The scheduler to get timesteps from. |
132 | | - num_inference_steps (`int`): |
133 | | - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
134 | | - must be `None`. |
135 | | - device (`str` or `torch.device`, *optional*): |
136 | | - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
137 | | - timesteps (`List[int]`, *optional*): |
138 | | - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
139 | | - `num_inference_steps` and `sigmas` must be `None`. |
140 | | - sigmas (`List[float]`, *optional*): |
141 | | - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
142 | | - `num_inference_steps` and `timesteps` must be `None`. |
143 | | -
|
144 | | - Returns: |
145 | | - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
146 | | - second element is the number of inference steps. |
147 | | - """ |
148 | | - if timesteps is not None and sigmas is not None: |
149 | | - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") |
150 | | - if timesteps is not None: |
151 | | - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
152 | | - if not accepts_timesteps: |
153 | | - raise ValueError( |
154 | | - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
155 | | - f" timestep schedules. Please check whether you are using the correct scheduler." |
156 | | - ) |
157 | | - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
158 | | - timesteps = scheduler.timesteps |
159 | | - num_inference_steps = len(timesteps) |
160 | | - elif sigmas is not None: |
161 | | - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
162 | | - if not accept_sigmas: |
163 | | - raise ValueError( |
164 | | - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
165 | | - f" sigmas schedules. Please check whether you are using the correct scheduler." |
166 | | - ) |
167 | | - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
168 | | - timesteps = scheduler.timesteps |
169 | | - num_inference_steps = len(timesteps) |
170 | | - else: |
171 | | - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
172 | | - timesteps = scheduler.timesteps |
173 | | - return timesteps, num_inference_steps |
174 | | - |
175 | | - |
176 | 83 | class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin): |
177 | 84 | r""" |
178 | 85 | The Flux pipeline for image inpainting. |
|
0 commit comments