2727from jetstream .engine import mock_utils
2828from jetstream .engine import tokenizer_api
2929from jetstream .engine import tokenizer_pb2
30+ from jetstream .third_party .llama3 import llama3_tokenizer
3031
3132# ResultToken class to store tokens ids.
3233ResultTokens = Any
@@ -40,9 +41,10 @@ def take_nearest_length(lengths: list[int], length: int) -> int:
4041 return lengths [pos ]
4142
4243
43- def tokenize_and_pad (
44- s : str ,
45- vocab : Vocabulary ,
44+ def pad_tokens (
45+ tokens : np .ndarray ,
46+ bos_id : int ,
47+ pad_id : int ,
4648 is_bos : bool = True ,
4749 prefill_lengths : Optional [List [int ]] = None ,
4850 max_prefill_length : Optional [int ] = None ,
@@ -84,14 +86,13 @@ def tokenize_and_pad(
8486 ] + [
8587 max_prefill_length ,
8688 ]
87- tokens = np .array (vocab .encode_tf (s )) # [Length]
8889 # Add a beginning of sequence token if this is the beginning.
8990 if is_bos :
9091 tokens = np .concatenate (
9192 [
9293 np .array (
9394 [
94- vocab . bos_id ,
95+ bos_id ,
9596 ]
9697 ),
9798 tokens ,
@@ -101,13 +102,12 @@ def tokenize_and_pad(
101102 true_length = tokens .shape [- 1 ]
102103 padded_length = take_nearest_length (prefill_lengths , true_length )
103104 padding = padded_length - true_length
104- assert vocab .pad_id == 0 , "Further logic required if pad_id not 0."
105105 if padding < 0 :
106106 logging .warning ("Provided sequence longer than available." )
107107 # Take the last N tokens if we have too many.
108108 padded_tokens = tokens [- padded_length :]
109109 else :
110- padded_tokens = np .pad (tokens , (0 , padding ))
110+ padded_tokens = np .pad (tokens , (0 , padding ), constant_values = ( pad_id ,) )
111111 if jax_padding :
112112 padded_tokens = jnp .array (padded_tokens )
113113 return padded_tokens , true_length
@@ -117,7 +117,8 @@ def process_result_tokens(
117117 slot : int ,
118118 slot_max_length : int ,
119119 result_tokens : ResultTokens ,
120- vocab : Vocabulary ,
120+ eos_id : int ,
121+ pad_id : int ,
121122 complete : np .ndarray ,
122123 debug : bool = False ,
123124) -> Tuple [List [List [int ]], np .ndarray ]:
@@ -128,7 +129,8 @@ def process_result_tokens(
128129 slot: The slot at which to draw tokens from.
129130 slot_max_length: Max length for a sample in the slot.
130131 result_tokens: The tokens to access by slot.
131- vocab: For the detokenizer.
132+ eos_id: Id for EOS token.
133+ pad_id: Id for pad token.
132134 complete: Array representing the completion status of each sample in the
133135 slot.
134136 debug: Whether to log step by step detokenisation.
@@ -143,7 +145,7 @@ def process_result_tokens(
143145 slot_valid = slot_data .valid
144146 slot_lengths = slot_data .lengths
145147 samples , speculations = slot_tokens .shape
146- stop_tokens = [vocab . eos_id , vocab . pad_id ]
148+ stop_tokens = [eos_id , pad_id ]
147149 # Stop anything which has reached it's max length.
148150 complete = complete | (slot_lengths > slot_max_length )
149151 if debug :
@@ -212,11 +214,9 @@ def encode(
212214 self , s : str , ** kwargs
213215 ) -> Tuple [Union [jax .Array , np .ndarray ], int ]:
214216 """Tokenize a string.
215-
216217 Args:
217218 s: String to tokenize.
218219 **kwargs: Additional keyword arguments
219-
220220 Returns:
221221 tokens: Tokenized into integers.
222222 true_length: Actual length of the non-padded sequence
@@ -225,13 +225,18 @@ def encode(
225225 is_bos = kwargs .pop ("is_bos" , True )
226226 prefill_lengths = kwargs .pop ("prefill_lengths" , None )
227227 max_prefill_length = kwargs .pop ("max_prefill_length" , None )
228+ jax_padding = kwargs .pop ("jax_padding" , True )
229+
230+ tokens = np .array (self .vocab .encode_tf (s ))
228231
229- tokens , true_length = tokenize_and_pad (
230- s ,
231- self .vocab ,
232+ tokens , true_length = pad_tokens (
233+ tokens ,
234+ self .bos_id ,
235+ self .pad_id ,
232236 is_bos = is_bos ,
233237 prefill_lengths = prefill_lengths ,
234238 max_prefill_length = max_prefill_length ,
239+ jax_padding = jax_padding ,
235240 )
236241 return tokens , true_length
237242
@@ -245,15 +250,13 @@ def decode(
245250 ) -> Tuple [List [List [int ]], np .ndarray ]:
246251 """Processes a result tokens into a list of strings, handling multiple
247252 samples.
248-
249253 Args:
250254 slot: The slot at which to draw tokens from.
251255 slot_max_length: Max length for a sample in the slot.
252256 result_tokens: The tokens to access by slot.
253257 complete: Array representing the completion status of each sample in the
254258 slot.
255259 kwargs: Additional keyword arguments.
256-
257260 Returns:
258261 sample_return: List of strings, one per sample.
259262 complete: Updated complete.
@@ -263,12 +266,22 @@ def decode(
263266 slot = slot ,
264267 slot_max_length = slot_max_length ,
265268 result_tokens = result_tokens ,
266- vocab = self .vocab ,
269+ eos_id = self .eos_id ,
270+ pad_id = self .pad_id ,
267271 complete = complete ,
268272 debug = debug ,
269273 )
270274 return results , complete
271275
276+ def decode_str (self , token_ids : list [int ]) -> str :
277+ """Processess input token ids to generate a string.
278+ Args:
279+ token_ids: List of token ids.
280+ Returns:
281+ str: String generated from the token ids.
282+ """
283+ return self .vocab .tokenizer .decode (token_ids )
284+
272285 @property
273286 def pad_id (self ) -> int :
274287 """ID of the pad token."""
@@ -278,3 +291,102 @@ def pad_id(self) -> int:
278291 def eos_id (self ) -> int :
279292 """ID of EOS token."""
280293 return self .vocab .eos_id
294+
295+ @property
296+ def bos_id (self ) -> int :
297+ """ID of the BOS token."""
298+ return self .vocab .bos_id
299+
300+
301+ class TikToken (tokenizer_api .Tokenizer ):
302+ """Tokenizer to convert strings to token ids and vice-versa."""
303+
304+ def __init__ (self , metadata : tokenizer_pb2 .TokenizerParameters ):
305+ self .tokenizer = llama3_tokenizer .Tokenizer (metadata .path )
306+
307+ def encode (
308+ self , s : str , ** kwargs
309+ ) -> Tuple [Union [jax .Array , np .ndarray ], int ]:
310+ """Tokenize a string.
311+ Args:
312+ s: String to tokenize.
313+ **kwargs: Additional keyword arguments
314+ Returns:
315+ tokens: Tokenized into integers.
316+ true_length: Actual length of the non-padded sequence
317+ if padding is used.
318+ """
319+ is_bos = kwargs .pop ("is_bos" , True )
320+ prefill_lengths = kwargs .pop ("prefill_lengths" , None )
321+ max_prefill_length = kwargs .pop ("max_prefill_length" , None )
322+ jax_padding = kwargs .pop ("jax_padding" , True )
323+
324+ tokens = np .array (self .tokenizer .encode (s , bos = False , eos = False ))
325+
326+ tokens , true_length = pad_tokens (
327+ tokens ,
328+ self .bos_id ,
329+ self .pad_id ,
330+ is_bos = is_bos ,
331+ prefill_lengths = prefill_lengths ,
332+ max_prefill_length = max_prefill_length ,
333+ jax_padding = jax_padding ,
334+ )
335+ return tokens , true_length
336+
337+ def decode (
338+ self ,
339+ slot : int ,
340+ slot_max_length : int ,
341+ result_tokens : ResultTokens ,
342+ complete : np .ndarray ,
343+ ** kwargs ,
344+ ) -> Tuple [List [List [int ]], np .ndarray ]:
345+ """Processes a result tokens into a list of strings, handling multiple
346+ samples.
347+ Args:
348+ slot: The slot at which to draw tokens from.
349+ slot_max_length: Max length for a sample in the slot.
350+ result_tokens: The tokens to access by slot.
351+ complete: Array representing the completion status of each sample in the
352+ slot.
353+ kwargs: Additional keyword arguments.
354+ Returns:
355+ sample_return: List of strings, one per sample.
356+ complete: Updated complete.
357+ """
358+ debug = kwargs .pop ("debug" , False )
359+ results , complete = process_result_tokens (
360+ slot = slot ,
361+ slot_max_length = slot_max_length ,
362+ result_tokens = result_tokens ,
363+ eos_id = self .eos_id ,
364+ pad_id = self .pad_id ,
365+ complete = complete ,
366+ debug = debug ,
367+ )
368+ return results , complete
369+
370+ def decode_str (self , token_ids : list [int ]) -> str :
371+ """Processess input token ids to generate a string.
372+ Args:
373+ token_ids: List of token ids.
374+ Returns:
375+ str: String generated from the token ids.
376+ """
377+ return self .tokenizer .decode (token_ids )
378+
379+ @property
380+ def pad_id (self ) -> int :
381+ """ID of the pad token."""
382+ return self .tokenizer .pad_id
383+
384+ @property
385+ def eos_id (self ) -> int :
386+ """ID of EOS token."""
387+ return self .tokenizer .eos_id
388+
389+ @property
390+ def bos_id (self ) -> int :
391+ """ID of the BOS token."""
392+ return self .tokenizer .bos_id
0 commit comments