@@ -171,6 +171,17 @@ def __init__(
171171 dtype = dtype ,
172172 param_dtype = weights_dtype ,
173173 precision = precision ,
174+ # kernel_init=nnx.with_partitioning(
175+ # nnx.initializers.xavier_uniform(),
176+ # ("blockwise", None, None),
177+ # ),
178+ # bias_init=nnx.with_partitioning(
179+ # nnx.initializers.zeros,
180+ # (
181+ # "blockwise",
182+ # None,
183+ # ),
184+ # ),
174185 )
175186
176187 def __call__ (self , x : jax .Array ) -> jax .Array :
@@ -218,10 +229,18 @@ def __init__(
218229 kernel_init = nnx .with_partitioning (
219230 nnx .initializers .xavier_uniform (),
220231 (
232+ # "blockwise",
221233 "mlp" ,
222234 "embed" ,
223235 ),
224236 ),
237+ # bias_init=nnx.with_partitioning(
238+ # nnx.initializers.zeros,
239+ # (
240+ # "blockwise",
241+ # "mlp",
242+ # ),
243+ # ),
225244 )
226245
227246 def __call__ (self , hidden_states : jax .Array ) -> jax .Array :
@@ -389,10 +408,9 @@ def __init__(
389408 )
390409
391410 # 3. Transformer blocks
392- @nnx .split_rngs (splits = num_layers )
393- @nnx .vmap (in_axes = 0 , out_axes = 0 )
394- def init_block (rngs ):
395- return WanTransformerBlock (
411+ blocks = []
412+ for _ in range (num_layers ):
413+ block = WanTransformerBlock (
396414 rngs = rngs ,
397415 dim = inner_dim ,
398416 ffn_dim = ffn_dim ,
@@ -408,10 +426,15 @@ def init_block(rngs):
408426 precision = precision ,
409427 attention = attention ,
410428 )
429+ blocks .append (block )
430+ self .blocks = blocks
411431
412- self .gradient_checkpoint = GradientCheckpointType .from_str (remat_policy )
432+ # 2. Use a predicate to create a "state-free" version.
433+ # The lambda function `lambda _: False` simply tells nnx.state_if
434+ # to filter out ALL state components (params, variables, etc.).
435+ # self.block_template = nnx.state_if(lambda _: False, template_block_with_state)
413436
414- self .blocks = init_block ( rngs )
437+ self .gradient_checkpoint = GradientCheckpointType . from_str ( remat_policy )
415438
416439 self .norm_out = FP32LayerNorm (rngs = rngs , dim = inner_dim , eps = eps , elementwise_affine = False )
417440 self .proj_out = nnx .Linear (
@@ -426,7 +449,7 @@ def init_block(rngs):
426449 key = rngs .params ()
427450 self .scale_shift_table = nnx .Param (
428451 jax .random .normal (key , (1 , 2 , inner_dim )) / inner_dim ** 0.5 ,
429- kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , "embed" )),
452+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , "embed" )),
430453 )
431454
432455 def __call__ (
@@ -456,22 +479,12 @@ def __call__(
456479
457480 if encoder_hidden_states_image is not None :
458481 raise NotImplementedError ("img2vid is not yet implemented." )
459-
460- def scan_fn (carry , block ):
461- hidden_states , encoder_hidden_states , timestep_proj , rotary_emb = carry
462- hidden_states = block (hidden_states , encoder_hidden_states , timestep_proj , rotary_emb )
463- return (hidden_states , encoder_hidden_states , timestep_proj , rotary_emb )
464-
465- initial_carry = (hidden_states , encoder_hidden_states , timestep_proj , rotary_emb )
466- rematted_block_forward = self .gradient_checkpoint .apply (scan_fn )
467- final_carry = nnx .scan (
468- rematted_block_forward ,
469- length = self .num_layers ,
470- in_axes = (nnx .Carry , 0 ),
471- out_axes = nnx .Carry ,
472- )(initial_carry , self .blocks )
473-
474- hidden_states = final_carry [0 ]
482+
483+ for block in self .blocks :
484+ def block_forward (hidden_states , encoder_hidden_states , timestep_proj , rotary_emb ):
485+ return block (hidden_states , encoder_hidden_states , timestep_proj , rotary_emb )
486+ rematted_block_forward = self .gradient_checkpoint .apply (block_forward )
487+ hidden_states = rematted_block_forward (hidden_states , encoder_hidden_states , timestep_proj , rotary_emb )
475488
476489 shift , scale = jnp .split (self .scale_shift_table + jnp .expand_dims (temb , axis = 1 ), 2 , axis = 1 )
477490
0 commit comments