Skip to content

Commit 8752555

Browse files
committed
remove no exp
1 parent 9777a4c commit 8752555

16 files changed

Lines changed: 22 additions & 98 deletions

File tree

src/maxtext/common/common_types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
AxisIdxes = tuple[int, ...]
3333

3434
BATCH = "activation_batch"
35-
BATCH_NO_EXP = "activation_batch_no_exp"
3635

3736
ATTN_LENGTH = "activation_attn_length"
3837
ATTN_LENGTH_NO_EXP = "activation_attn_length_no_exp"

src/maxtext/configs/base.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,6 @@ mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'co
450450
logical_axis_rules: [
451451
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
452452
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
453-
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
454453
['activation_batch_no_exp_moe', ['data', 'fsdp', 'fsdp_transpose']],
455454
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
456455
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ data_sharding: [['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']]
3131
logical_axis_rules: [
3232
['activation_batch', ['data', 'fsdp', 'expert']],
3333
['activation_batch_moe', ['data', 'fsdp', 'expert']],
34-
['activation_batch_no_exp', ['data', 'fsdp']],
3534
['activation_batch_no_exp_moe', ['data', 'fsdp']],
3635
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']],
3736
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'context', 'expert']],

src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ mesh_axes: ['fsdp']
1818
data_sharding: [['fsdp']]
1919
logical_axis_rules: [
2020
['activation_batch', ['fsdp']],
21-
['activation_batch_no_exp', ['fsdp']],
2221
['activation_batch_moe', ['fsdp']],
2322
['activation_batch_no_exp_moe', ['fsdp']],
2423
['activation_embed_and_logits_batch', ['fsdp']],

src/maxtext/configs/inference/inference.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ base_config: "base.yml"
22

33
logical_axis_rules: [
44
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
5-
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
65
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
76
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
87
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],

src/maxtext/configs/inference/vllm.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
3131
logical_axis_rules: [
3232
['activation_batch', ['data']],
3333
['activation_batch_moe', []],
34-
['activation_batch_no_exp', []],
3534
['activation_batch_no_exp_moe', []],
3635
['activation_embed_and_logits_batch', ['data', 'expert']],
3736
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],

src/maxtext/configs/post_train/rl_mt_jt.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ logical_axis_rules: [
1818
['prefill_activation_length', ['data']],
1919
['prefill_activation_norm_length', ['data']],
2020
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
21-
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
2221
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
2322
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
2423
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],

src/maxtext/layers/attention_mla.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
AxisIdxes,
3838
AxisNames,
3939
BATCH,
40-
BATCH_NO_EXP,
4140
CACHE_BATCH,
4241
CACHE_BATCH_PREFILL,
4342
CACHE_SEQUENCE,
@@ -432,9 +431,7 @@ def mla_as_linen(
432431
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
433432
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
434433
input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED),
435-
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED),
436434
out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV),
437-
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV),
438435
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
439436
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
440437
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
@@ -503,9 +500,7 @@ def mla_as_linen(
503500
ep_key_axis_names=ep_key_axis_names,
504501
ep_value_axis_names=ep_value_axis_names,
505502
input_axis_names=input_axis_names,
506-
ep_input_axis_names=ep_input_axis_names,
507503
out_axis_names=out_axis_names,
508-
ep_out_axis_names=ep_out_axis_names,
509504
prefill_input_axis_names=prefill_input_axis_names,
510505
decode_input_axis_names=decode_input_axis_names,
511506
prefill_out_axis_names=prefill_out_axis_names,
@@ -580,9 +575,7 @@ def __init__(
580575
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
581576
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
582577
input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED),
583-
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED),
584578
out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV),
585-
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV),
586579
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
587580
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
588581
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
@@ -668,9 +661,7 @@ def __init__(
668661
ep_key_axis_names=ep_key_axis_names,
669662
ep_value_axis_names=ep_value_axis_names,
670663
input_axis_names=input_axis_names,
671-
ep_input_axis_names=ep_input_axis_names,
672664
out_axis_names=out_axis_names,
673-
ep_out_axis_names=ep_out_axis_names,
674665
prefill_input_axis_names=prefill_input_axis_names,
675666
decode_input_axis_names=decode_input_axis_names,
676667
prefill_out_axis_names=prefill_out_axis_names,
@@ -1178,10 +1169,6 @@ def __call__(
11781169
inputs_q = self._maybe_shard_with_logical(inputs_q, self.prefill_input_axis_names)
11791170
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.prefill_input_axis_names)
11801171
out_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV)
1181-
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
1182-
inputs_q = self._maybe_shard_with_logical(inputs_q, self.ep_input_axis_names)
1183-
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.ep_input_axis_names)
1184-
out_logical_name = (BATCH_NO_EXP, LENGTH, HEAD, D_KV)
11851172
else:
11861173
inputs_q = self._maybe_shard_with_logical(inputs_q, self.input_axis_names)
11871174
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names)

src/maxtext/layers/attention_op.py

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
AxisIdxes,
4040
AxisNames,
4141
BATCH,
42-
BATCH_NO_EXP,
4342
CACHE_BATCH,
4443
CACHE_BATCH_PREFILL,
4544
CACHE_HEADS,
@@ -61,7 +60,6 @@
6160
HEAD,
6261
KV_LENGTH,
6362
LENGTH,
64-
LENGTH_NO_EXP,
6563
MODEL_MODE_AUTOREGRESSIVE,
6664
MODEL_MODE_PREFILL,
6765
MODEL_MODE_TRAIN,
@@ -302,12 +300,9 @@ def attention_op_as_linen(
302300
float32_qk_product: bool = False,
303301
max_prefill_predict_length: int = -1,
304302
float32_logits: bool = False,
305-
flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH_NO_EXP, D_KV),
306-
flash_axis_names_q_ep: AxisNames = (BATCH_NO_EXP, HEAD, LENGTH, D_KV),
303+
flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV),
307304
flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV),
308-
flash_axis_names_kv_ep: AxisNames = (BATCH_NO_EXP, HEAD, KV_LENGTH, D_KV),
309-
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH_NO_EXP),
310-
flash_axis_names_splash_kernel_ep: AxisNames = (HEAD, LENGTH),
305+
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH),
311306
prefill_cache_logical_axis_names: AxisNames = (
312307
CACHE_BATCH_PREFILL,
313308
CACHE_SEQUENCE,
@@ -364,11 +359,8 @@ def attention_op_as_linen(
364359
max_prefill_predict_length=max_prefill_predict_length,
365360
float32_logits=float32_logits,
366361
flash_axis_names_q=flash_axis_names_q,
367-
flash_axis_names_q_ep=flash_axis_names_q_ep,
368362
flash_axis_names_kv=flash_axis_names_kv,
369-
flash_axis_names_kv_ep=flash_axis_names_kv_ep,
370363
flash_axis_names_splash_kernel=flash_axis_names_splash_kernel,
371-
flash_axis_names_splash_kernel_ep=flash_axis_names_splash_kernel_ep,
372364
prefill_cache_logical_axis_names=prefill_cache_logical_axis_names,
373365
cache_logical_axis_names=cache_logical_axis_names,
374366
cache_scale_logical_axis_names=cache_scale_logical_axis_names,
@@ -405,12 +397,9 @@ def __init__(
405397
float32_qk_product: bool = False,
406398
max_prefill_predict_length: int = -1,
407399
float32_logits: bool = False,
408-
flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH_NO_EXP, D_KV),
409-
flash_axis_names_q_ep: AxisNames = (BATCH_NO_EXP, HEAD, LENGTH, D_KV),
400+
flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV),
410401
flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV),
411-
flash_axis_names_kv_ep: AxisNames = (BATCH_NO_EXP, HEAD, KV_LENGTH, D_KV),
412-
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH_NO_EXP),
413-
flash_axis_names_splash_kernel_ep: AxisNames = (HEAD, LENGTH),
402+
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH),
414403
prefill_cache_logical_axis_names: AxisNames = (
415404
CACHE_BATCH_PREFILL,
416405
CACHE_SEQUENCE,
@@ -492,11 +481,8 @@ def __init__(
492481
self.max_prefill_predict_length = max_prefill_predict_length
493482
self.float32_logits = float32_logits
494483
self.flash_axis_names_q = flash_axis_names_q
495-
self.flash_axis_names_q_ep = flash_axis_names_q_ep
496484
self.flash_axis_names_kv = flash_axis_names_kv
497-
self.flash_axis_names_kv_ep = flash_axis_names_kv_ep
498485
self.flash_axis_names_splash_kernel = flash_axis_names_splash_kernel
499-
self.flash_axis_names_splash_kernel_ep = flash_axis_names_splash_kernel_ep
500486
self.prefill_cache_logical_axis_names = prefill_cache_logical_axis_names
501487
self.cache_logical_axis_names = cache_logical_axis_names
502488
self.cache_scale_logical_axis_names = cache_scale_logical_axis_names
@@ -1150,23 +1136,13 @@ def tpu_flash_attention(
11501136
segment_axis_names_kv = None
11511137
sink_axis_names = self._logical_to_mesh_axes((HEAD,))
11521138
if decoder_segment_ids is not None:
1153-
if self.config.expert_shard_attention_option == EP_AS_CONTEXT:
1154-
segment_axis_names_q = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH))
1155-
segment_axis_names_kv = self._logical_to_mesh_axes((BATCH_NO_EXP, KV_LENGTH))
1156-
else:
1157-
segment_axis_names_q = self._logical_to_mesh_axes((BATCH, Q_LENGTH_NO_EXP))
1158-
segment_axis_names_kv = self._logical_to_mesh_axes((BATCH, KV_LENGTH))
1159-
1160-
if self.config.expert_shard_attention_option == EP_AS_CONTEXT:
1161-
axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel_ep)
1162-
axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q_ep)
1163-
axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv_ep)
1164-
indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH, KV_LENGTH))
1165-
else:
1166-
axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel)
1167-
axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q)
1168-
axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv)
1169-
indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH))
1139+
segment_axis_names_q = self._logical_to_mesh_axes((BATCH, Q_LENGTH_NO_EXP))
1140+
segment_axis_names_kv = self._logical_to_mesh_axes((BATCH, KV_LENGTH))
1141+
1142+
axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel)
1143+
axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q)
1144+
axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv)
1145+
indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH))
11701146

11711147
global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv
11721148
global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel

src/maxtext/layers/attentions.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from maxtext.common.common_types import (
2929
DecoderBlockType,
3030
BATCH,
31-
BATCH_NO_EXP,
3231
HEAD,
3332
PREFILL_LENGTH,
3433
D_KV,
@@ -149,9 +148,7 @@ def attention_as_linen(
149148
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
150149
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
151150
input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED),
152-
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, ATTN_EMBED),
153151
out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV),
154-
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, HEAD, D_KV),
155152
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED),
156153
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED),
157154
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
@@ -215,9 +212,7 @@ def attention_as_linen(
215212
ep_key_axis_names=ep_key_axis_names,
216213
ep_value_axis_names=ep_value_axis_names,
217214
input_axis_names=input_axis_names,
218-
ep_input_axis_names=ep_input_axis_names,
219215
out_axis_names=out_axis_names,
220-
ep_out_axis_names=ep_out_axis_names,
221216
prefill_input_axis_names=prefill_input_axis_names,
222217
decode_input_axis_names=decode_input_axis_names,
223218
prefill_out_axis_names=prefill_out_axis_names,
@@ -316,9 +311,7 @@ def __init__(
316311
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
317312
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
318313
input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED),
319-
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, ATTN_EMBED),
320314
out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV),
321-
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, HEAD, D_KV),
322315
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED),
323316
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED),
324317
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
@@ -424,9 +417,7 @@ def __init__(
424417
self.ep_key_axis_names = ep_key_axis_names
425418
self.ep_value_axis_names = ep_value_axis_names
426419
self.input_axis_names = input_axis_names
427-
self.ep_input_axis_names = ep_input_axis_names
428420
self.out_axis_names = out_axis_names
429-
self.ep_out_axis_names = ep_out_axis_names
430421
self.prefill_input_axis_names = prefill_input_axis_names
431422
self.decode_input_axis_names = decode_input_axis_names
432423
self.prefill_out_axis_names = prefill_out_axis_names
@@ -1100,8 +1091,6 @@ def __call__(
11001091
"""
11011092
if model_mode == MODEL_MODE_PREFILL:
11021093
input_axis_names = self.prefill_input_axis_names
1103-
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
1104-
input_axis_names = self.ep_input_axis_names
11051094
elif model_mode == MODEL_MODE_TRAIN:
11061095
input_axis_names = self.input_axis_names
11071096
else:
@@ -1219,8 +1208,6 @@ def __call__(
12191208
out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
12201209
if model_mode == MODEL_MODE_PREFILL:
12211210
out = self._maybe_shard_with_logical(out, self.prefill_out_axis_names)
1222-
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
1223-
out = self._maybe_shard_with_logical(out, self.ep_out_axis_names)
12241211
elif model_mode == MODEL_MODE_TRAIN:
12251212
out = self._maybe_shard_with_logical(out, self.out_axis_names)
12261213
else:

0 commit comments

Comments
 (0)