Skip to content

Commit 14e5854

Browse files
committed
adding support for fused_moe_gmm
1 parent 2b057eb commit 14e5854

7 files changed

Lines changed: 473 additions & 18 deletions

File tree

src/maxtext/configs/inference/vllm.yml

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,14 @@ weight_dtype: bfloat16
3030
mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
3131
logical_axis_rules: [
3232
['activation_batch', ['data']],
33-
['activation_batch_moe', []],
33+
['activation_batch_moe', ['data']],
3434
['activation_embed_and_logits_batch', ['data', 'expert']],
3535
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
3636
['activation_heads', ['model', 'expert']],
3737
['activation_kv_heads', ['model', 'expert']],
3838
['activation_attn_length', []],
39-
['activation_length', ['data']],
40-
['activation_length_moe', ['data', 'expert']],
41-
['activation_length_moe', 'data'],
39+
['activation_length', []],
40+
['activation_length_moe', []],
4241
['activation_q_length', ['expert', 'attn_dp_expert']],
4342
['activation_attn_embed', 'model'],
4443
['activation_embed', ['model', 'attn_dp']],
@@ -53,21 +52,19 @@ logical_axis_rules: [
5352
['activation_norm_length', []],
5453
['activation_norm_length_moe', []],
5554
['activation_exp', ['expert', 'attn_dp_expert']],
56-
['decode_batch', ['expert', 'attn_dp_expert']],
57-
['decode_batch_moe', []],
55+
['decode_batch', ['data']],
56+
['decode_batch_moe', ['data']],
5857
['decode_length', []],
5958
['mlp', ['model', 'attn_dp']],
6059
['mlp_moe', ['model', 'attn_dp']],
6160
['mlp_no_fsdp', ['model', 'attn_dp']],
6261
['vocab', ['model', 'attn_dp']],
63-
['heads', ['model']],
62+
['heads', ['model', 'expert']],
6463
['q_heads', ['model', 'expert']],
6564
['kv_heads', ['model', 'expert']],
6665
['kv_head_dim', []],
6766
['kv', []],
68-
['embed', ['expert', 'attn_dp_expert']],
69-
['embed', ['attn_dp_expert']],
70-
['embed_moe', []],
67+
['embed', []],
7168
['embed_moe', []],
7269
['embed_tensor_transpose', ['attn_dp', 'model']],
7370
['q_lora', ['expert', 'attn_dp_expert']],

src/maxtext/configs/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,11 @@ class MoEGeneral(BaseModel):
690690
False,
691691
description="Whether to cast inputs to fp32 to compute MoE gate logits for numerical stability.",
692692
)
693+
prefuse_moe_weights: bool = Field(
694+
False,
695+
description="Whether to pre-fuse MoE weights (w0 and w1) during initialization. "
696+
"This is useful for inference performance in vllm_rpa mode.",
697+
)
693698

694699

695700
class MoEKernels(BaseModel):

src/maxtext/inference/vllm_decode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def decode_with_vllm(config: Config) -> None:
8282
"weight_dtype": "bfloat16",
8383
"allow_split_physical_axes": True,
8484
"debug_sharding": config.debug_sharding,
85+
"prefuse_moe_weights": config.prefuse_moe_weights,
8586
},
8687
"sharding": {
8788
"sharding_strategy": {

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
124124
# Model creation
125125
self.model: nnx.Module | None = None
126126

127+
# Indicates that the model handles its own sharding logic
128+
self._self_manages_sharding = True
129+
127130
# Handle dummy weight loading during initialization
128131
if vllm_config.load_config.load_format == "dummy":
129132
self.load_weights(rng_key)
@@ -161,8 +164,8 @@ def __call__(
161164
raise ValueError("Model must be an instance of type nnx.Module.")
162165

163166
# Ensure inputs are at least 2D with a batch dimension
164-
input_ids = jnp.atleast_2d(input_ids)
165-
input_positions = jnp.atleast_2d(attention_metadata.input_positions)
167+
input_ids = jnp.expand_dims(input_ids, axis=1)
168+
input_positions = jnp.expand_dims(attention_metadata.input_positions, axis=1)
166169

167170
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
168171
aux_hidden_states = []
@@ -233,7 +236,7 @@ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
233236

234237
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
235238
# Reshape to (num_tokens, 1, hidden_dim) for decoder output head
236-
y = hidden_states[:, jnp.newaxis, :]
239+
y = jnp.expand_dims(hidden_states, axis=1)
237240

238241
# Compute logits using the MaxText decoder's output head
239242
logits = self.model.decoder.apply_output_head(self.model.token_embedder, y, True, self.model_mode)
@@ -250,7 +253,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
250253
if self.model is not None:
251254
return
252255

253-
with self.mesh, nn.logical_axis_rules(""):
256+
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
254257
model, _ = model_creation_utils.create_nnx_model(
255258
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
256259
)

src/maxtext/layers/moe.py

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def __init__(
384384
kernel_init=self.kernel_init,
385385
kernel_axes=self.kernel_axes,
386386
use_bias=self.config.routed_bias,
387-
score_func=self.config.routed_score_func,
387+
score_func="" if self.config.attention == "vllm_rpa" else self.config.routed_score_func,
388388
matmul_precision=self.config.matmul_precision,
389389
shard_mode=config.shard_mode,
390390
rngs=self.rngs,
@@ -403,6 +403,27 @@ def __init__(
403403
self.wi_0 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim))
404404
self.wi_1 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim))
405405
self.wo = jnp.zeros((num_experts, intermediate_dim, self.config.emb_dim))
406+
elif self.config.prefuse_moe_weights and self.config.attention == "vllm_rpa":
407+
self.wi = nnx.Param(
408+
self.kernel_init(
409+
self.rngs.params(),
410+
(num_experts, self.config.emb_dim, intermediate_dim * 2),
411+
weight_dtype,
412+
kernel_in_axis,
413+
kernel_out_axis,
414+
),
415+
sharding=self.wi_kernel_axes,
416+
)
417+
self.wo = nnx.Param(
418+
self.kernel_init(
419+
self.rngs.params(),
420+
(self.num_experts, self.intermediate_dim, self.config.emb_dim),
421+
self.weight_dtype,
422+
kernel_in_axis,
423+
kernel_out_axis,
424+
),
425+
sharding=self.wo_kernel_axes,
426+
)
406427
else:
407428
self.wi_0 = nnx.Param(
408429
self.kernel_init(
@@ -1985,6 +2006,72 @@ def dense_matmul(
19852006
).astype(self.dtype)
19862007
return output, lb_loss, bias_updates
19872008

2009+
def fused_moe_matmul(
2010+
self,
2011+
inputs,
2012+
gate_logits,
2013+
wo_kernel,
2014+
w0_kernel=None,
2015+
w1_kernel=None,
2016+
fused_kernel=None,
2017+
) -> tuple[jax.Array, None, None]:
2018+
"""Fused MoE via tpu_inference fused_moe_func (vllm_rpa path only).
2019+
2020+
fused_moe_func handles routing, GMM, and weighted combination internally.
2021+
It does not compute lb_loss or bias_updates (inference-only).
2022+
"""
2023+
try:
2024+
# pylint: disable=import-outside-toplevel
2025+
# pytype: disable=import-error
2026+
from tpu_inference.layers.common.fused_moe_gmm import fused_moe_func
2027+
except ImportError as e:
2028+
raise ImportError("fused_moe_matmul requires the tpu-inference package.") from e
2029+
2030+
# Reshape 3D [B, S, D] -> 2D [T, D] (fused_moe_func expects 2D input)
2031+
batch_size, seq_len, emb_dim = inputs.shape
2032+
hidden_states = jnp.reshape(inputs, (batch_size * seq_len, emb_dim))
2033+
gating_output = jnp.reshape(gate_logits, (batch_size * seq_len, self.num_experts))
2034+
2035+
# Concatenate gate and up projections: [E, D, H] + [E, D, H] -> [E, D, 2H]
2036+
# fused_moe_func splits this internally: gate=w1[..., :H], up=w1[..., H:]
2037+
if fused_kernel is None:
2038+
fused_kernel = jnp.concatenate([w0_kernel, w1_kernel], axis=-1)
2039+
2040+
# Use expert parallelism if the expert axis has size > 1
2041+
use_ep = self.get_expert_parallelism_size() > 1
2042+
2043+
# Map MaxText config fields to fused_moe_func args
2044+
activation = self.config.mlp_activations[0] # e.g. "silu"
2045+
scoring_fn = self.config.routed_score_func if self.config.routed_score_func else "softmax"
2046+
2047+
# Check if the model architecture intrinsically renormalizes weights
2048+
renormalize = self.config.norm_topk_prob or (
2049+
self.config.decoder_block not in (ctypes.DecoderBlockType.LLAMA4, ctypes.DecoderBlockType.GEMMA4)
2050+
)
2051+
2052+
output_2d = fused_moe_func(
2053+
hidden_states=hidden_states,
2054+
w1=fused_kernel,
2055+
w2=wo_kernel,
2056+
w1_scale=None,
2057+
w2_scale=None,
2058+
w1_bias=None,
2059+
w2_bias=None,
2060+
gating_output=gating_output,
2061+
topk=self.num_experts_per_tok,
2062+
renormalize=renormalize,
2063+
mesh=self.mesh,
2064+
use_ep=use_ep,
2065+
activation=activation,
2066+
scoring_fn=scoring_fn,
2067+
sc_kernel_threshold=16777216,
2068+
sc_kernel_col_chunk_size=1024,
2069+
)
2070+
2071+
# Reshape output 2D [T, D] -> 3D [B, S, D]
2072+
output = jnp.reshape(output_2d, (batch_size, seq_len, emb_dim))
2073+
return output, None, None
2074+
19882075
def retrieve_quantized_weight(
19892076
self,
19902077
inputs,
@@ -2023,10 +2110,17 @@ def __call__(
20232110
routing_inputs = inputs if gate_inputs is None else gate_inputs.astype(gate_dtype)
20242111
gate_logits, pre_bias_logits = self.gate(routing_inputs)
20252112

2026-
w0_kernel = jnp.asarray(self.wi_0[...], self.dtype)
2027-
w1_kernel = jnp.asarray(self.wi_1[...], self.dtype)
20282113
wo_kernel = jnp.asarray(self.wo[...], self.dtype)
20292114

2115+
fused_kernel = None
2116+
w0_kernel = None
2117+
w1_kernel = None
2118+
if cfg.prefuse_moe_weights and cfg.attention == "vllm_rpa":
2119+
fused_kernel = jnp.asarray(self.wi[...], self.dtype)
2120+
else:
2121+
w0_kernel = jnp.asarray(self.wi_0[...], self.dtype)
2122+
w1_kernel = jnp.asarray(self.wi_1[...], self.dtype)
2123+
20302124
if self.per_expert_scale is not None:
20312125
wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None]
20322126

@@ -2037,7 +2131,12 @@ def __call__(
20372131
else:
20382132
w0_bias, w1_bias, wo_bias = None, None, None
20392133

2040-
if cfg.sparse_matmul:
2134+
# vllm_rpa uses fused_moe_func from tpu_inference (highest priority)
2135+
if cfg.attention == "vllm_rpa":
2136+
output, lb_loss, bias_updates = self.fused_moe_matmul(
2137+
inputs, gate_logits, wo_kernel, w0_kernel=w0_kernel, w1_kernel=w1_kernel, fused_kernel=fused_kernel
2138+
)
2139+
elif cfg.sparse_matmul:
20412140
if quantizations.in_serve_mode(self.quant):
20422141
w0_kernel, w1_kernel, wo_kernel = self.retrieve_quantized_weight(
20432142
inputs,

src/maxtext/utils/model_creation_utils.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import flax.linen as nn
2626
import jax
2727
import jax.numpy as jnp
28+
import numpy as np
2829
from jax.sharding import AxisType, Mesh
2930
from maxtext.configs import pyconfig
3031
from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode
@@ -330,6 +331,35 @@ def create_sharded_state():
330331
# Get the structure of checkpoint in `config.load_parameters_path`
331332
metadata = ckptr.metadata(config.load_parameters_path)
332333

334+
def _adjust_target_for_moe_fusion(target, meta_tree, is_nnx):
335+
if not hasattr(target, "items") or not hasattr(meta_tree, "items"):
336+
return target
337+
new_target = {}
338+
for k, v in target.items():
339+
if k == "wi" and "wi" not in meta_tree and "wi_0" in meta_tree and "wi_1" in meta_tree:
340+
if not is_nnx:
341+
arr = v
342+
half_dim = arr.shape[-1] // 2
343+
new_target["wi_0"] = jax.ShapeDtypeStruct(
344+
shape=arr.shape[:-1] + (half_dim,), dtype=arr.dtype, sharding=arr.sharding
345+
)
346+
new_target["wi_1"] = jax.ShapeDtypeStruct(
347+
shape=arr.shape[:-1] + (half_dim,), dtype=arr.dtype, sharding=arr.sharding
348+
)
349+
else:
350+
arr = v["value"]
351+
half_dim = arr.shape[-1] // 2
352+
new_target["wi_0"] = {
353+
"value": jax.ShapeDtypeStruct(shape=arr.shape[:-1] + (half_dim,), dtype=arr.dtype, sharding=arr.sharding)
354+
}
355+
new_target["wi_1"] = {
356+
"value": jax.ShapeDtypeStruct(shape=arr.shape[:-1] + (half_dim,), dtype=arr.dtype, sharding=arr.sharding)
357+
}
358+
else:
359+
new_target[k] = _adjust_target_for_moe_fusion(v, meta_tree.get(k, {}), is_nnx)
360+
361+
return new_target
362+
333363
is_nnx_checkpoint = True
334364
if (
335365
"params" in metadata.item_metadata.tree.keys()
@@ -343,6 +373,10 @@ def create_sharded_state():
343373
is_leaf=lambda n: hasattr(n, "value"),
344374
)
345375

376+
target_for_restore = _adjust_target_for_moe_fusion(
377+
target_for_restore, metadata.item_metadata.tree["params"]["params"], False
378+
)
379+
346380
item_to_restore = {"params": {"params": target_for_restore}}
347381
base_restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
348382
restore_args = {
@@ -361,6 +395,7 @@ def create_sharded_state():
361395
sharded_state,
362396
is_leaf=lambda n: isinstance(n, nnx.Variable),
363397
)
398+
target_for_restore = _adjust_target_for_moe_fusion(target_for_restore, metadata.item_metadata.tree, True)
364399
item_to_restore = target_for_restore
365400
base_restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
366401
restore_args = _fix_restore_args_for_shape_mismatch(
@@ -400,6 +435,36 @@ def create_sharded_state():
400435
sharded_state,
401436
is_leaf=lambda n: isinstance(n, nnx.Variable),
402437
)
438+
439+
def to_dict(tree):
440+
if hasattr(tree, "items"):
441+
return {k: to_dict(v) for k, v in tree.items()}
442+
return tree
443+
444+
model_arrays = to_dict(model_arrays)
445+
checkpoint = to_dict(checkpoint)
446+
447+
def _fuse_moe_weights(ckpt_tree, model_arrays_tree):
448+
if not hasattr(ckpt_tree, "items") or not hasattr(model_arrays_tree, "items"):
449+
return ckpt_tree
450+
new_ckpt = {}
451+
for k, v in ckpt_tree.items():
452+
if k in ("wi_0", "wi_1") and "wi" in model_arrays_tree:
453+
continue
454+
new_ckpt[k] = _fuse_moe_weights(v, model_arrays_tree.get(k, {}))
455+
456+
if "wi" in model_arrays_tree and "wi_0" in ckpt_tree and "wi_1" in ckpt_tree:
457+
wi_0 = ckpt_tree["wi_0"]
458+
wi_1 = ckpt_tree["wi_1"]
459+
new_ckpt["wi"] = np.concatenate([wi_0, wi_1], axis=-1)
460+
461+
return new_ckpt
462+
463+
checkpoint = _fuse_moe_weights(checkpoint, model_arrays)
464+
# Release the raw restored buffers now that wi_0/wi_1 have been fused (if needed).
465+
# This prevents the replicated intermediate copies from persisting until function return.
466+
del restored
467+
403468
checkpoint = jax.tree.map(_expand_checkpoint_to_model_shapes, checkpoint, model_arrays)
404469
nnx.update(model, checkpoint)
405470

0 commit comments

Comments
 (0)