Skip to content

Commit 9fc1ccc

Browse files
committed
refactor logical rule
1 parent a214533 commit 9fc1ccc

1 file changed

Lines changed: 72 additions & 66 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 72 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -453,86 +453,96 @@ compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_s
453453
shard_mode: "auto" # can be either auto or explicit
454454
custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying yml name under config/mesh_and_rule/.
455455
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
456-
logical_axis_rules: [
457-
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
458-
# Vocab activation
456+
logical_axis_rules: [
457+
# ==========================================
458+
# Vocabulary Embedding
459+
# ==========================================
460+
# Vocab Activations
459461
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
460462
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
461463
['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
462-
['activation_vocab', ['tensor', 'tensor_transpose']],
463-
['activation_vocab', 'tensor_sequence'],
464-
['activation_vocab', ['sequence','context']],
465-
# Vocab weight
464+
# ['activation_vocab', ['tensor', 'tensor_transpose']],
465+
# ['activation_vocab', 'tensor_sequence'],
466+
# ['activation_vocab', ['sequence', 'context']],
467+
# Vocab Weights
466468
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
467469
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
468-
# MoE activation
469-
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
470-
['activation_length_moe', ['sequence', 'context']],
471-
['activation_length_moe', ['context']],
472-
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
473-
['activation_embed_moe', ['tensor', 'tensor_transpose']],
474-
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
475-
['activation_exp', ['expert']],
476-
# MoE weight
477-
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
478-
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], # should be deprecated
479-
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
480-
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
481-
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
482-
['embed_moe', ['fsdp', 'sequence', 'context']],
483-
['embed_tensor_transpose', ['tensor_transpose']], # should be deprecated
484-
['exp_with_fsdp', 'fsdp'], # should be deprecated
485-
# Attn activation
470+
# ==========================================
471+
# Attention
472+
# ==========================================
473+
# Attention Activations
474+
['activation_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence', 'autoregressive']],
475+
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
486476
['activation_attn_length', ['sequence', 'context']],
487-
['activation_attn_length', ['context']],
477+
# ['activation_attn_length', ['context']],
488478
['activation_q_length', ['context']],
489479
['activation_kv_length', []],
490480
['activation_attn_embed', ['tensor', 'tensor_transpose']],
491481
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
492482
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
493483
['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
494-
['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
495-
['activation_vocab', ['tensor', 'tensor_transpose']],
496-
['activation_vocab', 'tensor_sequence'],
497-
['activation_vocab', ['sequence','context']],
498-
['activation_stage', 'stage'],
499-
['activation_exp', ['expert']],
500-
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
501-
['decode_length', ['sequence']],
502-
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
503-
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
504-
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
484+
# Attention Weights
505485
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
506486
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
507487
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
488+
['qkv', []],
489+
['kv', []],
490+
['kv_head_dim', []],
508491
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
509-
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
510-
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
511-
['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
512-
["q_lora_up_proj",[]],
492+
# ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
493+
# ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
494+
# ['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
495+
["q_lora_up_proj", []],
513496
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
514-
['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
515-
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
516-
['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
517-
["kv_lora_up_proj",[]],
518-
# Other activation
519-
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
520-
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
497+
# ['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
498+
# ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
499+
# ['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
500+
["kv_lora_up_proj", []],
501+
# ==========================================
502+
# Mixture of Experts (MoE)
503+
# ==========================================
504+
# MoE Activations
505+
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
506+
['activation_length_moe', ['sequence', 'context']],
507+
# ['activation_length_moe', ['context']],
508+
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
509+
['activation_embed_moe', ['tensor', 'tensor_transpose']],
510+
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
511+
['activation_exp', ['expert']],
512+
# MoE Weights
513+
['exp', 'expert'],
514+
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
515+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
516+
# ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
517+
# ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
518+
# ['embed_moe', ['fsdp', 'sequence', 'context']],
519+
# ==========================================
520+
# Standard MLP / Dense Layers / Model Structure
521+
# ==========================================
522+
# Dense Activations
523+
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
524+
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
521525
['activation_length', ['sequence', 'context']],
522-
['activation_length', ['context']],
526+
# ['activation_length', ['context']],
523527
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
524528
['activation_embed', ['tensor', 'tensor_transpose']],
525-
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
526529
['activation_stage', 'stage'],
527-
# Other weight
530+
# General Weights
528531
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
529532
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
530-
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
531-
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
532-
['embed', ['fsdp', 'sequence', 'context', 'expert']],
533+
# ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context', 'expert']],
534+
# ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
535+
# ['embed', ['fsdp', 'sequence', 'context', 'expert']],
533536
['norm', ['tensor', 'tensor_transpose']],
534537
['layers', 'stage'],
535-
# Others (inference etc.)
538+
['diloco', 'diloco'],
539+
['engram_dim', ['tensor']],
540+
['dense_layers', []],
541+
['moe_layers', []],
542+
['mhc', []],
543+
# ==========================================
544+
# Inference(Prefill, Decode, Cache)
545+
# ==========================================
536546
['prefill_activation_length', ['sequence', 'context']],
537547
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
538548
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
@@ -541,25 +551,21 @@ logical_axis_rules: [
541551
['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
542552
['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
543553
['paged_kv_heads', ['tensor']],
544-
['diloco', 'diloco'],
545-
['engram_dim', ['tensor']],
546-
# Should remove following names as they duplicate shardings
547-
['qkv', []],
548-
['kv', []],
549-
['kv_head_dim', []],
550554
['cache_batch_prefill', []],
551555
['cache_batch', []],
552556
['cache_heads_none', []],
553557
['cache_kv', []],
554558
['cache_sequence', []],
555-
['exp', 'expert'],
556559
['num_pages', []],
557560
['tokens_per_page', []],
558561
['paged_kv_head_dim_size', []],
559-
['dense_layers', []],
560-
['moe_layers', []],
561-
['mhc', []],
562-
]
562+
# ==========================================
563+
# Deprecated / Scheduled for Removal
564+
# ==========================================
565+
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
566+
['embed_tensor_transpose', ['tensor_transpose']],
567+
['exp_with_fsdp', 'fsdp'],
568+
]
563569
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
564570
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
565571
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length']

0 commit comments

Comments
 (0)