Skip to content

Commit 7e130fc

Browse files
committed
ruff + code_style
1 parent 5453b3c commit 7e130fc

10 files changed

Lines changed: 244 additions & 304 deletions

File tree

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,7 @@ def create_orbax_checkpoint_manager(
5858
p = epath.Path(checkpoint_dir)
5959

6060
if checkpoint_type == FLUX_CHECKPOINT:
61-
item_names = ("flux_state", "flux_config",
62-
"vae_state", "vae_config",
63-
"scheduler", "scheduler_config")
61+
item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config")
6462
else:
6563
item_names = (
6664
"unet_config",

src/maxdiffusion/checkpointing/flux_checkpointer.py

Lines changed: 72 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818
from contextlib import nullcontext
1919
import functools
2020
import json
21-
import os
2221
import jax
23-
import jax.numpy as jnp
2422
from jax.sharding import Mesh
2523
import orbax.checkpoint as ocp
2624
import grain.python as grain
@@ -32,11 +30,9 @@
3230
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
3331
from ..pipelines.flux.flux_pipeline import FluxPipeline
3432

35-
from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer)
33+
from transformers import (CLIPTokenizer, FlaxCLIPTextModel, FlaxT5EncoderModel, AutoTokenizer)
3634

37-
from maxdiffusion.checkpointing.checkpointing_utils import (
38-
create_orbax_checkpoint_manager
39-
)
35+
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
4036
from maxdiffusion.models.flux.util import load_flow_model
4137

4238
FLUX_CHECKPOINT = "FLUX_CHECKPOINT"
@@ -49,6 +45,7 @@
4945
VAE_STATE_KEY = "vae_state"
5046
VAE_STATE_SHARDINGS_KEY = "vae_state_shardings"
5147

48+
5249
class FluxCheckpointer(ABC):
5350

5451
def __init__(self, config, checkpoint_type):
@@ -87,12 +84,14 @@ def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training)
8784
tx, learning_rate_scheduler = self._create_optimizer(self.config, learning_rate)
8885

8986
transformer_eval_params = transformer.init_weights(
90-
rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True
87+
rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True
9188
)
9289

9390
transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu")
9491

95-
weights_init_fn = functools.partial(pipeline.flux.init_weights, rngs=self.rng, max_sequence_length=self.config.max_sequence_length)
92+
weights_init_fn = functools.partial(
93+
pipeline.flux.init_weights, rngs=self.rng, max_sequence_length=self.config.max_sequence_length
94+
)
9695
flux_state, state_mesh_shardings = max_utils.setup_initial_state(
9796
model=pipeline.flux,
9897
tx=tx,
@@ -150,10 +149,11 @@ def _set_checkpoint_format(self, checkpoint_format):
150149
def save_checkpoint(self, train_step, pipeline, train_states):
151150
def config_to_json(model_or_config):
152151
return json.loads(model_or_config.to_json_string())
152+
153153
items = {
154154
"flux_config": ocp.args.JsonSave(config_to_json(pipeline.flux)),
155155
"vae_config": ocp.args.JsonSave(config_to_json(pipeline.vae)),
156-
"scheduler_config": ocp.args.JsonSave(config_to_json(pipeline.scheduler))
156+
"scheduler_config": ocp.args.JsonSave(config_to_json(pipeline.scheduler)),
157157
}
158158

159159
items[FLUX_STATE_KEY] = ocp.args.PyTreeSave(train_states[FLUX_STATE_KEY])
@@ -165,7 +165,7 @@ def config_to_json(model_or_config):
165165
def load_params(self, step=None):
166166

167167
self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX
168-
168+
169169
def load_flux_configs_from_orbax(self, step):
170170
max_logging.log("Restoring stable diffusion configs")
171171
if step is None:
@@ -188,68 +188,57 @@ def load_diffusers_checkpoint(self):
188188
context = jax.default_device(jax.devices("cpu")[0])
189189
else:
190190
context = nullcontext()
191-
191+
192192
with context:
193-
clip_encoder = FlaxCLIPTextModel.from_pretrained(
194-
self.config.clip_model_name_or_path, dtype=self.config.weights_dtype
195-
)
196-
clip_tokenizer = CLIPTokenizer.from_pretrained(
197-
self.config.clip_model_name_or_path,
198-
max_length=77,
199-
use_fast=True
200-
)
193+
clip_encoder = FlaxCLIPTextModel.from_pretrained(self.config.clip_model_name_or_path, dtype=self.config.weights_dtype)
194+
clip_tokenizer = CLIPTokenizer.from_pretrained(self.config.clip_model_name_or_path, max_length=77, use_fast=True)
201195
t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype)
202196
t5_tokenizer = AutoTokenizer.from_pretrained(
203-
self.config.t5xxl_model_name_or_path,
204-
max_length=self.config.max_sequence_length,
205-
use_fast=True
197+
self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True
206198
)
207199

208200
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
209-
self.config.pretrained_model_name_or_path,
210-
subfolder="vae",
211-
from_pt=True,
212-
use_safetensors=True,
213-
dtype=self.config.weights_dtype
201+
self.config.pretrained_model_name_or_path,
202+
subfolder="vae",
203+
from_pt=True,
204+
use_safetensors=True,
205+
dtype=self.config.weights_dtype,
214206
)
215207

216208
# loading from pretrained here causes a crash when trying to compile the model
217209
# Failed to load HSACO: HIP_ERROR_NoBinaryForGpu
218210
transformer = FluxTransformer2DModel.from_config(
219-
self.config.pretrained_model_name_or_path,
220-
subfolder="transformer",
221-
mesh=self.mesh,
222-
split_head_dim=self.config.split_head_dim,
223-
attention_kernel=self.config.attention,
224-
flash_block_sizes=flash_block_sizes,
225-
dtype=self.config.activations_dtype,
226-
weights_dtype=self.config.weights_dtype,
227-
precision=max_utils.get_precision(self.config),
211+
self.config.pretrained_model_name_or_path,
212+
subfolder="transformer",
213+
mesh=self.mesh,
214+
split_head_dim=self.config.split_head_dim,
215+
attention_kernel=self.config.attention,
216+
flash_block_sizes=flash_block_sizes,
217+
dtype=self.config.activations_dtype,
218+
weights_dtype=self.config.weights_dtype,
219+
precision=max_utils.get_precision(self.config),
228220
)
229221
transformer_eval_params = transformer.init_weights(
230-
rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True
222+
rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True
231223
)
232-
224+
233225
transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu")
234226

235227
pipeline = FluxPipeline(
236-
t5_encoder,
237-
clip_encoder,
238-
vae,
239-
t5_tokenizer,
240-
clip_tokenizer,
241-
transformer,
242-
None,
243-
dtype=self.config.activations_dtype,
244-
mesh=self.mesh,
245-
config=self.config,
246-
rng=self.rng
228+
t5_encoder,
229+
clip_encoder,
230+
vae,
231+
t5_tokenizer,
232+
clip_tokenizer,
233+
transformer,
234+
None,
235+
dtype=self.config.activations_dtype,
236+
mesh=self.mesh,
237+
config=self.config,
238+
rng=self.rng,
247239
)
248240

249-
params = {
250-
FLUX_VAE_PARAMS_KEY : vae_params,
251-
FLUX_TRANSFORMER_PARAMS_KEY : transformer_params
252-
}
241+
params = {FLUX_VAE_PARAMS_KEY: vae_params, FLUX_TRANSFORMER_PARAMS_KEY: transformer_params}
253242

254243
return pipeline, params
255244

@@ -267,55 +256,50 @@ def load_checkpoint(self, step=None, scheduler_class=None):
267256

268257
with context:
269258
clip_encoder = FlaxCLIPTextModel.from_pretrained(
270-
self.config.clip_model_name_or_path, dtype=self.config.weights_dtype
259+
self.config.clip_model_name_or_path, dtype=self.config.weights_dtype
271260
)
272-
clip_tokenizer = CLIPTokenizer.from_pretrained(
273-
self.config.clip_model_name_or_path,
274-
max_length=77,
275-
use_fast=True
261+
clip_tokenizer = CLIPTokenizer.from_pretrained(self.config.clip_model_name_or_path, max_length=77, use_fast=True)
262+
t5_encoder = FlaxT5EncoderModel.from_pretrained(
263+
self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype
276264
)
277-
t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype)
278265
t5_tokenizer = AutoTokenizer.from_pretrained(
279-
self.config.t5xxl_model_name_or_path,
280-
max_length=self.config.max_sequence_length,
281-
use_fast=True
266+
self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True
282267
)
283268

284269
vae = FlaxAutoencoderKL.from_config(
285-
model_configs[0]["vae_config"],
286-
dtype=self.config.activations_dtype,
287-
weights_dtype=self.config.weights_dtype,
288-
from_pt=self.config.from_pt,
270+
model_configs[0]["vae_config"],
271+
dtype=self.config.activations_dtype,
272+
weights_dtype=self.config.weights_dtype,
273+
from_pt=self.config.from_pt,
289274
)
290275

291276
transformer = FluxTransformer2DModel.from_config(
292-
model_configs[0]["flux_config"],
293-
mesh=self.mesh,
294-
split_head_dim=self.config.split_head_dim,
295-
attention_kernel=self.config.attention,
296-
flash_block_sizes=max_utils.get_flash_block_sizes(self.config),
297-
dtype=self.config.activations_dtype,
298-
weights_dtype=self.config.weights_dtype,
299-
precision=max_utils.get_precision(self.config),
300-
from_pt=self.config.from_pt,
277+
model_configs[0]["flux_config"],
278+
mesh=self.mesh,
279+
split_head_dim=self.config.split_head_dim,
280+
attention_kernel=self.config.attention,
281+
flash_block_sizes=max_utils.get_flash_block_sizes(self.config),
282+
dtype=self.config.activations_dtype,
283+
weights_dtype=self.config.weights_dtype,
284+
precision=max_utils.get_precision(self.config),
285+
from_pt=self.config.from_pt,
301286
)
302287

303288
pipeline = FluxPipeline(
304-
t5_encoder,
305-
clip_encoder,
306-
vae,
307-
t5_tokenizer,
308-
clip_tokenizer,
309-
transformer,
310-
None,
311-
dtype=self.config.activations_dtype,
312-
mesh=self.mesh,
313-
config=self.config,
314-
rng=self.rng
289+
t5_encoder,
290+
clip_encoder,
291+
vae,
292+
t5_tokenizer,
293+
clip_tokenizer,
294+
transformer,
295+
None,
296+
dtype=self.config.activations_dtype,
297+
mesh=self.mesh,
298+
config=self.config,
299+
rng=self.rng,
315300
)
316301

317302
else:
318303
pipeline, params = self.load_diffusers_checkpoint()
319-
320-
return pipeline, params
321304

305+
return pipeline, params

src/maxdiffusion/generate_flux_pipeline.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828
from maxdiffusion.checkpointing.checkpointing_utils import load_params_from_path
2929
from maxdiffusion.max_utils import setup_initial_state
3030

31+
3132
def run(config):
3233
from maxdiffusion.checkpointing.flux_checkpointer import FluxCheckpointer
34+
3335
checkpoint_loader = FluxCheckpointer(config, "FLUX_CHECKPOINT")
3436
pipeline, params = checkpoint_loader.load_checkpoint()
3537

@@ -47,9 +49,9 @@ def run(config):
4749
vae_state = {"params": vae_params}
4850

4951
## Flux
50-
weights_init_fn = functools.partial(pipeline.flux.init_weights,
51-
rngs=checkpoint_loader.rng,
52-
max_sequence_length=config.max_sequence_length)
52+
weights_init_fn = functools.partial(
53+
pipeline.flux.init_weights, rngs=checkpoint_loader.rng, max_sequence_length=config.max_sequence_length
54+
)
5355

5456
unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
5557
pipeline.flux, None, config, checkpoint_loader.mesh, weights_init_fn, False
@@ -61,10 +63,10 @@ def run(config):
6163
flux_state = {"params": flux_params}
6264
else:
6365
weights_init_fn = functools.partial(
64-
pipeline.flux.init_weights,
65-
rngs=checkpoint_loader.rng,
66-
max_sequence_length=config.max_sequence_length,
67-
eval_only=False
66+
pipeline.flux.init_weights,
67+
rngs=checkpoint_loader.rng,
68+
max_sequence_length=config.max_sequence_length,
69+
eval_only=False,
6870
)
6971
transformer_state, flux_state_shardings = setup_initial_state(
7072
model=pipeline.flux,
@@ -85,26 +87,22 @@ def run(config):
8587
config=config,
8688
mesh=checkpoint_loader.mesh,
8789
weights_init_fn=weights_init_fn,
88-
model_params=params['flux_vae'],
90+
model_params=params["flux_vae"],
8991
training=False,
9092
)
9193

9294
vae_state = {"params": vae_state.params}
9395
flux_state = {"params": transformer_state.params}
9496

9597
t0 = time.perf_counter()
96-
with ExitStack() as stack:
97-
imgs = pipeline(flux_params=flux_state,
98-
timesteps=50,
99-
vae_params=vae_state).block_until_ready()
98+
with ExitStack():
99+
imgs = pipeline(flux_params=flux_state, timesteps=50, vae_params=vae_state).block_until_ready()
100100
t1 = time.perf_counter()
101101
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")
102102

103103
t0 = time.perf_counter()
104-
with ExitStack() as stack:
105-
imgs = pipeline(flux_params=flux_state,
106-
timesteps=50,
107-
vae_params=vae_state).block_until_ready()
104+
with ExitStack():
105+
imgs = pipeline(flux_params=flux_state, timesteps=50, vae_params=vae_state).block_until_ready()
108106
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
109107
t1 = time.perf_counter()
110108
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")

src/maxdiffusion/maxdiffusion_utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -255,19 +255,20 @@ def calculate_unet_tflops(config, pipeline, batch_size, rngs, train):
255255
/ jax.local_device_count()
256256
)
257257

258+
258259
def get_dummy_flux_inputs(config, pipeline, batch_size):
259260
"""Returns randomly initialized flux inputs."""
260261
latents, latents_ids = pipeline.prepare_latents(
261-
batch_size=batch_size,
262-
num_channels_latents=pipeline.flux.in_channels // 4,
263-
height=config.resolution,
264-
width=config.resolution,
265-
vae_scale_factor=pipeline.vae_scale_factor,
266-
dtype=config.activations_dtype,
267-
rng=pipeline.rng
262+
batch_size=batch_size,
263+
num_channels_latents=pipeline.flux.in_channels // 4,
264+
height=config.resolution,
265+
width=config.resolution,
266+
vae_scale_factor=pipeline.vae_scale_factor,
267+
dtype=config.activations_dtype,
268+
rng=pipeline.rng,
268269
)
269270
guidance_vec = jnp.asarray([config.guidance_scale] * batch_size, dtype=config.activations_dtype)
270-
271+
271272
timesteps = jnp.ones((batch_size,), dtype=config.weights_dtype)
272273
t5_hidden_states_shape = (
273274
batch_size,
@@ -282,7 +283,7 @@ def get_dummy_flux_inputs(config, pipeline, batch_size):
282283
768,
283284
)
284285
clip_hidden_states = jnp.zeros(clip_hidden_states_shape, dtype=config.weights_dtype)
285-
286+
286287
return (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states)
287288

288289

@@ -293,7 +294,9 @@ def calculate_flux_tflops(config, pipeline, batch_size, rngs, train):
293294
cache the compilation when flash is enabled.
294295
"""
295296

296-
(latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states) = get_dummy_flux_inputs(config, pipeline, batch_size)
297+
(latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states) = get_dummy_flux_inputs(
298+
config, pipeline, batch_size
299+
)
297300
return (
298301
max_utils.calculate_model_tflops(
299302
pipeline.flux,
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
_import_structure = { "pipeline_jflux" : "JfluxPipeline" }
1+
_import_structure = {"pipeline_jflux": "JfluxPipeline"}
22

33
from .flux_pipeline import (
44
FluxPipeline,
5-
)
5+
)

0 commit comments

Comments
 (0)