Skip to content

Adding support for fused_moe_gmm#3627

Open
NicoGrande wants to merge 1 commit intomainfrom
nicogrande/fused-moe-gmm
Open

Adding support for fused_moe_gmm#3627
NicoGrande wants to merge 1 commit intomainfrom
nicogrande/fused-moe-gmm

Conversation

@NicoGrande
Copy link
Copy Markdown
Collaborator

@NicoGrande NicoGrande commented Apr 9, 2026

Description

This PR adds support for the tpu-inferece fused_moe_gmm kernel in the MaxText MoE inference codepath. Initial results using this kernel show up to ~4x generation throughput increase when testing with qwen3-30b-a3b.

Additionally, this PR introduces a second optimization to MaxText which pre-fuses the MoE weight kernels such that they can be efficiently passed into the fused_moe_gmm kernel. We show the impact of these optimizations below with autoregressive generation step times:

Baseline (MaxText sparse_matmul MoE): 28.353 ms

Fused MoE (prefuse_moe_weights=False): 20.432 ms

Fused MoE (prefuse_moe_weights=True): 6.114 ms

Tests

This PR adds new tests to tests/unit/moe_test.py. Additionally this PR was tested e2e with both qwen3-30b-a3b and qwen3-235b-a22b.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 9, 2026

Codecov Report

❌ Patch coverage is 45.07042% with 39 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/moe.py 23.33% 20 Missing and 3 partials ⚠️
src/maxtext/utils/model_creation_utils.py 60.97% 13 Missing and 3 partials ⚠️

📢 Thoughts on this report? Let us know!

@NicoGrande NicoGrande force-pushed the nicogrande/fused-moe-gmm branch 2 times, most recently from 4a22680 to 814348f Compare April 9, 2026 22:32
Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Could you add your tests to showcase the functionality? Thanks!

Comment thread tests/unit/fused_moe_test.py Outdated
from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides


def make_moe(cfg, mesh):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done :)

LMK what you think!

@NicoGrande NicoGrande force-pushed the nicogrande/fused-moe-gmm branch 4 times, most recently from e2a24f7 to b5d61be Compare April 10, 2026 17:56
['moe_mlp', ['model', 'attn_dp']],
['vocab', ['model', 'attn_dp']],
['heads', ['model']],
['heads', ['model', 'expert']],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a comment here that expert is intended to act like TP for attention

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can even explicitly say we target two all reduces, one at end of attention out_proj, one at end of mlp

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

['activation_length_no_exp_moe', 'data'],
['activation_q_length', ['expert', 'attn_dp_expert']],
['activation_attn_embed', 'model'],
['activation_embed', ['model', 'attn_dp']],
Copy link
Copy Markdown
Collaborator

@gobbleturk gobbleturk Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a note here for activation_embed that expert is missing explicitly despite using TP because we are going for a replicate-AR style of TP as opposed to our typical AG-RS style of TP - the replicate-AR is sort of forced by the output sharding of the VLLM kernel

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment thread src/maxtext/layers/moe.py
kernel_axes=self.kernel_axes,
use_bias=self.config.routed_bias,
score_func=self.config.routed_score_func,
score_func="" if self.config.attention == "vllm_rpa" else self.config.routed_score_func,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would include a comment or maybe a link to short documentation/bug/vllm code that says the the vllm mega kernel we are calling does the score_func for us so we don't want to apply it ourselves

Comment thread src/maxtext/layers/moe.py Outdated
w0_bias, w1_bias, wo_bias = None, None, None

if cfg.sparse_matmul:
# vllm_rpa uses fused_moe_func from tpu_inference (highest priority)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does "highest priority" mean here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was just trying to say it will use this by default - updating the comment for clarify.

Comment thread tests/unit/moe_test.py
self.assertIsNone(lb_loss)
self.assertIsNone(bias_updates)

def test_fused_vs_sparse_softmax(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow this is a great test!

Copy link
Copy Markdown
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

@NicoGrande NicoGrande force-pushed the nicogrande/fused-moe-gmm branch from b5d61be to a6cfd99 Compare April 11, 2026 22:21
Comment thread src/maxtext/configs/inference/vllm.yml
@NicoGrande NicoGrande force-pushed the nicogrande/fused-moe-gmm branch 2 times, most recently from 14e5854 to 57d1a60 Compare April 14, 2026 00:27
@NicoGrande NicoGrande force-pushed the nicogrande/fused-moe-gmm branch from 57d1a60 to be95bcf Compare April 14, 2026 00:30
Comment thread src/maxtext/layers/moe.py
)

# Reshape output 2D [T, D] -> 3D [B, S, D]
output = jnp.reshape(output_2d, (batch_size, seq_len, emb_dim))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would add a sharding hint here for output, e.g.

output = nn.with_logical_constraint(output, ("activation_batch", "activation_length", "activation_embed"))

but it is optional.

Comment thread tests/unit/moe_test.py
# fused_moe_func requires num_tokens * topk % 16 == 0.
# B=1, S=16, topk=2 -> T*topk = 32, divisible by 16.
_B = 1
_S = 16
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to move these two values inside FusedMoeTPUTest?

Copy link
Copy Markdown
Collaborator

@NuojCheng NuojCheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants