@@ -384,7 +384,7 @@ def __init__(
384384 kernel_init = self .kernel_init ,
385385 kernel_axes = self .kernel_axes ,
386386 use_bias = self .config .routed_bias ,
387- score_func = self .config .routed_score_func ,
387+ score_func = "" if self . config . attention == "vllm_rpa" else self .config .routed_score_func ,
388388 matmul_precision = self .config .matmul_precision ,
389389 shard_mode = config .shard_mode ,
390390 rngs = self .rngs ,
@@ -403,6 +403,27 @@ def __init__(
403403 self .wi_0 = jnp .zeros ((num_experts , self .config .emb_dim , intermediate_dim ))
404404 self .wi_1 = jnp .zeros ((num_experts , self .config .emb_dim , intermediate_dim ))
405405 self .wo = jnp .zeros ((num_experts , intermediate_dim , self .config .emb_dim ))
406+ elif self .config .prefuse_moe_weights and self .config .attention == "vllm_rpa" :
407+ self .wi = nnx .Param (
408+ self .kernel_init (
409+ self .rngs .params (),
410+ (num_experts , self .config .emb_dim , intermediate_dim * 2 ),
411+ weight_dtype ,
412+ kernel_in_axis ,
413+ kernel_out_axis ,
414+ ),
415+ sharding = self .wi_kernel_axes ,
416+ )
417+ self .wo = nnx .Param (
418+ self .kernel_init (
419+ self .rngs .params (),
420+ (self .num_experts , self .intermediate_dim , self .config .emb_dim ),
421+ self .weight_dtype ,
422+ kernel_in_axis ,
423+ kernel_out_axis ,
424+ ),
425+ sharding = self .wo_kernel_axes ,
426+ )
406427 else :
407428 self .wi_0 = nnx .Param (
408429 self .kernel_init (
@@ -1988,6 +2009,68 @@ def dense_matmul(
19882009 ).astype (self .dtype )
19892010 return output , lb_loss , bias_updates
19902011
2012+ def fused_moe_matmul (
2013+ self ,
2014+ inputs ,
2015+ gate_logits ,
2016+ wo_kernel ,
2017+ w0_kernel = None ,
2018+ w1_kernel = None ,
2019+ fused_kernel = None ,
2020+ ) -> tuple [jax .Array , None , None ]:
2021+ """Fused MoE via tpu_inference fused_moe_func (vllm_rpa path only).
2022+
2023+ fused_moe_func handles routing, GMM, and weighted combination internally.
2024+ It does not compute lb_loss or bias_updates (inference-only).
2025+ """
2026+ try :
2027+ # pylint: disable=import-outside-toplevel
2028+ # pytype: disable=import-error
2029+ from tpu_inference .layers .common .fused_moe_gmm import fused_moe_func
2030+ except ImportError as e :
2031+ raise ImportError ("fused_moe_matmul requires the tpu-inference package." ) from e
2032+
2033+ # Reshape 3D [B, S, D] -> 2D [T, D] (fused_moe_func expects 2D input)
2034+ batch_size , seq_len , emb_dim = inputs .shape
2035+ hidden_states = jnp .reshape (inputs , (batch_size * seq_len , emb_dim ))
2036+ gating_output = jnp .reshape (gate_logits , (batch_size * seq_len , self .num_experts ))
2037+
2038+ # Concatenate gate and up projections: [E, D, H] + [E, D, H] -> [E, D, 2H]
2039+ # fused_moe_func splits this internally: gate=w1[..., :H], up=w1[..., H:]
2040+ if fused_kernel is None :
2041+ fused_kernel = jnp .concatenate ([w0_kernel , w1_kernel ], axis = - 1 )
2042+
2043+ # Use expert parallelism if the expert axis has size > 1
2044+ use_ep = self .get_expert_parallelism_size () > 1
2045+
2046+ # Map MaxText config fields to fused_moe_func args
2047+ activation = self .config .mlp_activations [0 ] # e.g. "silu"
2048+ scoring_fn = self .config .routed_score_func if self .config .routed_score_func else "softmax"
2049+ renormalize = self .config .norm_topk_prob
2050+
2051+ output_2d = fused_moe_func (
2052+ hidden_states = hidden_states ,
2053+ w1 = fused_kernel ,
2054+ w2 = wo_kernel ,
2055+ w1_scale = None ,
2056+ w2_scale = None ,
2057+ w1_bias = None ,
2058+ w2_bias = None ,
2059+ gating_output = gating_output ,
2060+ topk = self .num_experts_per_tok ,
2061+ renormalize = renormalize ,
2062+ mesh = self .mesh ,
2063+ use_ep = use_ep ,
2064+ activation = activation ,
2065+ scoring_fn = scoring_fn ,
2066+ sc_kernel_threshold = 16777216 ,
2067+ sc_kernel_col_chunk_size = 1024 ,
2068+ )
2069+
2070+ # Reshape output 2D [T, D] -> 3D [B, S, D]
2071+ output = jnp .reshape (output_2d , (batch_size , seq_len , emb_dim ))
2072+ return output , None , None
2073+
19912074 def retrieve_quantized_weight (
19922075 self ,
19932076 inputs ,
@@ -2026,10 +2109,17 @@ def __call__(
20262109 routing_inputs = inputs if gate_inputs is None else gate_inputs .astype (gate_dtype )
20272110 gate_logits , pre_bias_logits = self .gate (routing_inputs )
20282111
2029- w0_kernel = jnp .asarray (self .wi_0 [...], self .dtype )
2030- w1_kernel = jnp .asarray (self .wi_1 [...], self .dtype )
20312112 wo_kernel = jnp .asarray (self .wo [...], self .dtype )
20322113
2114+ fused_kernel = None
2115+ w0_kernel = None
2116+ w1_kernel = None
2117+ if cfg .prefuse_moe_weights and cfg .attention == "vllm_rpa" :
2118+ fused_kernel = jnp .asarray (self .wi [...], self .dtype )
2119+ else :
2120+ w0_kernel = jnp .asarray (self .wi_0 [...], self .dtype )
2121+ w1_kernel = jnp .asarray (self .wi_1 [...], self .dtype )
2122+
20332123 if self .per_expert_scale is not None :
20342124 wo_kernel = wo_kernel * jnp .asarray (self .per_expert_scale [...], self .dtype )[:, None , None ]
20352125
@@ -2040,7 +2130,12 @@ def __call__(
20402130 else :
20412131 w0_bias , w1_bias , wo_bias = None , None , None
20422132
2043- if cfg .sparse_matmul :
2133+ # vllm_rpa uses fused_moe_func from tpu_inference (highest priority)
2134+ if cfg .attention == "vllm_rpa" :
2135+ output , lb_loss , bias_updates = self .fused_moe_matmul (
2136+ inputs , gate_logits , wo_kernel , w0_kernel = w0_kernel , w1_kernel = w1_kernel , fused_kernel = fused_kernel
2137+ )
2138+ elif cfg .sparse_matmul :
20442139 if quantizations .in_serve_mode (self .quant ):
20452140 w0_kernel , w1_kernel , wo_kernel = self .retrieve_quantized_weight (
20462141 inputs ,
0 commit comments