Skip to content

Commit ef90d9b

Browse files
committed
refresh dump sharding test
1 parent 83f4af0 commit ef90d9b

15 files changed

Lines changed: 81 additions & 85 deletions

File tree

src/maxtext/configs/inference/vllm.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ logical_axis_rules: [
3636
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
3737
['activation_heads', ['model', 'expert']],
3838
['activation_kv_heads', ['model', 'expert']],
39-
['activation_attn_length', ["expert"]],
39+
['activation_attn_length', ['expert']],
4040
['activation_attn_length_no_exp', []],
4141
['activation_length', ['data']],
4242
['activation_length_moe', ['data', 'expert']],

src/maxtext/layers/attention_op.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
DEFAULT_MASK_VALUE,
5656
DType,
5757
D_KV,
58-
EP_AS_CONTEXT,
5958
EP_AS_FSDP,
6059
HEAD,
6160
KV_LENGTH,
@@ -64,7 +63,6 @@
6463
MODEL_MODE_PREFILL,
6564
MODEL_MODE_TRAIN,
6665
PREFILL_LENGTH,
67-
Q_LENGTH,
6866
Q_LENGTH_NO_EXP,
6967
)
7068
from maxtext.inference import page_manager
@@ -1271,10 +1269,7 @@ def wrap_splash_kernel(single_head_mask):
12711269
return splash_kernel
12721270

12731271
splash_kernel = wrap_splash_kernel(single_head_mask)
1274-
if self.config.expert_shard_attention_option == EP_AS_CONTEXT:
1275-
segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH,))
1276-
else:
1277-
segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH_NO_EXP,))
1272+
segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH_NO_EXP,))
12781273
elif self.config.use_jax_splash and self.config.expert_shard_attention_option == EP_AS_FSDP:
12791274
if self.config.use_max_logit_estimate > 0:
12801275
sa_config = dataclasses.replace(sa_config, max_logit_const=self.config.use_max_logit_estimate)

tests/unit/custom_mesh_and_rule_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def test_ds3_large_pp(self):
6363
"base_emb_dim=256",
6464
"base_mlp_dim=256",
6565
"base_num_decoder_layers=4",
66+
"use_tokamax_splash=true",
6667
"custom_mesh_and_rule=pipeline-large-moe",
6768
)
6869
)

tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,55 +14,55 @@
1414
},
1515
{
1616
"attention_mla/inputs_q: bfloat16[192,2048,2048]": {
17-
"logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')",
17+
"logic_axes": "('activation_batch', 'activation_length', 'activation_embed')",
1818
"PartitionSpec": "P('fsdp', None, None)"
1919
}
2020
},
2121
{
2222
"attention_mla/inputs_kv: bfloat16[192,2048,2048]": {
23-
"logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')",
23+
"logic_axes": "('activation_batch', 'activation_length', 'activation_embed')",
2424
"PartitionSpec": "P('fsdp', None, None)"
2525
}
2626
},
2727
{
2828
"attention_mla/q_nope: bfloat16[192,2048,16,128]": {
29-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
29+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
3030
"PartitionSpec": "P('fsdp', None, None, None)"
3131
}
3232
},
3333
{
3434
"attention_mla/q_pe: bfloat16[192,2048,16,64]": {
35-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
35+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
3636
"PartitionSpec": "P('fsdp', None, None, None)"
3737
}
3838
},
3939
{
4040
"attention_mla/query: bfloat16[192,2048,16,192]": {
41-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
41+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
4242
"PartitionSpec": "P('fsdp', None, None, None)"
4343
}
4444
},
4545
{
4646
"attention_mla/key_nope: bfloat16[192,2048,16,128]": {
47-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
47+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
4848
"PartitionSpec": "P('fsdp', None, None, None)"
4949
}
5050
},
5151
{
5252
"attention_mla/key_rope: bfloat16[192,2048,16,64]": {
53-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
53+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
5454
"PartitionSpec": "P('fsdp', None, None, None)"
5555
}
5656
},
5757
{
5858
"attention_mla/key: bfloat16[192,2048,16,192]": {
59-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
59+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
6060
"PartitionSpec": "P('fsdp', None, None, None)"
6161
}
6262
},
6363
{
6464
"attention_mla/value: bfloat16[192,2048,16,128]": {
65-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
65+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
6666
"PartitionSpec": "P('fsdp', None, None, None)"
6767
}
6868
},
@@ -86,7 +86,7 @@
8686
},
8787
{
8888
"attention_mla/out: bfloat16[192,2048,16,128]": {
89-
"logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')",
89+
"logic_axes": "('activation_batch', 'activation_length', 'activation_heads', 'activation_kv')",
9090
"PartitionSpec": "P('fsdp', None, None, None)"
9191
}
9292
},
@@ -104,7 +104,7 @@
104104
},
105105
{
106106
"linears/x: bfloat16[192,2048,10944]": {
107-
"logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')",
107+
"logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')",
108108
"PartitionSpec": "P('fsdp', None, None)"
109109
}
110110
},
@@ -134,7 +134,7 @@
134134
},
135135
{
136136
"linears/x: bfloat16[192,2048,2816]": {
137-
"logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')",
137+
"logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')",
138138
"PartitionSpec": "P('fsdp', None, None)"
139139
}
140140
},

tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,55 +14,55 @@
1414
},
1515
{
1616
"attention_mla/inputs_q: bfloat16[768,2048,2048]": {
17-
"logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')",
17+
"logic_axes": "('activation_batch', 'activation_length', 'activation_embed')",
1818
"PartitionSpec": "P(('data', 'fsdp'), None, None)"
1919
}
2020
},
2121
{
2222
"attention_mla/inputs_kv: bfloat16[768,2048,2048]": {
23-
"logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')",
23+
"logic_axes": "('activation_batch', 'activation_length', 'activation_embed')",
2424
"PartitionSpec": "P(('data', 'fsdp'), None, None)"
2525
}
2626
},
2727
{
2828
"attention_mla/q_nope: bfloat16[768,2048,16,128]": {
29-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
29+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
3030
"PartitionSpec": "P(('data', 'fsdp'), None, None, None)"
3131
}
3232
},
3333
{
3434
"attention_mla/q_pe: bfloat16[768,2048,16,64]": {
35-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
35+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
3636
"PartitionSpec": "P(('data', 'fsdp'), None, None, None)"
3737
}
3838
},
3939
{
4040
"attention_mla/query: bfloat16[768,2048,16,192]": {
41-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
41+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
4242
"PartitionSpec": "P(('data', 'fsdp'), None, None, None)"
4343
}
4444
},
4545
{
4646
"attention_mla/key_nope: bfloat16[768,2048,16,128]": {
47-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
47+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
4848
"PartitionSpec": "P(('data', 'fsdp'), None, None, None)"
4949
}
5050
},
5151
{
5252
"attention_mla/key_rope: bfloat16[768,2048,16,64]": {
53-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
53+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
5454
"PartitionSpec": "P(('data', 'fsdp'), None, None, None)"
5555
}
5656
},
5757
{
5858
"attention_mla/key: bfloat16[768,2048,16,192]": {
59-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
59+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
6060
"PartitionSpec": "P(('data', 'fsdp'), None, None, None)"
6161
}
6262
},
6363
{
6464
"attention_mla/value: bfloat16[768,2048,16,128]": {
65-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
65+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
6666
"PartitionSpec": "P(('data', 'fsdp'), None, None, None)"
6767
}
6868
},
@@ -86,7 +86,7 @@
8686
},
8787
{
8888
"attention_mla/out: bfloat16[768,2048,16,128]": {
89-
"logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')",
89+
"logic_axes": "('activation_batch', 'activation_length', 'activation_heads', 'activation_kv')",
9090
"PartitionSpec": "P(('data', 'fsdp'), None, None, None)"
9191
}
9292
},
@@ -104,7 +104,7 @@
104104
},
105105
{
106106
"linears/x: bfloat16[768,2048,10944]": {
107-
"logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')",
107+
"logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')",
108108
"PartitionSpec": "P(('data', 'fsdp'), None, None)"
109109
}
110110
},
@@ -134,7 +134,7 @@
134134
},
135135
{
136136
"linears/x: bfloat16[768,2048,2816]": {
137-
"logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')",
137+
"logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')",
138138
"PartitionSpec": "P(('data', 'fsdp'), None, None)"
139139
}
140140
},

tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,55 +14,55 @@
1414
},
1515
{
1616
"attention_mla/inputs_q: bfloat16[96,2048,2048]": {
17-
"logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')",
17+
"logic_axes": "('activation_batch', 'activation_length', 'activation_embed')",
1818
"PartitionSpec": "P('fsdp', None, None)"
1919
}
2020
},
2121
{
2222
"attention_mla/inputs_kv: bfloat16[96,2048,2048]": {
23-
"logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')",
23+
"logic_axes": "('activation_batch', 'activation_length', 'activation_embed')",
2424
"PartitionSpec": "P('fsdp', None, None)"
2525
}
2626
},
2727
{
2828
"attention_mla/q_nope: bfloat16[96,2048,16,128]": {
29-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
29+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
3030
"PartitionSpec": "P('fsdp', None, None, None)"
3131
}
3232
},
3333
{
3434
"attention_mla/q_pe: bfloat16[96,2048,16,64]": {
35-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
35+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
3636
"PartitionSpec": "P('fsdp', None, None, None)"
3737
}
3838
},
3939
{
4040
"attention_mla/query: bfloat16[96,2048,16,192]": {
41-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
41+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
4242
"PartitionSpec": "P('fsdp', None, None, None)"
4343
}
4444
},
4545
{
4646
"attention_mla/key_nope: bfloat16[96,2048,16,128]": {
47-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
47+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
4848
"PartitionSpec": "P('fsdp', None, None, None)"
4949
}
5050
},
5151
{
5252
"attention_mla/key_rope: bfloat16[96,2048,16,64]": {
53-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
53+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
5454
"PartitionSpec": "P('fsdp', None, None, None)"
5555
}
5656
},
5757
{
5858
"attention_mla/key: bfloat16[96,2048,16,192]": {
59-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
59+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
6060
"PartitionSpec": "P('fsdp', None, None, None)"
6161
}
6262
},
6363
{
6464
"attention_mla/value: bfloat16[96,2048,16,128]": {
65-
"logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
65+
"logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')",
6666
"PartitionSpec": "P('fsdp', None, None, None)"
6767
}
6868
},
@@ -86,7 +86,7 @@
8686
},
8787
{
8888
"attention_mla/out: bfloat16[96,2048,16,128]": {
89-
"logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')",
89+
"logic_axes": "('activation_batch', 'activation_length', 'activation_heads', 'activation_kv')",
9090
"PartitionSpec": "P('fsdp', None, None, None)"
9191
}
9292
},
@@ -104,7 +104,7 @@
104104
},
105105
{
106106
"linears/x: bfloat16[96,2048,10944]": {
107-
"logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')",
107+
"logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')",
108108
"PartitionSpec": "P('fsdp', None, None)"
109109
}
110110
},
@@ -134,7 +134,7 @@
134134
},
135135
{
136136
"linears/x: bfloat16[96,2048,2816]": {
137-
"logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')",
137+
"logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')",
138138
"PartitionSpec": "P('fsdp', None, None)"
139139
}
140140
},

0 commit comments

Comments
 (0)