@@ -384,7 +384,7 @@ def __init__(
384384 kernel_init = self .kernel_init ,
385385 kernel_axes = self .kernel_axes ,
386386 use_bias = self .config .routed_bias ,
387- score_func = self .config .routed_score_func ,
387+ score_func = "" if self . config . attention == "vllm_rpa" else self .config .routed_score_func ,
388388 matmul_precision = self .config .matmul_precision ,
389389 shard_mode = config .shard_mode ,
390390 rngs = self .rngs ,
@@ -403,6 +403,27 @@ def __init__(
403403 self .wi_0 = jnp .zeros ((num_experts , self .config .emb_dim , intermediate_dim ))
404404 self .wi_1 = jnp .zeros ((num_experts , self .config .emb_dim , intermediate_dim ))
405405 self .wo = jnp .zeros ((num_experts , intermediate_dim , self .config .emb_dim ))
406+ elif self .config .prefuse_moe_weights and self .config .attention == "vllm_rpa" :
407+ self .wi = nnx .Param (
408+ self .kernel_init (
409+ self .rngs .params (),
410+ (num_experts , self .config .emb_dim , intermediate_dim * 2 ),
411+ weight_dtype ,
412+ kernel_in_axis ,
413+ kernel_out_axis ,
414+ ),
415+ sharding = self .wi_kernel_axes ,
416+ )
417+ self .wo = nnx .Param (
418+ self .kernel_init (
419+ self .rngs .params (),
420+ (self .num_experts , self .intermediate_dim , self .config .emb_dim ),
421+ self .weight_dtype ,
422+ kernel_in_axis ,
423+ kernel_out_axis ,
424+ ),
425+ sharding = self .wo_kernel_axes ,
426+ )
406427 else :
407428 self .wi_0 = nnx .Param (
408429 self .kernel_init (
@@ -1985,6 +2006,72 @@ def dense_matmul(
19852006 ).astype (self .dtype )
19862007 return output , lb_loss , bias_updates
19872008
2009+ def fused_moe_matmul (
2010+ self ,
2011+ inputs ,
2012+ gate_logits ,
2013+ wo_kernel ,
2014+ w0_kernel = None ,
2015+ w1_kernel = None ,
2016+ fused_kernel = None ,
2017+ ) -> tuple [jax .Array , None , None ]:
2018+ """Fused MoE via tpu_inference fused_moe_func (vllm_rpa path only).
2019+
2020+ fused_moe_func handles routing, GMM, and weighted combination internally.
2021+ It does not compute lb_loss or bias_updates (inference-only).
2022+ """
2023+ try :
2024+ # pylint: disable=import-outside-toplevel
2025+ # pytype: disable=import-error
2026+ from tpu_inference .layers .common .fused_moe_gmm import fused_moe_func
2027+ except ImportError as e :
2028+ raise ImportError ("fused_moe_matmul requires the tpu-inference package." ) from e
2029+
2030+ # Reshape 3D [B, S, D] -> 2D [T, D] (fused_moe_func expects 2D input)
2031+ batch_size , seq_len , emb_dim = inputs .shape
2032+ hidden_states = jnp .reshape (inputs , (batch_size * seq_len , emb_dim ))
2033+ gating_output = jnp .reshape (gate_logits , (batch_size * seq_len , self .num_experts ))
2034+
2035+ # Concatenate gate and up projections: [E, D, H] + [E, D, H] -> [E, D, 2H]
2036+ # fused_moe_func splits this internally: gate=w1[..., :H], up=w1[..., H:]
2037+ if fused_kernel is None :
2038+ fused_kernel = jnp .concatenate ([w0_kernel , w1_kernel ], axis = - 1 )
2039+
2040+ # Use expert parallelism if the expert axis has size > 1
2041+ use_ep = self .get_expert_parallelism_size () > 1
2042+
2043+ # Map MaxText config fields to fused_moe_func args
2044+ activation = self .config .mlp_activations [0 ] # e.g. "silu"
2045+ scoring_fn = self .config .routed_score_func if self .config .routed_score_func else "softmax"
2046+
2047+ # Check if the model architecture intrinsically renormalizes weights
2048+ renormalize = self .config .norm_topk_prob or (
2049+ self .config .decoder_block not in (ctypes .DecoderBlockType .LLAMA4 , ctypes .DecoderBlockType .GEMMA4 )
2050+ )
2051+
2052+ output_2d = fused_moe_func (
2053+ hidden_states = hidden_states ,
2054+ w1 = fused_kernel ,
2055+ w2 = wo_kernel ,
2056+ w1_scale = None ,
2057+ w2_scale = None ,
2058+ w1_bias = None ,
2059+ w2_bias = None ,
2060+ gating_output = gating_output ,
2061+ topk = self .num_experts_per_tok ,
2062+ renormalize = renormalize ,
2063+ mesh = self .mesh ,
2064+ use_ep = use_ep ,
2065+ activation = activation ,
2066+ scoring_fn = scoring_fn ,
2067+ sc_kernel_threshold = 16777216 ,
2068+ sc_kernel_col_chunk_size = 1024 ,
2069+ )
2070+
2071+ # Reshape output 2D [T, D] -> 3D [B, S, D]
2072+ output = jnp .reshape (output_2d , (batch_size , seq_len , emb_dim ))
2073+ return output , None , None
2074+
19882075 def retrieve_quantized_weight (
19892076 self ,
19902077 inputs ,
@@ -2023,10 +2110,17 @@ def __call__(
20232110 routing_inputs = inputs if gate_inputs is None else gate_inputs .astype (gate_dtype )
20242111 gate_logits , pre_bias_logits = self .gate (routing_inputs )
20252112
2026- w0_kernel = jnp .asarray (self .wi_0 [...], self .dtype )
2027- w1_kernel = jnp .asarray (self .wi_1 [...], self .dtype )
20282113 wo_kernel = jnp .asarray (self .wo [...], self .dtype )
20292114
2115+ fused_kernel = None
2116+ w0_kernel = None
2117+ w1_kernel = None
2118+ if cfg .prefuse_moe_weights and cfg .attention == "vllm_rpa" :
2119+ fused_kernel = jnp .asarray (self .wi [...], self .dtype )
2120+ else :
2121+ w0_kernel = jnp .asarray (self .wi_0 [...], self .dtype )
2122+ w1_kernel = jnp .asarray (self .wi_1 [...], self .dtype )
2123+
20302124 if self .per_expert_scale is not None :
20312125 wo_kernel = wo_kernel * jnp .asarray (self .per_expert_scale [...], self .dtype )[:, None , None ]
20322126
@@ -2037,7 +2131,12 @@ def __call__(
20372131 else :
20382132 w0_bias , w1_bias , wo_bias = None , None , None
20392133
2040- if cfg .sparse_matmul :
2134+ # vllm_rpa uses fused_moe_func from tpu_inference (highest priority)
2135+ if cfg .attention == "vllm_rpa" :
2136+ output , lb_loss , bias_updates = self .fused_moe_matmul (
2137+ inputs , gate_logits , wo_kernel , w0_kernel = w0_kernel , w1_kernel = w1_kernel , fused_kernel = fused_kernel
2138+ )
2139+ elif cfg .sparse_matmul :
20412140 if quantizations .in_serve_mode (self .quant ):
20422141 w0_kernel , w1_kernel , wo_kernel = self .retrieve_quantized_weight (
20432142 inputs ,
0 commit comments