Skip to content

Commit f6751d2

Browse files
authored
Add an abstract class for Tokenizer (#53)
* Add an abstract class for tokenizer * Add sentence piece tokenizer as a subclass of Tokenizer * Fix decode method for SentencePieceTokenizer * Fix circular import issue * fix type annotations * fix linting issues * Format files using pyink * Update the tokenizer decode interface to return ids instead of str * format using pyink * Move Tokenizer class to a tokenizer_api.py file * Update engine.build_tokenizer method to return SentencePieceTokenizer by default
1 parent 3fdacc8 commit f6751d2

4 files changed

Lines changed: 180 additions & 12 deletions

File tree

jetstream/core/orchestrator.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@
9393
from jetstream.core.proto import jetstream_pb2_grpc
9494
from jetstream.core.utils import async_multifuture
9595
from jetstream.engine import engine_api
96-
from jetstream.engine import token_utils
9796
import numpy as np
9897

9998

@@ -397,7 +396,7 @@ def _prefill_thread(self, idx: int):
397396
prefill_engine = self._prefill_engines[idx]
398397
prefill_params = self._prefill_params[idx]
399398
metadata = prefill_engine.get_tokenizer()
400-
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
399+
tokenizer = prefill_engine.build_tokenizer(metadata)
401400
logging.info("---------Prefill params %d loaded.---------", idx)
402401

403402
while self.live:
@@ -429,9 +428,8 @@ def _prefill_thread(self, idx: int):
429428
is_bos,
430429
request.history_path,
431430
)
432-
padded_tokens, true_length = token_utils.tokenize_and_pad(
431+
padded_tokens, true_length = tokenizer.encode(
433432
request.prefill_text,
434-
vocab,
435433
is_bos=is_bos,
436434
max_prefill_length=prefill_engine.max_prefill_length,
437435
jax_padding=self._jax_padding,
@@ -568,8 +566,7 @@ def _detokenize_thread(self, idx: int):
568566
my_slots = self._generate_slots[idx]
569567

570568
metadata = my_generate_engine.get_tokenizer()
571-
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
572-
569+
tokenizer = my_generate_engine.build_tokenizer(metadata)
573570
my_live_requests = {
574571
i: None for i in range(my_generate_engine.max_concurrent_decodes)
575572
}
@@ -587,11 +584,10 @@ def _detokenize_thread(self, idx: int):
587584

588585
for slot, request in my_live_requests.items():
589586
if request is not None:
590-
results, complete = token_utils.process_result_tokens(
587+
results, complete = tokenizer.decode(
591588
slot=slot,
592589
slot_max_length=request.max_tokens,
593590
result_tokens=result_tokens,
594-
vocab=vocab,
595591
complete=request.complete,
596592
)
597593
request.complete = complete

jetstream/engine/engine_api.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import numpy as np
2727

2828
from jetstream.engine import tokenizer_pb2
29+
from jetstream.engine import token_utils
2930

3031

3132
# The model parameters - their partitioning will be unique for different prefill
@@ -39,6 +40,8 @@
3940
DeviceTokens = Any
4041
# Cpus asscociated with the mesh.
4142
CpuDevices = Any
43+
# Tokenkizer used by the engine
44+
Tokenizer = Any
4245

4346

4447
@struct.dataclass
@@ -200,7 +203,14 @@ def get_prefix_destination_sharding(self) -> Any:
200203
def get_tokenizer(
201204
self,
202205
) -> tokenizer_pb2.TokenizerParameters:
203-
"""Returns the info to construct a sentencepiece tokenizer in py/c++."""
206+
"""Returns the info to construct a tokenizer in py/c++."""
207+
208+
def build_tokenizer(
209+
self,
210+
metadata: tokenizer_pb2.TokenizerParameters,
211+
) -> Tokenizer:
212+
"""Builds a new tokenizer object and returns it."""
213+
return token_utils.SentencePieceTokenizer(metadata)
204214

205215
@abc.abstractmethod
206216
def init_decode_state(self, *args, **kwargs) -> DecodeState:

jetstream/engine/token_utils.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,20 @@
1616

1717
from bisect import bisect_left
1818
import logging
19-
from typing import List, Optional, Tuple, Union
19+
from typing import Any, List, Optional, Tuple, Union
2020

2121
import jax
2222
import jax.numpy as jnp
2323
import numpy as np
2424
from seqio.vocabularies import SentencePieceVocabulary
2525
from seqio.vocabularies import Vocabulary
2626

27-
from jetstream.engine import engine_api
2827
from jetstream.engine import mock_utils
28+
from jetstream.engine import tokenizer_api
29+
from jetstream.engine import tokenizer_pb2
30+
31+
# ResultToken class to store tokens ids.
32+
ResultTokens = Any
2933

3034

3135
def take_nearest_length(lengths: list[int], length: int) -> int:
@@ -112,7 +116,7 @@ def tokenize_and_pad(
112116
def process_result_tokens(
113117
slot: int,
114118
slot_max_length: int,
115-
result_tokens: engine_api.ResultTokens,
119+
result_tokens: ResultTokens,
116120
vocab: Vocabulary,
117121
complete: np.ndarray,
118122
debug: bool = False,
@@ -196,3 +200,81 @@ def load_vocab(path: str, extra_ids: int = 0) -> Vocabulary:
196200
sp_model = vocab.sp_model
197201
del sp_model
198202
return vocab
203+
204+
205+
class SentencePieceTokenizer(tokenizer_api.Tokenizer):
206+
"""Tokenizer to convert strings to token ids and vice-versa."""
207+
208+
def __init__(self, metadata: tokenizer_pb2.TokenizerParameters):
209+
self.vocab = load_vocab(metadata.path, metadata.extra_ids)
210+
211+
def encode(
212+
self, s: str, **kwargs
213+
) -> Tuple[Union[jax.Array, np.ndarray], int]:
214+
"""Tokenize a string.
215+
216+
Args:
217+
s: String to tokenize.
218+
**kwargs: Additional keyword arguments
219+
220+
Returns:
221+
tokens: Tokenized into integers.
222+
true_length: Actual length of the non-padded sequence
223+
if padding is used.
224+
"""
225+
is_bos = kwargs.pop("is_bos", True)
226+
prefill_lengths = kwargs.pop("prefill_lengths", None)
227+
max_prefill_length = kwargs.pop("max_prefill_length", None)
228+
229+
tokens, true_length = tokenize_and_pad(
230+
s,
231+
self.vocab,
232+
is_bos=is_bos,
233+
prefill_lengths=prefill_lengths,
234+
max_prefill_length=max_prefill_length,
235+
)
236+
return tokens, true_length
237+
238+
def decode(
239+
self,
240+
slot: int,
241+
slot_max_length: int,
242+
result_tokens: ResultTokens,
243+
complete: np.ndarray,
244+
**kwargs,
245+
) -> Tuple[List[List[int]], np.ndarray]:
246+
"""Processes a result tokens into a list of strings, handling multiple
247+
samples.
248+
249+
Args:
250+
slot: The slot at which to draw tokens from.
251+
slot_max_length: Max length for a sample in the slot.
252+
result_tokens: The tokens to access by slot.
253+
complete: Array representing the completion status of each sample in the
254+
slot.
255+
kwargs: Additional keyword arguments.
256+
257+
Returns:
258+
sample_return: List of strings, one per sample.
259+
complete: Updated complete.
260+
"""
261+
debug = kwargs.pop("debug", False)
262+
results, complete = process_result_tokens(
263+
slot=slot,
264+
slot_max_length=slot_max_length,
265+
result_tokens=result_tokens,
266+
vocab=self.vocab,
267+
complete=complete,
268+
debug=debug,
269+
)
270+
return results, complete
271+
272+
@property
273+
def pad_id(self) -> int:
274+
"""ID of the pad token."""
275+
return self.vocab.pad_id
276+
277+
@property
278+
def eos_id(self) -> int:
279+
"""ID of EOS token."""
280+
return self.vocab.eos_id

jetstream/engine/tokenizer_api.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Defines the JetStream Tokenizer API."""
16+
17+
import abc
18+
from typing import Any, Tuple, Union
19+
20+
import numpy as np
21+
import jax
22+
23+
# Class to store token ids.
24+
ResultTokens = Any
25+
26+
27+
class Tokenizer(abc.ABC):
28+
"""Tokenizer to convert strings to token ids and vice-versa."""
29+
30+
@abc.abstractmethod
31+
def encode(
32+
self, s: str, **kwargs
33+
) -> Tuple[Union[jax.Array, np.ndarray], int]:
34+
"""Tokenize a string.
35+
36+
Args:
37+
s: String to tokenize.
38+
**kwargs: Additional keyword arguments
39+
40+
Returns:
41+
tokens: Tokenized into integers.
42+
true_length: Actual length of the non-padded sequence
43+
if padding is used.
44+
"""
45+
46+
@abc.abstractmethod
47+
def decode(
48+
self,
49+
slot: int,
50+
slot_max_length: int,
51+
result_tokens: ResultTokens,
52+
complete: np.ndarray,
53+
**kwargs,
54+
) -> Tuple[list[list[int]], np.ndarray]:
55+
"""Processes a result tokens into a list of token ids, handling multiple
56+
samples.
57+
58+
Args:
59+
slot: The slot at which to draw tokens from.
60+
slot_max_length: Max length for a sample in the slot.
61+
result_tokens: The tokens to access by slot.
62+
complete: Array representing the completion status of each sample in the
63+
slot.
64+
**kwards: Additional keyword arguments.
65+
66+
Returns:
67+
sample_return: List of strings, one per sample.
68+
complete: Updated complete.
69+
"""
70+
# TODO(bbahl): Add an option to return str from decode.
71+
72+
@property
73+
@abc.abstractmethod
74+
def pad_id(self) -> int:
75+
"""ID of the pad token."""
76+
77+
@property
78+
@abc.abstractmethod
79+
def eos_id(self) -> int:
80+
"""ID of EOS token."""

0 commit comments

Comments
 (0)