Skip to content

Commit 94d2010

Browse files
Merge pull request #3529 from AI-Hypercomputer:dsv32_decode_clean
PiperOrigin-RevId: 894126759
2 parents e735af1 + f4a4f21 commit 94d2010

8 files changed

Lines changed: 695 additions & 75 deletions

File tree

benchmarks/api_server/encoding/encoding_dsv32.py

Lines changed: 403 additions & 0 deletions
Large diffs are not rendered by default.

benchmarks/api_server/maxtext_server.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -59,6 +59,7 @@
5959
ChatMessage,
6060
)
6161
from benchmarks.api_server import server_utils
62+
from benchmarks.api_server.encoding import encoding_dsv32
6263

6364
# ----------------------------
6465
# Init
@@ -95,10 +96,13 @@
9596
response_dict = {}
9697
response_lock = threading.Lock()
9798

98-
# Batching configuration
99-
BATCH_TIMEOUT_S = 0.1 # 100ms
99+
# Batching configuration.
100+
BATCH_TIMEOUT_S = float(os.environ.get("MAXTEXT_BATCH_TIMEOUT_S", "0.1"))
100101
# Timeout for a client waiting for a response.
101102
REQUEST_TIMEOUT_S = int(os.environ.get("MAXTEXT_REQUEST_TIMEOUT_S", "36000"))
103+
# Define a maximum size for the request payload to be broadcasted.
104+
# This avoids broadcasting variable-sized arrays, which can be complex.
105+
MAX_REQUEST_SIZE = int(os.environ.get("MAXTEXT_REQUEST_SIZE", "655360"))
102106

103107

104108
async def _queue_and_wait_for_response(request: Union[CompletionRequest, ChatCompletionRequest]):
@@ -165,14 +169,11 @@ def run_server():
165169
uvicorn.run(app, host="0.0.0.0", port=8000)
166170

167171

168-
# Define a maximum size for the request payload to be broadcasted.
169-
# This avoids broadcasting variable-sized arrays, which can be complex.
170-
MAX_REQUEST_SIZE = 65536 * 10
171-
172-
173172
def _build_chat_completion_response(request, completion_result, llm):
174173
"""Builds a ChatCompletionResponse from a single completion result."""
175174
text_out = completion_result.text
175+
reasoning_out = None
176+
176177
if "gpt-oss" in request.model and harmony_enc:
177178
try:
178179
parsed_messages = harmony_enc.parse_messages_from_completion_tokens(completion_result.tokens, role=Role.ASSISTANT)
@@ -186,6 +187,15 @@ def _build_chat_completion_response(request, completion_result, llm):
186187
except (ValueError, IndexError) as e:
187188
logger.error("Harmony parsing failed for gpt-oss: %s. Falling back to raw text.", e, exc_info=True)
188189

190+
if server_utils.is_dsv32_encoding_enabled(request.model):
191+
try:
192+
# DeepSeek-V3.2 models often generate thinking block.
193+
parsed = encoding_dsv32.parse_message_from_completion_text(text_out, thinking_mode="thinking")
194+
text_out = parsed.get("content", text_out)
195+
reasoning_out = parsed.get("reasoning_content")
196+
except (AssertionError, ValueError, IndexError) as e:
197+
logger.error("DeepSeek-V3.2 parsing failed: %s. Falling back to raw text.", e, exc_info=True)
198+
189199
want_top_logprobs = (
190200
(request.top_logprobs or 0) > 0 if isinstance(request, ChatCompletionRequest) else (request.logprobs or 0) > 0
191201
)
@@ -206,7 +216,7 @@ def _build_chat_completion_response(request, completion_result, llm):
206216
choices=[
207217
ChatCompletionChoice(
208218
index=0,
209-
message=ChatMessage(role="assistant", content=text_out),
219+
message=ChatMessage(role="assistant", content=text_out, reasoning_content=reasoning_out),
210220
finish_reason=finish_reason,
211221
logprobs=lp_payload,
212222
)

benchmarks/api_server/server_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,12 @@ class ChatMessage(BaseModel):
172172
Attributes:
173173
role: The role of the message's author (e.g., 'user', 'assistant').
174174
content: The text content of the message.
175+
reasoning_content: The text content for reasoning/thinking.
175176
"""
176177

177178
role: str
178179
content: str
180+
reasoning_content: Optional[str] = None
179181

180182

181183
class ChatCompletionRequest(SamplingParams):

benchmarks/api_server/server_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
from benchmarks.api_server.maxtext_generator import MaxTextGenerator
3434
from benchmarks.api_server.server_models import LogProbsPayload
35+
from benchmarks.api_server.encoding import encoding_dsv32
3536

3637
# ----------------------------
3738
# Debugging
@@ -41,6 +42,19 @@
4142
DEBUG_LOG_FILE = os.environ.get("MAXTEXT_DEBUG_LOG_FILE", "benchmarks/api_server/server_debug_log.jsonl")
4243
logger = logging.getLogger(__name__)
4344

45+
# Indicate if we should disable specific encoding for DeepSeek-V3.2 family.
46+
# Encoding is needed for v3.2 and v3.2-speciale, but not deepseek v3.2-exp.
47+
DISABLE_DSV32_ENCODING = os.environ.get("DISABLE_DSV32_ENCODING", "0") == "1"
48+
49+
50+
def is_dsv32_encoding_enabled(model_name: str) -> bool:
51+
"""
52+
Checks if DeepSeek-V3.2 specific encoding should be applied to the given model.
53+
"""
54+
if DISABLE_DSV32_ENCODING:
55+
return False
56+
return "deepseek3.2" in model_name.lower()
57+
4458

4559
def log_debug_event(request_id: str, event_type: str, content: dict):
4660
"""
@@ -114,6 +128,11 @@ def get_prompts_for_request(req: any, llm: MaxTextGenerator) -> List[str]:
114128
A list of string prompts.
115129
"""
116130
if hasattr(req, "messages"): # ChatCompletionRequest
131+
if is_dsv32_encoding_enabled(req.model):
132+
messages = [m.model_dump(exclude_none=True) for m in req.messages]
133+
encode_config = {"thinking_mode": "thinking", "drop_thinking": True, "add_default_bos_token": True}
134+
return [encoding_dsv32.encode_messages(messages, **encode_config)]
135+
117136
messages = [m.model_dump() for m in req.messages]
118137
formatted_prompt = llm.tokenizer.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
119138
return [formatted_prompt]

src/maxtext/layers/attention_mla.py

Lines changed: 102 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@
3838
AxisNames,
3939
BATCH,
4040
BATCH_NO_EXP,
41+
CACHE_BATCH,
42+
CACHE_BATCH_PREFILL,
43+
CACHE_SEQUENCE,
44+
CACHE_HEADS_NONE,
45+
CACHE_KV,
4146
Config,
4247
DECODE_BATCH,
4348
DECODE_LENGTH,
@@ -76,6 +81,9 @@
7681
from maxtext.utils.globals import EPS
7782

7883

84+
PLACEHOLDER_SEQ_LEN = 1
85+
86+
7987
class Indexer(nnx.Module):
8088
"""Indexer for DeepSeek Sparse Attention (DSA).
8189
@@ -109,6 +117,7 @@ def __init__(
109117
self.rngs = rngs
110118
self.dtype = config.dtype
111119
self.weight_dtype = config.weight_dtype
120+
self.max_target_length = config.max_target_length
112121

113122
self.n_heads = config.indexer_n_heads
114123
self.head_dim = config.indexer_head_dim
@@ -168,6 +177,31 @@ def __init__(
168177
rngs=self.rngs,
169178
)
170179

180+
def update_indexer_cache(self, kv_cache, k, decoder_segment_ids, model_mode, previous_chunk):
181+
"""Updates Indexer buffers by processing KV cache results."""
182+
k_expanded = k[:, :, jnp.newaxis, :]
183+
p_res, a_res = kv_cache(
184+
key=k_expanded,
185+
value=k_expanded,
186+
decoder_segment_ids=decoder_segment_ids,
187+
model_mode=model_mode,
188+
use_ragged_attention=self.config.use_ragged_attention,
189+
previous_chunk=previous_chunk,
190+
)
191+
192+
# Filter out None values to handle PREFILL vs AR modes uniformly
193+
active_results = [res for res in [p_res, a_res] if res is not None]
194+
195+
if not active_results:
196+
return None, None
197+
198+
# Extract keys (index 0) and segment IDs (index 2)
199+
keys = jnp.concatenate([res[0] for res in active_results], axis=1)
200+
segs = jnp.concatenate([res[2] for res in active_results], axis=1)
201+
202+
# squeeze(2) removes the jnp.newaxis added above
203+
return keys.squeeze(2), segs
204+
171205
def apply_partial_rope(
172206
self,
173207
inputs: Array,
@@ -221,6 +255,10 @@ def __call__(
221255
inputs_kv: Array,
222256
inputs_positions: Optional[Array | None] = None,
223257
attention_mask: Optional[Array | None] = None,
258+
decoder_segment_ids: Optional[Array | None] = None,
259+
previous_chunk: Any = None,
260+
kv_cache: Any = None,
261+
model_mode: str = MODEL_MODE_TRAIN,
224262
):
225263
"""Computes the index score to determine the top-k relevant tokens.
226264
@@ -245,6 +283,10 @@ def __call__(
245283
`DEFAULT_MASK_VALUE` (a large negative number) prevent it.
246284
Returns `None` if no masking is determined to be necessary based on
247285
the inputs and configuration.
286+
decoder_segment_ids: Segment IDs for decoder masking.
287+
previous_chunk: Previous chunk info for prefill.
288+
kv_cache: Key-value cache used when serving models.
289+
model_mode: "train", "prefill", or "autoregressive".
248290
249291
Returns:
250292
indexer_mask: A sparse mask [b, t, s] with 0.0 for top-k selected tokens
@@ -259,10 +301,6 @@ def __call__(
259301
h: Number of Indexer Heads (indexer_n_heads)
260302
d: Indexer Head Dimension (indexer_head_dim)
261303
"""
262-
# NOTE: If sequence length <= topk, indexer always selects all tokens.
263-
if self.config.max_target_length <= self.indexer_topk:
264-
return None, None, None
265-
266304
bsz, seqlen, _ = inputs_q.shape # s = t = seqlen
267305
# ==============================================================================
268306
# Gradient Isolation Strategy: Main Model vs. Indexer
@@ -300,6 +338,16 @@ def __call__(
300338
k = self.apply_partial_rope(k, inputs_positions=inputs_positions)
301339
k = k.squeeze(2) # [b, s, 1, d] -> [b, s, d]
302340

341+
# Update and retrieve from cache if not training
342+
cached_s = None
343+
if model_mode != MODEL_MODE_TRAIN:
344+
k_cached, cached_s = self.update_indexer_cache(kv_cache, k, decoder_segment_ids, model_mode, previous_chunk)
345+
k = k_cached if k_cached is not None else k
346+
347+
# NOTE: If the total available sequence length <= topk, indexer always selects all tokens.
348+
if k.shape[1] <= self.indexer_topk:
349+
return None, None, None
350+
303351
# Compute Index Scores
304352
# QK product: relu(q @ k.T), [b, t, s, h]
305353
# Similar to MQA, each key is shared by h query head
@@ -313,6 +361,12 @@ def __call__(
313361
# Aggregate head-wise logits: logits @ weights
314362
indexer_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision) # [b, t, s]
315363

364+
internal_padding_mask = None
365+
if cached_s is not None:
366+
# cached_s marks valid tokens from the original prefill step and all subsequent AR steps
367+
internal_padding_mask = jnp.where(cached_s > 0, 0.0, DEFAULT_MASK_VALUE)
368+
indexer_score += internal_padding_mask[:, None, :]
369+
316370
# Apply attention mask before TopK
317371
if attention_mask is not None:
318372
indexer_score += attention_mask
@@ -321,12 +375,15 @@ def __call__(
321375
_, topk_indices = jax.lax.top_k(indexer_score, k=self.indexer_topk) # topk_indices [b, t, k]
322376

323377
# Create Sparse Index Mask: 0 and large negatives
324-
indexer_mask = self.generate_mask(topk_indices, seqlen) # [b, t, s]
378+
indexer_mask = self.generate_mask(topk_indices, k.shape[1]) # [b, t, s]
325379

326380
# Re-apply attention mask after TopK: in case number of unmasked tokens < TopK
327381
if attention_mask is not None:
328382
indexer_mask += attention_mask
329383

384+
if internal_padding_mask is not None:
385+
indexer_mask += internal_padding_mask[:, None, :]
386+
330387
return indexer_mask, topk_indices, indexer_score
331388

332389

@@ -645,10 +702,41 @@ def __init__(
645702
quant=quant,
646703
model_mode=model_mode,
647704
)
705+
self.IndexerKVCache_0 = self.init_indexer_cache(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None
706+
else:
707+
self.indexer = None
708+
self.IndexerKVCache_0 = None
648709

649710
# Module attribute names must match names previously passed to Linen for checkpointing
650711
self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None
651712

713+
def init_indexer_cache(self, inputs_kv_shape: Tuple):
714+
"""Initializes Indexer Cache."""
715+
batch_size, _, _ = inputs_kv_shape
716+
# Use standard KVCache to store keys. Values are unused but required by KVCache API.
717+
# KVCache expects key_heads and value_heads. Since k is shared (MQA-like for Indexer),
718+
# we use key_heads=1, value_heads=1.
719+
return kvcache.KVCache(
720+
max_prefill_length=self.max_prefill_predict_length,
721+
max_target_length=self.max_target_length,
722+
batch=batch_size,
723+
key_seq_len=PLACEHOLDER_SEQ_LEN,
724+
value_seq_len=PLACEHOLDER_SEQ_LEN,
725+
key_heads=1,
726+
value_heads=1,
727+
key_head_size=self.config.indexer_head_dim,
728+
value_head_size=self.config.indexer_head_dim,
729+
dtype=self.dtype,
730+
kv_quant=None, # Quantization is not yet supported by the indexer.
731+
prefill_cache_logical_axis_names=(CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS_NONE, CACHE_KV),
732+
cache_logical_axis_names=(CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS_NONE, CACHE_KV),
733+
prefill_cache_axis_order=(1, 2, 0, 3),
734+
ar_cache_axis_order=(1, 2, 0, 3),
735+
use_chunked_prefill=self.config.use_chunked_prefill,
736+
model_mode=self.model_mode,
737+
rngs=self.rngs,
738+
)
739+
652740
def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None:
653741
"""Initializes the MLA-specific projections."""
654742
# Assert required configuration parameters for MLA attention.
@@ -881,14 +969,13 @@ def init_mla_kv_caches(self, inputs_kv_shape: Tuple):
881969
# and max_target_length, not the passed seq_len.
882970
# We can use a placeholder value. The correct fix might involve refactoring
883971
# MlaKVCache.
884-
placeholder_seq_len = 1
885972

886973
return kvcache.MlaKVCache(
887974
max_prefill_length=self.max_prefill_predict_length,
888975
max_target_length=self.max_target_length,
889976
batch=batch_size,
890-
key_seq_len=placeholder_seq_len,
891-
value_seq_len=placeholder_seq_len,
977+
key_seq_len=PLACEHOLDER_SEQ_LEN,
978+
value_seq_len=PLACEHOLDER_SEQ_LEN,
892979
key_head_size=self.kv_lora_rank,
893980
value_head_size=self.qk_rope_head_dim,
894981
dtype=self.dtype,
@@ -1100,6 +1187,9 @@ def __call__(
11001187
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names)
11011188
out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV)
11021189

1190+
if model_mode != MODEL_MODE_TRAIN and decoder_segment_ids is None:
1191+
decoder_segment_ids = jnp.ones(inputs_q.shape[:2], dtype=jnp.int32)
1192+
11031193
query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
11041194
if self.config.force_q_layout:
11051195
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
@@ -1113,8 +1203,6 @@ def __call__(
11131203
# Indexer Logic
11141204
indexer_mask = None
11151205
if self.use_indexer:
1116-
if model_mode != MODEL_MODE_TRAIN:
1117-
raise NotImplementedError("Sparse indexer has not implemented for inference yet.")
11181206
# generate mask: with 0 and large negative, [b, 1, 1, q_len, kv_len] -> [b, q_len, kv_len]
11191207
attention_mask = self.attention_op.generate_attention_mask(
11201208
query, key, decoder_segment_ids, model_mode, previous_chunk, bidirectional_mask
@@ -1128,6 +1216,10 @@ def __call__(
11281216
inputs_kv=inputs_kv,
11291217
inputs_positions=inputs_positions,
11301218
attention_mask=attention_mask,
1219+
decoder_segment_ids=decoder_segment_ids,
1220+
previous_chunk=previous_chunk,
1221+
kv_cache=self.IndexerKVCache_0,
1222+
model_mode=model_mode,
11311223
)
11321224

11331225
if indexer_mask is not None and self.config.indexer_loss_scaling_factor > 0.0:

src/maxtext/layers/attention_op.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def generate_attention_mask(
677677
Chunked Prefills - ArXiv:2308.16369 (https://arxiv.org/abs/2308.16369)
678678
"""
679679
mask = None
680-
if model_mode == MODEL_MODE_AUTOREGRESSIVE:
680+
if model_mode == MODEL_MODE_AUTOREGRESSIVE and decoder_segment_ids is not None:
681681
mask = decoder_segment_ids[:, None, None, None, :] == DECODING_ACTIVE_SEQUENCE_INDICATOR
682682
elif decoder_segment_ids is not None:
683683
mask = decoder_segment_ids[:, :, None] == decoder_segment_ids[:, None, :]
@@ -2047,6 +2047,14 @@ def __call__(
20472047
assert prefill_kv_cache
20482048
key, value, decoder_segment_ids = prefill_kv_cache
20492049

2050+
indexer_mask_prefill = None
2051+
indexer_mask_ar = None
2052+
if indexer_mask is not None:
2053+
prefill_len = key.shape[1]
2054+
indexer_mask_prefill = indexer_mask[:, :, :prefill_len]
2055+
if ar_kv_cache is not None:
2056+
indexer_mask_ar = indexer_mask[:, :, prefill_len:]
2057+
20502058
prefill_unnormalized_output, prefill_exponentials_max, prefill_exponentials_sum = self.apply_attention(
20512059
query=query,
20522060
key=key,
@@ -2058,7 +2066,7 @@ def __call__(
20582066
previous_chunk=previous_chunk,
20592067
bidirectional_mask=bidirectional_mask,
20602068
sinks=sinks,
2061-
indexer_mask=indexer_mask,
2069+
indexer_mask=indexer_mask_prefill,
20622070
record_max_logits=record_max_logits,
20632071
qk_product_einsum=self.AqtEinsum_0,
20642072
wv_product_einsum=self.AqtEinsum_1,
@@ -2081,6 +2089,7 @@ def __call__(
20812089
model_mode=model_mode,
20822090
use_ragged_attention=self.use_ragged_attention,
20832091
bidirectional_mask=bidirectional_mask,
2092+
indexer_mask=indexer_mask_ar,
20842093
qk_product_einsum=self.AqtEinsum_2,
20852094
wv_product_einsum=self.AqtEinsum_3,
20862095
)

0 commit comments

Comments
 (0)