Skip to content

Commit 5bd32d7

Browse files
committed
add chat template check for sft
1 parent 093ab89 commit 5bd32d7

3 files changed

Lines changed: 151 additions & 14 deletions

File tree

src/maxtext/input_pipeline/hf_data_processing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def preprocessing_pipeline(
307307
)
308308
operations = []
309309
if use_sft:
310+
input_pipeline_utils.verify_chat_template_generation_prompt_logic(tokenizer)
310311
operations.append(
311312
input_pipeline_utils.SFTPromptMasking(
312313
text_column_name=data_column_names[0],

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from threading import current_thread
2020
from typing import Any, Iterable, TYPE_CHECKING
2121

22+
from jinja2 import TemplateError
23+
2224
if TYPE_CHECKING:
2325
import datasets
2426
import tensorflow as tf
@@ -175,6 +177,103 @@ def is_conversational(features, data_columns):
175177
return False
176178

177179

180+
def _extract_token_ids(tokens):
181+
"""Extracts token IDs from various tokenizer output formats.
182+
183+
This helper function standardizes the extraction of tokenized integer IDs
184+
from common return types of Hugging Face tokenizers, including
185+
`BatchEncoding` objects, dictionaries, or simple lists.
186+
187+
Args:
188+
tokens: The object containing token IDs. Supported types include:
189+
- A list of integers.
190+
- A dictionary containing the `INPUT_TOKENS_KEY`.
191+
- An object (e.g., `BatchEncoding`) with an attribute named `INPUT_TOKENS_KEY`.
192+
193+
Returns:
194+
A list of integer token IDs.
195+
196+
Raises:
197+
ValueError: If the input type is not supported or does not contain the expected key.
198+
"""
199+
# attention masks in BatchEncoding are effectively ignored
200+
if hasattr(tokens, INPUT_TOKENS_KEY):
201+
return getattr(tokens, INPUT_TOKENS_KEY)
202+
elif isinstance(tokens, dict) and INPUT_TOKENS_KEY in tokens:
203+
return tokens[INPUT_TOKENS_KEY]
204+
elif isinstance(tokens, list):
205+
return tokens
206+
else:
207+
raise ValueError(f"Can't extract token_ids from type {type(tokens)}")
208+
209+
210+
def verify_chat_template_generation_prompt_logic(tokenizer_model):
211+
"""Verifies the tokenizer's chat template for correct SFT loss masking.
212+
213+
This function ensures that the tokens added by `add_generation_prompt=True`
214+
are identical to the tokens that begin an assistant's turn in a complete
215+
conversation, which is critical for masking prompt tokens during SFT loss
216+
calculation.
217+
218+
Example of a mismatch:
219+
A `ValueError` is raised if the generation prompt and the actual
220+
assistant prefix do not match. For example:
221+
222+
- `add_generation_prompt=True` on a user message produces a prompt ending in:
223+
`...<|im_start|>generation\n`
224+
- A full turn with an assistant message starts the reply with:
225+
`...<|im_start|>assistant\n...`
226+
227+
This function would fail because the tokens for "generation" do not
228+
match the tokens for "assistant".
229+
230+
Args:
231+
tokenizer_model: The Hugging Face tokenizer instance to verify.
232+
233+
Raises:
234+
ValueError: If the `add_generation_prompt` tokens do not exactly
235+
match the beginning of an assistant message in the template.
236+
"""
237+
dummy_msgs = [{"role": "system", "content": "System message"}, {"role": "user", "content": "Test message"}]
238+
239+
try:
240+
prompt_wo_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=False, tokenize=True)
241+
except TemplateError:
242+
max_logging.info(
243+
"Tokenizer failed to apply chat template with 'system' role. "
244+
"Falling back to 'user' role only for chat template verification."
245+
)
246+
dummy_msgs.pop(0)
247+
prompt_wo_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=False, tokenize=True)
248+
prompt_wo_gen_ids = _extract_token_ids(prompt_wo_gen_tokens)
249+
250+
prompt_w_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=True, tokenize=True)
251+
prompt_w_gen_ids = _extract_token_ids(prompt_w_gen_tokens)
252+
253+
if prompt_w_gen_ids[: len(prompt_wo_gen_ids)] != prompt_wo_gen_ids:
254+
raise ValueError("Unable to extract generation prompt tokens.")
255+
# Extract the tokenized generation prompt (the expected assistant prefix)
256+
assistant_prefix = prompt_w_gen_ids[len(prompt_wo_gen_ids) :]
257+
full_turn_tokens = _extract_token_ids(
258+
tokenizer_model.apply_chat_template(
259+
dummy_msgs + [{"role": "assistant", "content": "Dummy response"}], add_generation_prompt=False, tokenize=True
260+
)
261+
)
262+
full_turn_ids = _extract_token_ids(full_turn_tokens)
263+
# Extract the actual tokens that appear right after the user message in the full turn
264+
actual_prefix_in_full_turn = full_turn_ids[len(prompt_wo_gen_ids) : len(prompt_wo_gen_ids) + len(assistant_prefix)]
265+
266+
if actual_prefix_in_full_turn != assistant_prefix:
267+
expected_str = tokenizer_model.decode(assistant_prefix)
268+
actual_str = tokenizer_model.decode(actual_prefix_in_full_turn)
269+
raise ValueError(
270+
"Chat template generation prompt mismatch!\n"
271+
f"Expected assistant prefix tokens: {assistant_prefix} ('{expected_str}')\n"
272+
f"Actual prefix tokens found: {actual_prefix_in_full_turn} ('{actual_str}')\n"
273+
"This means the tokenizer's chat template will break the sft masking logic."
274+
)
275+
276+
178277
def _get_completion_in_chat_template(tokenizer_model, round_msgs):
179278
"""
180279
Calculates the completion part of a conversation turn when formatted with a chat template.
@@ -193,18 +292,8 @@ def _get_completion_in_chat_template(tokenizer_model, round_msgs):
193292
# include generation_prompt as part of the prompt tokens
194293
prompt_tokens = tokenizer_model.apply_chat_template(round_msgs[:-1], add_generation_prompt=True, tokenize=True)
195294

196-
# attention masks in BatchEncoding are effectively ignored
197-
if hasattr(prompt_completion_tokens, INPUT_TOKENS_KEY):
198-
prompt_completion_ids = getattr(prompt_completion_tokens, INPUT_TOKENS_KEY)
199-
prompt_ids = getattr(prompt_tokens, INPUT_TOKENS_KEY)
200-
elif isinstance(prompt_completion_tokens, dict) and INPUT_TOKENS_KEY in prompt_completion_tokens:
201-
prompt_completion_ids = prompt_completion_tokens[INPUT_TOKENS_KEY]
202-
prompt_ids = prompt_tokens[INPUT_TOKENS_KEY]
203-
elif isinstance(prompt_completion_tokens, list):
204-
prompt_completion_ids = prompt_completion_tokens
205-
prompt_ids = prompt_tokens
206-
else:
207-
raise ValueError(f"Can't handle the chat template output of type {type(prompt_completion_tokens)}")
295+
prompt_completion_ids = _extract_token_ids(prompt_completion_tokens)
296+
prompt_ids = _extract_token_ids(prompt_tokens)
208297

209298
completion_tokens = prompt_completion_ids[len(prompt_ids) :]
210299
completion_in_chat_template = tokenizer_model.decode(completion_tokens, skip_special_tokens=False)

tests/post_training/unit/sft_data_processing_test.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,21 @@
2020
import subprocess
2121
import unittest
2222
import os.path
23-
import pytest
2423
import numpy as np
2524
import jax
25+
import re
2626
from jax.sharding import Mesh
2727
from jax.experimental import mesh_utils
2828
from datasets import Dataset
2929
import transformers
3030
from parameterized import parameterized_class
31-
31+
from unittest.mock import patch
3232
from maxtext.configs import pyconfig
3333
from maxtext.utils.globals import MAXTEXT_PKG_DIR, MAXTEXT_CONFIGS_DIR, MAXTEXT_ASSETS_ROOT
3434
from maxtext.input_pipeline import hf_data_processing
3535
from maxtext.input_pipeline import input_pipeline_interface
3636
from maxtext.input_pipeline.hf_data_processing import _get_pad_id
37+
from maxtext.input_pipeline.input_pipeline_utils import verify_chat_template_generation_prompt_logic
3738

3839
PROMPT_DATA = [
3940
[
@@ -484,5 +485,51 @@ def test_system_message_not_at_beginning(self):
484485
self.get_data_iterator(dataset, ["messages"])
485486

486487

488+
@pytest.mark.external_training
489+
class SFTChatTemplateLogicTest(unittest.TestCase):
490+
LLAMA_TOKENIZER_PATH = os.path.join(MAXTEXT_ASSETS_ROOT, "llama2-chat-tokenizer")
491+
492+
@classmethod
493+
def setUpClass(cls):
494+
super().setUpClass()
495+
if not os.path.exists(cls.LLAMA_TOKENIZER_PATH):
496+
exit_code = subprocess.call(
497+
[
498+
"gsutil",
499+
"cp",
500+
"-r",
501+
"gs://maxtext-dataset/hf/llama2-chat-tokenizer",
502+
os.path.join(MAXTEXT_ASSETS_ROOT, ""),
503+
]
504+
)
505+
if exit_code != 0:
506+
raise ValueError("Failed to download llama tokenizer")
507+
508+
def setUp(self):
509+
super().setUp()
510+
self.qwen3_tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
511+
self.llama2_tokenizer = transformers.AutoTokenizer.from_pretrained(self.LLAMA_TOKENIZER_PATH)
512+
513+
def test_tokenizer_w_generation_prompt(self):
514+
verify_chat_template_generation_prompt_logic(self.qwen3_tokenizer)
515+
516+
def test_tokenizer_wo_generation_prompt(self):
517+
verify_chat_template_generation_prompt_logic(self.llama2_tokenizer)
518+
519+
def test_failure_path_with_modified_template(self):
520+
"""Verifies the function correctly raises a ValueError on a bad template."""
521+
# Replace the role within the existing add_generation_prompt block with a deliberately faulty one.
522+
fault_chat_template = re.sub(
523+
r"(\{%-?\s*if add_generation_prompt\s*%\}.*?<\|im_start\|>)assistant(.*?\{%-?\s*endif\s*%\})",
524+
r"\1wrong_role\2",
525+
self.qwen3_tokenizer.chat_template,
526+
flags=re.DOTALL,
527+
)
528+
with patch.object(self.qwen3_tokenizer, "chat_template", fault_chat_template):
529+
# Verify that our function catches the mismatch and raises the expected error
530+
with self.assertRaisesRegex(ValueError, "Chat template generation prompt mismatch!"):
531+
verify_chat_template_generation_prompt_logic(self.qwen3_tokenizer)
532+
533+
487534
if __name__ == "__main__":
488535
unittest.main()

0 commit comments

Comments
 (0)