Skip to content

Commit 2db6c14

Browse files
authored
Support llama3 tokenizer (#67)
* Support llama3 tokenizer * Add tiktoken to requirements * Add blobfile to requirements * Fix unit tests * Fix linting issues * Fix pytype errors * Move llama3 tokenizer to third_party directory * Fix pytype error * Update pytype command
1 parent f8fc0a0 commit 2db6c14

11 files changed

Lines changed: 128676 additions & 53 deletions

File tree

.github/workflows/unit_tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ jobs:
5050
pip install -r benchmarks/requirements.in
5151
- name: Typecheck the code with pytype
5252
run: |
53-
pytype --jobs auto --disable import-error --disable module-attr jetstream/ benchmarks/
53+
pytype --jobs auto --disable=import-error,module-attr jetstream/ benchmarks/
5454
- name: Analysing the code with pylint
5555
run: |
5656
pylint jetstream/ benchmarks/

jetstream/engine/token_utils.py

Lines changed: 130 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from jetstream.engine import mock_utils
2828
from jetstream.engine import tokenizer_api
2929
from jetstream.engine import tokenizer_pb2
30+
from jetstream.third_party.llama3 import llama3_tokenizer
3031

3132
# ResultToken class to store tokens ids.
3233
ResultTokens = Any
@@ -40,9 +41,10 @@ def take_nearest_length(lengths: list[int], length: int) -> int:
4041
return lengths[pos]
4142

4243

43-
def tokenize_and_pad(
44-
s: str,
45-
vocab: Vocabulary,
44+
def pad_tokens(
45+
tokens: np.ndarray,
46+
bos_id: int,
47+
pad_id: int,
4648
is_bos: bool = True,
4749
prefill_lengths: Optional[List[int]] = None,
4850
max_prefill_length: Optional[int] = None,
@@ -84,14 +86,13 @@ def tokenize_and_pad(
8486
] + [
8587
max_prefill_length,
8688
]
87-
tokens = np.array(vocab.encode_tf(s)) # [Length]
8889
# Add a beginning of sequence token if this is the beginning.
8990
if is_bos:
9091
tokens = np.concatenate(
9192
[
9293
np.array(
9394
[
94-
vocab.bos_id,
95+
bos_id,
9596
]
9697
),
9798
tokens,
@@ -101,13 +102,12 @@ def tokenize_and_pad(
101102
true_length = tokens.shape[-1]
102103
padded_length = take_nearest_length(prefill_lengths, true_length)
103104
padding = padded_length - true_length
104-
assert vocab.pad_id == 0, "Further logic required if pad_id not 0."
105105
if padding < 0:
106106
logging.warning("Provided sequence longer than available.")
107107
# Take the last N tokens if we have too many.
108108
padded_tokens = tokens[-padded_length:]
109109
else:
110-
padded_tokens = np.pad(tokens, (0, padding))
110+
padded_tokens = np.pad(tokens, (0, padding), constant_values=(pad_id,))
111111
if jax_padding:
112112
padded_tokens = jnp.array(padded_tokens)
113113
return padded_tokens, true_length
@@ -117,7 +117,8 @@ def process_result_tokens(
117117
slot: int,
118118
slot_max_length: int,
119119
result_tokens: ResultTokens,
120-
vocab: Vocabulary,
120+
eos_id: int,
121+
pad_id: int,
121122
complete: np.ndarray,
122123
debug: bool = False,
123124
) -> Tuple[List[List[int]], np.ndarray]:
@@ -128,7 +129,8 @@ def process_result_tokens(
128129
slot: The slot at which to draw tokens from.
129130
slot_max_length: Max length for a sample in the slot.
130131
result_tokens: The tokens to access by slot.
131-
vocab: For the detokenizer.
132+
eos_id: Id for EOS token.
133+
pad_id: Id for pad token.
132134
complete: Array representing the completion status of each sample in the
133135
slot.
134136
debug: Whether to log step by step detokenisation.
@@ -143,7 +145,7 @@ def process_result_tokens(
143145
slot_valid = slot_data.valid
144146
slot_lengths = slot_data.lengths
145147
samples, speculations = slot_tokens.shape
146-
stop_tokens = [vocab.eos_id, vocab.pad_id]
148+
stop_tokens = [eos_id, pad_id]
147149
# Stop anything which has reached it's max length.
148150
complete = complete | (slot_lengths > slot_max_length)
149151
if debug:
@@ -212,11 +214,9 @@ def encode(
212214
self, s: str, **kwargs
213215
) -> Tuple[Union[jax.Array, np.ndarray], int]:
214216
"""Tokenize a string.
215-
216217
Args:
217218
s: String to tokenize.
218219
**kwargs: Additional keyword arguments
219-
220220
Returns:
221221
tokens: Tokenized into integers.
222222
true_length: Actual length of the non-padded sequence
@@ -225,13 +225,18 @@ def encode(
225225
is_bos = kwargs.pop("is_bos", True)
226226
prefill_lengths = kwargs.pop("prefill_lengths", None)
227227
max_prefill_length = kwargs.pop("max_prefill_length", None)
228+
jax_padding = kwargs.pop("jax_padding", True)
229+
230+
tokens = np.array(self.vocab.encode_tf(s))
228231

229-
tokens, true_length = tokenize_and_pad(
230-
s,
231-
self.vocab,
232+
tokens, true_length = pad_tokens(
233+
tokens,
234+
self.bos_id,
235+
self.pad_id,
232236
is_bos=is_bos,
233237
prefill_lengths=prefill_lengths,
234238
max_prefill_length=max_prefill_length,
239+
jax_padding=jax_padding,
235240
)
236241
return tokens, true_length
237242

@@ -245,15 +250,13 @@ def decode(
245250
) -> Tuple[List[List[int]], np.ndarray]:
246251
"""Processes a result tokens into a list of strings, handling multiple
247252
samples.
248-
249253
Args:
250254
slot: The slot at which to draw tokens from.
251255
slot_max_length: Max length for a sample in the slot.
252256
result_tokens: The tokens to access by slot.
253257
complete: Array representing the completion status of each sample in the
254258
slot.
255259
kwargs: Additional keyword arguments.
256-
257260
Returns:
258261
sample_return: List of strings, one per sample.
259262
complete: Updated complete.
@@ -263,12 +266,22 @@ def decode(
263266
slot=slot,
264267
slot_max_length=slot_max_length,
265268
result_tokens=result_tokens,
266-
vocab=self.vocab,
269+
eos_id=self.eos_id,
270+
pad_id=self.pad_id,
267271
complete=complete,
268272
debug=debug,
269273
)
270274
return results, complete
271275

276+
def decode_str(self, token_ids: list[int]) -> str:
277+
"""Processess input token ids to generate a string.
278+
Args:
279+
token_ids: List of token ids.
280+
Returns:
281+
str: String generated from the token ids.
282+
"""
283+
return self.vocab.tokenizer.decode(token_ids)
284+
272285
@property
273286
def pad_id(self) -> int:
274287
"""ID of the pad token."""
@@ -278,3 +291,102 @@ def pad_id(self) -> int:
278291
def eos_id(self) -> int:
279292
"""ID of EOS token."""
280293
return self.vocab.eos_id
294+
295+
@property
296+
def bos_id(self) -> int:
297+
"""ID of the BOS token."""
298+
return self.vocab.bos_id
299+
300+
301+
class TikToken(tokenizer_api.Tokenizer):
302+
"""Tokenizer to convert strings to token ids and vice-versa."""
303+
304+
def __init__(self, metadata: tokenizer_pb2.TokenizerParameters):
305+
self.tokenizer = llama3_tokenizer.Tokenizer(metadata.path)
306+
307+
def encode(
308+
self, s: str, **kwargs
309+
) -> Tuple[Union[jax.Array, np.ndarray], int]:
310+
"""Tokenize a string.
311+
Args:
312+
s: String to tokenize.
313+
**kwargs: Additional keyword arguments
314+
Returns:
315+
tokens: Tokenized into integers.
316+
true_length: Actual length of the non-padded sequence
317+
if padding is used.
318+
"""
319+
is_bos = kwargs.pop("is_bos", True)
320+
prefill_lengths = kwargs.pop("prefill_lengths", None)
321+
max_prefill_length = kwargs.pop("max_prefill_length", None)
322+
jax_padding = kwargs.pop("jax_padding", True)
323+
324+
tokens = np.array(self.tokenizer.encode(s, bos=False, eos=False))
325+
326+
tokens, true_length = pad_tokens(
327+
tokens,
328+
self.bos_id,
329+
self.pad_id,
330+
is_bos=is_bos,
331+
prefill_lengths=prefill_lengths,
332+
max_prefill_length=max_prefill_length,
333+
jax_padding=jax_padding,
334+
)
335+
return tokens, true_length
336+
337+
def decode(
338+
self,
339+
slot: int,
340+
slot_max_length: int,
341+
result_tokens: ResultTokens,
342+
complete: np.ndarray,
343+
**kwargs,
344+
) -> Tuple[List[List[int]], np.ndarray]:
345+
"""Processes a result tokens into a list of strings, handling multiple
346+
samples.
347+
Args:
348+
slot: The slot at which to draw tokens from.
349+
slot_max_length: Max length for a sample in the slot.
350+
result_tokens: The tokens to access by slot.
351+
complete: Array representing the completion status of each sample in the
352+
slot.
353+
kwargs: Additional keyword arguments.
354+
Returns:
355+
sample_return: List of strings, one per sample.
356+
complete: Updated complete.
357+
"""
358+
debug = kwargs.pop("debug", False)
359+
results, complete = process_result_tokens(
360+
slot=slot,
361+
slot_max_length=slot_max_length,
362+
result_tokens=result_tokens,
363+
eos_id=self.eos_id,
364+
pad_id=self.pad_id,
365+
complete=complete,
366+
debug=debug,
367+
)
368+
return results, complete
369+
370+
def decode_str(self, token_ids: list[int]) -> str:
371+
"""Processess input token ids to generate a string.
372+
Args:
373+
token_ids: List of token ids.
374+
Returns:
375+
str: String generated from the token ids.
376+
"""
377+
return self.tokenizer.decode(token_ids)
378+
379+
@property
380+
def pad_id(self) -> int:
381+
"""ID of the pad token."""
382+
return self.tokenizer.pad_id
383+
384+
@property
385+
def eos_id(self) -> int:
386+
"""ID of EOS token."""
387+
return self.tokenizer.eos_id
388+
389+
@property
390+
def bos_id(self) -> int:
391+
"""ID of the BOS token."""
392+
return self.tokenizer.bos_id

jetstream/engine/tokenizer_api.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,9 @@ def encode(
3232
self, s: str, **kwargs
3333
) -> Tuple[Union[jax.Array, np.ndarray], int]:
3434
"""Tokenize a string.
35-
3635
Args:
3736
s: String to tokenize.
3837
**kwargs: Additional keyword arguments
39-
4038
Returns:
4139
tokens: Tokenized into integers.
4240
true_length: Actual length of the non-padded sequence
@@ -54,20 +52,26 @@ def decode(
5452
) -> Tuple[list[list[int]], np.ndarray]:
5553
"""Processes a result tokens into a list of token ids, handling multiple
5654
samples.
57-
5855
Args:
5956
slot: The slot at which to draw tokens from.
6057
slot_max_length: Max length for a sample in the slot.
6158
result_tokens: The tokens to access by slot.
6259
complete: Array representing the completion status of each sample in the
6360
slot.
6461
**kwards: Additional keyword arguments.
65-
6662
Returns:
67-
sample_return: List of strings, one per sample.
63+
sample_return: List of token_ids, one per sample.
6864
complete: Updated complete.
6965
"""
70-
# TODO(bbahl): Add an option to return str from decode.
66+
67+
@abc.abstractmethod
68+
def decode_str(self, token_ids: list[int]) -> str:
69+
"""Processess input token ids to generate a string.
70+
Args:
71+
token_ids: List of token ids.
72+
Returns:
73+
str: String generated from the token ids.
74+
"""
7175

7276
@property
7377
@abc.abstractmethod
@@ -78,3 +82,8 @@ def pad_id(self) -> int:
7882
@abc.abstractmethod
7983
def eos_id(self) -> int:
8084
"""ID of EOS token."""
85+
86+
@property
87+
@abc.abstractmethod
88+
def bos_id(self) -> int:
89+
"""ID of BOS token."""

jetstream/tests/engine/test_mock_engine.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def _prefill(self):
5252
# A 2 will be pre-pended as 'bos' token from the vocab.
5353
text = "AB"
5454
metadata = engine.get_tokenizer()
55-
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
56-
tokens, true_length = token_utils.tokenize_and_pad(text, vocab, is_bos=True)
55+
tokenizer = engine.build_tokenizer(metadata)
56+
tokens, true_length = tokenizer.encode(text, is_bos=True)
5757
prefill_result = engine.prefill(
5858
params=params, padded_tokens=tokens, true_length=3
5959
)
@@ -65,10 +65,8 @@ def _prefill_np(self):
6565
# A 2 will be pre-pended as 'bos' token from the vocab.
6666
text = "AB"
6767
metadata = engine.get_tokenizer()
68-
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
69-
tokens, true_length = token_utils.tokenize_and_pad(
70-
text, vocab, is_bos=True, jax_padding=False
71-
)
68+
tokenizer = engine.build_tokenizer(metadata)
69+
tokens, true_length = tokenizer.encode(text, is_bos=True, jax_padding=False)
7270
prefill_result = engine.prefill(
7371
params=params, padded_tokens=tokens, true_length=3
7472
)

0 commit comments

Comments
 (0)