Skip to content

Commit a214533

Browse files
committed
refactor rule order and add vocab embed
1 parent d643a1e commit a214533

8 files changed

Lines changed: 61 additions & 34 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -455,30 +455,40 @@ custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying y
455455
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
456456
logical_axis_rules: [
457457
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
458-
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
458+
# Vocab activation
459459
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
460460
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
461-
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
462-
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
463-
['activation_length', ['sequence', 'context']],
464-
['activation_length', ['context']],
465-
['activation_attn_length', ['sequence', 'context']],
466-
['activation_attn_length', ['context']],
461+
['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
462+
['activation_vocab', ['tensor', 'tensor_transpose']],
463+
['activation_vocab', 'tensor_sequence'],
464+
['activation_vocab', ['sequence','context']],
465+
# Vocab weight
466+
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
467+
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
468+
# MoE activation
469+
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
467470
['activation_length_moe', ['sequence', 'context']],
468471
['activation_length_moe', ['context']],
469-
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
470472
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
473+
['activation_embed_moe', ['tensor', 'tensor_transpose']],
474+
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
475+
['activation_exp', ['expert']],
476+
# MoE weight
477+
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
478+
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], # should be deprecated
479+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
480+
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
481+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
482+
['embed_moe', ['fsdp', 'sequence', 'context']],
483+
['embed_tensor_transpose', ['tensor_transpose']], # should be deprecated
484+
['exp_with_fsdp', 'fsdp'], # should be deprecated
485+
# Attn activation
486+
['activation_attn_length', ['sequence', 'context']],
487+
['activation_attn_length', ['context']],
471488
['activation_q_length', ['context']],
472-
['prefill_activation_length', ['sequence', 'context']],
473-
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
474489
['activation_kv_length', []],
475490
['activation_attn_embed', ['tensor', 'tensor_transpose']],
476-
['activation_embed', ['tensor', 'tensor_transpose']],
477-
['activation_embed_moe', ['tensor', 'tensor_transpose']],
478-
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
479-
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
480491
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
481-
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
482492
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
483493
['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
484494
['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
@@ -490,21 +500,11 @@ logical_axis_rules: [
490500
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
491501
['decode_length', ['sequence']],
492502
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
493-
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
494503
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
495504
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
496505
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
497506
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
498507
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
499-
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
500-
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
501-
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
502-
['embed', ['fsdp', 'sequence', 'context', 'expert']],
503-
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
504-
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
505-
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
506-
['embed_moe', ['fsdp', 'sequence', 'context']],
507-
['embed_tensor_transpose', ['tensor_transpose']],
508508
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
509509
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
510510
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
@@ -515,29 +515,50 @@ logical_axis_rules: [
515515
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
516516
['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
517517
["kv_lora_up_proj",[]],
518+
# Other activation
519+
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
520+
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
521+
['activation_length', ['sequence', 'context']],
522+
['activation_length', ['context']],
523+
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
524+
['activation_embed', ['tensor', 'tensor_transpose']],
525+
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
526+
['activation_stage', 'stage'],
527+
# Other weight
528+
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
529+
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
530+
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
531+
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
532+
['embed', ['fsdp', 'sequence', 'context', 'expert']],
518533
['norm', ['tensor', 'tensor_transpose']],
519534
['layers', 'stage'],
535+
# Others (inference etc.)
536+
['prefill_activation_length', ['sequence', 'context']],
537+
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
538+
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
539+
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
540+
['decode_length', ['sequence']],
541+
['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
542+
['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
543+
['paged_kv_heads', ['tensor']],
544+
['diloco', 'diloco'],
545+
['engram_dim', ['tensor']],
546+
# Should remove following names as they duplicate shardings
520547
['qkv', []],
521548
['kv', []],
522549
['kv_head_dim', []],
523550
['cache_batch_prefill', []],
524551
['cache_batch', []],
525552
['cache_heads_none', []],
526-
['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
527-
['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
528553
['cache_kv', []],
529554
['cache_sequence', []],
530555
['exp', 'expert'],
531-
['exp_with_fsdp', 'fsdp'],
532-
['paged_kv_heads', ['tensor']],
533556
['num_pages', []],
534557
['tokens_per_page', []],
535558
['paged_kv_head_dim_size', []],
536559
['dense_layers', []],
537560
['moe_layers', []],
538-
['engram_dim', ['tensor']],
539561
['mhc', []],
540-
['diloco', 'diloco'],
541562
]
542563
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
543564
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]

src/maxtext/configs/inference/inference.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ logical_axis_rules: [
2828
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
2929
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
3030
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']],
31+
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
3132
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
3233
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
3334
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],

src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ logical_axis_rules: [
7070
['activation_stage', 'stage'],
7171
['embed', ['fsdp']],
7272
['embed_moe', ['fsdp']],
73+
['embed_no_exp', ['fsdp']],
74+
['embed_no_exp_moe', ['fsdp']],
7375
['q_lora', ['fsdp']],
7476
['kv_lora', ['fsdp']],
7577
['layers', 'stage'],

src/maxtext/configs/models/deepseek3-671b-batchsplit.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ logical_axis_rules: [
7171
['activation_stage', 'stage'],
7272
['embed', ['fsdp']],
7373
['embed_moe', ['fsdp']],
74+
['embed_no_exp', ['fsdp']],
75+
['embed_no_exp_moe', ['fsdp']],
7476
['q_lora', ['fsdp']],
7577
['kv_lora', ['fsdp']],
7678
['layers', 'stage'],

src/maxtext/configs/post_train/rl_mt_jt.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ logical_axis_rules: [
4242
['decode_length', []],
4343
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
4444
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']],
45+
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
4546
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
4647
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
4748
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],

src/maxtext/layers/decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi
736736
out_features_shape=cfg.vocab_size,
737737
weight_dtype=cfg.weight_dtype,
738738
dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability
739-
kernel_axes=("embed", "vocab"),
739+
kernel_axes=("embed_vocab", "vocab"),
740740
shard_mode=cfg.shard_mode,
741741
name="logits_dense",
742742
matmul_precision=self.config.matmul_precision,

src/maxtext/layers/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __init__(
132132
(self.num_embeddings, self.num_features),
133133
self.config.weight_dtype,
134134
),
135-
sharding=("vocab", "embed"),
135+
sharding=("vocab", "embed_vocab"),
136136
)
137137

138138
def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array:

src/maxtext/layers/nnx_decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def __init__(
287287
out_features_shape=config.vocab_size,
288288
weight_dtype=config.weight_dtype,
289289
dtype=jnp.float32 if config.logits_dot_in_fp32 else config.dtype,
290-
kernel_axes=("embed", "vocab"),
290+
kernel_axes=("embed_vocab", "vocab"),
291291
shard_mode=config.shard_mode,
292292
matmul_precision=self.config.matmul_precision,
293293
parameter_memory_host_offload=config.parameter_memory_host_offload,

0 commit comments

Comments
 (0)