1919from threading import current_thread
2020from typing import Any , Iterable , TYPE_CHECKING
2121
22+ from jinja2 import TemplateError
23+
2224if TYPE_CHECKING :
2325 import datasets
2426 import tensorflow as tf
@@ -175,6 +177,103 @@ def is_conversational(features, data_columns):
175177 return False
176178
177179
180+ def _extract_token_ids (tokens ):
181+ """Extracts token IDs from various tokenizer output formats.
182+
183+ This helper function standardizes the extraction of tokenized integer IDs
184+ from common return types of Hugging Face tokenizers, including
185+ `BatchEncoding` objects, dictionaries, or simple lists.
186+
187+ Args:
188+ tokens: The object containing token IDs. Supported types include:
189+ - A list of integers.
190+ - A dictionary containing the `INPUT_TOKENS_KEY`.
191+ - An object (e.g., `BatchEncoding`) with an attribute named `INPUT_TOKENS_KEY`.
192+
193+ Returns:
194+ A list of integer token IDs.
195+
196+ Raises:
197+ ValueError: If the input type is not supported or does not contain the expected key.
198+ """
199+ # attention masks in BatchEncoding are effectively ignored
200+ if hasattr (tokens , INPUT_TOKENS_KEY ):
201+ return getattr (tokens , INPUT_TOKENS_KEY )
202+ elif isinstance (tokens , dict ) and INPUT_TOKENS_KEY in tokens :
203+ return tokens [INPUT_TOKENS_KEY ]
204+ elif isinstance (tokens , list ):
205+ return tokens
206+ else :
207+ raise ValueError (f"Can't extract token_ids from type { type (tokens )} " )
208+
209+
210+ def verify_chat_template_generation_prompt_logic (tokenizer_model ):
211+ """Verifies the tokenizer's chat template for correct SFT loss masking.
212+
213+ This function ensures that the tokens added by `add_generation_prompt=True`
214+ are identical to the tokens that begin an assistant's turn in a complete
215+ conversation, which is critical for masking prompt tokens during SFT loss
216+ calculation.
217+
218+ Example of a mismatch:
219+ A `ValueError` is raised if the generation prompt and the actual
220+ assistant prefix do not match. For example:
221+
222+ - `add_generation_prompt=True` on a user message produces a prompt ending in:
223+ `...<|im_start|>generation\n `
224+ - A full turn with an assistant message starts the reply with:
225+ `...<|im_start|>assistant\n ...`
226+
227+ This function would fail because the tokens for "generation" do not
228+ match the tokens for "assistant".
229+
230+ Args:
231+ tokenizer_model: The Hugging Face tokenizer instance to verify.
232+
233+ Raises:
234+ ValueError: If the `add_generation_prompt` tokens do not exactly
235+ match the beginning of an assistant message in the template.
236+ """
237+ dummy_msgs = [{"role" : "system" , "content" : "System message" }, {"role" : "user" , "content" : "Test message" }]
238+
239+ try :
240+ prompt_wo_gen_tokens = tokenizer_model .apply_chat_template (dummy_msgs , add_generation_prompt = False , tokenize = True )
241+ except TemplateError :
242+ max_logging .info (
243+ "Tokenizer failed to apply chat template with 'system' role. "
244+ "Falling back to 'user' role only for chat template verification."
245+ )
246+ dummy_msgs .pop (0 )
247+ prompt_wo_gen_tokens = tokenizer_model .apply_chat_template (dummy_msgs , add_generation_prompt = False , tokenize = True )
248+ prompt_wo_gen_ids = _extract_token_ids (prompt_wo_gen_tokens )
249+
250+ prompt_w_gen_tokens = tokenizer_model .apply_chat_template (dummy_msgs , add_generation_prompt = True , tokenize = True )
251+ prompt_w_gen_ids = _extract_token_ids (prompt_w_gen_tokens )
252+
253+ if prompt_w_gen_ids [: len (prompt_wo_gen_ids )] != prompt_wo_gen_ids :
254+ raise ValueError ("Unable to extract generation prompt tokens." )
255+ # Extract the tokenized generation prompt (the expected assistant prefix)
256+ assistant_prefix = prompt_w_gen_ids [len (prompt_wo_gen_ids ) :]
257+ full_turn_tokens = _extract_token_ids (
258+ tokenizer_model .apply_chat_template (
259+ dummy_msgs + [{"role" : "assistant" , "content" : "Dummy response" }], add_generation_prompt = False , tokenize = True
260+ )
261+ )
262+ full_turn_ids = _extract_token_ids (full_turn_tokens )
263+ # Extract the actual tokens that appear right after the user message in the full turn
264+ actual_prefix_in_full_turn = full_turn_ids [len (prompt_wo_gen_ids ) : len (prompt_wo_gen_ids ) + len (assistant_prefix )]
265+
266+ if actual_prefix_in_full_turn != assistant_prefix :
267+ expected_str = tokenizer_model .decode (assistant_prefix )
268+ actual_str = tokenizer_model .decode (actual_prefix_in_full_turn )
269+ raise ValueError (
270+ "Chat template generation prompt mismatch!\n "
271+ f"Expected assistant prefix tokens: { assistant_prefix } ('{ expected_str } ')\n "
272+ f"Actual prefix tokens found: { actual_prefix_in_full_turn } ('{ actual_str } ')\n "
273+ "This means the tokenizer's chat template will break the sft masking logic."
274+ )
275+
276+
178277def _get_completion_in_chat_template (tokenizer_model , round_msgs ):
179278 """
180279 Calculates the completion part of a conversation turn when formatted with a chat template.
@@ -193,18 +292,8 @@ def _get_completion_in_chat_template(tokenizer_model, round_msgs):
193292 # include generation_prompt as part of the prompt tokens
194293 prompt_tokens = tokenizer_model .apply_chat_template (round_msgs [:- 1 ], add_generation_prompt = True , tokenize = True )
195294
196- # attention masks in BatchEncoding are effectively ignored
197- if hasattr (prompt_completion_tokens , INPUT_TOKENS_KEY ):
198- prompt_completion_ids = getattr (prompt_completion_tokens , INPUT_TOKENS_KEY )
199- prompt_ids = getattr (prompt_tokens , INPUT_TOKENS_KEY )
200- elif isinstance (prompt_completion_tokens , dict ) and INPUT_TOKENS_KEY in prompt_completion_tokens :
201- prompt_completion_ids = prompt_completion_tokens [INPUT_TOKENS_KEY ]
202- prompt_ids = prompt_tokens [INPUT_TOKENS_KEY ]
203- elif isinstance (prompt_completion_tokens , list ):
204- prompt_completion_ids = prompt_completion_tokens
205- prompt_ids = prompt_tokens
206- else :
207- raise ValueError (f"Can't handle the chat template output of type { type (prompt_completion_tokens )} " )
295+ prompt_completion_ids = _extract_token_ids (prompt_completion_tokens )
296+ prompt_ids = _extract_token_ids (prompt_tokens )
208297
209298 completion_tokens = prompt_completion_ids [len (prompt_ids ) :]
210299 completion_in_chat_template = tokenizer_model .decode (completion_tokens , skip_special_tokens = False )
0 commit comments