Skip to content

Commit d7fd385

Browse files
committed
Add MoE and MLA remat policies
1 parent 8b86c71 commit d7fd385

5 files changed

Lines changed: 47 additions & 9 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,11 +317,17 @@ mlpwi: 'remat'
317317
mlpwi_0: 'remat'
318318
mlpwi_1: 'remat'
319319
mlpwo: 'remat'
320+
moe_mlpwi: 'remat'
321+
moe_mlpwi_0: 'remat'
322+
moe_mlpwi_1: 'remat'
323+
moe_mlpwo: 'remat'
320324
query_proj: 'remat'
321325
key_proj: 'remat'
322326
value_proj: 'remat'
323327
qkv_proj: 'remat'
324328
out_proj: 'remat'
329+
query_wa_proj: 'remat'
330+
kv_wa_proj: 'remat'
325331
mla_q: 'remat'
326332
mla_kv: 'remat'
327333
attention_out: 'remat'

src/maxtext/configs/pyconfig_deprecated.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,9 +516,15 @@ def validate_and_assign_remat_tensors(keys):
516516
"mlpwi_0",
517517
"mlpwi_1",
518518
"mlpwo",
519+
"moe_mlpwi",
520+
"moe_mlpwi_0",
521+
"moe_mlpwi_1",
522+
"moe_mlpwo",
519523
"query_proj",
520524
"key_proj",
521525
"value_proj",
526+
"query_wa_proj",
527+
"kv_wa_proj",
522528
"out_proj",
523529
]
524530
assert keys["decoder_layer_input"] != "remat", "Cannot remeterialize this tensor with scan_layers=True"

src/maxtext/configs/types.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,9 +896,33 @@ class RematAndOffload(BaseModel):
896896
RematLocation.REMAT,
897897
description="Remat policy for the second MLP layer's output.",
898898
)
899+
moe_mlpwi: RematLocation = Field(
900+
RematLocation.REMAT,
901+
description="Remat policy for the first MoE layer's intermediate output.",
902+
)
903+
moe_mlpwi_0: RematLocation = Field(
904+
RematLocation.REMAT,
905+
description="Remat policy for the first part of a gated MoE's output.",
906+
)
907+
moe_mlpwi_1: RematLocation = Field(
908+
RematLocation.REMAT,
909+
description="Remat policy for the second part of a gated MoE's output.",
910+
)
911+
moe_mlpwo: RematLocation = Field(
912+
RematLocation.REMAT,
913+
description="Remat policy for the second MoE layer's output.",
914+
)
899915
query_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the query projection.")
900916
key_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the key projection.")
901917
value_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the value projection.")
918+
query_wa_proj: RematLocation = Field(
919+
RematLocation.REMAT,
920+
description="Remat policy for the MLA query weighted attention projection.",
921+
)
922+
kv_wa_proj: RematLocation = Field(
923+
RematLocation.REMAT,
924+
description="Remat policy for the MLA key and value weighted attention projection.",
925+
)
902926
qkv_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for fused QKV projection.")
903927
out_proj: RematLocation = Field(
904928
RematLocation.REMAT,

src/maxtext/layers/attention_mla.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,7 @@ def mla_query_projection(
794794
else:
795795
# LoRA path
796796
low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank]
797+
low_rank_q = checkpoint_name(low_rank_q, "query_wa_proj")
797798
low_rank_q = self.q_norm(low_rank_q) # RMSNorm on low rank
798799
low_rank_q = checkpoint_name(low_rank_q, "mla_q")
799800
q = self.wq_b(low_rank_q, out_sharding=query_sharding) # [B, L, n_heads, qk_head_dim]
@@ -933,6 +934,7 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
933934
wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ)
934935
wkva_out_sharding = create_sharding(self.mesh, wka_logical_name)
935936
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
937+
low_rank = checkpoint_name(low_rank, "kv_wa_proj")
936938
low_rank_main, low_rank_rope = jnp.split(low_rank, [self.kv_lora_rank], axis=-1)
937939
low_rank_main = self.kv_norm(low_rank_main)
938940
low_rank_main = checkpoint_name(low_rank_main, "mla_kv")

src/maxtext/layers/moe.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,7 +1274,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12741274
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
12751275
if self.config.mlp_bias:
12761276
layer_w0 = layer_w0 + w0_bias
1277-
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
1277+
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
12781278

12791279
layer_w1 = gmm_fn(
12801280
x,
@@ -1288,7 +1288,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12881288
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
12891289
if self.config.mlp_bias:
12901290
layer_w1 = layer_w1 + w1_bias
1291-
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
1291+
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
12921292
intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1)
12931293

12941294
intermediate_output = gmm_fn(
@@ -1305,7 +1305,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13051305
)
13061306
if self.config.mlp_bias:
13071307
intermediate_output = intermediate_output + wo_bias
1308-
intermediate_output = adc.checkpoint_name(intermediate_output, "mlpwo")
1308+
intermediate_output = adc.checkpoint_name(intermediate_output, "moe_mlpwo")
13091309

13101310
if self.config.use_ring_of_experts:
13111311
# Set the outputs of tokens which were not processed to 0.
@@ -1860,7 +1860,7 @@ def dense_matmul(
18601860
layer_w0,
18611861
mlp_axis,
18621862
)
1863-
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
1863+
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
18641864
with jax.named_scope("wi_1"):
18651865
w1_kernel_axes = ("exp", None, "mlp")
18661866
w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes)
@@ -1876,7 +1876,7 @@ def dense_matmul(
18761876
layer_w1,
18771877
mlp_axis,
18781878
)
1879-
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
1879+
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
18801880
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1)
18811881
with jax.named_scope("wo"):
18821882
wo_kernel_axes = ("exp", "mlp", None)
@@ -1902,7 +1902,7 @@ def dense_matmul(
19021902
"activation_embed",
19031903
),
19041904
)
1905-
intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo")
1905+
intermediate_layer = adc.checkpoint_name(intermediate_layer, "moe_mlpwo")
19061906
with jax.named_scope("combine"):
19071907
# Matmul & element wise operation
19081908
output = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=COMBINE)(
@@ -1931,7 +1931,7 @@ def dense_matmul(
19311931
layer_w0 = layer_w0 + w0_bias[None, None, :, :]
19321932
if self.config.activations_in_float32:
19331933
layer_w0 = layer_w0.astype(jnp.float32)
1934-
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
1934+
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
19351935
with jax.named_scope("wi_1"):
19361936
layer_w1 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)(
19371937
"BSM,EMH -> BSEH", inputs, w1_kernel, precision=matmul_precision
@@ -1940,7 +1940,7 @@ def dense_matmul(
19401940
layer_w1 = layer_w1 + w1_bias[None, None, :, :]
19411941
if self.config.activations_in_float32:
19421942
layer_w1 = layer_w1.astype(jnp.float32)
1943-
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
1943+
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
19441944
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1)
19451945

19461946
with jax.named_scope("wo"):
@@ -1954,7 +1954,7 @@ def dense_matmul(
19541954
intermediate_layer = intermediate_layer + wo_bias[None, None, :, :]
19551955
if self.config.activations_in_float32:
19561956
intermediate_layer = intermediate_layer.astype(jnp.float32)
1957-
intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo")
1957+
intermediate_layer = adc.checkpoint_name(intermediate_layer, "moe_mlpwo")
19581958
with jax.named_scope("weight_sum"):
19591959
if is_llama4_decoder_layer:
19601960
weights = self.reshape_and_update_weights(jnp.ones_like(top_k_weights), top_k_indices)

0 commit comments

Comments
 (0)