Skip to content

Commit 8048623

Browse files
committed
move some methods to pipeline specific stuff.
1 parent 72fc6ad commit 8048623

13 files changed

Lines changed: 109 additions & 1068 deletions

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 1 addition & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import inspect
1615
from typing import Any, Callable, Dict, List, Optional, Union
1716

1817
import numpy as np
@@ -37,7 +36,7 @@
3736
)
3837
from ...utils.torch_utils import randn_tensor
3938
from ..pipeline_utils import DiffusionPipeline
40-
from .pipeline_flux_utils import FluxMixin
39+
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_timesteps
4140
from .pipeline_output import FluxPipelineOutput
4241

4342

@@ -68,79 +67,6 @@
6867
"""
6968

7069

71-
def calculate_shift(
72-
image_seq_len,
73-
base_seq_len: int = 256,
74-
max_seq_len: int = 4096,
75-
base_shift: float = 0.5,
76-
max_shift: float = 1.15,
77-
):
78-
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
79-
b = base_shift - m * base_seq_len
80-
mu = image_seq_len * m + b
81-
return mu
82-
83-
84-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
85-
def retrieve_timesteps(
86-
scheduler,
87-
num_inference_steps: Optional[int] = None,
88-
device: Optional[Union[str, torch.device]] = None,
89-
timesteps: Optional[List[int]] = None,
90-
sigmas: Optional[List[float]] = None,
91-
**kwargs,
92-
):
93-
r"""
94-
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
95-
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
96-
97-
Args:
98-
scheduler (`SchedulerMixin`):
99-
The scheduler to get timesteps from.
100-
num_inference_steps (`int`):
101-
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
102-
must be `None`.
103-
device (`str` or `torch.device`, *optional*):
104-
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
105-
timesteps (`List[int]`, *optional*):
106-
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
107-
`num_inference_steps` and `sigmas` must be `None`.
108-
sigmas (`List[float]`, *optional*):
109-
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
110-
`num_inference_steps` and `timesteps` must be `None`.
111-
112-
Returns:
113-
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
114-
second element is the number of inference steps.
115-
"""
116-
if timesteps is not None and sigmas is not None:
117-
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
118-
if timesteps is not None:
119-
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
120-
if not accepts_timesteps:
121-
raise ValueError(
122-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
123-
f" timestep schedules. Please check whether you are using the correct scheduler."
124-
)
125-
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
126-
timesteps = scheduler.timesteps
127-
num_inference_steps = len(timesteps)
128-
elif sigmas is not None:
129-
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
130-
if not accept_sigmas:
131-
raise ValueError(
132-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
133-
f" sigmas schedules. Please check whether you are using the correct scheduler."
134-
)
135-
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
136-
timesteps = scheduler.timesteps
137-
num_inference_steps = len(timesteps)
138-
else:
139-
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
140-
timesteps = scheduler.timesteps
141-
return timesteps, num_inference_steps
142-
143-
14470
class FluxPipeline(
14571
DiffusionPipeline,
14672
FluxMixin,

src/diffusers/pipelines/flux/pipeline_flux_control.py

Lines changed: 1 addition & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import inspect
1615
from typing import Any, Callable, Dict, List, Optional, Union
1716

1817
import numpy as np
@@ -31,7 +30,7 @@
3130
)
3231
from ...utils.torch_utils import randn_tensor
3332
from ..pipeline_utils import DiffusionPipeline
34-
from .pipeline_flux_utils import FluxControlMixin
33+
from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_timesteps
3534
from .pipeline_output import FluxPipelineOutput
3635

3736

@@ -80,80 +79,6 @@
8079
"""
8180

8281

83-
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
84-
def calculate_shift(
85-
image_seq_len,
86-
base_seq_len: int = 256,
87-
max_seq_len: int = 4096,
88-
base_shift: float = 0.5,
89-
max_shift: float = 1.15,
90-
):
91-
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
92-
b = base_shift - m * base_seq_len
93-
mu = image_seq_len * m + b
94-
return mu
95-
96-
97-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
98-
def retrieve_timesteps(
99-
scheduler,
100-
num_inference_steps: Optional[int] = None,
101-
device: Optional[Union[str, torch.device]] = None,
102-
timesteps: Optional[List[int]] = None,
103-
sigmas: Optional[List[float]] = None,
104-
**kwargs,
105-
):
106-
r"""
107-
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
108-
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
109-
110-
Args:
111-
scheduler (`SchedulerMixin`):
112-
The scheduler to get timesteps from.
113-
num_inference_steps (`int`):
114-
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
115-
must be `None`.
116-
device (`str` or `torch.device`, *optional*):
117-
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
118-
timesteps (`List[int]`, *optional*):
119-
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
120-
`num_inference_steps` and `sigmas` must be `None`.
121-
sigmas (`List[float]`, *optional*):
122-
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
123-
`num_inference_steps` and `timesteps` must be `None`.
124-
125-
Returns:
126-
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
127-
second element is the number of inference steps.
128-
"""
129-
if timesteps is not None and sigmas is not None:
130-
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
131-
if timesteps is not None:
132-
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
133-
if not accepts_timesteps:
134-
raise ValueError(
135-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
136-
f" timestep schedules. Please check whether you are using the correct scheduler."
137-
)
138-
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
139-
timesteps = scheduler.timesteps
140-
num_inference_steps = len(timesteps)
141-
elif sigmas is not None:
142-
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
143-
if not accept_sigmas:
144-
raise ValueError(
145-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
146-
f" sigmas schedules. Please check whether you are using the correct scheduler."
147-
)
148-
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
149-
timesteps = scheduler.timesteps
150-
num_inference_steps = len(timesteps)
151-
else:
152-
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
153-
timesteps = scheduler.timesteps
154-
return timesteps, num_inference_steps
155-
156-
15782
class FluxControlPipeline(
15883
DiffusionPipeline,
15984
FluxControlMixin,

src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py

Lines changed: 2 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import inspect
1615
from typing import Any, Callable, Dict, List, Optional, Union
1716

1817
import numpy as np
@@ -24,14 +23,10 @@
2423
from ...models.autoencoders import AutoencoderKL
2524
from ...models.transformers import FluxTransformer2DModel
2625
from ...schedulers import FlowMatchEulerDiscreteScheduler
27-
from ...utils import (
28-
is_torch_xla_available,
29-
logging,
30-
replace_example_docstring,
31-
)
26+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
3227
from ...utils.torch_utils import randn_tensor
3328
from ..pipeline_utils import DiffusionPipeline
34-
from .pipeline_flux_utils import FluxControlMixin
29+
from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_latents, retrieve_timesteps
3530
from .pipeline_output import FluxPipelineOutput
3631

3732

@@ -85,94 +80,6 @@
8580
"""
8681

8782

88-
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
89-
def calculate_shift(
90-
image_seq_len,
91-
base_seq_len: int = 256,
92-
max_seq_len: int = 4096,
93-
base_shift: float = 0.5,
94-
max_shift: float = 1.15,
95-
):
96-
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
97-
b = base_shift - m * base_seq_len
98-
mu = image_seq_len * m + b
99-
return mu
100-
101-
102-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
103-
def retrieve_latents(
104-
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
105-
):
106-
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
107-
return encoder_output.latent_dist.sample(generator)
108-
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
109-
return encoder_output.latent_dist.mode()
110-
elif hasattr(encoder_output, "latents"):
111-
return encoder_output.latents
112-
else:
113-
raise AttributeError("Could not access latents of provided encoder_output")
114-
115-
116-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
117-
def retrieve_timesteps(
118-
scheduler,
119-
num_inference_steps: Optional[int] = None,
120-
device: Optional[Union[str, torch.device]] = None,
121-
timesteps: Optional[List[int]] = None,
122-
sigmas: Optional[List[float]] = None,
123-
**kwargs,
124-
):
125-
r"""
126-
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
127-
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
128-
129-
Args:
130-
scheduler (`SchedulerMixin`):
131-
The scheduler to get timesteps from.
132-
num_inference_steps (`int`):
133-
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
134-
must be `None`.
135-
device (`str` or `torch.device`, *optional*):
136-
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
137-
timesteps (`List[int]`, *optional*):
138-
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
139-
`num_inference_steps` and `sigmas` must be `None`.
140-
sigmas (`List[float]`, *optional*):
141-
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
142-
`num_inference_steps` and `timesteps` must be `None`.
143-
144-
Returns:
145-
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
146-
second element is the number of inference steps.
147-
"""
148-
if timesteps is not None and sigmas is not None:
149-
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
150-
if timesteps is not None:
151-
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
152-
if not accepts_timesteps:
153-
raise ValueError(
154-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
155-
f" timestep schedules. Please check whether you are using the correct scheduler."
156-
)
157-
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
158-
timesteps = scheduler.timesteps
159-
num_inference_steps = len(timesteps)
160-
elif sigmas is not None:
161-
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
162-
if not accept_sigmas:
163-
raise ValueError(
164-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
165-
f" sigmas schedules. Please check whether you are using the correct scheduler."
166-
)
167-
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
168-
timesteps = scheduler.timesteps
169-
num_inference_steps = len(timesteps)
170-
else:
171-
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
172-
timesteps = scheduler.timesteps
173-
return timesteps, num_inference_steps
174-
175-
17683
class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin):
17784
r"""
17885
The Flux pipeline for image inpainting.

0 commit comments

Comments
 (0)