@@ -453,86 +453,96 @@ compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_s
453453shard_mode : " auto" # can be either auto or explicit
454454custom_mesh_and_rule : " " # replace default mesh and logical rule by specifying yml name under config/mesh_and_rule/.
455455mesh_axes : ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
456- logical_axis_rules : [
457- ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
458- # Vocab activation
456+ logical_axis_rules : [
457+ # ==========================================
458+ # Vocabulary Embedding
459+ # ==========================================
460+ # Vocab Activations
459461 ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
460462 ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
461463 ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
462- ['activation_vocab', ['tensor', 'tensor_transpose']],
463- ['activation_vocab', 'tensor_sequence'],
464- ['activation_vocab', ['sequence','context']],
465- # Vocab weight
464+ # ['activation_vocab', ['tensor', 'tensor_transpose']],
465+ # ['activation_vocab', 'tensor_sequence'],
466+ # ['activation_vocab', ['sequence', 'context']],
467+ # Vocab Weights
466468 ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
467469 ['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
468- # MoE activation
469- ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
470- ['activation_length_moe', ['sequence', 'context']],
471- ['activation_length_moe', ['context']],
472- ['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
473- ['activation_embed_moe', ['tensor', 'tensor_transpose']],
474- ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
475- ['activation_exp', ['expert']],
476- # MoE weight
477- ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
478- ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], # should be deprecated
479- ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
480- ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
481- ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
482- ['embed_moe', ['fsdp', 'sequence', 'context']],
483- ['embed_tensor_transpose', ['tensor_transpose']], # should be deprecated
484- ['exp_with_fsdp', 'fsdp'], # should be deprecated
485- # Attn activation
470+ # ==========================================
471+ # Attention
472+ # ==========================================
473+ # Attention Activations
474+ ['activation_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence', 'autoregressive']],
475+ ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
486476 ['activation_attn_length', ['sequence', 'context']],
487- ['activation_attn_length', ['context']],
477+ # ['activation_attn_length', ['context']],
488478 ['activation_q_length', ['context']],
489479 ['activation_kv_length', []],
490480 ['activation_attn_embed', ['tensor', 'tensor_transpose']],
491481 ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
492482 ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
493483 ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
494- ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
495- ['activation_vocab', ['tensor', 'tensor_transpose']],
496- ['activation_vocab', 'tensor_sequence'],
497- ['activation_vocab', ['sequence','context']],
498- ['activation_stage', 'stage'],
499- ['activation_exp', ['expert']],
500- ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
501- ['decode_length', ['sequence']],
502- ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
503- ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
504- ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
484+ # Attention Weights
505485 ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
506486 ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
507487 ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
488+ ['qkv', []],
489+ ['kv', []],
490+ ['kv_head_dim', []],
508491 ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
509- ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
510- ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
511- ['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
512- ["q_lora_up_proj",[]],
492+ # ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
493+ # ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
494+ # ['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
495+ ["q_lora_up_proj", []],
513496 ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
514- ['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
515- ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
516- ['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
517- ["kv_lora_up_proj",[]],
518- # Other activation
519- ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
520- ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
497+ # ['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
498+ # ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
499+ # ['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
500+ ["kv_lora_up_proj", []],
501+ # ==========================================
502+ # Mixture of Experts (MoE)
503+ # ==========================================
504+ # MoE Activations
505+ ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
506+ ['activation_length_moe', ['sequence', 'context']],
507+ # ['activation_length_moe', ['context']],
508+ ['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
509+ ['activation_embed_moe', ['tensor', 'tensor_transpose']],
510+ ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
511+ ['activation_exp', ['expert']],
512+ # MoE Weights
513+ ['exp', 'expert'],
514+ ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
515+ ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
516+ # ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
517+ # ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
518+ # ['embed_moe', ['fsdp', 'sequence', 'context']],
519+ # ==========================================
520+ # Standard MLP / Dense Layers / Model Structure
521+ # ==========================================
522+ # Dense Activations
523+ ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
524+ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
521525 ['activation_length', ['sequence', 'context']],
522- ['activation_length', ['context']],
526+ # ['activation_length', ['context']],
523527 ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
524528 ['activation_embed', ['tensor', 'tensor_transpose']],
525- ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
526529 ['activation_stage', 'stage'],
527- # Other weight
530+ # General Weights
528531 ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
529532 ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
530- ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
531- ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
532- ['embed', ['fsdp', 'sequence', 'context', 'expert']],
533+ # ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context', 'expert']],
534+ # ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
535+ # ['embed', ['fsdp', 'sequence', 'context', 'expert']],
533536 ['norm', ['tensor', 'tensor_transpose']],
534537 ['layers', 'stage'],
535- # Others (inference etc.)
538+ ['diloco', 'diloco'],
539+ ['engram_dim', ['tensor']],
540+ ['dense_layers', []],
541+ ['moe_layers', []],
542+ ['mhc', []],
543+ # ==========================================
544+ # Inference(Prefill, Decode, Cache)
545+ # ==========================================
536546 ['prefill_activation_length', ['sequence', 'context']],
537547 ['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
538548 ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
@@ -541,25 +551,21 @@ logical_axis_rules: [
541551 ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
542552 ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
543553 ['paged_kv_heads', ['tensor']],
544- ['diloco', 'diloco'],
545- ['engram_dim', ['tensor']],
546- # Should remove following names as they duplicate shardings
547- ['qkv', []],
548- ['kv', []],
549- ['kv_head_dim', []],
550554 ['cache_batch_prefill', []],
551555 ['cache_batch', []],
552556 ['cache_heads_none', []],
553557 ['cache_kv', []],
554558 ['cache_sequence', []],
555- ['exp', 'expert'],
556559 ['num_pages', []],
557560 ['tokens_per_page', []],
558561 ['paged_kv_head_dim_size', []],
559- ['dense_layers', []],
560- ['moe_layers', []],
561- ['mhc', []],
562- ]
562+ # ==========================================
563+ # Deprecated / Scheduled for Removal
564+ # ==========================================
565+ ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
566+ ['embed_tensor_transpose', ['tensor_transpose']],
567+ ['exp_with_fsdp', 'fsdp'],
568+ ]
563569# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
564570data_sharding : [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
565571input_data_sharding_logical_axes : ['activation_embed_and_logits_batch', 'activation_norm_length']
0 commit comments