@@ -455,30 +455,40 @@ custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying y
455455mesh_axes : ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
456456logical_axis_rules : [
457457 ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
458- ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
458+ # Vocab activation
459459 ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
460460 ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
461- ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
462- ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
463- ['activation_length', ['sequence', 'context']],
464- ['activation_length', ['context']],
465- ['activation_attn_length', ['sequence', 'context']],
466- ['activation_attn_length', ['context']],
461+ ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
462+ ['activation_vocab', ['tensor', 'tensor_transpose']],
463+ ['activation_vocab', 'tensor_sequence'],
464+ ['activation_vocab', ['sequence','context']],
465+ # Vocab weight
466+ ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
467+ ['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
468+ # MoE activation
469+ ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
467470 ['activation_length_moe', ['sequence', 'context']],
468471 ['activation_length_moe', ['context']],
469- ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
470472 ['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
473+ ['activation_embed_moe', ['tensor', 'tensor_transpose']],
474+ ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
475+ ['activation_exp', ['expert']],
476+ # MoE weight
477+ ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
478+ ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], # should be deprecated
479+ ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
480+ ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
481+ ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
482+ ['embed_moe', ['fsdp', 'sequence', 'context']],
483+ ['embed_tensor_transpose', ['tensor_transpose']], # should be deprecated
484+ ['exp_with_fsdp', 'fsdp'], # should be deprecated
485+ # Attn activation
486+ ['activation_attn_length', ['sequence', 'context']],
487+ ['activation_attn_length', ['context']],
471488 ['activation_q_length', ['context']],
472- ['prefill_activation_length', ['sequence', 'context']],
473- ['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
474489 ['activation_kv_length', []],
475490 ['activation_attn_embed', ['tensor', 'tensor_transpose']],
476- ['activation_embed', ['tensor', 'tensor_transpose']],
477- ['activation_embed_moe', ['tensor', 'tensor_transpose']],
478- ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
479- ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
480491 ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
481- ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
482492 ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
483493 ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
484494 ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
@@ -490,21 +500,11 @@ logical_axis_rules: [
490500 ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
491501 ['decode_length', ['sequence']],
492502 ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
493- ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
494503 ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
495504 ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
496505 ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
497506 ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
498507 ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
499- ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
500- ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
501- ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
502- ['embed', ['fsdp', 'sequence', 'context', 'expert']],
503- ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
504- ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
505- ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
506- ['embed_moe', ['fsdp', 'sequence', 'context']],
507- ['embed_tensor_transpose', ['tensor_transpose']],
508508 ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
509509 ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
510510 ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
@@ -515,29 +515,50 @@ logical_axis_rules: [
515515 ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
516516 ['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
517517 ["kv_lora_up_proj",[]],
518+ # Other activation
519+ ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
520+ ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
521+ ['activation_length', ['sequence', 'context']],
522+ ['activation_length', ['context']],
523+ ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
524+ ['activation_embed', ['tensor', 'tensor_transpose']],
525+ ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
526+ ['activation_stage', 'stage'],
527+ # Other weight
528+ ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
529+ ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
530+ ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
531+ ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
532+ ['embed', ['fsdp', 'sequence', 'context', 'expert']],
518533 ['norm', ['tensor', 'tensor_transpose']],
519534 ['layers', 'stage'],
535+ # Others (inference etc.)
536+ ['prefill_activation_length', ['sequence', 'context']],
537+ ['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
538+ ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
539+ ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
540+ ['decode_length', ['sequence']],
541+ ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
542+ ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
543+ ['paged_kv_heads', ['tensor']],
544+ ['diloco', 'diloco'],
545+ ['engram_dim', ['tensor']],
546+ # Should remove following names as they duplicate shardings
520547 ['qkv', []],
521548 ['kv', []],
522549 ['kv_head_dim', []],
523550 ['cache_batch_prefill', []],
524551 ['cache_batch', []],
525552 ['cache_heads_none', []],
526- ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
527- ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
528553 ['cache_kv', []],
529554 ['cache_sequence', []],
530555 ['exp', 'expert'],
531- ['exp_with_fsdp', 'fsdp'],
532- ['paged_kv_heads', ['tensor']],
533556 ['num_pages', []],
534557 ['tokens_per_page', []],
535558 ['paged_kv_head_dim_size', []],
536559 ['dense_layers', []],
537560 ['moe_layers', []],
538- ['engram_dim', ['tensor']],
539561 ['mhc', []],
540- ['diloco', 'diloco'],
541562 ]
542563# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
543564data_sharding : [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
0 commit comments