1818from contextlib import nullcontext
1919import functools
2020import json
21- import os
2221import jax
23- import jax .numpy as jnp
2422from jax .sharding import Mesh
2523import orbax .checkpoint as ocp
2624import grain .python as grain
3230from maxdiffusion .models .flux .transformers .transformer_flux_flax import FluxTransformer2DModel
3331from ..pipelines .flux .flux_pipeline import FluxPipeline
3432
35- from transformers import (CLIPTokenizer , FlaxCLIPTextModel , T5EncoderModel , FlaxT5EncoderModel , AutoTokenizer )
33+ from transformers import (CLIPTokenizer , FlaxCLIPTextModel , FlaxT5EncoderModel , AutoTokenizer )
3634
37- from maxdiffusion .checkpointing .checkpointing_utils import (
38- create_orbax_checkpoint_manager
39- )
35+ from maxdiffusion .checkpointing .checkpointing_utils import (create_orbax_checkpoint_manager )
4036from maxdiffusion .models .flux .util import load_flow_model
4137
4238FLUX_CHECKPOINT = "FLUX_CHECKPOINT"
4945VAE_STATE_KEY = "vae_state"
5046VAE_STATE_SHARDINGS_KEY = "vae_state_shardings"
5147
48+
5249class FluxCheckpointer (ABC ):
5350
5451 def __init__ (self , config , checkpoint_type ):
@@ -87,12 +84,14 @@ def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training)
8784 tx , learning_rate_scheduler = self ._create_optimizer (self .config , learning_rate )
8885
8986 transformer_eval_params = transformer .init_weights (
90- rngs = self .rng , max_sequence_length = self .config .max_sequence_length , eval_only = True
87+ rngs = self .rng , max_sequence_length = self .config .max_sequence_length , eval_only = True
9188 )
9289
9390 transformer_params = load_flow_model (self .config .flux_name , transformer_eval_params , "cpu" )
9491
95- weights_init_fn = functools .partial (pipeline .flux .init_weights , rngs = self .rng , max_sequence_length = self .config .max_sequence_length )
92+ weights_init_fn = functools .partial (
93+ pipeline .flux .init_weights , rngs = self .rng , max_sequence_length = self .config .max_sequence_length
94+ )
9695 flux_state , state_mesh_shardings = max_utils .setup_initial_state (
9796 model = pipeline .flux ,
9897 tx = tx ,
@@ -150,10 +149,11 @@ def _set_checkpoint_format(self, checkpoint_format):
150149 def save_checkpoint (self , train_step , pipeline , train_states ):
151150 def config_to_json (model_or_config ):
152151 return json .loads (model_or_config .to_json_string ())
152+
153153 items = {
154154 "flux_config" : ocp .args .JsonSave (config_to_json (pipeline .flux )),
155155 "vae_config" : ocp .args .JsonSave (config_to_json (pipeline .vae )),
156- "scheduler_config" : ocp .args .JsonSave (config_to_json (pipeline .scheduler ))
156+ "scheduler_config" : ocp .args .JsonSave (config_to_json (pipeline .scheduler )),
157157 }
158158
159159 items [FLUX_STATE_KEY ] = ocp .args .PyTreeSave (train_states [FLUX_STATE_KEY ])
@@ -165,7 +165,7 @@ def config_to_json(model_or_config):
165165 def load_params (self , step = None ):
166166
167167 self .checkpoint_format = _CHECKPOINT_FORMAT_ORBAX
168-
168+
169169 def load_flux_configs_from_orbax (self , step ):
170170 max_logging .log ("Restoring stable diffusion configs" )
171171 if step is None :
@@ -188,68 +188,57 @@ def load_diffusers_checkpoint(self):
188188 context = jax .default_device (jax .devices ("cpu" )[0 ])
189189 else :
190190 context = nullcontext ()
191-
191+
192192 with context :
193- clip_encoder = FlaxCLIPTextModel .from_pretrained (
194- self .config .clip_model_name_or_path , dtype = self .config .weights_dtype
195- )
196- clip_tokenizer = CLIPTokenizer .from_pretrained (
197- self .config .clip_model_name_or_path ,
198- max_length = 77 ,
199- use_fast = True
200- )
193+ clip_encoder = FlaxCLIPTextModel .from_pretrained (self .config .clip_model_name_or_path , dtype = self .config .weights_dtype )
194+ clip_tokenizer = CLIPTokenizer .from_pretrained (self .config .clip_model_name_or_path , max_length = 77 , use_fast = True )
201195 t5_encoder = FlaxT5EncoderModel .from_pretrained (self .config .t5xxl_model_name_or_path , dtype = self .config .weights_dtype )
202196 t5_tokenizer = AutoTokenizer .from_pretrained (
203- self .config .t5xxl_model_name_or_path ,
204- max_length = self .config .max_sequence_length ,
205- use_fast = True
197+ self .config .t5xxl_model_name_or_path , max_length = self .config .max_sequence_length , use_fast = True
206198 )
207199
208200 vae , vae_params = FlaxAutoencoderKL .from_pretrained (
209- self .config .pretrained_model_name_or_path ,
210- subfolder = "vae" ,
211- from_pt = True ,
212- use_safetensors = True ,
213- dtype = self .config .weights_dtype
201+ self .config .pretrained_model_name_or_path ,
202+ subfolder = "vae" ,
203+ from_pt = True ,
204+ use_safetensors = True ,
205+ dtype = self .config .weights_dtype ,
214206 )
215207
216208 # loading from pretrained here causes a crash when trying to compile the model
217209 # Failed to load HSACO: HIP_ERROR_NoBinaryForGpu
218210 transformer = FluxTransformer2DModel .from_config (
219- self .config .pretrained_model_name_or_path ,
220- subfolder = "transformer" ,
221- mesh = self .mesh ,
222- split_head_dim = self .config .split_head_dim ,
223- attention_kernel = self .config .attention ,
224- flash_block_sizes = flash_block_sizes ,
225- dtype = self .config .activations_dtype ,
226- weights_dtype = self .config .weights_dtype ,
227- precision = max_utils .get_precision (self .config ),
211+ self .config .pretrained_model_name_or_path ,
212+ subfolder = "transformer" ,
213+ mesh = self .mesh ,
214+ split_head_dim = self .config .split_head_dim ,
215+ attention_kernel = self .config .attention ,
216+ flash_block_sizes = flash_block_sizes ,
217+ dtype = self .config .activations_dtype ,
218+ weights_dtype = self .config .weights_dtype ,
219+ precision = max_utils .get_precision (self .config ),
228220 )
229221 transformer_eval_params = transformer .init_weights (
230- rngs = self .rng , max_sequence_length = self .config .max_sequence_length , eval_only = True
222+ rngs = self .rng , max_sequence_length = self .config .max_sequence_length , eval_only = True
231223 )
232-
224+
233225 transformer_params = load_flow_model (self .config .flux_name , transformer_eval_params , "cpu" )
234226
235227 pipeline = FluxPipeline (
236- t5_encoder ,
237- clip_encoder ,
238- vae ,
239- t5_tokenizer ,
240- clip_tokenizer ,
241- transformer ,
242- None ,
243- dtype = self .config .activations_dtype ,
244- mesh = self .mesh ,
245- config = self .config ,
246- rng = self .rng
228+ t5_encoder ,
229+ clip_encoder ,
230+ vae ,
231+ t5_tokenizer ,
232+ clip_tokenizer ,
233+ transformer ,
234+ None ,
235+ dtype = self .config .activations_dtype ,
236+ mesh = self .mesh ,
237+ config = self .config ,
238+ rng = self .rng ,
247239 )
248240
249- params = {
250- FLUX_VAE_PARAMS_KEY : vae_params ,
251- FLUX_TRANSFORMER_PARAMS_KEY : transformer_params
252- }
241+ params = {FLUX_VAE_PARAMS_KEY : vae_params , FLUX_TRANSFORMER_PARAMS_KEY : transformer_params }
253242
254243 return pipeline , params
255244
@@ -267,55 +256,50 @@ def load_checkpoint(self, step=None, scheduler_class=None):
267256
268257 with context :
269258 clip_encoder = FlaxCLIPTextModel .from_pretrained (
270- self .config .clip_model_name_or_path , dtype = self .config .weights_dtype
259+ self .config .clip_model_name_or_path , dtype = self .config .weights_dtype
271260 )
272- clip_tokenizer = CLIPTokenizer .from_pretrained (
273- self .config .clip_model_name_or_path ,
274- max_length = 77 ,
275- use_fast = True
261+ clip_tokenizer = CLIPTokenizer .from_pretrained (self .config .clip_model_name_or_path , max_length = 77 , use_fast = True )
262+ t5_encoder = FlaxT5EncoderModel .from_pretrained (
263+ self .config .t5xxl_model_name_or_path , dtype = self .config .weights_dtype
276264 )
277- t5_encoder = FlaxT5EncoderModel .from_pretrained (self .config .t5xxl_model_name_or_path , dtype = self .config .weights_dtype )
278265 t5_tokenizer = AutoTokenizer .from_pretrained (
279- self .config .t5xxl_model_name_or_path ,
280- max_length = self .config .max_sequence_length ,
281- use_fast = True
266+ self .config .t5xxl_model_name_or_path , max_length = self .config .max_sequence_length , use_fast = True
282267 )
283268
284269 vae = FlaxAutoencoderKL .from_config (
285- model_configs [0 ]["vae_config" ],
286- dtype = self .config .activations_dtype ,
287- weights_dtype = self .config .weights_dtype ,
288- from_pt = self .config .from_pt ,
270+ model_configs [0 ]["vae_config" ],
271+ dtype = self .config .activations_dtype ,
272+ weights_dtype = self .config .weights_dtype ,
273+ from_pt = self .config .from_pt ,
289274 )
290275
291276 transformer = FluxTransformer2DModel .from_config (
292- model_configs [0 ]["flux_config" ],
293- mesh = self .mesh ,
294- split_head_dim = self .config .split_head_dim ,
295- attention_kernel = self .config .attention ,
296- flash_block_sizes = max_utils .get_flash_block_sizes (self .config ),
297- dtype = self .config .activations_dtype ,
298- weights_dtype = self .config .weights_dtype ,
299- precision = max_utils .get_precision (self .config ),
300- from_pt = self .config .from_pt ,
277+ model_configs [0 ]["flux_config" ],
278+ mesh = self .mesh ,
279+ split_head_dim = self .config .split_head_dim ,
280+ attention_kernel = self .config .attention ,
281+ flash_block_sizes = max_utils .get_flash_block_sizes (self .config ),
282+ dtype = self .config .activations_dtype ,
283+ weights_dtype = self .config .weights_dtype ,
284+ precision = max_utils .get_precision (self .config ),
285+ from_pt = self .config .from_pt ,
301286 )
302287
303288 pipeline = FluxPipeline (
304- t5_encoder ,
305- clip_encoder ,
306- vae ,
307- t5_tokenizer ,
308- clip_tokenizer ,
309- transformer ,
310- None ,
311- dtype = self .config .activations_dtype ,
312- mesh = self .mesh ,
313- config = self .config ,
314- rng = self .rng
289+ t5_encoder ,
290+ clip_encoder ,
291+ vae ,
292+ t5_tokenizer ,
293+ clip_tokenizer ,
294+ transformer ,
295+ None ,
296+ dtype = self .config .activations_dtype ,
297+ mesh = self .mesh ,
298+ config = self .config ,
299+ rng = self .rng ,
315300 )
316301
317302 else :
318303 pipeline , params = self .load_diffusers_checkpoint ()
319-
320- return pipeline , params
321304
305+ return pipeline , params
0 commit comments