Skip to content

Commit ffc9562

Browse files
committed
Merge branch 'main' into qwen-pipeline-mixin
2 parents b3b11b5 + e6d4612 commit ffc9562

11 files changed

Lines changed: 665 additions & 80 deletions

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ def timestep_embedding(t, dim, max_period=10000):
6969

7070
def forward(self, t):
7171
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
72-
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
72+
weight_dtype = self.mlp[0].weight.dtype
73+
if weight_dtype.is_floating_point:
74+
t_freq = t_freq.to(weight_dtype)
75+
t_emb = self.mlp(t_freq)
7376
return t_emb
7477

7578

@@ -126,6 +129,10 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
126129
dtype = query.dtype
127130
query, key = query.to(dtype), key.to(dtype)
128131

132+
# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
133+
if attention_mask is not None and attention_mask.ndim == 2:
134+
attention_mask = attention_mask[:, None, None, :]
135+
129136
# Compute joint attention
130137
hidden_states = dispatch_attention_fn(
131138
query,
@@ -306,6 +313,10 @@ def __call__(self, ids: torch.Tensor):
306313
if self.freqs_cis is None:
307314
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
308315
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
316+
else:
317+
# Ensure freqs_cis are on the same device as ids
318+
if self.freqs_cis[0].device != device:
319+
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
309320

310321
result = []
311322
for i in range(len(self.axes_dims)):
@@ -317,6 +328,7 @@ def __call__(self, ids: torch.Tensor):
317328
class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
318329
_supports_gradient_checkpointing = True
319330
_no_split_modules = ["ZImageTransformerBlock"]
331+
_skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers
320332

321333
@register_to_config
322334
def __init__(
@@ -553,8 +565,6 @@ def forward(
553565
t = t * self.t_scale
554566
t = self.t_embedder(t)
555567

556-
adaln_input = t
557-
558568
(
559569
x,
560570
cap_feats,
@@ -572,6 +582,9 @@ def forward(
572582

573583
x = torch.cat(x, dim=0)
574584
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
585+
586+
# Match t_embedder output dtype to x for layerwise casting compatibility
587+
adaln_input = t.type_as(x)
575588
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
576589
x = list(x.split(x_item_seqlens, dim=0))
577590
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))

src/diffusers/pipelines/z_image/pipeline_z_image.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -165,21 +165,16 @@ def encode_prompt(
165165
self,
166166
prompt: Union[str, List[str]],
167167
device: Optional[torch.device] = None,
168-
dtype: Optional[torch.dtype] = None,
169-
num_images_per_prompt: int = 1,
170168
do_classifier_free_guidance: bool = True,
171169
negative_prompt: Optional[Union[str, List[str]]] = None,
172170
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
173171
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
174172
max_sequence_length: int = 512,
175-
lora_scale: Optional[float] = None,
176173
):
177174
prompt = [prompt] if isinstance(prompt, str) else prompt
178175
prompt_embeds = self._encode_prompt(
179176
prompt=prompt,
180177
device=device,
181-
dtype=dtype,
182-
num_images_per_prompt=num_images_per_prompt,
183178
prompt_embeds=prompt_embeds,
184179
max_sequence_length=max_sequence_length,
185180
)
@@ -193,8 +188,6 @@ def encode_prompt(
193188
negative_prompt_embeds = self._encode_prompt(
194189
prompt=negative_prompt,
195190
device=device,
196-
dtype=dtype,
197-
num_images_per_prompt=num_images_per_prompt,
198191
prompt_embeds=negative_prompt_embeds,
199192
max_sequence_length=max_sequence_length,
200193
)
@@ -206,12 +199,9 @@ def _encode_prompt(
206199
self,
207200
prompt: Union[str, List[str]],
208201
device: Optional[torch.device] = None,
209-
dtype: Optional[torch.dtype] = None,
210-
num_images_per_prompt: int = 1,
211202
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
212203
max_sequence_length: int = 512,
213204
) -> List[torch.FloatTensor]:
214-
assert num_images_per_prompt == 1
215205
device = device or self._execution_device
216206

217207
if prompt_embeds is not None:
@@ -417,8 +407,6 @@ def __call__(
417407
f"Please adjust the width to a multiple of {vae_scale}."
418408
)
419409

420-
assert self.dtype == torch.bfloat16
421-
dtype = self.dtype
422410
device = self._execution_device
423411

424412
self._guidance_scale = guidance_scale
@@ -434,10 +422,6 @@ def __call__(
434422
else:
435423
batch_size = len(prompt_embeds)
436424

437-
lora_scale = (
438-
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
439-
)
440-
441425
# If prompt_embeds is provided and prompt is None, skip encoding
442426
if prompt_embeds is not None and prompt is None:
443427
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
@@ -455,11 +439,8 @@ def __call__(
455439
do_classifier_free_guidance=self.do_classifier_free_guidance,
456440
prompt_embeds=prompt_embeds,
457441
negative_prompt_embeds=negative_prompt_embeds,
458-
dtype=dtype,
459442
device=device,
460-
num_images_per_prompt=num_images_per_prompt,
461443
max_sequence_length=max_sequence_length,
462-
lora_scale=lora_scale,
463444
)
464445

465446
# 4. Prepare latent variables
@@ -475,6 +456,14 @@ def __call__(
475456
generator,
476457
latents,
477458
)
459+
460+
# Repeat prompt_embeds for num_images_per_prompt
461+
if num_images_per_prompt > 1:
462+
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
463+
if self.do_classifier_free_guidance and negative_prompt_embeds:
464+
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
465+
466+
actual_batch_size = batch_size * num_images_per_prompt
478467
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
479468

480469
# 5. Prepare timesteps
@@ -523,12 +512,12 @@ def __call__(
523512
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
524513

525514
if apply_cfg:
526-
latents_typed = latents if latents.dtype == dtype else latents.to(dtype)
515+
latents_typed = latents.to(self.transformer.dtype)
527516
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
528517
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
529518
timestep_model_input = timestep.repeat(2)
530519
else:
531-
latent_model_input = latents if latents.dtype == dtype else latents.to(dtype)
520+
latent_model_input = latents.to(self.transformer.dtype)
532521
prompt_embeds_model_input = prompt_embeds
533522
timestep_model_input = timestep
534523

@@ -543,11 +532,11 @@ def __call__(
543532

544533
if apply_cfg:
545534
# Perform CFG
546-
pos_out = model_out_list[:batch_size]
547-
neg_out = model_out_list[batch_size:]
535+
pos_out = model_out_list[:actual_batch_size]
536+
neg_out = model_out_list[actual_batch_size:]
548537

549538
noise_pred = []
550-
for j in range(batch_size):
539+
for j in range(actual_batch_size):
551540
pos = pos_out[j].float()
552541
neg = neg_out[j].float()
553542

@@ -588,11 +577,11 @@ def __call__(
588577
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
589578
progress_bar.update()
590579

591-
latents = latents.to(dtype)
592580
if output_type == "latent":
593581
image = latents
594582

595583
else:
584+
latents = latents.to(self.vae.dtype)
596585
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
597586

598587
image = self.vae.decode(latents, return_dict=False)[0]

src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,22 @@ def multistep_dpm_solver_second_order_update(
429429
return x_t
430430

431431
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
432-
def index_for_timestep(self, timestep, schedule_timesteps=None):
432+
def index_for_timestep(
433+
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
434+
) -> int:
435+
"""
436+
Find the index for a given timestep in the schedule.
437+
438+
Args:
439+
timestep (`int` or `torch.Tensor`):
440+
The timestep for which to find the index.
441+
schedule_timesteps (`torch.Tensor`, *optional*):
442+
The timestep schedule to search in. If `None`, uses `self.timesteps`.
443+
444+
Returns:
445+
`int`:
446+
The index of the timestep in the schedule.
447+
"""
433448
if schedule_timesteps is None:
434449
schedule_timesteps = self.timesteps
435450

@@ -452,6 +467,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
452467
def _init_step_index(self, timestep):
453468
"""
454469
Initialize the step_index counter for the scheduler.
470+
471+
Args:
472+
timestep (`int` or `torch.Tensor`):
473+
The current timestep for which to initialize the step index.
455474
"""
456475

457476
if self.begin_index is None:

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,17 @@ def _sigma_to_t(self, sigma, log_sigmas):
401401

402402
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
403403
def _sigma_to_alpha_sigma_t(self, sigma):
404+
"""
405+
Convert sigma values to alpha_t and sigma_t values.
406+
407+
Args:
408+
sigma (`torch.Tensor`):
409+
The sigma value(s) to convert.
410+
411+
Returns:
412+
`Tuple[torch.Tensor, torch.Tensor]`:
413+
A tuple containing (alpha_t, sigma_t) values.
414+
"""
404415
if self.config.use_flow_sigmas:
405416
alpha_t = 1 - sigma
406417
sigma_t = sigma
@@ -808,7 +819,22 @@ def ind_fn(t, b, c, d):
808819
raise NotImplementedError("only support log-rho multistep deis now")
809820

810821
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
811-
def index_for_timestep(self, timestep, schedule_timesteps=None):
822+
def index_for_timestep(
823+
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
824+
) -> int:
825+
"""
826+
Find the index for a given timestep in the schedule.
827+
828+
Args:
829+
timestep (`int` or `torch.Tensor`):
830+
The timestep for which to find the index.
831+
schedule_timesteps (`torch.Tensor`, *optional*):
832+
The timestep schedule to search in. If `None`, uses `self.timesteps`.
833+
834+
Returns:
835+
`int`:
836+
The index of the timestep in the schedule.
837+
"""
812838
if schedule_timesteps is None:
813839
schedule_timesteps = self.timesteps
814840

@@ -831,6 +857,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
831857
def _init_step_index(self, timestep):
832858
"""
833859
Initialize the step_index counter for the scheduler.
860+
861+
Args:
862+
timestep (`int` or `torch.Tensor`):
863+
The current timestep for which to initialize the step index.
834864
"""
835865

836866
if self.begin_index is None:
@@ -927,6 +957,21 @@ def add_noise(
927957
noise: torch.Tensor,
928958
timesteps: torch.IntTensor,
929959
) -> torch.Tensor:
960+
"""
961+
Add noise to the original samples according to the noise schedule at the specified timesteps.
962+
963+
Args:
964+
original_samples (`torch.Tensor`):
965+
The original samples without noise.
966+
noise (`torch.Tensor`):
967+
The noise to add to the samples.
968+
timesteps (`torch.IntTensor`):
969+
The timesteps at which to add noise to the samples.
970+
971+
Returns:
972+
`torch.Tensor`:
973+
The noisy samples.
974+
"""
930975
# Make sure sigmas and timesteps have the same device and dtype as original_samples
931976
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
932977
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):

0 commit comments

Comments
 (0)