Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
4a22680 to
814348f
Compare
RissyRan
left a comment
There was a problem hiding this comment.
Overall LGTM. Could you add your tests to showcase the functionality? Thanks!
| from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides | ||
|
|
||
|
|
||
| def make_moe(cfg, mesh): |
There was a problem hiding this comment.
Do you think we could move this to https://github.com/AI-Hypercomputer/maxtext/blob/7e3f19ff7828a75322cbf19e27374d6a6324aaad/tests/unit/moe_test.py?
There was a problem hiding this comment.
Done :)
LMK what you think!
e2a24f7 to
b5d61be
Compare
| ['moe_mlp', ['model', 'attn_dp']], | ||
| ['vocab', ['model', 'attn_dp']], | ||
| ['heads', ['model']], | ||
| ['heads', ['model', 'expert']], |
There was a problem hiding this comment.
I would add a comment here that expert is intended to act like TP for attention
There was a problem hiding this comment.
Can even explicitly say we target two all reduces, one at end of attention out_proj, one at end of mlp
| ['activation_length_no_exp_moe', 'data'], | ||
| ['activation_q_length', ['expert', 'attn_dp_expert']], | ||
| ['activation_attn_embed', 'model'], | ||
| ['activation_embed', ['model', 'attn_dp']], |
There was a problem hiding this comment.
I would add a note here for activation_embed that expert is missing explicitly despite using TP because we are going for a replicate-AR style of TP as opposed to our typical AG-RS style of TP - the replicate-AR is sort of forced by the output sharding of the VLLM kernel
| kernel_axes=self.kernel_axes, | ||
| use_bias=self.config.routed_bias, | ||
| score_func=self.config.routed_score_func, | ||
| score_func="" if self.config.attention == "vllm_rpa" else self.config.routed_score_func, |
There was a problem hiding this comment.
I would include a comment or maybe a link to short documentation/bug/vllm code that says the the vllm mega kernel we are calling does the score_func for us so we don't want to apply it ourselves
| w0_bias, w1_bias, wo_bias = None, None, None | ||
|
|
||
| if cfg.sparse_matmul: | ||
| # vllm_rpa uses fused_moe_func from tpu_inference (highest priority) |
There was a problem hiding this comment.
what does "highest priority" mean here?
There was a problem hiding this comment.
Was just trying to say it will use this by default - updating the comment for clarify.
| self.assertIsNone(lb_loss) | ||
| self.assertIsNone(bias_updates) | ||
|
|
||
| def test_fused_vs_sparse_softmax(self): |
There was a problem hiding this comment.
wow this is a great test!
b5d61be to
a6cfd99
Compare
14e5854 to
57d1a60
Compare
57d1a60 to
be95bcf
Compare
| ) | ||
|
|
||
| # Reshape output 2D [T, D] -> 3D [B, S, D] | ||
| output = jnp.reshape(output_2d, (batch_size, seq_len, emb_dim)) |
There was a problem hiding this comment.
nit: I would add a sharding hint here for output, e.g.
output = nn.with_logical_constraint(output, ("activation_batch", "activation_length", "activation_embed"))
but it is optional.
| # fused_moe_func requires num_tokens * topk % 16 == 0. | ||
| # B=1, S=16, topk=2 -> T*topk = 32, divisible by 16. | ||
| _B = 1 | ||
| _S = 16 |
There was a problem hiding this comment.
is it possible to move these two values inside FusedMoeTPUTest?
Description
This PR adds support for the
tpu-inferecefused_moe_gmmkernel in the MaxText MoE inference codepath. Initial results using this kernel show up to ~4x generation throughput increase when testing withqwen3-30b-a3b.Additionally, this PR introduces a second optimization to MaxText which pre-fuses the MoE weight kernels such that they can be efficiently passed into the
fused_moe_gmmkernel. We show the impact of these optimizations below with autoregressive generation step times:Baseline (MaxText
sparse_matmulMoE): 28.353 msFused MoE (prefuse_moe_weights=False): 20.432 ms
Fused MoE (prefuse_moe_weights=True): 6.114 ms
Tests
This PR adds new tests to
tests/unit/moe_test.py. Additionally this PR was tested e2e with both qwen3-30b-a3b and qwen3-235b-a22b.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.