Skip to content

Commit a6cfd99

Browse files
committed
adding support for fused_moe_gmm
1 parent 1b47668 commit a6cfd99

7 files changed

Lines changed: 470 additions & 13 deletions

File tree

src/maxtext/configs/inference/vllm.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,26 +50,25 @@ logical_axis_rules: [
5050
['activation_mlp', ['model', 'attn_dp']],
5151
['activation_kv', ['model']],
5252
['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']],
53-
['activation_kv_batch', ['data', 'expert', 'attn_dp_expert']],
53+
['activation_kv_batch', ['data']],
5454
['activation_kv_batch_no_exp', ['data']],
5555
['activation_kv_head_dim', ['model']],
5656
['activation_vocab', ['model', 'attn_dp']],
5757
['activation_norm_length', []],
5858
['activation_norm_length_moe', []],
5959
['activation_exp', ['expert', 'attn_dp_expert']],
60-
['decode_batch', ['expert', 'attn_dp_expert']],
60+
['decode_batch', ['data']],
6161
['decode_length', []],
6262
['mlp', ['model', 'attn_dp']],
6363
['mlp_no_fsdp', ['model', 'attn_dp']],
6464
['moe_mlp', ['model', 'attn_dp']],
6565
['vocab', ['model', 'attn_dp']],
66-
['heads', ['model']],
66+
['heads', ['model', 'expert']],
6767
['q_heads', ['model', 'expert']],
6868
['kv_heads', ['model', 'expert']],
6969
['kv_head_dim', []],
7070
['kv', []],
71-
['embed', ['expert', 'attn_dp_expert']],
72-
['embed', ['attn_dp_expert']],
71+
['embed', []],
7372
['embed_moe', ['expert', 'attn_dp_expert']],
7473
['embed_moe', ['attn_dp_expert']],
7574
['embed_tensor_transpose', ['attn_dp', 'model']],

src/maxtext/configs/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,11 @@ class MoEGeneral(BaseModel):
686686
False,
687687
description="Whether to cast inputs to fp32 to compute MoE gate logits for numerical stability.",
688688
)
689+
prefuse_moe_weights: bool = Field(
690+
False,
691+
description="Whether to pre-fuse MoE weights (w0 and w1) during initialization. "
692+
"This is useful for inference performance in vllm_rpa mode.",
693+
)
689694

690695

691696
class MoEKernels(BaseModel):

src/maxtext/inference/vllm_decode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def decode_with_vllm(config: Config) -> None:
8282
"weight_dtype": "bfloat16",
8383
"allow_split_physical_axes": True,
8484
"debug_sharding": config.debug_sharding,
85+
"prefuse_moe_weights": config.prefuse_moe_weights,
8586
},
8687
"sharding": {
8788
"sharding_strategy": {

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
124124
# Model creation
125125
self.model: nnx.Module | None = None
126126

127+
# Indicates that the model handles its own sharding logic
128+
self._self_manages_sharding = True
129+
127130
# Handle dummy weight loading during initialization
128131
if vllm_config.load_config.load_format == "dummy":
129132
self.load_weights(rng_key)
@@ -161,8 +164,8 @@ def __call__(
161164
raise ValueError("Model must be an instance of type nnx.Module.")
162165

163166
# Ensure inputs are at least 2D with a batch dimension
164-
input_ids = jnp.atleast_2d(input_ids)
165-
input_positions = jnp.atleast_2d(attention_metadata.input_positions)
167+
input_ids = jnp.expand_dims(input_ids, axis=1)
168+
input_positions = jnp.expand_dims(attention_metadata.input_positions, axis=1)
166169

167170
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
168171
aux_hidden_states = []
@@ -233,7 +236,7 @@ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
233236

234237
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
235238
# Reshape to (num_tokens, 1, hidden_dim) for decoder output head
236-
y = hidden_states[:, jnp.newaxis, :]
239+
y = jnp.expand_dims(hidden_states, axis=1)
237240

238241
# Compute logits using the MaxText decoder's output head
239242
logits = self.model.decoder.apply_output_head(self.model.token_embedder, y, True, self.model_mode)
@@ -250,7 +253,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
250253
if self.model is not None:
251254
return
252255

253-
with self.mesh, nn.logical_axis_rules(""):
256+
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
254257
model, _ = model_creation_utils.create_nnx_model(
255258
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
256259
)

src/maxtext/layers/moe.py

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def __init__(
384384
kernel_init=self.kernel_init,
385385
kernel_axes=self.kernel_axes,
386386
use_bias=self.config.routed_bias,
387-
score_func=self.config.routed_score_func,
387+
score_func="" if self.config.attention == "vllm_rpa" else self.config.routed_score_func,
388388
matmul_precision=self.config.matmul_precision,
389389
shard_mode=config.shard_mode,
390390
rngs=self.rngs,
@@ -403,6 +403,27 @@ def __init__(
403403
self.wi_0 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim))
404404
self.wi_1 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim))
405405
self.wo = jnp.zeros((num_experts, intermediate_dim, self.config.emb_dim))
406+
elif self.config.prefuse_moe_weights and self.config.attention == "vllm_rpa":
407+
self.wi = nnx.Param(
408+
self.kernel_init(
409+
self.rngs.params(),
410+
(num_experts, self.config.emb_dim, intermediate_dim * 2),
411+
weight_dtype,
412+
kernel_in_axis,
413+
kernel_out_axis,
414+
),
415+
sharding=self.wi_kernel_axes,
416+
)
417+
self.wo = nnx.Param(
418+
self.kernel_init(
419+
self.rngs.params(),
420+
(self.num_experts, self.intermediate_dim, self.config.emb_dim),
421+
self.weight_dtype,
422+
kernel_in_axis,
423+
kernel_out_axis,
424+
),
425+
sharding=self.wo_kernel_axes,
426+
)
406427
else:
407428
self.wi_0 = nnx.Param(
408429
self.kernel_init(
@@ -1988,6 +2009,72 @@ def dense_matmul(
19882009
).astype(self.dtype)
19892010
return output, lb_loss, bias_updates
19902011

2012+
def fused_moe_matmul(
2013+
self,
2014+
inputs,
2015+
gate_logits,
2016+
wo_kernel,
2017+
w0_kernel=None,
2018+
w1_kernel=None,
2019+
fused_kernel=None,
2020+
) -> tuple[jax.Array, None, None]:
2021+
"""Fused MoE via tpu_inference fused_moe_func (vllm_rpa path only).
2022+
2023+
fused_moe_func handles routing, GMM, and weighted combination internally.
2024+
It does not compute lb_loss or bias_updates (inference-only).
2025+
"""
2026+
try:
2027+
# pylint: disable=import-outside-toplevel
2028+
# pytype: disable=import-error
2029+
from tpu_inference.layers.common.fused_moe_gmm import fused_moe_func
2030+
except ImportError as e:
2031+
raise ImportError("fused_moe_matmul requires the tpu-inference package.") from e
2032+
2033+
# Reshape 3D [B, S, D] -> 2D [T, D] (fused_moe_func expects 2D input)
2034+
batch_size, seq_len, emb_dim = inputs.shape
2035+
hidden_states = jnp.reshape(inputs, (batch_size * seq_len, emb_dim))
2036+
gating_output = jnp.reshape(gate_logits, (batch_size * seq_len, self.num_experts))
2037+
2038+
# Concatenate gate and up projections: [E, D, H] + [E, D, H] -> [E, D, 2H]
2039+
# fused_moe_func splits this internally: gate=w1[..., :H], up=w1[..., H:]
2040+
if fused_kernel is None:
2041+
fused_kernel = jnp.concatenate([w0_kernel, w1_kernel], axis=-1)
2042+
2043+
# Use expert parallelism if the expert axis has size > 1
2044+
use_ep = self.get_expert_parallelism_size() > 1
2045+
2046+
# Map MaxText config fields to fused_moe_func args
2047+
activation = self.config.mlp_activations[0] # e.g. "silu"
2048+
scoring_fn = self.config.routed_score_func if self.config.routed_score_func else "softmax"
2049+
2050+
# Check if the model architecture intrinsically renormalizes weights
2051+
renormalize = self.config.norm_topk_prob or (
2052+
self.config.decoder_block not in (ctypes.DecoderBlockType.LLAMA4, ctypes.DecoderBlockType.GEMMA4)
2053+
)
2054+
2055+
output_2d = fused_moe_func(
2056+
hidden_states=hidden_states,
2057+
w1=fused_kernel,
2058+
w2=wo_kernel,
2059+
w1_scale=None,
2060+
w2_scale=None,
2061+
w1_bias=None,
2062+
w2_bias=None,
2063+
gating_output=gating_output,
2064+
topk=self.num_experts_per_tok,
2065+
renormalize=renormalize,
2066+
mesh=self.mesh,
2067+
use_ep=use_ep,
2068+
activation=activation,
2069+
scoring_fn=scoring_fn,
2070+
sc_kernel_threshold=16777216,
2071+
sc_kernel_col_chunk_size=1024,
2072+
)
2073+
2074+
# Reshape output 2D [T, D] -> 3D [B, S, D]
2075+
output = jnp.reshape(output_2d, (batch_size, seq_len, emb_dim))
2076+
return output, None, None
2077+
19912078
def retrieve_quantized_weight(
19922079
self,
19932080
inputs,
@@ -2026,10 +2113,17 @@ def __call__(
20262113
routing_inputs = inputs if gate_inputs is None else gate_inputs.astype(gate_dtype)
20272114
gate_logits, pre_bias_logits = self.gate(routing_inputs)
20282115

2029-
w0_kernel = jnp.asarray(self.wi_0[...], self.dtype)
2030-
w1_kernel = jnp.asarray(self.wi_1[...], self.dtype)
20312116
wo_kernel = jnp.asarray(self.wo[...], self.dtype)
20322117

2118+
fused_kernel = None
2119+
w0_kernel = None
2120+
w1_kernel = None
2121+
if cfg.prefuse_moe_weights and cfg.attention == "vllm_rpa":
2122+
fused_kernel = jnp.asarray(self.wi[...], self.dtype)
2123+
else:
2124+
w0_kernel = jnp.asarray(self.wi_0[...], self.dtype)
2125+
w1_kernel = jnp.asarray(self.wi_1[...], self.dtype)
2126+
20332127
if self.per_expert_scale is not None:
20342128
wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None]
20352129

@@ -2040,7 +2134,12 @@ def __call__(
20402134
else:
20412135
w0_bias, w1_bias, wo_bias = None, None, None
20422136

2043-
if cfg.sparse_matmul:
2137+
# vllm_rpa uses fused_moe_func from tpu_inference (highest priority)
2138+
if cfg.attention == "vllm_rpa":
2139+
output, lb_loss, bias_updates = self.fused_moe_matmul(
2140+
inputs, gate_logits, wo_kernel, w0_kernel=w0_kernel, w1_kernel=w1_kernel, fused_kernel=fused_kernel
2141+
)
2142+
elif cfg.sparse_matmul:
20442143
if quantizations.in_serve_mode(self.quant):
20452144
w0_kernel, w1_kernel, wo_kernel = self.retrieve_quantized_weight(
20462145
inputs,

src/maxtext/utils/model_creation_utils.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import flax.linen as nn
2626
import jax
2727
import jax.numpy as jnp
28+
import numpy as np
2829
from jax.sharding import AxisType, Mesh
2930
from maxtext.configs import pyconfig
3031
from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode
@@ -298,6 +299,35 @@ def create_sharded_state():
298299
# Get the structure of checkpoint in `config.load_parameters_path`
299300
metadata = ckptr.metadata(config.load_parameters_path)
300301

302+
def _adjust_target_for_moe_fusion(target, meta_tree, is_nnx):
303+
if not hasattr(target, "items") or not hasattr(meta_tree, "items"):
304+
return target
305+
new_target = {}
306+
for k, v in target.items():
307+
if k == "wi" and "wi" not in meta_tree and "wi_0" in meta_tree and "wi_1" in meta_tree:
308+
if not is_nnx:
309+
arr = v
310+
half_dim = arr.shape[-1] // 2
311+
new_target["wi_0"] = jax.ShapeDtypeStruct(
312+
shape=arr.shape[:-1] + (half_dim,), dtype=arr.dtype, sharding=arr.sharding
313+
)
314+
new_target["wi_1"] = jax.ShapeDtypeStruct(
315+
shape=arr.shape[:-1] + (half_dim,), dtype=arr.dtype, sharding=arr.sharding
316+
)
317+
else:
318+
arr = v["value"]
319+
half_dim = arr.shape[-1] // 2
320+
new_target["wi_0"] = {
321+
"value": jax.ShapeDtypeStruct(shape=arr.shape[:-1] + (half_dim,), dtype=arr.dtype, sharding=arr.sharding)
322+
}
323+
new_target["wi_1"] = {
324+
"value": jax.ShapeDtypeStruct(shape=arr.shape[:-1] + (half_dim,), dtype=arr.dtype, sharding=arr.sharding)
325+
}
326+
else:
327+
new_target[k] = _adjust_target_for_moe_fusion(v, meta_tree.get(k, {}), is_nnx)
328+
329+
return new_target
330+
301331
is_nnx_checkpoint = True
302332
if (
303333
"params" in metadata.item_metadata.tree.keys()
@@ -311,6 +341,10 @@ def create_sharded_state():
311341
is_leaf=lambda n: hasattr(n, "value"),
312342
)
313343

344+
target_for_restore = _adjust_target_for_moe_fusion(
345+
target_for_restore, metadata.item_metadata.tree["params"]["params"], False
346+
)
347+
314348
item_to_restore = {"params": {"params": target_for_restore}}
315349
base_restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
316350
restore_args = {
@@ -329,6 +363,7 @@ def create_sharded_state():
329363
sharded_state,
330364
is_leaf=lambda n: isinstance(n, nnx.Variable),
331365
)
366+
target_for_restore = _adjust_target_for_moe_fusion(target_for_restore, metadata.item_metadata.tree, True)
332367
item_to_restore = target_for_restore
333368
base_restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)
334369
restore_args = _fix_restore_args_for_shape_mismatch(
@@ -368,6 +403,36 @@ def create_sharded_state():
368403
sharded_state,
369404
is_leaf=lambda n: isinstance(n, nnx.Variable),
370405
)
406+
407+
def to_dict(tree):
408+
if hasattr(tree, "items"):
409+
return {k: to_dict(v) for k, v in tree.items()}
410+
return tree
411+
412+
model_arrays = to_dict(model_arrays)
413+
checkpoint = to_dict(checkpoint)
414+
415+
def _fuse_moe_weights(ckpt_tree, model_arrays_tree):
416+
if not hasattr(ckpt_tree, "items") or not hasattr(model_arrays_tree, "items"):
417+
return ckpt_tree
418+
new_ckpt = {}
419+
for k, v in ckpt_tree.items():
420+
if k in ("wi_0", "wi_1") and "wi" in model_arrays_tree:
421+
continue
422+
new_ckpt[k] = _fuse_moe_weights(v, model_arrays_tree.get(k, {}))
423+
424+
if "wi" in model_arrays_tree and "wi_0" in ckpt_tree and "wi_1" in ckpt_tree:
425+
wi_0 = ckpt_tree["wi_0"]
426+
wi_1 = ckpt_tree["wi_1"]
427+
new_ckpt["wi"] = np.concatenate([wi_0, wi_1], axis=-1)
428+
429+
return new_ckpt
430+
431+
checkpoint = _fuse_moe_weights(checkpoint, model_arrays)
432+
# Release the raw restored buffers now that wi_0/wi_1 have been fused (if needed).
433+
# This prevents the replicated intermediate copies from persisting until function return.
434+
del restored
435+
371436
checkpoint = jax.tree.map(_expand_checkpoint_to_model_shapes, checkpoint, model_arrays)
372437
nnx.update(model, checkpoint)
373438

0 commit comments

Comments
 (0)