@@ -165,21 +165,16 @@ def encode_prompt(
165165 self ,
166166 prompt : Union [str , List [str ]],
167167 device : Optional [torch .device ] = None ,
168- dtype : Optional [torch .dtype ] = None ,
169- num_images_per_prompt : int = 1 ,
170168 do_classifier_free_guidance : bool = True ,
171169 negative_prompt : Optional [Union [str , List [str ]]] = None ,
172170 prompt_embeds : Optional [List [torch .FloatTensor ]] = None ,
173171 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
174172 max_sequence_length : int = 512 ,
175- lora_scale : Optional [float ] = None ,
176173 ):
177174 prompt = [prompt ] if isinstance (prompt , str ) else prompt
178175 prompt_embeds = self ._encode_prompt (
179176 prompt = prompt ,
180177 device = device ,
181- dtype = dtype ,
182- num_images_per_prompt = num_images_per_prompt ,
183178 prompt_embeds = prompt_embeds ,
184179 max_sequence_length = max_sequence_length ,
185180 )
@@ -193,8 +188,6 @@ def encode_prompt(
193188 negative_prompt_embeds = self ._encode_prompt (
194189 prompt = negative_prompt ,
195190 device = device ,
196- dtype = dtype ,
197- num_images_per_prompt = num_images_per_prompt ,
198191 prompt_embeds = negative_prompt_embeds ,
199192 max_sequence_length = max_sequence_length ,
200193 )
@@ -206,12 +199,9 @@ def _encode_prompt(
206199 self ,
207200 prompt : Union [str , List [str ]],
208201 device : Optional [torch .device ] = None ,
209- dtype : Optional [torch .dtype ] = None ,
210- num_images_per_prompt : int = 1 ,
211202 prompt_embeds : Optional [List [torch .FloatTensor ]] = None ,
212203 max_sequence_length : int = 512 ,
213204 ) -> List [torch .FloatTensor ]:
214- assert num_images_per_prompt == 1
215205 device = device or self ._execution_device
216206
217207 if prompt_embeds is not None :
@@ -417,8 +407,6 @@ def __call__(
417407 f"Please adjust the width to a multiple of { vae_scale } ."
418408 )
419409
420- assert self .dtype == torch .bfloat16
421- dtype = self .dtype
422410 device = self ._execution_device
423411
424412 self ._guidance_scale = guidance_scale
@@ -434,10 +422,6 @@ def __call__(
434422 else :
435423 batch_size = len (prompt_embeds )
436424
437- lora_scale = (
438- self .joint_attention_kwargs .get ("scale" , None ) if self .joint_attention_kwargs is not None else None
439- )
440-
441425 # If prompt_embeds is provided and prompt is None, skip encoding
442426 if prompt_embeds is not None and prompt is None :
443427 if self .do_classifier_free_guidance and negative_prompt_embeds is None :
@@ -455,11 +439,8 @@ def __call__(
455439 do_classifier_free_guidance = self .do_classifier_free_guidance ,
456440 prompt_embeds = prompt_embeds ,
457441 negative_prompt_embeds = negative_prompt_embeds ,
458- dtype = dtype ,
459442 device = device ,
460- num_images_per_prompt = num_images_per_prompt ,
461443 max_sequence_length = max_sequence_length ,
462- lora_scale = lora_scale ,
463444 )
464445
465446 # 4. Prepare latent variables
@@ -475,6 +456,14 @@ def __call__(
475456 generator ,
476457 latents ,
477458 )
459+
460+ # Repeat prompt_embeds for num_images_per_prompt
461+ if num_images_per_prompt > 1 :
462+ prompt_embeds = [pe for pe in prompt_embeds for _ in range (num_images_per_prompt )]
463+ if self .do_classifier_free_guidance and negative_prompt_embeds :
464+ negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range (num_images_per_prompt )]
465+
466+ actual_batch_size = batch_size * num_images_per_prompt
478467 image_seq_len = (latents .shape [2 ] // 2 ) * (latents .shape [3 ] // 2 )
479468
480469 # 5. Prepare timesteps
@@ -523,12 +512,12 @@ def __call__(
523512 apply_cfg = self .do_classifier_free_guidance and current_guidance_scale > 0
524513
525514 if apply_cfg :
526- latents_typed = latents if latents . dtype == dtype else latents . to (dtype )
515+ latents_typed = latents . to (self . transformer . dtype )
527516 latent_model_input = latents_typed .repeat (2 , 1 , 1 , 1 )
528517 prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
529518 timestep_model_input = timestep .repeat (2 )
530519 else :
531- latent_model_input = latents if latents . dtype == dtype else latents . to (dtype )
520+ latent_model_input = latents . to (self . transformer . dtype )
532521 prompt_embeds_model_input = prompt_embeds
533522 timestep_model_input = timestep
534523
@@ -543,11 +532,11 @@ def __call__(
543532
544533 if apply_cfg :
545534 # Perform CFG
546- pos_out = model_out_list [:batch_size ]
547- neg_out = model_out_list [batch_size :]
535+ pos_out = model_out_list [:actual_batch_size ]
536+ neg_out = model_out_list [actual_batch_size :]
548537
549538 noise_pred = []
550- for j in range (batch_size ):
539+ for j in range (actual_batch_size ):
551540 pos = pos_out [j ].float ()
552541 neg = neg_out [j ].float ()
553542
@@ -588,11 +577,11 @@ def __call__(
588577 if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
589578 progress_bar .update ()
590579
591- latents = latents .to (dtype )
592580 if output_type == "latent" :
593581 image = latents
594582
595583 else :
584+ latents = latents .to (self .vae .dtype )
596585 latents = (latents / self .vae .config .scaling_factor ) + self .vae .config .shift_factor
597586
598587 image = self .vae .decode (latents , return_dict = False )[0 ]
0 commit comments