1+ from json import encoder
12from absl import app
23from typing import Sequence
34import jax
5+ from flax import linen as nn
46import json
7+ from flax .linen import partitioning as nn_partitioning
58from maxdiffusion .models .ltx_video .transformers .transformer3d import Transformer3DModel
69import os
710import functools
811import jax .numpy as jnp
9- from maxdiffusion import FlaxAutoencoderKL , pyconfig , max_logging
12+ from maxdiffusion import pyconfig
1013from maxdiffusion .max_utils import (
1114 create_device_mesh ,
1215 setup_initial_state ,
16+ get_memory_allocations ,
1317)
14- from jax .sharding import Mesh , PositionalSharding , PartitionSpec as P
18+ from jax .sharding import Mesh , PartitionSpec as P
19+ import orbax .checkpoint as ocp
1520
1621
17- def validate_transformer_inputs (prompt_embeds , fractional_coords , latents , noise_cond ):
22+ def validate_transformer_inputs (prompt_embeds , fractional_coords , latents , noise_cond , segment_ids , encoder_attention_segment_ids ):
1823 print ("prompts_embeds.shape: " , prompt_embeds .shape , prompt_embeds .dtype )
1924 print ("fractional_coords.shape: " , fractional_coords .shape , fractional_coords .dtype )
2025 print ("latents.shape: " , latents .shape , latents .dtype )
2126 print ("noise_cond.shape: " , noise_cond .shape , noise_cond .dtype )
27+ print ("noise_cond.shape: " , noise_cond .shape , noise_cond .dtype )
28+ print ("segment_ids.shape: " , segment_ids .shape , segment_ids .dtype )
29+ print ("encoder_attention_segment_ids.shape: " , encoder_attention_segment_ids .shape , encoder_attention_segment_ids .dtype )
30+
31+
32+ def loop_body (
33+ step ,
34+ args ,
35+ transformer ,
36+ fractional_cords ,
37+ prompt_embeds ,
38+ segment_ids ,
39+ encoder_attention_segment_ids
40+ ):
41+ latents , state , noise_cond = args
42+ noise_pred = transformer .apply (
43+ {"params" : state .params },
44+ hidden_states = latents ,
45+ indices_grid = fractional_cords ,
46+ encoder_hidden_states = prompt_embeds ,
47+ timestep = noise_cond ,
48+ segment_ids = segment_ids ,
49+ encoder_attention_segment_ids = encoder_attention_segment_ids
50+ )
51+ import pdb ; pdb .set_trace ()
52+ return noise_pred , state , noise_cond #need to make changes here? latents need to be changed based on noise_pred, but needs scheduler, return noise_pred for now
53+
54+
55+
56+ def run_inference (
57+ states , transformer , config , mesh , latents , fractional_cords , prompt_embeds , timestep , segment_ids , encoder_attention_segment_ids
58+ ):
59+ transformer_state = states ["transformer" ]
60+ loop_body_p = functools .partial (
61+ loop_body ,
62+ transformer = transformer ,
63+ fractional_cords = fractional_cords ,
64+ prompt_embeds = prompt_embeds ,
65+ segment_ids = segment_ids ,
66+ encoder_attention_segment_ids = encoder_attention_segment_ids
67+ )
68+ ## TODO: add vae decode step
69+ ## TODO: add loop
70+ with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
71+ latents , transformer_state , _ = jax .lax .fori_loop (0 , 1 , loop_body_p , (latents , transformer_state , timestep ))
72+ return latents
73+
2274
2375def run (config ):
2476 key = jax .random .PRNGKey (0 )
2577
2678 devices_array = create_device_mesh (config )
2779 mesh = Mesh (devices_array , config .mesh_axes )
2880
29- batch_size , text_tokens , num_tokens , features = 4 , 256 , 2048 , 128
3081 base_dir = os .path .dirname (__file__ )
3182
3283 ##load in model config
3384 config_path = os .path .join (base_dir , "models/ltx_video/xora_v1.2-13B-balanced-128.json" )
3485 with open (config_path , "r" ) as f :
3586 model_config = json .load (f )
87+ relative_ckpt_path = model_config ["ckpt_path" ]
3688
89+ ignored_keys = ["_class_name" , "_diffusers_version" , "_name_or_path" , "causal_temporal_positioning" , "in_channels" , "ckpt_path" ]
90+ in_channels = model_config ["in_channels" ]
91+ for name in ignored_keys :
92+ if name in model_config :
93+ del model_config [name ]
3794
38- transformer = Transformer3DModel (** model_config , dtype = jnp .bfloat16 , gradient_checkpointing = "matmul_without_batch" )
39- transformer_param_shapes = transformer .init_weights (key , batch_size , text_tokens , num_tokens , features , eval_only = False )
95+
96+ transformer = Transformer3DModel (** model_config , dtype = jnp .float32 , gradient_checkpointing = "matmul_without_batch" , sharding_mesh = mesh )
97+ transformer_param_shapes = transformer .init_weights (in_channels , model_config ['caption_channels' ], eval_only = True ) #use this to test!
4098
41- key , split_key = jax .random .split (key )
4299 weights_init_fn = functools .partial (
43100 transformer .init_weights ,
44- split_key ,
45- batch_size ,
46- text_tokens ,
47- num_tokens ,
48- features ,
49- eval_only = False
101+ in_channels ,
102+ model_config ['caption_channels' ],
103+ eval_only = True
50104 )
51105
52- transformer_state , transformer_state_shardings = setup_initial_state (
106+ absolute_ckpt_path = os .path .abspath (relative_ckpt_path )
107+
108+ checkpoint_manager = ocp .CheckpointManager (absolute_ckpt_path )
109+ transformer_state , transformer_state_shardings = setup_initial_state (
53110 model = transformer ,
54111 tx = None ,
55112 config = config ,
56113 mesh = mesh ,
57114 weights_init_fn = weights_init_fn ,
115+ checkpoint_manager = checkpoint_manager ,
116+ checkpoint_item = " " ,
58117 model_params = None ,
59118 training = False ,
60119 )
120+
121+
122+
123+
124+ transformer_state = jax .device_put (transformer_state , transformer_state_shardings )
125+ get_memory_allocations ()
126+
127+ states = {}
128+ state_shardings = {}
129+
130+ state_shardings ["transformer" ] = transformer_state_shardings
131+ states ["transformer" ] = transformer_state
132+
133+ #create dummy inputs:
134+ example_inputs = {}
135+ batch_size , num_tokens = 4 , 256
136+ input_shapes = {
137+ "latents" : (batch_size , num_tokens , in_channels ),
138+ "fractional_coords" : (batch_size , 3 , num_tokens ),
139+ "prompt_embeds" : (batch_size , 128 , model_config ["caption_channels" ]),
140+ "timestep" : (batch_size , 256 ), #TODO: add in the segment id stuff
141+ "segment_ids" : (batch_size , 256 ),
142+ "encoder_attention_segment_ids" : (batch_size , 128 ),
143+ }
144+ for name , shape in input_shapes .items ():
145+ example_inputs [name ] = jnp .ones (
146+ shape , dtype = jnp .float32 if name not in ["attention_mask" , "encoder_attention_mask" ] else jnp .bool
147+ )
148+
149+ data_sharding = jax .sharding .NamedSharding (mesh , P (* config .data_sharding ))
150+ latents = jax .device_put (example_inputs ["latents" ], data_sharding )
151+ prompt_embeds = jax .device_put (example_inputs ["prompt_embeds" ], data_sharding )
152+ fractional_coords = jax .device_put (example_inputs ["fractional_coords" ], data_sharding )
153+ noise_cond = jax .device_put (example_inputs ["timestep" ], data_sharding )
154+ segment_ids = jax .device_put (example_inputs ["segment_ids" ], data_sharding )
155+ encoder_attention_segment_ids = jax .device_put (example_inputs ["encoder_attention_segment_ids" ], data_sharding )
156+
157+ validate_transformer_inputs (prompt_embeds , fractional_coords , latents , noise_cond , segment_ids , encoder_attention_segment_ids )
158+ p_run_inference = jax .jit (
159+ functools .partial (
160+ run_inference ,
161+ transformer = transformer ,
162+ config = config ,
163+ mesh = mesh ,
164+ latents = latents ,
165+ fractional_cords = fractional_coords ,
166+ prompt_embeds = prompt_embeds ,
167+ timestep = noise_cond ,
168+ segment_ids = segment_ids ,
169+ encoder_attention_segment_ids = encoder_attention_segment_ids
170+ ),
171+ in_shardings = (state_shardings ,),
172+ out_shardings = None ,
173+ )
174+ noise_pred = p_run_inference (states ).block_until_ready ()
175+ print (noise_pred ) #(4, 256, 128)
176+
177+
178+
179+
180+
181+
182+
183+
184+
185+
186+
187+
188+
189+
190+
61191
192+
193+
194+
195+
196+
197+
198+
199+
200+
201+
202+
203+
204+
205+
206+
207+
208+
62209
63210
64211def main (argv : Sequence [str ]) -> None :
@@ -71,3 +218,13 @@ def main(argv: Sequence[str]) -> None:
71218
72219
73220
221+
222+ ###setup_initial_state, can optionally load from checkpoint
223+
224+
225+
226+
227+
228+
229+
230+ #end to end steps from ltx repo: pipeline_ltx_video.py
0 commit comments