|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -import inspect |
16 | 15 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
17 | 16 |
|
18 | | -import numpy as np |
19 | 17 | import PIL.Image |
20 | 18 | import torch |
21 | 19 | from transformers import ( |
|
46 | 44 | from ...utils.torch_utils import randn_tensor |
47 | 45 | from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin |
48 | 46 | from .pipeline_output import StableDiffusionXLPipelineOutput |
| 47 | +from .pipeline_stable_diffusion_xl_utils import ( |
| 48 | + StableDiffusionXLMixin, |
| 49 | + rescale_noise_cfg, |
| 50 | + retrieve_latents, |
| 51 | + retrieve_timesteps, |
| 52 | +) |
49 | 53 |
|
50 | 54 |
|
51 | 55 | if is_invisible_watermark_available(): |
|
91 | 95 | """ |
92 | 96 |
|
93 | 97 |
|
94 | | -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg |
95 | | -def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
96 | | - r""" |
97 | | - Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on |
98 | | - Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are |
99 | | - Flawed](https://huggingface.co/papers/2305.08891). |
100 | | -
|
101 | | - Args: |
102 | | - noise_cfg (`torch.Tensor`): |
103 | | - The predicted noise tensor for the guided diffusion process. |
104 | | - noise_pred_text (`torch.Tensor`): |
105 | | - The predicted noise tensor for the text-guided diffusion process. |
106 | | - guidance_rescale (`float`, *optional*, defaults to 0.0): |
107 | | - A rescale factor applied to the noise predictions. |
108 | | -
|
109 | | - Returns: |
110 | | - noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. |
111 | | - """ |
112 | | - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) |
113 | | - std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
114 | | - # rescale the results from guidance (fixes overexposure) |
115 | | - noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
116 | | - # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images |
117 | | - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg |
118 | | - return noise_cfg |
119 | | - |
120 | | - |
121 | | -def mask_pil_to_torch(mask, height, width): |
122 | | - # preprocess mask |
123 | | - if isinstance(mask, (PIL.Image.Image, np.ndarray)): |
124 | | - mask = [mask] |
125 | | - |
126 | | - if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): |
127 | | - mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] |
128 | | - mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) |
129 | | - mask = mask.astype(np.float32) / 255.0 |
130 | | - elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): |
131 | | - mask = np.concatenate([m[None, None, :] for m in mask], axis=0) |
132 | | - |
133 | | - mask = torch.from_numpy(mask) |
134 | | - return mask |
135 | | - |
136 | | - |
137 | | -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents |
138 | | -def retrieve_latents( |
139 | | - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" |
140 | | -): |
141 | | - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": |
142 | | - return encoder_output.latent_dist.sample(generator) |
143 | | - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": |
144 | | - return encoder_output.latent_dist.mode() |
145 | | - elif hasattr(encoder_output, "latents"): |
146 | | - return encoder_output.latents |
147 | | - else: |
148 | | - raise AttributeError("Could not access latents of provided encoder_output") |
149 | | - |
150 | | - |
151 | | -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps |
152 | | -def retrieve_timesteps( |
153 | | - scheduler, |
154 | | - num_inference_steps: Optional[int] = None, |
155 | | - device: Optional[Union[str, torch.device]] = None, |
156 | | - timesteps: Optional[List[int]] = None, |
157 | | - sigmas: Optional[List[float]] = None, |
158 | | - **kwargs, |
159 | | -): |
160 | | - r""" |
161 | | - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
162 | | - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
163 | | -
|
164 | | - Args: |
165 | | - scheduler (`SchedulerMixin`): |
166 | | - The scheduler to get timesteps from. |
167 | | - num_inference_steps (`int`): |
168 | | - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
169 | | - must be `None`. |
170 | | - device (`str` or `torch.device`, *optional*): |
171 | | - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
172 | | - timesteps (`List[int]`, *optional*): |
173 | | - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
174 | | - `num_inference_steps` and `sigmas` must be `None`. |
175 | | - sigmas (`List[float]`, *optional*): |
176 | | - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
177 | | - `num_inference_steps` and `timesteps` must be `None`. |
178 | | -
|
179 | | - Returns: |
180 | | - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
181 | | - second element is the number of inference steps. |
182 | | - """ |
183 | | - if timesteps is not None and sigmas is not None: |
184 | | - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") |
185 | | - if timesteps is not None: |
186 | | - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
187 | | - if not accepts_timesteps: |
188 | | - raise ValueError( |
189 | | - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
190 | | - f" timestep schedules. Please check whether you are using the correct scheduler." |
191 | | - ) |
192 | | - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
193 | | - timesteps = scheduler.timesteps |
194 | | - num_inference_steps = len(timesteps) |
195 | | - elif sigmas is not None: |
196 | | - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
197 | | - if not accept_sigmas: |
198 | | - raise ValueError( |
199 | | - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
200 | | - f" sigmas schedules. Please check whether you are using the correct scheduler." |
201 | | - ) |
202 | | - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
203 | | - timesteps = scheduler.timesteps |
204 | | - num_inference_steps = len(timesteps) |
205 | | - else: |
206 | | - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
207 | | - timesteps = scheduler.timesteps |
208 | | - return timesteps, num_inference_steps |
209 | | - |
210 | | - |
211 | 98 | class StableDiffusionXLInpaintPipeline( |
212 | 99 | DiffusionPipeline, |
213 | 100 | StableDiffusionMixin, |
| 101 | + StableDiffusionXLMixin, |
214 | 102 | TextualInversionLoaderMixin, |
215 | 103 | StableDiffusionXLLoraLoaderMixin, |
216 | 104 | FromSingleFileMixin, |
|
0 commit comments