Skip to content

Commit a0df320

Browse files
authored
Align Tokenizer in JetStream (#40)
* Align Tokenizer in JetStream * Update requirements with pytest dep * Remove mix_decode unit test
1 parent f6f9b06 commit a0df320

16 files changed

Lines changed: 136 additions & 143 deletions

.github/workflows/unit_tests.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ jobs:
4747
pip install pylint
4848
pip install pyink
4949
pip install -r requirements.txt
50+
pip install -r benchmarks/requirements.in
5051
- name: Typecheck the code with pytype
5152
run: |
52-
pytype --jobs auto --disable import-error --disable module-attr jetstream/
53+
pytype --jobs auto --disable import-error --disable module-attr jetstream/ benchmarks/
5354
- name: Analysing the code with pylint
5455
run: |
5556
pylint jetstream/ benchmarks/

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,15 @@ python -m jetstream.tools.load_tester
5757
### Test core modules
5858
```
5959
# Test JetStream core orchestrator
60-
python -m jetstream.core.orchestrator_test
60+
python -m jetstream.tests.core.test_orchestrator
6161
6262
# Test JetStream core server library
63-
python -m jetstream.core.server_test
63+
python -m jetstream.tests.core.test_server
6464
6565
# Test mock JetStream engine implementation
66-
python -m jetstream.engine.mock_engine_test
66+
python -m jetstream.tests.engine.test_mock_engine
6767
6868
# Test mock JetStream token utils
69-
python -m jetstream.engine.utils_test
69+
python -m jetstream.tests.engine.test_utils
7070
7171
```

benchmarks/benchmark_serving.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,14 @@
6565
import json
6666
import random
6767
import time
68-
from typing import Any, AsyncGenerator, List, Optional
68+
from typing import Any, AsyncGenerator, Optional
6969

7070
import grpc
7171
from jetstream.core.proto import jetstream_pb2
7272
from jetstream.core.proto import jetstream_pb2_grpc
73+
from jetstream.engine.token_utils import load_vocab
7374
import numpy as np
74-
import tensorflow as tf
75-
import tensorflow_text as tftxt
76-
from tqdm.asyncio import tqdm
75+
from tqdm.asyncio import tqdm # pytype: disable=pyi-error
7776
from eval_accuracy import eval_accuracy
7877

7978

@@ -106,9 +105,9 @@ class InputRequest:
106105

107106
@dataclass
108107
class RequestFuncOutput:
109-
input_request: InputRequest = None
110-
generated_token_list: list[str] = None
111-
generated_text: str = None
108+
input_request: Optional[InputRequest] = None
109+
generated_token_list: list[str] = []
110+
generated_text: str = ""
112111
success: bool = False
113112
latency: float = 0
114113
ttft: float = 0
@@ -132,18 +131,16 @@ def get_tokenizer(tokenizer_name: str) -> Any:
132131
if tokenizer_name == "test":
133132
return "test"
134133
else:
135-
with tf.io.gfile.GFile(tokenizer_name, "rb") as model_fp:
136-
sp_model = model_fp.read()
137-
sp_tokenizer = tftxt.SentencepieceTokenizer(
138-
model=sp_model, add_bos=True, add_eos=False, reverse=False
139-
)
140-
return sp_tokenizer
134+
# Use JetStream tokenizer util. It's using the sentencepiece wrapper in
135+
# seqio library.
136+
vocab = load_vocab(tokenizer_name)
137+
return vocab.tokenizer
141138

142139

143140
def load_sharegpt_dataset(
144141
dataset_path: str,
145142
conversation_starter: str,
146-
) -> List[tuple[str]]:
143+
) -> list[tuple[Any, Any]]:
147144
# Load the dataset.
148145
with open(dataset_path, "r", encoding="utf-8") as f:
149146
dataset = json.load(f)
@@ -166,7 +163,7 @@ def load_sharegpt_dataset(
166163
return dataset
167164

168165

169-
def load_openorca_dataset(dataset_path: str) -> List[tuple[str]]:
166+
def load_openorca_dataset(dataset_path: str) -> list[tuple[Any, Any]]:
170167
# Load the dataset.
171168
with open(dataset_path, "r", encoding="utf-8") as f:
172169
dataset = json.load(f)
@@ -179,9 +176,9 @@ def load_openorca_dataset(dataset_path: str) -> List[tuple[str]]:
179176

180177

181178
def tokenize_dataset(
182-
dataset: List[tuple[str]],
179+
dataset: list[tuple[Any, Any, Any]],
183180
tokenizer: Any,
184-
) -> List[tuple[Any]]:
181+
) -> list[tuple[str, Any, str, int, int, int]]:
185182

186183
n = len(dataset)
187184

@@ -194,10 +191,10 @@ def tokenize_dataset(
194191
outputs.append(output)
195192
indices.append(idx)
196193

197-
prompt_token_ids = tokenizer.tokenize(
194+
prompt_token_ids = tokenizer.encode(
198195
prompts
199196
) # adjust this code based on tokenizer method
200-
outputs_token_ids = tokenizer.tokenize(
197+
outputs_token_ids = tokenizer.encode(
201198
outputs
202199
) # adjust this code based on tokenizer method
203200

@@ -218,8 +215,9 @@ def tokenize_dataset(
218215

219216

220217
def filter_dataset(
221-
tokenized_dataset: List[tuple[Any]], max_output_length: Optional[int] = None
222-
) -> List[InputRequest]:
218+
tokenized_dataset: list[tuple[str, Any, str, int, int, int]],
219+
max_output_length: Optional[int] = None,
220+
) -> list[InputRequest]:
223221
if max_output_length is None:
224222
print("In InputRequest, pass in actual output_length for each sample")
225223
else:
@@ -229,7 +227,7 @@ def filter_dataset(
229227
)
230228

231229
# Filter out too long sequences.
232-
filtered_dataset: List[InputRequest] = []
230+
filtered_dataset: list[InputRequest] = []
233231
for (
234232
prompt,
235233
_,
@@ -258,12 +256,12 @@ def filter_dataset(
258256

259257

260258
def sample_requests(
261-
dataset: List[tuple[str]],
259+
dataset: list[tuple[Any, Any]],
262260
tokenizer: Any,
263261
num_requests: int,
264262
max_output_length: Optional[int] = None,
265263
oversample_multiplier: float = 1.2,
266-
) -> List[InputRequest]:
264+
) -> list[InputRequest]:
267265

268266
# Original dataset size
269267
n = len(dataset)
@@ -304,7 +302,7 @@ def sample_requests(
304302

305303

306304
async def get_request(
307-
input_requests: List[InputRequest],
305+
input_requests: list[InputRequest],
308306
request_rate: float,
309307
) -> AsyncGenerator[InputRequest, None]:
310308
input_requests = iter(input_requests)
@@ -321,8 +319,8 @@ async def get_request(
321319

322320

323321
def calculate_metrics(
324-
input_requests: List[InputRequest],
325-
outputs: List[RequestFuncOutput],
322+
input_requests: list[InputRequest],
323+
outputs: list[RequestFuncOutput],
326324
dur_s: float,
327325
tokenizer: Any,
328326
) -> BenchmarkMetrics:
@@ -374,16 +372,17 @@ async def grpc_async_request(
374372
token_list = []
375373
request_start_time = time.perf_counter()
376374
response = stub.Decode(request)
377-
async for token in response:
375+
async for sample_list in response:
378376
if ttft == 0:
379377
ttft = time.perf_counter() - request_start_time
380-
token_list.append(token.response[0])
378+
token_list.extend(sample_list.response[0].token_ids)
381379
latency = time.perf_counter() - request_start_time
382380
return token_list, ttft, latency
383381

384382

385383
async def send_request(
386384
api_url: str,
385+
tokenizer: Any,
387386
input_request: InputRequest,
388387
pbar: tqdm,
389388
session_cache: str,
@@ -405,7 +404,8 @@ async def send_request(
405404
output.ttft = ttft
406405
output.latency = latency
407406
output.generated_token_list = generated_token_list
408-
output.generated_text = "".join(generated_token_list)
407+
# generated_token_list is a list of token ids, decode it to generated_text.
408+
output.generated_text = tokenizer.decode(generated_token_list)
409409
output.success = True
410410
if pbar:
411411
pbar.update(1)
@@ -415,7 +415,7 @@ async def send_request(
415415
async def benchmark(
416416
api_url: str,
417417
tokenizer: Any,
418-
input_requests: List[InputRequest],
418+
input_requests: list[InputRequest],
419419
request_rate: float,
420420
disable_tqdm: bool,
421421
session_cache: str,
@@ -433,6 +433,7 @@ async def benchmark(
433433
asyncio.create_task(
434434
send_request(
435435
api_url=api_url,
436+
tokenizer=tokenizer,
436437
input_request=request,
437438
pbar=pbar,
438439
session_cache=session_cache,
@@ -442,7 +443,7 @@ async def benchmark(
442443
)
443444
outputs = await asyncio.gather(*tasks)
444445

445-
if not disable_tqdm:
446+
if not disable_tqdm and pbar:
446447
pbar.close()
447448

448449
benchmark_duration = time.perf_counter() - benchmark_start_time

benchmarks/requirements.in

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
nltk
22
evaluate
3-
rouge-score
3+
rouge-score
4+
tqdm

jetstream/core/orchestrator.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class ActiveRequest:
127127
# We keep prefill and decode information together in the same object so that
128128
# there is less indirection about where this return channel is.
129129
# The return channel returns a list of strings, one per sample for that query.
130-
return_channel: async_multifuture.AsyncMultifuture[list[str]]
130+
return_channel: async_multifuture.AsyncMultifuture[list[list[int]]]
131131
# [num_samples,] which corresponds to whether each sample is complete for the
132132
# requests.
133133
complete: Optional[np.ndarray] = None
@@ -139,7 +139,7 @@ class ActiveRequest:
139139
# Which generate step this was added at.
140140
generate_timestep_added: Optional[int] = None
141141

142-
def enqueue_tokens(self, generated_tokens: list[str]):
142+
def enqueue_tokens(self, generated_tokens: list[list[int]]):
143143
"""Records information about the step.
144144
145145
Args:
@@ -662,4 +662,9 @@ async def Decode( # pylint: disable=invalid-overridden-method
662662
# The DecodeResponse stream should consume all generated tokens in
663663
# return_channel when complete signal is received. It should check if
664664
# return_channel is empty to decide if it should exit the while loop.
665-
yield jetstream_pb2.DecodeResponse(response=response)
665+
repeated_token_ids = []
666+
for token_ids in response:
667+
repeated_token_ids.append(
668+
jetstream_pb2.RepeatedTokenIds(token_ids=token_ids)
669+
)
670+
yield jetstream_pb2.DecodeResponse(response=repeated_token_ids)

jetstream/core/proto/jetstream.proto

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ message DecodeRequest {
3737
int32 max_tokens = 4;
3838
}
3939
message DecodeResponse {
40-
// List of responses, one per sample.
41-
repeated string response = 1;
40+
// List of responses, one per sample. The list size depends on text generation strategy the engine used.
41+
repeated RepeatedTokenIds response = 1;
4242
}
43+
message RepeatedTokenIds {
44+
// List of token ids, one list per sample. When speculative decoding is disabled, the list size should be 1; When speculative decoding is enabled, the list size should be >= 1.
45+
repeated int32 token_ids = 1;
46+
}

jetstream/core/proto/jetstream_pb2.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929

3030
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
31-
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"e\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x17\n\x0f\x61\x64\x64itional_text\x18\x02 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05""\n\x0e\x44\x65\x63odeResponse\x12\x10\n\x08response\x18\x01 \x03(\t2]\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x62\x06proto3'
31+
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"e\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x17\n\x0f\x61\x64\x64itional_text\x18\x02 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05"E\n\x0e\x44\x65\x63odeResponse\x12\x33\n\x08response\x18\x01 \x03(\x0b\x32!.jetstream_proto.RepeatedTokenIds"%\n\x10RepeatedTokenIds\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x32]\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x62\x06proto3'
3232
)
3333

3434
_globals = globals()
@@ -41,7 +41,9 @@
4141
_globals["_DECODEREQUEST"]._serialized_start = 57
4242
_globals["_DECODEREQUEST"]._serialized_end = 158
4343
_globals["_DECODERESPONSE"]._serialized_start = 160
44-
_globals["_DECODERESPONSE"]._serialized_end = 194
45-
_globals["_ORCHESTRATOR"]._serialized_start = 196
46-
_globals["_ORCHESTRATOR"]._serialized_end = 289
44+
_globals["_DECODERESPONSE"]._serialized_end = 229
45+
_globals["_REPEATEDTOKENIDS"]._serialized_start = 231
46+
_globals["_REPEATEDTOKENIDS"]._serialized_end = 268
47+
_globals["_ORCHESTRATOR"]._serialized_start = 270
48+
_globals["_ORCHESTRATOR"]._serialized_end = 363
4749
# @@protoc_insertion_point(module_scope)

jetstream/engine/mock_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,7 @@ def _encode(self, s: str) -> Sequence[int]:
6262

6363
def _decode(self, ids: np.ndarray):
6464
"""Converts a numpy array into a string."""
65-
# 'We use array methods, not python iterables so we don't
66-
# implement this method in the mock vocab.
67-
raise NotImplementedError
65+
return "".join([chr(r) for r in list(ids)])
6866

6967
def _encode_tf(self, s: str) -> np.ndarray:
7068
"""Converts a string into a numpy array."""
@@ -78,6 +76,10 @@ def _decode_tf(self, ids: np.ndarray) -> List[str]:
7876
results = np.split(ids, ids.shape[0])
7977
return ["".join([chr(r) for r in list(line[0])]) for line in results]
8078

79+
def decode(self, ids: np.ndarray):
80+
"""Converts a numpy array into a string."""
81+
return self._decode(ids)
82+
8183
def encode_tf(self, s: str) -> np.ndarray:
8284
"""Converts a string into a numpy array."""
8385
return self._encode_tf(s)

jetstream/engine/token_utils.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,6 @@
2828
from jetstream.engine import mock_utils
2929

3030

31-
def mix_decode(vocab: Vocabulary, tok_id: int):
32-
"""
33-
The IdToPiece and decode results differ for 344 tokens in Llama2.
34-
Use the decode function to generate the correct strings for these 344 tokens.
35-
If IdToPiece returns a hex string (e.g., '<0x0A>') for a token within these
36-
344, utilize IdToPiece to convert it into a string, likely with a space
37-
placeholder (' ') for the corresponding tokens.
38-
"""
39-
p_token = vocab.tokenizer.IdToPiece(tok_id)
40-
# SentencePiece escapes the whitespace with a meta symbol "▁" (U+2581)
41-
p_token = p_token.replace("▁", " ")
42-
d_token = vocab.tokenizer.decode([tok_id])
43-
return p_token if p_token.lstrip() == d_token else d_token
44-
45-
4631
def take_nearest_length(lengths: list[int], length: int) -> int:
4732
"""Gets the nearest length to the right in a set of lengths."""
4833
pos = bisect_left(lengths, length)
@@ -131,7 +116,7 @@ def process_result_tokens(
131116
vocab: Vocabulary,
132117
complete: np.ndarray,
133118
debug: bool = False,
134-
) -> Tuple[List[str], np.ndarray]:
119+
) -> Tuple[List[List[int]], np.ndarray]:
135120
"""Processes a result tokens into a list of strings, handling multiple
136121
samples.
137122
@@ -145,7 +130,7 @@ def process_result_tokens(
145130
debug: Whether to log step by step detokenisation.
146131
147132
Returns:
148-
sample_return: List of strings, one per sample.
133+
sample_return: List of tok_id list, one list per sample.
149134
complete: Updated complete.
150135
"""
151136
# tokens: [samples, speculations]
@@ -166,7 +151,7 @@ def process_result_tokens(
166151
)
167152
sample_return = []
168153
for idx in range(samples):
169-
string_so_far = ""
154+
tok_id_so_far = []
170155
if not complete[idx].item():
171156
for spec_idx in range(speculations):
172157
tok_id = slot_tokens[idx, spec_idx].item()
@@ -182,17 +167,8 @@ def process_result_tokens(
182167
complete[idx] = True
183168
break
184169
else:
185-
try:
186-
# pytype: disable=attribute-error
187-
token = mix_decode(vocab, tok_id)
188-
except ValueError:
189-
# This error only occurs when using tests where the vocab range is
190-
# computed via addition and int->char is computed using chr(). Real
191-
# models have vocab logits which are at max the size of the vocab.
192-
logging.warning("%d exceeded vocab range", tok_id)
193-
token = "<sampled_outside_vocab>"
194-
string_so_far += token
195-
sample_return.append(string_so_far)
170+
tok_id_so_far.append(tok_id)
171+
sample_return.append(tok_id_so_far)
196172
if debug:
197173
logging.info("Sampled return %s", str(sample_return))
198174
return sample_return, complete

0 commit comments

Comments
 (0)