From 3897879b59584ba72c0d1a485dfcf9e1eab1c1ce Mon Sep 17 00:00:00 2001 From: Haixin Liu Date: Tue, 17 Jun 2025 23:02:50 -0700 Subject: [PATCH 1/2] fix multiple issues on GPU --- .../checkpointing/flux_checkpointer.py | 4 ++-- .../input_pipeline/_tfds_data_processing.py | 11 ++++++--- src/maxdiffusion/max_utils.py | 14 ++++++++--- .../transformers/transformer_flux_flax.py | 24 ++++++++++++------- src/maxdiffusion/trainers/flux_trainer.py | 7 +++++- 5 files changed, 42 insertions(+), 18 deletions(-) diff --git a/src/maxdiffusion/checkpointing/flux_checkpointer.py b/src/maxdiffusion/checkpointing/flux_checkpointer.py index a5e1bfc2f..89ac3764c 100644 --- a/src/maxdiffusion/checkpointing/flux_checkpointer.py +++ b/src/maxdiffusion/checkpointing/flux_checkpointer.py @@ -194,7 +194,7 @@ def load_diffusers_checkpoint(self): clip_tokenizer = CLIPTokenizer.from_pretrained(self.config.clip_model_name_or_path, max_length=77, use_fast=True) t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype) t5_tokenizer = AutoTokenizer.from_pretrained( - self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True + self.config.t5xxl_model_name_or_path, model_max_length=self.config.max_sequence_length, use_fast=True ) vae, vae_params = FlaxAutoencoderKL.from_pretrained( @@ -263,7 +263,7 @@ def load_checkpoint(self, step=None, scheduler_class=None): self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype ) t5_tokenizer = AutoTokenizer.from_pretrained( - self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True + self.config.t5xxl_model_name_or_path, model_max_length=self.config.max_sequence_length, use_fast=True ) vae = FlaxAutoencoderKL.from_config( diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index 454f65785..ce0ae5169 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -18,7 +18,7 @@ import tensorflow as tf import tensorflow.experimental.numpy as tnp from datasets import load_dataset, load_from_disk - +import jax from maxdiffusion import multihost_dataloading AUTOTUNE = tf.data.AUTOTUNE @@ -65,8 +65,13 @@ def make_tf_iterator( ) if config.cache_latents_text_encoder_outputs: train_ds.save_to_disk(config.dataset_save_location) - train_ds.cleanup_cache_files() - + # Only process 0 should attempt to clean up cache files + if jax.process_index() == 0: + try: + train_ds.cleanup_cache_files() + except FileNotFoundError: + # Ignore FileNotFoundError as files may have been cleaned up by another process + pass train_ds = load_as_tf_dataset(train_ds, global_batch_size, True, dataloading_host_count) train_ds = train_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index aaa929c59..6d6ab5c48 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -26,7 +26,7 @@ import os from pathlib import Path import subprocess - +from ctypes import cdll import numpy as np import flax @@ -63,6 +63,8 @@ from google.cloud import storage +libcudart = cdll.LoadLibrary("libcudart.so") + FrozenDict = core.frozen_dict.FrozenDict @@ -79,12 +81,18 @@ def l2norm_pytree(x): def activate_profiler(config): if jax.process_index() == 0 and config.enable_profiler: - jax.profiler.start_trace(config.tensorboard_dir) + if config.profiler == 'nsys': + libcudart.cudaProfilerStart() + else: + jax.profiler.start_trace(config.tensorboard_dir) def deactivate_profiler(config): if jax.process_index() == 0 and config.enable_profiler: - jax.profiler.stop_trace() + if config.profiler == 'nsys': + libcudart.cudaProfilerStop() + else: + jax.profiler.stop_trace() def initialize_summary_writer(config): diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 97a31ebe9..505a4f4ff 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -86,7 +86,7 @@ def setup(self): self.linear1 = nn.Dense( self.dim * 3 + self.mlp_hidden_dim, kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -96,7 +96,7 @@ def setup(self): self.linear2 = nn.Dense( self.dim, kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -209,7 +209,7 @@ def setup(self): int(self.dim * self.mlp_ratio), use_bias=True, kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -218,8 +218,8 @@ def setup(self): nn.Dense( self.dim, use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -240,7 +240,7 @@ def setup(self): int(self.dim * self.mlp_ratio), use_bias=True, kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -249,8 +249,8 @@ def setup(self): nn.Dense( self.dim, use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -483,6 +483,9 @@ def __call__( ): hidden_states = self.img_in(hidden_states) timestep = self.timestep_embedding(timestep, 256) + + timestep = nn.with_logical_constraint(timestep, ("activation_batch", None)) + if self.guidance_embeds: guidance = self.timestep_embedding(guidance, 256) else: @@ -492,6 +495,9 @@ def __call__( if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) + + temb = nn.with_logical_constraint(temb, ("activation_batch", None)) + encoder_hidden_states = self.txt_in(encoder_hidden_states) if txt_ids.ndim == 3: txt_ids = txt_ids[0] @@ -501,7 +507,7 @@ def __call__( ids = jnp.concatenate((txt_ids, img_ids), axis=0) ids = nn.with_logical_constraint(ids, ("activation_batch", None)) image_rotary_emb = self.pe_embedder(ids) - image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, ("activation_batch", "activation_embed")) + image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, (None, None)) for double_block in self.double_blocks: hidden_states, encoder_hidden_states = double_block( diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py index 32139faef..f3ca99cb2 100644 --- a/src/maxdiffusion/trainers/flux_trainer.py +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -252,6 +252,7 @@ def load_dataset(self, pipeline, params, train_states): t5_tokenizer=pipeline.t5_tokenizer, clip_text_encoder=pipeline.clip_encoder, t5_text_encoder=pipeline.t5_encoder, + max_sequence_length=config.max_sequence_length, encode_in_batches=True, encode_batch_size=16, ) @@ -348,9 +349,13 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera example_batch = load_next_batch(data_iterator, example_batch, self.config) example_batch = {key: jnp.asarray(value, dtype=self.config.activations_dtype) for key, value in example_batch.items()} - with jax.profiler.StepTraceAnnotation("train", step_num=step): + if self.config.profiler == 'nsys': with self.mesh: flux_state, train_metric, train_rngs = p_train_step(flux_state, example_batch, train_rngs) + else: + with jax.profiler.StepTraceAnnotation("train", step_num=step): + with self.mesh: + flux_state, train_metric, train_rngs = p_train_step(flux_state, example_batch, train_rngs) samples_count = self.total_train_batch_size * (step + 1) new_time = datetime.datetime.now() From 105b8fe3d3d8fa34c15957411a32cfa0c11caaa0 Mon Sep 17 00:00:00 2001 From: Haixin Liu Date: Tue, 15 Jul 2025 16:22:06 -0700 Subject: [PATCH 2/2] remove nsys --- src/maxdiffusion/max_utils.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 6d6ab5c48..fe6cc09af 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -26,7 +26,6 @@ import os from pathlib import Path import subprocess -from ctypes import cdll import numpy as np import flax @@ -63,8 +62,6 @@ from google.cloud import storage -libcudart = cdll.LoadLibrary("libcudart.so") - FrozenDict = core.frozen_dict.FrozenDict @@ -81,18 +78,12 @@ def l2norm_pytree(x): def activate_profiler(config): if jax.process_index() == 0 and config.enable_profiler: - if config.profiler == 'nsys': - libcudart.cudaProfilerStart() - else: - jax.profiler.start_trace(config.tensorboard_dir) + jax.profiler.start_trace(config.tensorboard_dir) def deactivate_profiler(config): if jax.process_index() == 0 and config.enable_profiler: - if config.profiler == 'nsys': - libcudart.cudaProfilerStop() - else: - jax.profiler.stop_trace() + jax.profiler.stop_trace() def initialize_summary_writer(config):