Skip to content

Commit 90b2a9d

Browse files
authored
Support JetStream MaxText user guide (#28)
1 parent 426c915 commit 90b2a9d

6 files changed

Lines changed: 128 additions & 85 deletions

File tree

benchmarks/benchmark_serving.py

Lines changed: 98 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -54,22 +54,21 @@
5454
"""
5555

5656

57-
import tensorflow as tf
58-
import tensorflow_text as tftxt
59-
6057
import argparse
6158
import asyncio
62-
6359
from dataclasses import dataclass
6460
from datetime import datetime
6561
import json
6662
import random
6763
import time
6864
from typing import Any, AsyncGenerator, List, Optional
65+
6966
import grpc
7067
from jetstream.core.proto import jetstream_pb2
7168
from jetstream.core.proto import jetstream_pb2_grpc
7269
import numpy as np
70+
import tensorflow as tf
71+
import tensorflow_text as tftxt
7372
from tqdm.asyncio import tqdm
7473

7574

@@ -96,6 +95,7 @@ class InputRequest:
9695
output: str = ""
9796
output_len: int = 0
9897

98+
9999
@dataclass
100100
class RequestFuncOutput:
101101
input_request: InputRequest = None
@@ -109,12 +109,12 @@ class RequestFuncOutput:
109109
# Flatten the structure and return only the necessary results
110110
def to_dict(self):
111111
return {
112-
"prompt": self.input_request.prompt,
113-
"original_output": self.input_request.output,
114-
"generated_text": self.generated_text,
115-
"success": self.success,
116-
"latency": self.latency,
117-
"prompt_len": self.prompt_len
112+
"prompt": self.input_request.prompt,
113+
"original_output": self.input_request.output,
114+
"generated_text": self.generated_text,
115+
"success": self.success,
116+
"latency": self.latency,
117+
"prompt_len": self.prompt_len,
118118
}
119119

120120

@@ -123,12 +123,14 @@ def get_tokenizer(tokenizer_name: str) -> Any:
123123
if tokenizer_name == "test":
124124
return "test"
125125
else:
126-
with tf.io.gfile.GFile(tokenizer_name, 'rb') as model_fp:
126+
with tf.io.gfile.GFile(tokenizer_name, "rb") as model_fp:
127127
sp_model = model_fp.read()
128128
sp_tokenizer = tftxt.SentencepieceTokenizer(
129-
model=sp_model, add_bos=True, add_eos=False, reverse=False)
129+
model=sp_model, add_bos=True, add_eos=False, reverse=False
130+
)
130131
return sp_tokenizer
131132

133+
132134
def load_sharegpt_dataset(
133135
dataset_path: str,
134136
conversation_starter: str,
@@ -141,7 +143,11 @@ def load_sharegpt_dataset(
141143

142144
# Filter based on conversation starter
143145
if conversation_starter != "both":
144-
dataset = [data for data in dataset if data["conversations"][0]["from"] == conversation_starter]
146+
dataset = [
147+
data
148+
for data in dataset
149+
if data["conversations"][0]["from"] == conversation_starter
150+
]
145151
# Only keep the first two turns of each conversation.
146152
dataset = [
147153
(data["conversations"][0]["value"], data["conversations"][1]["value"])
@@ -151,9 +157,7 @@ def load_sharegpt_dataset(
151157
return dataset
152158

153159

154-
def load_openorca_dataset(
155-
dataset_path: str
156-
) -> List[tuple[str]]:
160+
def load_openorca_dataset(dataset_path: str) -> List[tuple[str]]:
157161
# Load the dataset.
158162
with open(dataset_path) as f:
159163
dataset = json.load(f)
@@ -187,23 +191,31 @@ def tokenize_dataset(
187191
prompt_len = len(prompt_token_ids[i])
188192
output_len = len(outputs_token_ids[i])
189193
tokenized_dataset.append(
190-
(prompts[i], prompt_token_ids[i], outputs[i], prompt_len, output_len)
194+
(prompts[i], prompt_token_ids[i], outputs[i], prompt_len, output_len)
191195
)
192196
return tokenized_dataset
193197

194198

195199
def filter_dataset(
196-
tokenized_dataset: List[tuple[Any]],
197-
max_output_length: Optional[int] = None
200+
tokenized_dataset: List[tuple[Any]], max_output_length: Optional[int] = None
198201
) -> List[InputRequest]:
199202
if max_output_length is None:
200203
print("In InputRequest, pass in actual output_length for each sample")
201204
else:
202-
print(f"In InputRequest, pass in max_output_length: {max_output_length} for each sample")
205+
print(
206+
f"In InputRequest, pass in max_output_length: {max_output_length} for"
207+
" each sample"
208+
)
203209

204210
# Filter out too long sequences.
205211
filtered_dataset: List[InputRequest] = []
206-
for prompt, prompt_token_ids, output, prompt_len, output_len in tokenized_dataset:
212+
for (
213+
prompt,
214+
prompt_token_ids,
215+
output,
216+
prompt_len,
217+
output_len,
218+
) in tokenized_dataset:
207219
if prompt_len < 4 or output_len < 4:
208220
# Prune too short sequences.
209221
# This is because TGI causes errors when the input or output length
@@ -212,7 +224,9 @@ def filter_dataset(
212224
if prompt_len > 1024 or prompt_len + output_len > 2048:
213225
# Prune too long sequences.
214226
continue
215-
request = InputRequest(prompt, prompt_len, output, max_output_length or output_len)
227+
request = InputRequest(
228+
prompt, prompt_len, output, max_output_length or output_len
229+
)
216230
filtered_dataset.append(request)
217231

218232
print(f"The dataset contains {len(tokenized_dataset)} samples.")
@@ -226,20 +240,26 @@ def sample_requests(
226240
tokenizer: Any,
227241
num_requests: int,
228242
max_output_length: Optional[int] = None,
229-
oversample_multiplier: float=1.2,
230-
) -> List[InputRequest]:
243+
oversample_multiplier: float = 1.2,
244+
) -> List[InputRequest]:
231245

232246
# Original dataset size
233247
n = len(dataset)
234248

235249
# Create necessary number of requests even if bigger than dataset size
236250
sampled_indices = random.sample(
237-
range(n), min(int(num_requests * oversample_multiplier), n))
251+
range(n), min(int(num_requests * oversample_multiplier), n)
252+
)
238253

239254
if num_requests > len(sampled_indices):
240-
print(f"Number of requests {num_requests} is larger than size of dataset {n}.\n",
241-
f"Repeating data to meet number of requests.\n")
242-
sampled_indices = sampled_indices * int(np.ceil(num_requests / len(sampled_indices)))
255+
print(
256+
f"Number of requests {num_requests} is larger than size of dataset"
257+
f" {n}.\n",
258+
f"Repeating data to meet number of requests.\n",
259+
)
260+
sampled_indices = sampled_indices * int(
261+
np.ceil(num_requests / len(sampled_indices))
262+
)
243263

244264
print(f"{len(sampled_indices)=}")
245265
# some of these will be filtered out, so sample more than we need
@@ -315,7 +335,9 @@ def calculate_metrics(
315335
return metrics
316336

317337

318-
async def grpc_async_request(api_url: str, request: Any) -> tuple[list[str], float, float]:
338+
async def grpc_async_request(
339+
api_url: str, request: Any
340+
) -> tuple[list[str], float, float]:
319341
"""Send grpc synchronous request since the current grpc server is sync."""
320342
options = [("grpc.keepalive_timeout_ms", 10000)]
321343
async with grpc.aio.insecure_channel(api_url, options=options) as channel:
@@ -351,7 +373,9 @@ async def send_request(
351373
output = RequestFuncOutput()
352374
output.input_request = input_request
353375
output.prompt_len = input_request.prompt_len
354-
generated_token_list, ttft, latency = await grpc_async_request(api_url, request)
376+
generated_token_list, ttft, latency = await grpc_async_request(
377+
api_url, request
378+
)
355379
output.ttft = ttft
356380
output.latency = latency
357381
output.generated_token_list = generated_token_list
@@ -453,14 +477,15 @@ def mock_requests(total_mock_requests: int):
453477

454478
def sample_warmup_requests(requests):
455479
interesting_buckets = [
456-
0,
457-
16,
458-
32,
459-
64,
460-
128,
461-
256,
462-
512,
463-
1024,]
480+
0,
481+
16,
482+
32,
483+
64,
484+
128,
485+
256,
486+
512,
487+
1024,
488+
]
464489

465490
for start, end in zip(interesting_buckets[:-1], interesting_buckets[1:]):
466491
for request in requests:
@@ -481,28 +506,30 @@ def main(args: argparse.Namespace):
481506

482507
tokenizer = get_tokenizer(tokenizer_id)
483508
if tokenizer == "test" or args.dataset == "test":
484-
input_requests = mock_requests(args.total_mock_requests) # e.g. [("AB", 2, "AB", 3)]
509+
input_requests = mock_requests(
510+
args.total_mock_requests
511+
) # e.g. [("AB", 2, "AB", 3)]
485512
else:
486513
if args.dataset == "openorca":
487514
dataset = load_openorca_dataset(args.dataset_path)
488515
elif args.dataset == "sharegpt":
489516
dataset = load_sharegpt_dataset(
490-
args.dataset_path,
491-
args.conversation_starter,
517+
args.dataset_path,
518+
args.conversation_starter,
492519
)
493520

494521
# A given args.max_output_length value is the max generation step,
495522
# when the args.max_output_length is default to None, the sample's golden output length
496523
# will be used to decide the generation step
497524
input_requests = sample_requests(
498-
dataset=dataset,
499-
tokenizer=tokenizer,
500-
num_requests=args.num_prompts,
501-
max_output_length=args.max_output_length
525+
dataset=dataset,
526+
tokenizer=tokenizer,
527+
num_requests=args.num_prompts,
528+
max_output_length=args.max_output_length,
502529
)
503530

504531
if args.warmup_first:
505-
print('Warm up start:' )
532+
print("Warm up start:")
506533
warmup_requests = list(sample_warmup_requests(input_requests)) * 2
507534
benchmark_result, request_outputs = asyncio.run(
508535
benchmark(
@@ -516,7 +543,7 @@ def main(args: argparse.Namespace):
516543
threads=args.threads,
517544
)
518545
)
519-
print('Warm up done')
546+
print("Warm up done")
520547

521548
benchmark_result, request_outputs = asyncio.run(
522549
benchmark(
@@ -561,7 +588,11 @@ def main(args: argparse.Namespace):
561588
if args.save_request_outputs:
562589
file_path = args.request_outputs_file_path
563590
with open(file_path, "w") as output_file:
564-
json.dump([output.to_dict() for output in request_outputs], output_file, indent=4)
591+
json.dump(
592+
[output.to_dict() for output in request_outputs],
593+
output_file,
594+
indent=4,
595+
)
565596

566597

567598
if __name__ == "__main__":
@@ -576,11 +607,13 @@ def main(args: argparse.Namespace):
576607
)
577608
parser.add_argument("--port", type=str, default=9000)
578609
parser.add_argument(
579-
"--dataset", type=str, default="test", choices=["test", "sharegpt", "openorca"], help="The dataset name."
580-
)
581-
parser.add_argument(
582-
"--dataset-path", type=str, help="Path to the dataset."
610+
"--dataset",
611+
type=str,
612+
default="test",
613+
choices=["test", "sharegpt", "openorca"],
614+
help="The dataset name.",
583615
)
616+
parser.add_argument("--dataset-path", type=str, help="Path to the dataset.")
584617
parser.add_argument(
585618
"--model",
586619
type=str,
@@ -637,7 +670,16 @@ def main(args: argparse.Namespace):
637670
"--max-output-length",
638671
type=int,
639672
default=None,
640-
help="The maximum output length for reference request.",
673+
help=(
674+
"The maximum output length for reference request. It would be passed"
675+
" to `max_tokens` parameter of the JetStream's DecodeRequest proto,"
676+
" and used in JetStream to control the output/decode length of a"
677+
" sequence. It would not be used in the engine. We should always set"
678+
" max_tokens <= (max_target_length - max_prefill_predict_length)."
679+
" max_target_length is the maximum length of a sequence;"
680+
" max_prefill_predict_length is the maximum length of the"
681+
" input/prefill of a sequence."
682+
),
641683
)
642684

643685
parser.add_argument("--seed", type=int, default=0)
@@ -678,26 +720,20 @@ def main(args: argparse.Namespace):
678720
"--request-outputs-file-path",
679721
type=str,
680722
default="/tmp/request-outputs.json",
681-
help=(
682-
"File path to store request outputs"
683-
),
723+
help="File path to store request outputs",
684724
)
685725
parser.add_argument(
686726
"--warmup-first",
687727
type=bool,
688728
default=False,
689-
help=(
690-
"Whether to send warmup req first"
691-
),
729+
help="Whether to send warmup req first",
692730
)
693731
parser.add_argument(
694732
"--conversation-starter",
695733
type=str,
696734
default="human",
697735
choices=["human", "gpt", "both"],
698-
help=(
699-
"What entity should be the one starting the conversations."
700-
),
736+
help="What entity should be the one starting the conversations.",
701737
)
702738

703739
args = parser.parse_args()

jetstream/core/proto/jetstream.proto

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
1615
syntax = "proto3";
1716

1817
package jetstream_proto;
@@ -29,6 +28,12 @@ message DecodeRequest {
2928
// New text from a user or tool.
3029
string additional_text = 2;
3130
int32 priority = 3;
31+
// The maximum output length of a sequence. It's used in JetStream to control
32+
// the output/decode length of a sequence. It would not be used in the engine.
33+
// We should always set max_tokens <= (max_target_length -
34+
// max_prefill_predict_length). max_target_length is the maximum length of a
35+
// sequence; max_prefill_predict_length is the maximum length of the
36+
// input/prefill of a sequence.
3237
int32 max_tokens = 4;
3338
}
3439
message DecodeResponse {

0 commit comments

Comments
 (0)