@@ -175,12 +175,11 @@ def __init__(
175175 kernel_init = nnx .with_partitioning (
176176 nnx .initializers .xavier_uniform (),
177177 (
178+ "embed" ,
178179 None ,
179180 "mlp" ,
180- "embed" ,
181181 ),
182182 ),
183- bias_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" )),
184183 )
185184
186185 def __call__ (self , x : jax .Array ) -> jax .Array :
@@ -217,6 +216,8 @@ def __init__(
217216 else :
218217 raise NotImplementedError (f"{ activation_fn } is not implemented." )
219218
219+ self .drop_out = nnx .Dropout (dropout )
220+
220221 self .proj_out = nnx .Linear (
221222 rngs = rngs ,
222223 in_features = inner_dim ,
@@ -228,15 +229,16 @@ def __init__(
228229 kernel_init = nnx .with_partitioning (
229230 nnx .initializers .xavier_uniform (),
230231 (
231- None ,
232232 "embed" ,
233233 "mlp" ,
234+ None ,
234235 ),
235236 ),
236237 )
237238
238- def __call__ (self , hidden_states : jax .Array ) -> jax .Array :
239+ def __call__ (self , hidden_states : jax .Array , deterministic : bool = True , rngs : nnx . Rngs = None ) -> jax .Array :
239240 hidden_states = self .act_fn (hidden_states )
241+ hidden_states = self .drop_out (hidden_states , deterministic = deterministic , rngs = rngs )
240242 return self .proj_out (hidden_states )
241243
242244
@@ -260,6 +262,7 @@ def __init__(
260262 weights_dtype : jnp .dtype = jnp .float32 ,
261263 precision : jax .lax .Precision = None ,
262264 attention : str = "dot_product" ,
265+ dropout : float = 0.0 ,
263266 ):
264267
265268 # 1. Self-attention
@@ -278,6 +281,7 @@ def __init__(
278281 weights_dtype = weights_dtype ,
279282 precision = precision ,
280283 attention_kernel = attention ,
284+ dropout = dropout
281285 )
282286
283287 # 1. Cross-attention
@@ -295,6 +299,7 @@ def __init__(
295299 weights_dtype = weights_dtype ,
296300 precision = precision ,
297301 attention_kernel = attention ,
302+ dropout = dropout
298303 )
299304 assert cross_attn_norm is True
300305 self .norm2 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = True )
@@ -308,13 +313,16 @@ def __init__(
308313 dtype = dtype ,
309314 weights_dtype = weights_dtype ,
310315 precision = precision ,
316+ dropout = dropout
311317 )
312318 self .norm3 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = False )
313319
314320 key = rngs .params ()
315- self .adaln_scale_shift_table = nnx .Param (jax .random .normal (key , (1 , 6 , dim )) / dim ** 0.5 )
321+ self .adaln_scale_shift_table = nnx .Param (
322+ jax .random .normal (key , (1 , 6 , dim )) / dim ** 0.5 ,
323+ sharding = ("embed" ,))
316324
317- def __call__ (self , hidden_states : jax .Array , encoder_hidden_states : jax .Array , temb : jax .Array , rotary_emb : jax .Array ):
325+ def __call__ (self , hidden_states : jax .Array , encoder_hidden_states : jax .Array , temb : jax .Array , rotary_emb : jax .Array , deterministic : bool = True , rngs : nnx . Rngs = None , ):
318326 shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
319327 (self .adaln_scale_shift_table + temb ), 6 , axis = 1
320328 )
@@ -324,18 +332,18 @@ def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, t
324332 # 1. Self-attention
325333 norm_hidden_states = (self .norm1 (hidden_states ) * (1 + scale_msa ) + shift_msa ).astype (hidden_states .dtype )
326334 attn_output = self .attn1 (
327- hidden_states = norm_hidden_states , encoder_hidden_states = norm_hidden_states , rotary_emb = rotary_emb
335+ hidden_states = norm_hidden_states , encoder_hidden_states = norm_hidden_states , rotary_emb = rotary_emb , deterministic = deterministic , rngs = rngs
328336 )
329337 hidden_states = (hidden_states + attn_output * gate_msa ).astype (hidden_states .dtype )
330338
331339 # 2. Cross-attention
332340 norm_hidden_states = self .norm2 (hidden_states )
333- attn_output = self .attn2 (hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states )
341+ attn_output = self .attn2 (hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states , deterministic = deterministic , rngs = rngs )
334342 hidden_states = hidden_states + attn_output
335343
336344 # 3. Feed-forward
337345 norm_hidden_states = (self .norm3 (hidden_states ) * (1 + c_scale_msa ) + c_shift_msa ).astype (hidden_states .dtype )
338- ff_output = self .ffn (norm_hidden_states )
346+ ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
339347 hidden_states = (hidden_states + ff_output * c_gate_msa ).astype (hidden_states .dtype )
340348 return hidden_states
341349
@@ -356,6 +364,7 @@ def __init__(
356364 freq_dim : int = 256 ,
357365 ffn_dim : int = 13824 ,
358366 num_layers : int = 40 ,
367+ dropout : float = 0.0 ,
359368 cross_attn_norm : bool = True ,
360369 qk_norm : Optional [str ] = "rms_norm_across_heads" ,
361370 eps : float = 1e-6 ,
@@ -424,6 +433,7 @@ def init_block(rngs):
424433 weights_dtype = weights_dtype ,
425434 precision = precision ,
426435 attention = attention ,
436+ dropout = dropout ,
427437 )
428438
429439 self .gradient_checkpoint = GradientCheckpointType .from_str (remat_policy )
@@ -454,6 +464,8 @@ def __call__(
454464 encoder_hidden_states_image : Optional [jax .Array ] = None ,
455465 return_dict : bool = True ,
456466 attention_kwargs : Optional [Dict [str , Any ]] = None ,
467+ deterministic : bool = True ,
468+ rngs : nnx .Rngs = None ,
457469 ) -> Union [jax .Array , Dict [str , jax .Array ]]:
458470 batch_size , _ , num_frames , height , width = hidden_states .shape
459471 p_t , p_h , p_w = self .config .patch_size
@@ -476,20 +488,21 @@ def __call__(
476488 raise NotImplementedError ("img2vid is not yet implemented." )
477489
478490 def scan_fn (carry , block ):
479- hidden_states , encoder_hidden_states , timestep_proj , rotary_emb = carry
480- hidden_states = block (hidden_states , encoder_hidden_states , timestep_proj , rotary_emb )
481- return (hidden_states , encoder_hidden_states , timestep_proj , rotary_emb )
491+ hidden_states_carry , rngs_carry = carry
492+ hidden_states = block (hidden_states_carry , encoder_hidden_states , timestep_proj , rotary_emb , deterministic , rngs_carry )
493+ new_carry = (hidden_states , rngs_carry )
494+ return new_carry , None
482495
483- initial_carry = (hidden_states , encoder_hidden_states , timestep_proj , rotary_emb )
484496 rematted_block_forward = self .gradient_checkpoint .apply (scan_fn )
485- final_carry = nnx .scan (
497+ initial_carry = (hidden_states , rngs )
498+ final_carry , _ = nnx .scan (
486499 rematted_block_forward ,
487500 length = self .num_layers ,
488501 in_axes = (nnx .Carry , 0 ),
489- out_axes = nnx .Carry ,
502+ out_axes = ( nnx .Carry , 0 ) ,
490503 )(initial_carry , self .blocks )
491504
492- hidden_states = final_carry [ 0 ]
505+ hidden_states , _ = final_carry
493506
494507 shift , scale = jnp .split (self .scale_shift_table + jnp .expand_dims (temb , axis = 1 ), 2 , axis = 1 )
495508
0 commit comments