5757 ```
5858"""
5959
60+
6061# Adapted from
6162# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68
6263def format_input (
@@ -510,12 +511,13 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch
510511
511512 def upsample_prompt (
512513 self ,
513- prompts : List [str ],
514+ prompt : Union [ str , List [str ] ],
514515 images : Union [List [PIL .Image .Image ], List [List [PIL .Image .Image ]]] = None ,
515516 temperature : float = 0.15 ,
516517 device : torch .device = None ,
517518 ) -> List [str ]:
518- device = device or self ._execution_device
519+ prompt = [prompt ] if isinstance (prompt , str ) else prompt
520+ device = self .text_encoder .device if device is None else device
519521
520522 # Set system message based on whether images are provided
521523 if images is None or len (images ) == 0 or images [0 ] is None :
@@ -524,7 +526,7 @@ def upsample_prompt(
524526 system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
525527
526528 # Format input messages
527- messages_batch = format_input (prompts = prompts , system_message = system_message , images = images )
529+ messages_batch = format_input (prompts = prompt , system_message = system_message , images = images )
528530
529531 # Process all messages at once
530532 # with image processing a too short max length can throw an error in here.
@@ -560,10 +562,10 @@ def upsample_prompt(
560562 input_length = inputs ["input_ids" ].shape [1 ]
561563 generated_tokens = generated_ids [:, input_length :]
562564
563- raw_txt = self .tokenizer .tokenizer .batch_decode (
565+ upsampled_prompt = self .tokenizer .tokenizer .batch_decode (
564566 generated_tokens , skip_special_tokens = True , clean_up_tokenization_spaces = True
565567 )
566- return raw_txt
568+ return upsampled_prompt
567569
568570 def encode_prompt (
569571 self ,
@@ -775,11 +777,11 @@ def __call__(
775777 The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
776778 instead.
777779 guidance_scale (`float`, *optional*, defaults to 1.0):
778- Guidance scale as defined in [Classifier-Free Diffusion
779- Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2 .
780- of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
781- `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
782- the text `prompt`, usually at the expense of lower image quality .
780+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
781+ a model to generate images more aligned with `prompt` at the expense of lower image quality .
782+
783+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
784+ the [paper](https://huggingface.co/papers/2210.03142) to learn more .
783785 height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
784786 The height in pixels of the generated image. This is set to 1024 by default for the best results.
785787 width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -865,8 +867,6 @@ def __call__(
865867 prompt = self .upsample_prompt (
866868 prompt , images = image , temperature = caption_upsample_temperature , device = device
867869 )
868- print (f"{ prompt = } " )
869-
870870 prompt_embeds , text_ids = self .encode_prompt (
871871 prompt = prompt ,
872872 prompt_embeds = prompt_embeds ,
0 commit comments