Skip to content

Commit 83f4af0

Browse files
committed
remove activation_length_no_exp
1 parent 8752555 commit 83f4af0

18 files changed

Lines changed: 50 additions & 165 deletions

src/maxtext/common/common_types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
ATTN_LENGTH_NO_EXP = "activation_attn_length_no_exp"
3838

3939
LENGTH = "activation_length"
40-
LENGTH_NO_EXP = "activation_length_no_exp"
4140
PREFILL_LENGTH = "prefill_activation_length"
4241
Q_LENGTH = "activation_q_length"
4342
Q_LENGTH_NO_EXP = "activation_q_length_no_exp"

src/maxtext/configs/base.yml

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -455,16 +455,12 @@ logical_axis_rules: [
455455
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
456456
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
457457
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
458-
['activation_length', ['sequence', 'context', 'expert']],
459-
['activation_length', ['context', 'expert']],
460-
['activation_attn_length', ['sequence', 'context', 'expert']],
461-
['activation_attn_length', ['context', 'expert']],
462-
['activation_attn_length_no_exp', ['sequence', 'context']],
463-
['activation_attn_length_no_exp', ['context']],
464-
['activation_length_no_exp', ['sequence', 'context']],
465-
['activation_length_no_exp', ['context']],
466-
['activation_length_no_exp_moe', ['sequence', 'context']],
467-
['activation_length_no_exp_moe', ['context']],
458+
['activation_length', ['sequence', 'context']],
459+
['activation_length', ['context']],
460+
['activation_attn_length', ['sequence', 'context']],
461+
['activation_attn_length', ['context']],
462+
['activation_length_moe', ['sequence', 'context']],
463+
['activation_length_moe', ['context']],
468464
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
469465
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
470466
['activation_q_length', ['context', 'expert']],

src/maxtext/configs/inference/vllm.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,11 @@ logical_axis_rules: [
3636
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
3737
['activation_heads', ['model', 'expert']],
3838
['activation_kv_heads', ['model', 'expert']],
39-
['activation_attn_length', ['expert']],
39+
['activation_attn_length', ["expert"]],
4040
['activation_attn_length_no_exp', []],
41-
['activation_length', ['data', 'expert']],
41+
['activation_length', ['data']],
4242
['activation_length_moe', ['data', 'expert']],
43-
['activation_length_no_exp', 'data'],
44-
['activation_length_no_exp_moe', 'data'],
43+
['activation_length_moe', 'data'],
4544
['activation_q_length', ['expert', 'attn_dp_expert']],
4645
['activation_attn_embed', 'model'],
4746
['activation_embed', ['model', 'attn_dp']],

src/maxtext/layers/attention_mla.py

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,10 @@
5252
HEAD,
5353
Q_LORA_UP_PROJ,
5454
KV_BATCH,
55-
KV_BATCH_NO_EXP,
5655
KV_HEAD,
5756
KV_HEAD_DIM,
5857
KV_LORA_UP_PROJ,
5958
LENGTH,
60-
LENGTH_NO_EXP,
6159
MODEL_MODE_PREFILL,
6260
MODEL_MODE_TRAIN,
6361
PREFILL_KV_BATCH,
@@ -424,14 +422,11 @@ def mla_as_linen(
424422
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
425423
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
426424
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
427-
query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
428-
key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
429-
value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
430-
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
431-
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
432-
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
433-
input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED),
434-
out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV),
425+
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
426+
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
427+
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
428+
input_axis_names: AxisNames = (BATCH, LENGTH, EMBED),
429+
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
435430
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
436431
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
437432
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
@@ -496,9 +491,6 @@ def mla_as_linen(
496491
query_axis_names=query_axis_names,
497492
key_axis_names=key_axis_names,
498493
value_axis_names=value_axis_names,
499-
ep_query_axis_names=ep_query_axis_names,
500-
ep_key_axis_names=ep_key_axis_names,
501-
ep_value_axis_names=ep_value_axis_names,
502494
input_axis_names=input_axis_names,
503495
out_axis_names=out_axis_names,
504496
prefill_input_axis_names=prefill_input_axis_names,
@@ -568,14 +560,11 @@ def __init__(
568560
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
569561
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
570562
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
571-
query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
572-
key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
573-
value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
574-
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
575-
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
576-
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
577-
input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED),
578-
out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV),
563+
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
564+
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
565+
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
566+
input_axis_names: AxisNames = (BATCH, LENGTH, EMBED),
567+
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
579568
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
580569
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
581570
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
@@ -657,9 +646,6 @@ def __init__(
657646
query_axis_names=query_axis_names,
658647
key_axis_names=key_axis_names,
659648
value_axis_names=value_axis_names,
660-
ep_query_axis_names=ep_query_axis_names,
661-
ep_key_axis_names=ep_key_axis_names,
662-
ep_value_axis_names=ep_value_axis_names,
663649
input_axis_names=input_axis_names,
664650
out_axis_names=out_axis_names,
665651
prefill_input_axis_names=prefill_input_axis_names,
@@ -873,12 +859,9 @@ def mla_query_projection(
873859
if model_mode == MODEL_MODE_PREFILL:
874860
query_logical_name = self.prefill_query_axis_names
875861
wqa_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, Q_LORA_UP_PROJ)
876-
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
877-
query_logical_name = self.ep_query_axis_names
878-
wqa_logical_name = (KV_BATCH_NO_EXP, LENGTH, Q_LORA_UP_PROJ)
879862
else:
880863
query_logical_name = self.query_axis_names
881-
wqa_logical_name = (KV_BATCH, LENGTH_NO_EXP, Q_LORA_UP_PROJ)
864+
wqa_logical_name = (KV_BATCH, LENGTH, Q_LORA_UP_PROJ)
882865
query_sharding = create_sharding(self.mesh, query_logical_name)
883866
wqa_out_sharding = create_sharding(self.mesh, wqa_logical_name)
884867
# Set softmax scaling.
@@ -1029,10 +1012,8 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
10291012
"""MLA key/value projection with integrated rotary embedding."""
10301013
if model_mode == MODEL_MODE_PREFILL:
10311014
wka_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_LORA_UP_PROJ)
1032-
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
1033-
wka_logical_name = (KV_BATCH_NO_EXP, LENGTH, KV_LORA_UP_PROJ)
10341015
else:
1035-
wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ)
1016+
wka_logical_name = (KV_BATCH, LENGTH, KV_LORA_UP_PROJ)
10361017
wkva_out_sharding = create_sharding(self.mesh, wka_logical_name)
10371018
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
10381019
low_rank = checkpoint_name(low_rank, "kv_wa_proj")
@@ -1172,7 +1153,7 @@ def __call__(
11721153
else:
11731154
inputs_q = self._maybe_shard_with_logical(inputs_q, self.input_axis_names)
11741155
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names)
1175-
out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV)
1156+
out_logical_name = (BATCH, LENGTH, HEAD, D_KV)
11761157

11771158
if model_mode != MODEL_MODE_TRAIN and decoder_segment_ids is None:
11781159
decoder_segment_ids = jnp.ones(inputs_q.shape[:2], dtype=jnp.int32)

src/maxtext/layers/attention_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1142,7 +1142,7 @@ def tpu_flash_attention(
11421142
axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel)
11431143
axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q)
11441144
axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv)
1145-
indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH))
1145+
indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH_NO_EXP, KV_LENGTH))
11461146

11471147
global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv
11481148
global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel

src/maxtext/layers/decoders.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def __call__(
107107
if self.model_mode == MODEL_MODE_PREFILL:
108108
logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed")
109109
else:
110-
logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed")
110+
logical_axis_names = ("activation_batch", "activation_length", "activation_embed")
111111

112112
if model_mode == MODEL_MODE_PREFILL:
113113
inputs = _maybe_shard_with_logical(inputs, logical_axis_names)
@@ -690,7 +690,7 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi
690690

691691
cfg = self.config
692692
if cfg.shard_mode == ShardMode.EXPLICIT:
693-
norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed"))
693+
norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length", "activation_embed"))
694694
else:
695695
norm_out_sharding = None
696696

@@ -708,7 +708,7 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi
708708
out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab"))
709709
else:
710710
out_sharding = create_sharding(
711-
self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab")
711+
self.mesh, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab")
712712
)
713713

714714
# [batch, length, emb_dim] -> [batch, length, vocab_size]

src/maxtext/layers/embeddings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array:
165165
if model_mode == MODEL_MODE_PREFILL
166166
else (
167167
"activation_embed_and_logits_batch",
168-
"activation_length_no_exp",
168+
"activation_length",
169169
"activation_embed",
170170
)
171171
)
@@ -850,7 +850,7 @@ def __init__(
850850
self.attention_scaling = attention_scaling
851851

852852
self.freqs_sharding = (
853-
create_sharding(mesh, ("activation_batch", "activation_length_no_exp", "q_heads"))
853+
create_sharding(mesh, ("activation_batch", "activation_length", "q_heads"))
854854
if shard_mode == ShardMode.EXPLICIT
855855
else None
856856
)
@@ -976,7 +976,7 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
976976
inputs_complex = first_half + 1j * second_half # shape: [B, S, N, half_dim]
977977
# Apply the rotary transformation via complex multiplication.
978978
rotated_sharding = (
979-
create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", None, None))
979+
create_sharding(self.mesh, ("activation_batch", "activation_length", None, None))
980980
if self.shard_mode == ShardMode.EXPLICIT
981981
else None
982982
)

src/maxtext/layers/linears.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def __init__(
405405
if self.model_mode == MODEL_MODE_PREFILL:
406406
self.intermediate_logical = ("activation_batch", "prefill_activation_length", "activation_mlp")
407407
else:
408-
self.intermediate_logical = ("activation_batch", "activation_length_no_exp", "activation_mlp")
408+
self.intermediate_logical = ("activation_batch", "activation_length", "activation_mlp")
409409

410410
if config.fused_mlp:
411411
self.wi = DenseGeneral(

src/maxtext/layers/moe.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax.
278278

279279
contract_ind = tuple(range(0, len(norm_axis)))
280280
output_sharding = (
281-
create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_no_exp_moe", None))
281+
create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_moe", None))
282282
if self.shard_mode == ShardMode.EXPLICIT
283283
else None
284284
)
@@ -1452,11 +1452,11 @@ def reshape_and_update_weights(self, weights, indices):
14521452
self._maybe_shard_with_logical(
14531453
jnp.arange(weights.shape[0])[:, None, None], ("activation_batch_no_exp_moe", None, None)
14541454
),
1455-
self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_no_exp_moe", None)),
1455+
self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_moe", None)),
14561456
indices,
14571457
)
14581458
weight_sharding = (
1459-
create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_no_exp_moe", None))
1459+
create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_moe", None))
14601460
if self.config.shard_mode == ShardMode.EXPLICIT
14611461
else None
14621462
)
@@ -1705,13 +1705,11 @@ def dense_matmul(
17051705
) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]:
17061706
"""Dense matrix multiplication."""
17071707
# gate_logits: batch, length, expert
1708-
gate_logits = self._maybe_shard_with_logical(
1709-
gate_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None)
1710-
)
1708+
gate_logits = self._maybe_shard_with_logical(gate_logits, ("activation_batch_moe", "activation_length_moe", None))
17111709
if self.config.model_name.startswith("deepseek3"):
17121710
# pre_bias_logits is None for non-DeepSeek v3 models
17131711
pre_bias_logits = self._maybe_shard_with_logical(
1714-
pre_bias_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None)
1712+
pre_bias_logits, ("activation_batch_moe", "activation_length_moe", None)
17151713
)
17161714
top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs)
17171715
is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4

src/maxtext/layers/nnx_decoders.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def __call__(
169169
if self.model_mode == MODEL_MODE_PREFILL:
170170
logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed")
171171
else:
172-
logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed")
172+
logical_axis_names = ("activation_batch", "activation_length", "activation_embed")
173173

174174
inputs = _maybe_shard_with_logical(inputs, logical_axis_names)
175175
inputs = checkpoint_name(inputs, "decoder_layer_input")
@@ -736,7 +736,7 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode):
736736

737737
cfg = self.config
738738
if cfg.shard_mode == ShardMode.EXPLICIT:
739-
norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed"))
739+
norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length", "activation_embed"))
740740
else:
741741
norm_out_sharding = None
742742

@@ -747,7 +747,7 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode):
747747
out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab"))
748748
else:
749749
out_sharding = create_sharding(
750-
self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab")
750+
self.mesh, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab")
751751
)
752752

753753
# [batch, length, emb_dim] -> [batch, length, vocab_size]

0 commit comments

Comments
 (0)