1212import pandas as pd
1313from elastic_transport import ObjectApiResponse
1414from mypy_boto3_s3 import S3Client
15- from transformers import AutoTokenizer
1615
1716from rigging .chat import Chat
1817from rigging .error import TokenizeWarning
1918from rigging .message import Message
2019from rigging .tokenize import find_in_tokens
2120from rigging .tokenize .base import TokenizedChat , TokenSlice
2221
22+ if t .TYPE_CHECKING :
23+ from transformers .tokenization_utils_base import PreTrainedTokenizerBase
24+
2325
2426def flatten_chats (chats : Chat | t .Sequence [Chat ]) -> list [dict [t .Any , t .Any ]]:
2527 """
@@ -295,7 +297,7 @@ async def chats_to_elastic(
295297
296298async def chats_to_tokens (
297299 chat : Chat ,
298- tokenizer : AutoTokenizer ,
300+ tokenizer : "PreTrainedTokenizerBase" ,
299301 * ,
300302 apply_chat_template_kwargs : dict [str , t .Any ] | None = None ,
301303 encode_kwargs : dict [str , t .Any ] | None = None ,
@@ -331,8 +333,9 @@ async def chats_to_tokens(
331333 if chat .params and chat .params .tools
332334 else None
333335 )
336+ # the tools above return dict[str, Any], but Transformers expects list[dict[Any, Any]]
334337
335- chat_text = tokenizer .apply_chat_template (messages , tools = tools , ** apply_chat_template_kwargs )
338+ chat_text = tokenizer .apply_chat_template (messages , tools = tools , ** apply_chat_template_kwargs ) # type: ignore[arg-type]
336339 chat_tokens = tokenizer .encode (chat_text , ** encode_kwargs )
337340
338341 slices : list [TokenSlice ] = []
@@ -342,7 +345,13 @@ async def chats_to_tokens(
342345 for message in chat .all :
343346 # Find this message
344347 if not (
345- match := find_in_tokens (message .content , chat_tokens , tokenizer .decode , 0 , search_start )
348+ match := find_in_tokens (
349+ message .content ,
350+ chat_tokens ,
351+ lambda tokens : tokenizer .decode (tokens ),
352+ 0 ,
353+ search_start ,
354+ )
346355 ):
347356 warnings .warn (
348357 f"Warning: Could not find message '{ message .content [:50 ]} ...' in chat tokens" ,
@@ -378,7 +387,7 @@ async def chats_to_tokens(
378387 part_match = find_in_tokens (
379388 part_text ,
380389 message_tokens ,
381- tokenizer .decode ,
390+ lambda tokens : tokenizer .decode ( tokens ) ,
382391 msg_start ,
383392 part_search_start ,
384393 )
@@ -407,8 +416,9 @@ async def chats_to_tokens(
407416 # Continue searching after this message
408417 search_start = msg_end
409418
419+ # we ask for a string by default in apply_chat_template_kwargs with the tokenize=False
410420 return TokenizedChat (
411- text = chat_text ,
421+ text = chat_text , # type: ignore[arg-type]
412422 tokens = chat_tokens ,
413423 slices = slices ,
414424 obj = chat ,
0 commit comments