Skip to content

Commit 87335ad

Browse files
committed
deprecate expert_shard_attention_option flag
1 parent 9fc1ccc commit 87335ad

24 files changed

Lines changed: 107 additions & 135 deletions

File tree

docs/reference/core_concepts/moe_configuration.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,6 @@ Dropping:
9696

9797
## 2. Sharding
9898

99-
`expert_shard_attention_option`: Determines how the "expert" axis is interpreted when sharding attention layers. Options include:
100-
101-
- `fsdp`: Treats the expert axis as a FSDP axis.
102-
- `context`: Treats the expert axis as a context parallelism axis, useful for long context.
103-
10499
`use_ring_of_experts` (experimental): This feature requires expert parallelism. If enabled, it replaces the standard two All-to-All communications with All-Gather in dispatch and Reduce-Scatter in collect. By gathering inputs across all shards, it allows for local routing and Top-K calculations, followed by result aggregation via Reduce-Scatter. This approach is particularly effective for models with a large Top-K, as it gathers activations before they are replicated k times to reduce communication.
105100

106101
`moe_fsdp_use_two_stage_all_gather`: If enabled, split the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable.

src/maxtext/common/common_types.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,6 @@
6666
MODEL_MODE_PREFILL = "prefill"
6767
MODEL_MODE_TRAIN = "train"
6868

69-
# expert_shard_attention_option
70-
EP_AS_CONTEXT = "context"
71-
EP_AS_FSDP = "fsdp"
72-
7369
DECODING_ACTIVE_SEQUENCE_INDICATOR = 1
7470

7571
# A large negative mask value is used for masking to ensure that the

src/maxtext/configs/base.yml

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,6 @@ merge_gating_gmm: False
237237

238238
norm_topk_prob: false # boolean to enable the top-k probability normalization. qwen3-specific normalization of router weights.
239239

240-
# how the expert axis is used to shard attention weights and activations
241-
# "fsdp" (ep acts as fsdp parallelism)
242-
# "context" (ep acts as context parallelism, training only)
243-
expert_shard_attention_option: "fsdp"
244-
245240
# when moe weight matrices are sharded on both fsdp and fsdp-transpose axes, use two separate all-gather calls
246241
moe_fsdp_use_two_stage_all_gather: false
247242
# Shard the expert dimension of the MLP weights on the FSDP axis.
@@ -461,9 +456,9 @@ logical_axis_rules: [
461456
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
462457
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
463458
['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
464-
# ['activation_vocab', ['tensor', 'tensor_transpose']],
465-
# ['activation_vocab', 'tensor_sequence'],
466-
# ['activation_vocab', ['sequence', 'context']],
459+
['activation_vocab', ['tensor', 'tensor_transpose']],
460+
['activation_vocab', 'tensor_sequence'],
461+
['activation_vocab', ['sequence', 'context']],
467462
# Vocab Weights
468463
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
469464
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
@@ -489,14 +484,14 @@ logical_axis_rules: [
489484
['kv', []],
490485
['kv_head_dim', []],
491486
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
492-
# ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
493-
# ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
494-
# ['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
487+
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
488+
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
489+
['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
495490
["q_lora_up_proj", []],
496491
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
497-
# ['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
498-
# ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
499-
# ['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
492+
['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
493+
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
494+
['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
500495
["kv_lora_up_proj", []],
501496
# ==========================================
502497
# Mixture of Experts (MoE)
@@ -513,9 +508,9 @@ logical_axis_rules: [
513508
['exp', 'expert'],
514509
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
515510
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
516-
# ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
517-
# ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
518-
# ['embed_moe', ['fsdp', 'sequence', 'context']],
511+
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
512+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
513+
['embed_moe', ['fsdp', 'sequence', 'context']],
519514
# ==========================================
520515
# Standard MLP / Dense Layers / Model Structure
521516
# ==========================================
@@ -530,9 +525,9 @@ logical_axis_rules: [
530525
# General Weights
531526
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
532527
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
533-
# ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context', 'expert']],
534-
# ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
535-
# ['embed', ['fsdp', 'sequence', 'context', 'expert']],
528+
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context', 'expert']],
529+
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
530+
['embed', ['fsdp', 'sequence', 'context', 'expert']],
536531
['norm', ['tensor', 'tensor_transpose']],
537532
['layers', 'stage'],
538533
['diloco', 'diloco'],

src/maxtext/configs/types.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -661,10 +661,6 @@ class MoEGeneral(BaseModel):
661661
)
662662
use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.")
663663
interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.")
664-
expert_shard_attention_option: Literal["fsdp", "context"] = Field(
665-
"fsdp",
666-
description="How the expert axis is used to shard attention weights and activations.",
667-
)
668664
moe_fsdp_use_two_stage_all_gather: bool = Field(
669665
False,
670666
description="Use two separate All-Gather calls for MoE weights sharded on both FSDP and FSDP-transpose.",
@@ -2396,8 +2392,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
23962392
self.tensors_to_offload = [t for t in tensors if getattr(self, t) == "offload"]
23972393

23982394
cp_size = self.ici_context_parallelism * self.dcn_context_parallelism
2399-
if self.expert_shard_attention_option == "context":
2400-
cp_size *= self.ici_expert_parallelism * self.dcn_expert_parallelism
24012395
self.context_parallel_size = cp_size
24022396
if self.pipeline_parallel_layers == -1:
24032397
if self.decoder_block == DecoderBlockType.DEEPSEEK:

src/maxtext/layers/attention_mla.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
D_KV,
4949
DType,
5050
EMBED,
51-
EP_AS_CONTEXT,
5251
HEAD,
5352
Q_LORA_UP_PROJ,
5453
KV_BATCH,
@@ -901,9 +900,6 @@ def mla_get_key_value(self, low_rank_main, key_rope, model_mode):
901900
if model_mode == MODEL_MODE_PREFILL:
902901
key_logical_name = self.prefill_key_axis_names
903902
value_logical_name = self.prefill_value_axis_names
904-
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
905-
key_logical_name = self.ep_key_axis_names
906-
value_logical_name = self.ep_value_axis_names
907903
else:
908904
key_logical_name = self.key_axis_names
909905
value_logical_name = self.value_axis_names
@@ -1224,10 +1220,7 @@ def __call__(
12241220
)
12251221

12261222
out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
1227-
if model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
1228-
out = self._maybe_shard_with_logical(out, self.ep_out_axis_names)
1229-
else:
1230-
out = self._maybe_shard_with_logical(out, self.out_axis_names)
1223+
out = self._maybe_shard_with_logical(out, self.out_axis_names)
12311224

12321225
out_sharding = create_sharding(self.mesh, out_logical_name)
12331226
out = self.out_projection(out, out_sharding=out_sharding)

src/maxtext/layers/attention_op.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
DEFAULT_MASK_VALUE,
5656
DType,
5757
D_KV,
58-
EP_AS_FSDP,
5958
HEAD,
6059
KV_LENGTH,
6160
LENGTH,
@@ -1270,7 +1269,7 @@ def wrap_splash_kernel(single_head_mask):
12701269

12711270
splash_kernel = wrap_splash_kernel(single_head_mask)
12721271
segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH,))
1273-
elif self.config.use_jax_splash and self.config.expert_shard_attention_option == EP_AS_FSDP:
1272+
elif self.config.use_jax_splash:
12741273
if self.config.use_max_logit_estimate > 0:
12751274
sa_config = dataclasses.replace(sa_config, max_logit_const=self.config.use_max_logit_estimate)
12761275
segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH,))

tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/logical_shardings.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@
133133
},
134134
".params/['params']/['decoder']/['logits_dense']/['kernel']": {
135135
"partition_spec": [
136-
"embed",
136+
"embed_vocab",
137137
"vocab"
138138
],
139139
"shape": [
@@ -318,7 +318,7 @@
318318
".params/['params']/['token_embedder']/['embedding']": {
319319
"partition_spec": [
320320
"vocab",
321-
"embed"
321+
"embed_vocab"
322322
],
323323
"shape": [
324324
102400,
@@ -459,7 +459,7 @@
459459
},
460460
".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": {
461461
"partition_spec": [
462-
"embed",
462+
"embed_vocab",
463463
"vocab"
464464
],
465465
"shape": [
@@ -644,7 +644,7 @@
644644
".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": {
645645
"partition_spec": [
646646
"vocab",
647-
"embed"
647+
"embed_vocab"
648648
],
649649
"shape": [
650650
102400,
@@ -781,7 +781,7 @@
781781
},
782782
".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": {
783783
"partition_spec": [
784-
"embed",
784+
"embed_vocab",
785785
"vocab"
786786
],
787787
"shape": [
@@ -966,7 +966,7 @@
966966
".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": {
967967
"partition_spec": [
968968
"vocab",
969-
"embed"
969+
"embed_vocab"
970970
],
971971
"shape": [
972972
102400,

tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/logical_shardings.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@
133133
},
134134
".params/['params']/['decoder']/['logits_dense']/['kernel']": {
135135
"partition_spec": [
136-
"embed",
136+
"embed_vocab",
137137
"vocab"
138138
],
139139
"shape": [
@@ -318,7 +318,7 @@
318318
".params/['params']/['token_embedder']/['embedding']": {
319319
"partition_spec": [
320320
"vocab",
321-
"embed"
321+
"embed_vocab"
322322
],
323323
"shape": [
324324
102400,
@@ -459,7 +459,7 @@
459459
},
460460
".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": {
461461
"partition_spec": [
462-
"embed",
462+
"embed_vocab",
463463
"vocab"
464464
],
465465
"shape": [
@@ -644,7 +644,7 @@
644644
".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": {
645645
"partition_spec": [
646646
"vocab",
647-
"embed"
647+
"embed_vocab"
648648
],
649649
"shape": [
650650
102400,
@@ -781,7 +781,7 @@
781781
},
782782
".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": {
783783
"partition_spec": [
784-
"embed",
784+
"embed_vocab",
785785
"vocab"
786786
],
787787
"shape": [
@@ -966,7 +966,7 @@
966966
".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": {
967967
"partition_spec": [
968968
"vocab",
969-
"embed"
969+
"embed_vocab"
970970
],
971971
"shape": [
972972
102400,

tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/logical_shardings.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@
133133
},
134134
".params/['params']/['decoder']/['logits_dense']/['kernel']": {
135135
"partition_spec": [
136-
"embed",
136+
"embed_vocab",
137137
"vocab"
138138
],
139139
"shape": [
@@ -318,7 +318,7 @@
318318
".params/['params']/['token_embedder']/['embedding']": {
319319
"partition_spec": [
320320
"vocab",
321-
"embed"
321+
"embed_vocab"
322322
],
323323
"shape": [
324324
102400,
@@ -459,7 +459,7 @@
459459
},
460460
".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": {
461461
"partition_spec": [
462-
"embed",
462+
"embed_vocab",
463463
"vocab"
464464
],
465465
"shape": [
@@ -644,7 +644,7 @@
644644
".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": {
645645
"partition_spec": [
646646
"vocab",
647-
"embed"
647+
"embed_vocab"
648648
],
649649
"shape": [
650650
102400,
@@ -781,7 +781,7 @@
781781
},
782782
".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": {
783783
"partition_spec": [
784-
"embed",
784+
"embed_vocab",
785785
"vocab"
786786
],
787787
"shape": [
@@ -966,7 +966,7 @@
966966
".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": {
967967
"partition_spec": [
968968
"vocab",
969-
"embed"
969+
"embed_vocab"
970970
],
971971
"shape": [
972972
102400,

tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/logical_shardings.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@
133133
},
134134
".params/['params']/['decoder']/['logits_dense']/['kernel']": {
135135
"partition_spec": [
136-
"embed",
136+
"embed_vocab",
137137
"vocab"
138138
],
139139
"shape": [
@@ -318,7 +318,7 @@
318318
".params/['params']/['token_embedder']/['embedding']": {
319319
"partition_spec": [
320320
"vocab",
321-
"embed"
321+
"embed_vocab"
322322
],
323323
"shape": [
324324
102400,
@@ -459,7 +459,7 @@
459459
},
460460
".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": {
461461
"partition_spec": [
462-
"embed",
462+
"embed_vocab",
463463
"vocab"
464464
],
465465
"shape": [
@@ -644,7 +644,7 @@
644644
".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": {
645645
"partition_spec": [
646646
"vocab",
647-
"embed"
647+
"embed_vocab"
648648
],
649649
"shape": [
650650
102400,
@@ -781,7 +781,7 @@
781781
},
782782
".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": {
783783
"partition_spec": [
784-
"embed",
784+
"embed_vocab",
785785
"vocab"
786786
],
787787
"shape": [
@@ -966,7 +966,7 @@
966966
".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": {
967967
"partition_spec": [
968968
"vocab",
969-
"embed"
969+
"embed_vocab"
970970
],
971971
"shape": [
972972
102400,

0 commit comments

Comments
 (0)