Skip to content

Commit 89ebea4

Browse files
committed
up
1 parent b1a8835 commit 89ebea4

1 file changed

Lines changed: 7 additions & 119 deletions

File tree

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py

Lines changed: 7 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import inspect
1615
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1716

18-
import numpy as np
1917
import PIL.Image
2018
import torch
2119
from transformers import (
@@ -46,6 +44,12 @@
4644
from ...utils.torch_utils import randn_tensor
4745
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
4846
from .pipeline_output import StableDiffusionXLPipelineOutput
47+
from .pipeline_stable_diffusion_xl_utils import (
48+
StableDiffusionXLMixin,
49+
rescale_noise_cfg,
50+
retrieve_latents,
51+
retrieve_timesteps,
52+
)
4953

5054

5155
if is_invisible_watermark_available():
@@ -91,126 +95,10 @@
9195
"""
9296

9397

94-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg
95-
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
96-
r"""
97-
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
98-
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
99-
Flawed](https://huggingface.co/papers/2305.08891).
100-
101-
Args:
102-
noise_cfg (`torch.Tensor`):
103-
The predicted noise tensor for the guided diffusion process.
104-
noise_pred_text (`torch.Tensor`):
105-
The predicted noise tensor for the text-guided diffusion process.
106-
guidance_rescale (`float`, *optional*, defaults to 0.0):
107-
A rescale factor applied to the noise predictions.
108-
109-
Returns:
110-
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
111-
"""
112-
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
113-
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
114-
# rescale the results from guidance (fixes overexposure)
115-
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
116-
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
117-
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
118-
return noise_cfg
119-
120-
121-
def mask_pil_to_torch(mask, height, width):
122-
# preprocess mask
123-
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
124-
mask = [mask]
125-
126-
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
127-
mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
128-
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
129-
mask = mask.astype(np.float32) / 255.0
130-
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
131-
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
132-
133-
mask = torch.from_numpy(mask)
134-
return mask
135-
136-
137-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents
138-
def retrieve_latents(
139-
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
140-
):
141-
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
142-
return encoder_output.latent_dist.sample(generator)
143-
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
144-
return encoder_output.latent_dist.mode()
145-
elif hasattr(encoder_output, "latents"):
146-
return encoder_output.latents
147-
else:
148-
raise AttributeError("Could not access latents of provided encoder_output")
149-
150-
151-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
152-
def retrieve_timesteps(
153-
scheduler,
154-
num_inference_steps: Optional[int] = None,
155-
device: Optional[Union[str, torch.device]] = None,
156-
timesteps: Optional[List[int]] = None,
157-
sigmas: Optional[List[float]] = None,
158-
**kwargs,
159-
):
160-
r"""
161-
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
162-
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
163-
164-
Args:
165-
scheduler (`SchedulerMixin`):
166-
The scheduler to get timesteps from.
167-
num_inference_steps (`int`):
168-
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
169-
must be `None`.
170-
device (`str` or `torch.device`, *optional*):
171-
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
172-
timesteps (`List[int]`, *optional*):
173-
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
174-
`num_inference_steps` and `sigmas` must be `None`.
175-
sigmas (`List[float]`, *optional*):
176-
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
177-
`num_inference_steps` and `timesteps` must be `None`.
178-
179-
Returns:
180-
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
181-
second element is the number of inference steps.
182-
"""
183-
if timesteps is not None and sigmas is not None:
184-
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
185-
if timesteps is not None:
186-
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
187-
if not accepts_timesteps:
188-
raise ValueError(
189-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
190-
f" timestep schedules. Please check whether you are using the correct scheduler."
191-
)
192-
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
193-
timesteps = scheduler.timesteps
194-
num_inference_steps = len(timesteps)
195-
elif sigmas is not None:
196-
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
197-
if not accept_sigmas:
198-
raise ValueError(
199-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
200-
f" sigmas schedules. Please check whether you are using the correct scheduler."
201-
)
202-
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
203-
timesteps = scheduler.timesteps
204-
num_inference_steps = len(timesteps)
205-
else:
206-
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
207-
timesteps = scheduler.timesteps
208-
return timesteps, num_inference_steps
209-
210-
21198
class StableDiffusionXLInpaintPipeline(
21299
DiffusionPipeline,
213100
StableDiffusionMixin,
101+
StableDiffusionXLMixin,
214102
TextualInversionLoaderMixin,
215103
StableDiffusionXLLoraLoaderMixin,
216104
FromSingleFileMixin,

0 commit comments

Comments
 (0)