3838 AxisNames ,
3939 BATCH ,
4040 BATCH_NO_EXP ,
41+ CACHE_BATCH ,
42+ CACHE_BATCH_PREFILL ,
43+ CACHE_SEQUENCE ,
44+ CACHE_HEADS_NONE ,
45+ CACHE_KV ,
4146 Config ,
4247 DECODE_BATCH ,
4348 DECODE_LENGTH ,
7681from maxtext .utils .globals import EPS
7782
7883
84+ PLACEHOLDER_SEQ_LEN = 1
85+
86+
7987class Indexer (nnx .Module ):
8088 """Indexer for DeepSeek Sparse Attention (DSA).
8189
@@ -109,6 +117,7 @@ def __init__(
109117 self .rngs = rngs
110118 self .dtype = config .dtype
111119 self .weight_dtype = config .weight_dtype
120+ self .max_target_length = config .max_target_length
112121
113122 self .n_heads = config .indexer_n_heads
114123 self .head_dim = config .indexer_head_dim
@@ -168,6 +177,31 @@ def __init__(
168177 rngs = self .rngs ,
169178 )
170179
180+ def update_indexer_cache (self , kv_cache , k , decoder_segment_ids , model_mode , previous_chunk ):
181+ """Updates Indexer buffers by processing KV cache results."""
182+ k_expanded = k [:, :, jnp .newaxis , :]
183+ p_res , a_res = kv_cache (
184+ key = k_expanded ,
185+ value = k_expanded ,
186+ decoder_segment_ids = decoder_segment_ids ,
187+ model_mode = model_mode ,
188+ use_ragged_attention = self .config .use_ragged_attention ,
189+ previous_chunk = previous_chunk ,
190+ )
191+
192+ # Filter out None values to handle PREFILL vs AR modes uniformly
193+ active_results = [res for res in [p_res , a_res ] if res is not None ]
194+
195+ if not active_results :
196+ return None , None
197+
198+ # Extract keys (index 0) and segment IDs (index 2)
199+ keys = jnp .concatenate ([res [0 ] for res in active_results ], axis = 1 )
200+ segs = jnp .concatenate ([res [2 ] for res in active_results ], axis = 1 )
201+
202+ # squeeze(2) removes the jnp.newaxis added above
203+ return keys .squeeze (2 ), segs
204+
171205 def apply_partial_rope (
172206 self ,
173207 inputs : Array ,
@@ -221,6 +255,10 @@ def __call__(
221255 inputs_kv : Array ,
222256 inputs_positions : Optional [Array | None ] = None ,
223257 attention_mask : Optional [Array | None ] = None ,
258+ decoder_segment_ids : Optional [Array | None ] = None ,
259+ previous_chunk : Any = None ,
260+ kv_cache : Any = None ,
261+ model_mode : str = MODEL_MODE_TRAIN ,
224262 ):
225263 """Computes the index score to determine the top-k relevant tokens.
226264
@@ -245,6 +283,10 @@ def __call__(
245283 `DEFAULT_MASK_VALUE` (a large negative number) prevent it.
246284 Returns `None` if no masking is determined to be necessary based on
247285 the inputs and configuration.
286+ decoder_segment_ids: Segment IDs for decoder masking.
287+ previous_chunk: Previous chunk info for prefill.
288+ kv_cache: Key-value cache used when serving models.
289+ model_mode: "train", "prefill", or "autoregressive".
248290
249291 Returns:
250292 indexer_mask: A sparse mask [b, t, s] with 0.0 for top-k selected tokens
@@ -259,10 +301,6 @@ def __call__(
259301 h: Number of Indexer Heads (indexer_n_heads)
260302 d: Indexer Head Dimension (indexer_head_dim)
261303 """
262- # NOTE: If sequence length <= topk, indexer always selects all tokens.
263- if self .config .max_target_length <= self .indexer_topk :
264- return None , None , None
265-
266304 bsz , seqlen , _ = inputs_q .shape # s = t = seqlen
267305 # ==============================================================================
268306 # Gradient Isolation Strategy: Main Model vs. Indexer
@@ -300,6 +338,16 @@ def __call__(
300338 k = self .apply_partial_rope (k , inputs_positions = inputs_positions )
301339 k = k .squeeze (2 ) # [b, s, 1, d] -> [b, s, d]
302340
341+ # Update and retrieve from cache if not training
342+ cached_s = None
343+ if model_mode != MODEL_MODE_TRAIN :
344+ k_cached , cached_s = self .update_indexer_cache (kv_cache , k , decoder_segment_ids , model_mode , previous_chunk )
345+ k = k_cached if k_cached is not None else k
346+
347+ # NOTE: If the total available sequence length <= topk, indexer always selects all tokens.
348+ if k .shape [1 ] <= self .indexer_topk :
349+ return None , None , None
350+
303351 # Compute Index Scores
304352 # QK product: relu(q @ k.T), [b, t, s, h]
305353 # Similar to MQA, each key is shared by h query head
@@ -313,6 +361,12 @@ def __call__(
313361 # Aggregate head-wise logits: logits @ weights
314362 indexer_score = jnp .einsum ("btsh, bth -> bts" , logits , weights , precision = self .config .matmul_precision ) # [b, t, s]
315363
364+ internal_padding_mask = None
365+ if cached_s is not None :
366+ # cached_s marks valid tokens from the original prefill step and all subsequent AR steps
367+ internal_padding_mask = jnp .where (cached_s > 0 , 0.0 , DEFAULT_MASK_VALUE )
368+ indexer_score += internal_padding_mask [:, None , :]
369+
316370 # Apply attention mask before TopK
317371 if attention_mask is not None :
318372 indexer_score += attention_mask
@@ -321,12 +375,15 @@ def __call__(
321375 _ , topk_indices = jax .lax .top_k (indexer_score , k = self .indexer_topk ) # topk_indices [b, t, k]
322376
323377 # Create Sparse Index Mask: 0 and large negatives
324- indexer_mask = self .generate_mask (topk_indices , seqlen ) # [b, t, s]
378+ indexer_mask = self .generate_mask (topk_indices , k . shape [ 1 ] ) # [b, t, s]
325379
326380 # Re-apply attention mask after TopK: in case number of unmasked tokens < TopK
327381 if attention_mask is not None :
328382 indexer_mask += attention_mask
329383
384+ if internal_padding_mask is not None :
385+ indexer_mask += internal_padding_mask [:, None , :]
386+
330387 return indexer_mask , topk_indices , indexer_score
331388
332389
@@ -645,10 +702,41 @@ def __init__(
645702 quant = quant ,
646703 model_mode = model_mode ,
647704 )
705+ self .IndexerKVCache_0 = self .init_indexer_cache (inputs_kv_shape ) if model_mode != MODEL_MODE_TRAIN else None
706+ else :
707+ self .indexer = None
708+ self .IndexerKVCache_0 = None
648709
649710 # Module attribute names must match names previously passed to Linen for checkpointing
650711 self .MlaKVCache_0 = self .init_mla_kv_caches (inputs_kv_shape ) if model_mode != MODEL_MODE_TRAIN else None
651712
713+ def init_indexer_cache (self , inputs_kv_shape : Tuple ):
714+ """Initializes Indexer Cache."""
715+ batch_size , _ , _ = inputs_kv_shape
716+ # Use standard KVCache to store keys. Values are unused but required by KVCache API.
717+ # KVCache expects key_heads and value_heads. Since k is shared (MQA-like for Indexer),
718+ # we use key_heads=1, value_heads=1.
719+ return kvcache .KVCache (
720+ max_prefill_length = self .max_prefill_predict_length ,
721+ max_target_length = self .max_target_length ,
722+ batch = batch_size ,
723+ key_seq_len = PLACEHOLDER_SEQ_LEN ,
724+ value_seq_len = PLACEHOLDER_SEQ_LEN ,
725+ key_heads = 1 ,
726+ value_heads = 1 ,
727+ key_head_size = self .config .indexer_head_dim ,
728+ value_head_size = self .config .indexer_head_dim ,
729+ dtype = self .dtype ,
730+ kv_quant = None , # Quantization is not yet supported by the indexer.
731+ prefill_cache_logical_axis_names = (CACHE_BATCH_PREFILL , CACHE_SEQUENCE , CACHE_HEADS_NONE , CACHE_KV ),
732+ cache_logical_axis_names = (CACHE_BATCH , CACHE_SEQUENCE , CACHE_HEADS_NONE , CACHE_KV ),
733+ prefill_cache_axis_order = (1 , 2 , 0 , 3 ),
734+ ar_cache_axis_order = (1 , 2 , 0 , 3 ),
735+ use_chunked_prefill = self .config .use_chunked_prefill ,
736+ model_mode = self .model_mode ,
737+ rngs = self .rngs ,
738+ )
739+
652740 def _init_projections (self , inputs_q_shape : Tuple , inputs_kv_shape : Tuple ) -> None :
653741 """Initializes the MLA-specific projections."""
654742 # Assert required configuration parameters for MLA attention.
@@ -881,14 +969,13 @@ def init_mla_kv_caches(self, inputs_kv_shape: Tuple):
881969 # and max_target_length, not the passed seq_len.
882970 # We can use a placeholder value. The correct fix might involve refactoring
883971 # MlaKVCache.
884- placeholder_seq_len = 1
885972
886973 return kvcache .MlaKVCache (
887974 max_prefill_length = self .max_prefill_predict_length ,
888975 max_target_length = self .max_target_length ,
889976 batch = batch_size ,
890- key_seq_len = placeholder_seq_len ,
891- value_seq_len = placeholder_seq_len ,
977+ key_seq_len = PLACEHOLDER_SEQ_LEN ,
978+ value_seq_len = PLACEHOLDER_SEQ_LEN ,
892979 key_head_size = self .kv_lora_rank ,
893980 value_head_size = self .qk_rope_head_dim ,
894981 dtype = self .dtype ,
@@ -1100,6 +1187,9 @@ def __call__(
11001187 inputs_kv = self ._maybe_shard_with_logical (inputs_kv , self .input_axis_names )
11011188 out_logical_name = (BATCH , LENGTH_NO_EXP , HEAD , D_KV )
11021189
1190+ if model_mode != MODEL_MODE_TRAIN and decoder_segment_ids is None :
1191+ decoder_segment_ids = jnp .ones (inputs_q .shape [:2 ], dtype = jnp .int32 )
1192+
11031193 query , low_rank_q = self .mla_query_projection (inputs_q , inputs_positions , model_mode )
11041194 if self .config .force_q_layout :
11051195 query = layout .with_layout_constraint (query , DLL (major_to_minor = (0 , 2 , 3 , 1 )))
@@ -1113,8 +1203,6 @@ def __call__(
11131203 # Indexer Logic
11141204 indexer_mask = None
11151205 if self .use_indexer :
1116- if model_mode != MODEL_MODE_TRAIN :
1117- raise NotImplementedError ("Sparse indexer has not implemented for inference yet." )
11181206 # generate mask: with 0 and large negative, [b, 1, 1, q_len, kv_len] -> [b, q_len, kv_len]
11191207 attention_mask = self .attention_op .generate_attention_mask (
11201208 query , key , decoder_segment_ids , model_mode , previous_chunk , bidirectional_mask
@@ -1128,6 +1216,10 @@ def __call__(
11281216 inputs_kv = inputs_kv ,
11291217 inputs_positions = inputs_positions ,
11301218 attention_mask = attention_mask ,
1219+ decoder_segment_ids = decoder_segment_ids ,
1220+ previous_chunk = previous_chunk ,
1221+ kv_cache = self .IndexerKVCache_0 ,
1222+ model_mode = model_mode ,
11311223 )
11321224
11331225 if indexer_mask is not None and self .config .indexer_loss_scaling_factor > 0.0 :
0 commit comments