Skip to content

Commit 0b1f884

Browse files
committed
fix
1 parent b4a8406 commit 0b1f884

1 file changed

Lines changed: 12 additions & 12 deletions

File tree

src/diffusers/pipelines/flux2/pipeline_flux2.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
```
5858
"""
5959

60+
6061
# Adapted from
6162
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68
6263
def format_input(
@@ -510,12 +511,13 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch
510511

511512
def upsample_prompt(
512513
self,
513-
prompts: List[str],
514+
prompt: Union[str, List[str]],
514515
images: Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]] = None,
515516
temperature: float = 0.15,
516517
device: torch.device = None,
517518
) -> List[str]:
518-
device = device or self._execution_device
519+
prompt = [prompt] if isinstance(prompt, str) else prompt
520+
device = self.text_encoder.device if device is None else device
519521

520522
# Set system message based on whether images are provided
521523
if images is None or len(images) == 0 or images[0] is None:
@@ -524,7 +526,7 @@ def upsample_prompt(
524526
system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
525527

526528
# Format input messages
527-
messages_batch = format_input(prompts=prompts, system_message=system_message, images=images)
529+
messages_batch = format_input(prompts=prompt, system_message=system_message, images=images)
528530

529531
# Process all messages at once
530532
# with image processing a too short max length can throw an error in here.
@@ -560,10 +562,10 @@ def upsample_prompt(
560562
input_length = inputs["input_ids"].shape[1]
561563
generated_tokens = generated_ids[:, input_length:]
562564

563-
raw_txt = self.tokenizer.tokenizer.batch_decode(
565+
upsampled_prompt = self.tokenizer.tokenizer.batch_decode(
564566
generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
565567
)
566-
return raw_txt
568+
return upsampled_prompt
567569

568570
def encode_prompt(
569571
self,
@@ -775,11 +777,11 @@ def __call__(
775777
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
776778
instead.
777779
guidance_scale (`float`, *optional*, defaults to 1.0):
778-
Guidance scale as defined in [Classifier-Free Diffusion
779-
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
780-
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
781-
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
782-
the text `prompt`, usually at the expense of lower image quality.
780+
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
781+
a model to generate images more aligned with `prompt` at the expense of lower image quality.
782+
783+
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
784+
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
783785
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
784786
The height in pixels of the generated image. This is set to 1024 by default for the best results.
785787
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -865,8 +867,6 @@ def __call__(
865867
prompt = self.upsample_prompt(
866868
prompt, images=image, temperature=caption_upsample_temperature, device=device
867869
)
868-
print(f"{prompt=}")
869-
870870
prompt_embeds, text_ids = self.encode_prompt(
871871
prompt=prompt,
872872
prompt_embeds=prompt_embeds,

0 commit comments

Comments
 (0)