Skip to content

Commit f5ce0df

Browse files
Merge pull request #3578 from AI-Hypercomputer:chengnuojin-no-exp
PiperOrigin-RevId: 896774782
2 parents df49109 + ef90d9b commit f5ce0df

36 files changed

Lines changed: 147 additions & 342 deletions

src/maxtext/common/common_types.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,11 @@
3232
AxisIdxes = tuple[int, ...]
3333

3434
BATCH = "activation_batch"
35-
BATCH_NO_EXP = "activation_batch_no_exp"
3635

3736
ATTN_LENGTH = "activation_attn_length"
3837
ATTN_LENGTH_NO_EXP = "activation_attn_length_no_exp"
3938

4039
LENGTH = "activation_length"
41-
LENGTH_NO_EXP = "activation_length_no_exp"
4240
PREFILL_LENGTH = "prefill_activation_length"
4341
Q_LENGTH = "activation_q_length"
4442
Q_LENGTH_NO_EXP = "activation_q_length_no_exp"

src/maxtext/configs/base.yml

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -455,22 +455,17 @@ mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'co
455455
logical_axis_rules: [
456456
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
457457
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
458-
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
459458
['activation_batch_no_exp_moe', ['data', 'fsdp', 'fsdp_transpose']],
460459
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
461460
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
462461
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
463462
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
464-
['activation_length', ['sequence', 'context', 'expert']],
465-
['activation_length', ['context', 'expert']],
466-
['activation_attn_length', ['sequence', 'context', 'expert']],
467-
['activation_attn_length', ['context', 'expert']],
468-
['activation_attn_length_no_exp', ['sequence', 'context']],
469-
['activation_attn_length_no_exp', ['context']],
470-
['activation_length_no_exp', ['sequence', 'context']],
471-
['activation_length_no_exp', ['context']],
472-
['activation_length_no_exp_moe', ['sequence', 'context']],
473-
['activation_length_no_exp_moe', ['context']],
463+
['activation_length', ['sequence', 'context']],
464+
['activation_length', ['context']],
465+
['activation_attn_length', ['sequence', 'context']],
466+
['activation_attn_length', ['context']],
467+
['activation_length_moe', ['sequence', 'context']],
468+
['activation_length_moe', ['context']],
474469
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
475470
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
476471
['activation_q_length', ['context', 'expert']],

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ data_sharding: [['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']]
3131
logical_axis_rules: [
3232
['activation_batch', ['data', 'fsdp', 'expert']],
3333
['activation_batch_moe', ['data', 'fsdp', 'expert']],
34-
['activation_batch_no_exp', ['data', 'fsdp']],
3534
['activation_batch_no_exp_moe', ['data', 'fsdp']],
3635
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']],
3736
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'context', 'expert']],

src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ mesh_axes: ['fsdp']
1818
data_sharding: [['fsdp']]
1919
logical_axis_rules: [
2020
['activation_batch', ['fsdp']],
21-
['activation_batch_no_exp', ['fsdp']],
2221
['activation_batch_moe', ['fsdp']],
2322
['activation_batch_no_exp_moe', ['fsdp']],
2423
['activation_embed_and_logits_batch', ['fsdp']],

src/maxtext/configs/inference/inference.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ base_config: "base.yml"
22

33
logical_axis_rules: [
44
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
5-
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
65
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
76
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
87
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],

src/maxtext/configs/inference/vllm.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,16 @@ mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
3131
logical_axis_rules: [
3232
['activation_batch', ['data']],
3333
['activation_batch_moe', []],
34-
['activation_batch_no_exp', []],
3534
['activation_batch_no_exp_moe', []],
3635
['activation_embed_and_logits_batch', ['data', 'expert']],
3736
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
3837
['activation_heads', ['model', 'expert']],
3938
['activation_kv_heads', ['model', 'expert']],
4039
['activation_attn_length', ['expert']],
4140
['activation_attn_length_no_exp', []],
42-
['activation_length', ['data', 'expert']],
41+
['activation_length', ['data']],
4342
['activation_length_moe', ['data', 'expert']],
44-
['activation_length_no_exp', 'data'],
45-
['activation_length_no_exp_moe', 'data'],
43+
['activation_length_moe', 'data'],
4644
['activation_q_length', ['expert', 'attn_dp_expert']],
4745
['activation_attn_embed', 'model'],
4846
['activation_embed', ['model', 'attn_dp']],

src/maxtext/configs/post_train/rl_mt_jt.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ logical_axis_rules: [
1818
['prefill_activation_length', ['data']],
1919
['prefill_activation_norm_length', ['data']],
2020
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
21-
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
2221
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
2322
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
2423
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],

src/maxtext/layers/attention_mla.py

Lines changed: 13 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
AxisIdxes,
3838
AxisNames,
3939
BATCH,
40-
BATCH_NO_EXP,
4140
CACHE_BATCH,
4241
CACHE_BATCH_PREFILL,
4342
CACHE_SEQUENCE,
@@ -53,12 +52,10 @@
5352
HEAD,
5453
Q_LORA_UP_PROJ,
5554
KV_BATCH,
56-
KV_BATCH_NO_EXP,
5755
KV_HEAD,
5856
KV_HEAD_DIM,
5957
KV_LORA_UP_PROJ,
6058
LENGTH,
61-
LENGTH_NO_EXP,
6259
MODEL_MODE_PREFILL,
6360
MODEL_MODE_TRAIN,
6461
PREFILL_KV_BATCH,
@@ -425,16 +422,11 @@ def mla_as_linen(
425422
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
426423
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
427424
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
428-
query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
429-
key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
430-
value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
431-
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
432-
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
433-
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
434-
input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED),
435-
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED),
436-
out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV),
437-
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV),
425+
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
426+
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
427+
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
428+
input_axis_names: AxisNames = (BATCH, LENGTH, EMBED),
429+
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
438430
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
439431
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
440432
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
@@ -499,13 +491,8 @@ def mla_as_linen(
499491
query_axis_names=query_axis_names,
500492
key_axis_names=key_axis_names,
501493
value_axis_names=value_axis_names,
502-
ep_query_axis_names=ep_query_axis_names,
503-
ep_key_axis_names=ep_key_axis_names,
504-
ep_value_axis_names=ep_value_axis_names,
505494
input_axis_names=input_axis_names,
506-
ep_input_axis_names=ep_input_axis_names,
507495
out_axis_names=out_axis_names,
508-
ep_out_axis_names=ep_out_axis_names,
509496
prefill_input_axis_names=prefill_input_axis_names,
510497
decode_input_axis_names=decode_input_axis_names,
511498
prefill_out_axis_names=prefill_out_axis_names,
@@ -573,16 +560,11 @@ def __init__(
573560
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
574561
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
575562
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
576-
query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
577-
key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
578-
value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
579-
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
580-
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
581-
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
582-
input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED),
583-
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED),
584-
out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV),
585-
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV),
563+
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
564+
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
565+
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
566+
input_axis_names: AxisNames = (BATCH, LENGTH, EMBED),
567+
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
586568
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
587569
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
588570
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
@@ -664,13 +646,8 @@ def __init__(
664646
query_axis_names=query_axis_names,
665647
key_axis_names=key_axis_names,
666648
value_axis_names=value_axis_names,
667-
ep_query_axis_names=ep_query_axis_names,
668-
ep_key_axis_names=ep_key_axis_names,
669-
ep_value_axis_names=ep_value_axis_names,
670649
input_axis_names=input_axis_names,
671-
ep_input_axis_names=ep_input_axis_names,
672650
out_axis_names=out_axis_names,
673-
ep_out_axis_names=ep_out_axis_names,
674651
prefill_input_axis_names=prefill_input_axis_names,
675652
decode_input_axis_names=decode_input_axis_names,
676653
prefill_out_axis_names=prefill_out_axis_names,
@@ -882,12 +859,9 @@ def mla_query_projection(
882859
if model_mode == MODEL_MODE_PREFILL:
883860
query_logical_name = self.prefill_query_axis_names
884861
wqa_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, Q_LORA_UP_PROJ)
885-
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
886-
query_logical_name = self.ep_query_axis_names
887-
wqa_logical_name = (KV_BATCH_NO_EXP, LENGTH, Q_LORA_UP_PROJ)
888862
else:
889863
query_logical_name = self.query_axis_names
890-
wqa_logical_name = (KV_BATCH, LENGTH_NO_EXP, Q_LORA_UP_PROJ)
864+
wqa_logical_name = (KV_BATCH, LENGTH, Q_LORA_UP_PROJ)
891865
query_sharding = create_sharding(self.mesh, query_logical_name)
892866
wqa_out_sharding = create_sharding(self.mesh, wqa_logical_name)
893867
# Set softmax scaling.
@@ -1038,10 +1012,8 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
10381012
"""MLA key/value projection with integrated rotary embedding."""
10391013
if model_mode == MODEL_MODE_PREFILL:
10401014
wka_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_LORA_UP_PROJ)
1041-
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
1042-
wka_logical_name = (KV_BATCH_NO_EXP, LENGTH, KV_LORA_UP_PROJ)
10431015
else:
1044-
wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ)
1016+
wka_logical_name = (KV_BATCH, LENGTH, KV_LORA_UP_PROJ)
10451017
wkva_out_sharding = create_sharding(self.mesh, wka_logical_name)
10461018
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
10471019
low_rank = checkpoint_name(low_rank, "kv_wa_proj")
@@ -1178,14 +1150,10 @@ def __call__(
11781150
inputs_q = self._maybe_shard_with_logical(inputs_q, self.prefill_input_axis_names)
11791151
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.prefill_input_axis_names)
11801152
out_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV)
1181-
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
1182-
inputs_q = self._maybe_shard_with_logical(inputs_q, self.ep_input_axis_names)
1183-
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.ep_input_axis_names)
1184-
out_logical_name = (BATCH_NO_EXP, LENGTH, HEAD, D_KV)
11851153
else:
11861154
inputs_q = self._maybe_shard_with_logical(inputs_q, self.input_axis_names)
11871155
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names)
1188-
out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV)
1156+
out_logical_name = (BATCH, LENGTH, HEAD, D_KV)
11891157

11901158
if model_mode != MODEL_MODE_TRAIN and decoder_segment_ids is None:
11911159
decoder_segment_ids = jnp.ones(inputs_q.shape[:2], dtype=jnp.int32)

0 commit comments

Comments
 (0)