Skip to content

Commit 484ac77

Browse files
Merge pull request #3473 from AI-Hypercomputer:chengnuojin-separate-moe
PiperOrigin-RevId: 888698420
2 parents ff34f1b + 79c778a commit 484ac77

30 files changed

Lines changed: 338 additions & 298 deletions

src/maxtext/configs/base.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,9 @@ custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying y
435435
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
436436
logical_axis_rules: [
437437
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
438+
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
438439
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
440+
['activation_batch_no_exp_moe', ['data', 'fsdp', 'fsdp_transpose']],
439441
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
440442
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
441443
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
@@ -448,14 +450,18 @@ logical_axis_rules: [
448450
['activation_attn_length_no_exp', ['context']],
449451
['activation_length_no_exp', ['sequence', 'context']],
450452
['activation_length_no_exp', ['context']],
453+
['activation_length_no_exp_moe', ['sequence', 'context']],
454+
['activation_length_no_exp_moe', ['context']],
451455
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
456+
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
452457
['activation_q_length', ['context', 'expert']],
453458
['activation_q_length_no_exp', ['context']],
454459
['prefill_activation_length', ['sequence', 'context']],
455460
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
456461
['activation_kv_length', []],
457462
['activation_attn_embed', ['tensor', 'tensor_transpose']],
458463
['activation_embed', ['tensor', 'tensor_transpose']],
464+
['activation_embed_moe', ['tensor', 'tensor_transpose']],
459465
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
460466
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
461467
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
@@ -484,6 +490,14 @@ logical_axis_rules: [
484490
['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
485491
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
486492
['embed_no_exp', ['fsdp', 'sequence', 'context']],
493+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
494+
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
495+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
496+
['embed_moe', ['fsdp', 'sequence', 'context', 'expert']],
497+
['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
498+
['embed_no_exp_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
499+
['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
500+
['embed_no_exp_moe', ['fsdp', 'sequence', 'context']],
487501
['embed_tensor_transpose', ['tensor_transpose']],
488502
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
489503
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ mesh_axes: ['data', 'stage', 'fsdp', 'tensor', 'expert']
2828
data_sharding: [['data', 'stage', 'fsdp', 'tensor', 'expert']]
2929
logical_axis_rules: [
3030
['activation_batch', ['data', 'fsdp', 'expert']],
31+
['activation_batch_moe', ['data', 'fsdp', 'expert']],
3132
['activation_batch_no_exp', ['data', 'fsdp']],
33+
['activation_batch_no_exp_moe', ['data', 'fsdp']],
3234
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']],
3335
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'expert']],
3436
['activation_heads', ['tensor']],
@@ -38,6 +40,7 @@ logical_axis_rules: [
3840
['activation_q_length', ['expert']],
3941
['activation_attn_embed', ['tensor']],
4042
['activation_embed', ['tensor']],
43+
['activation_embed_moe', ['tensor']],
4144
['activation_mlp', ['tensor']],
4245
['activation_kv', ['tensor']],
4346
['activation_prefill_kv_batch', ['data', 'fsdp', 'expert']],
@@ -55,7 +58,10 @@ logical_axis_rules: [
5558
['q_heads', ['tensor']],
5659
['kv_heads', ['tensor']],
5760
['embed', ['fsdp', 'expert']],
61+
['embed_moe', ['fsdp', 'expert']],
5862
['embed_no_exp', ['fsdp']],
63+
['embed_no_exp_moe', ['fsdp']],
64+
['embed_moe', ['fsdp']],
5965
['q_lora', ['fsdp']],
6066
['kv_lora', ['fsdp']],
6167
['norm', ['tensor']],

src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ data_sharding: [['fsdp']]
1919
logical_axis_rules: [
2020
['activation_batch', ['fsdp']],
2121
['activation_batch_no_exp', ['fsdp']],
22+
['activation_batch_moe', ['fsdp']],
23+
['activation_batch_no_exp_moe', ['fsdp']],
2224
['activation_embed_and_logits_batch', ['fsdp']],
2325
['activation_embed_and_logits_batch_sequence', ['fsdp']],
2426
['activation_prefill_kv_batch', ['fsdp']],
@@ -27,6 +29,8 @@ logical_axis_rules: [
2729
['decode_batch', ['fsdp']],
2830
['embed', ['fsdp']],
2931
['embed_no_exp', ['fsdp']],
32+
['embed_moe', ['fsdp']],
33+
['embed_no_exp_moe', ['fsdp']],
3034
['q_lora', ['fsdp']],
3135
['kv_lora', ['fsdp']],
3236
['exp_with_fsdp', 'fsdp'],

src/maxtext/configs/inference/vllm.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,23 @@ weight_dtype: bfloat16
3030
mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
3131
logical_axis_rules: [
3232
['activation_batch', ['expert']],
33+
['activation_batch_moe', ['expert']],
3334
['activation_batch_no_exp', []],
35+
['activation_batch_no_exp_moe', []],
3436
['activation_embed_and_logits_batch', ['expert']],
3537
['activation_embed_and_logits_batch_sequence', ['expert']],
3638
['activation_heads', ['model']],
3739
['activation_kv_heads', ['model']],
3840
['activation_attn_length', ['expert']],
3941
['activation_attn_length_no_exp', []],
4042
['activation_length', ['data', 'expert']],
43+
['activation_length_moe', ['data', 'expert']],
4144
['activation_length_no_exp', 'data'],
45+
['activation_length_no_exp_moe', 'data'],
4246
['activation_q_length', ['expert', 'attn_dp_expert']],
4347
['activation_attn_embed', 'model'],
4448
['activation_embed', ['model', 'attn_dp']],
49+
['activation_embed_moe', ['model', 'attn_dp']],
4550
['activation_mlp', ['model', 'attn_dp']],
4651
['activation_kv', ['model']],
4752
['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']],
@@ -50,6 +55,7 @@ logical_axis_rules: [
5055
['activation_kv_head_dim', ['model']],
5156
['activation_vocab', ['model', 'attn_dp']],
5257
['activation_norm_length', []],
58+
['activation_norm_length_moe', []],
5359
['activation_exp', ['expert', 'attn_dp_expert']],
5460
['decode_batch', ['expert', 'attn_dp_expert']],
5561
['decode_length', []],
@@ -63,8 +69,10 @@ logical_axis_rules: [
6369
['kv_head_dim', []],
6470
['kv', []],
6571
['embed', ['expert', 'attn_dp_expert']],
72+
['embed_moe', ['expert', 'attn_dp_expert']],
6673
['embed_tensor_transpose', ['attn_dp', 'model']],
6774
['embed_no_exp', []],
75+
['embed_no_exp_moe', []],
6876
['q_lora', ['expert', 'attn_dp_expert']],
6977
['kv_lora', ['expert', 'attn_dp_expert']],
7078
['norm', []],

src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,18 @@ mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']
6060
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
6161
logical_axis_rules: [
6262
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
63+
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6364
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6465
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6566
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
6667
['activation_norm_length', ['context']],
68+
['activation_norm_length_moe', ['context']],
6769
['activation_heads', []],
6870
['activation_stage', 'stage'],
6971
['embed', ['fsdp']],
72+
['embed_moe', ['fsdp']],
7073
['embed_no_exp', ['fsdp']],
74+
['embed_no_exp_moe', ['fsdp']],
7175
['q_lora', ['fsdp']],
7276
['kv_lora', ['fsdp']],
7377
['layers', 'stage'],

0 commit comments

Comments
 (0)