1515"""Benchmark JetStream online serving.
1616
1717On the server side, run one of the following commands:
18- * For real server, you need to pass correct server config (include the model config that
19- being passed into your engine impl) to the command below. Refer to config_lib.py and
18+ * For real server, you need to pass correct server config (include the model config that
19+ being passed into your engine impl) to the command below. Refer to config_lib.py and
2020 implementations/mock/config.py for config impl detail.
2121
2222 (run with real server)
2828
2929On the client side, run:
3030 * For real server and shareGPT dataset, you need to pass the tokenizer, server config, and
31- dataset flags to the command below, and make some changes to the tokenizer logic in the
31+ dataset flags to the command below, and make some changes to the tokenizer logic in the
3232 benchmark script (get_tokenizer and sample_requests func) to use your tokenizer correctly.
3333 * Add `--save-result` flag to save the benchmark result to a json file in current folder.
3434 * Add `--threads` flag to set the maximum number of threads used for request dispatching.
3535
3636 (run with real model and engines)
3737 python -m benchmarks.benchmark_serving \
38- --tokenizer <your_tokenizer> --dataset <target_dataset_path> \
38+ --tokenizer <your_tokenizer> \
39+ --dataset <target_dataset_name> \
40+ --dataset-path <target_dataset_path> \
3941 --request-rate <request_rate>
4042
4143 (run with mock)
4244 python -m benchmarks.benchmark_serving \
4345 --request-rate 1
4446
45- e2e example: python3 benchmark_serving.py --tokenizer /home/rwitten/maxtext/assets/tokenizer --num-prompts 100 --dataset ~/ShareGPT_V3_unfiltered_cleaned_split.json
47+ e2e example:
48+ python3 benchmark_serving.py \
49+ --tokenizer /home/{username}/maxtext/assets/tokenizer \
50+ --num-prompts 100 \
51+ --dataset sharegpt \
52+ --dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json
53+
4654"""
4755
4856
5159
5260import argparse
5361import asyncio
54- from concurrent . futures import ThreadPoolExecutor
62+
5563from dataclasses import dataclass
5664from datetime import datetime
5765import json
5866import random
5967import time
60- from typing import Any , AsyncGenerator , List , Tuple
68+ from typing import Any , AsyncGenerator , List , Optional
6169import grpc
6270from jetstream .core .proto import jetstream_pb2
6371from jetstream .core .proto import jetstream_pb2_grpc
@@ -99,15 +107,15 @@ class RequestFuncOutput:
99107 prompt_len : int = 0
100108
101109 # Flatten the structure and return only the necessary results
102- def to_dict (self ):
110+ def to_dict (self ):
103111 return {
104112 "prompt" : self .input_request .prompt ,
105113 "original_output" : self .input_request .output ,
106114 "generated_text" : self .generated_text ,
107115 "success" : self .success ,
108116 "latency" : self .latency ,
109117 "prompt_len" : self .prompt_len
110- }
118+ }
111119
112120
113121def get_tokenizer (tokenizer_name : str ) -> Any :
@@ -121,13 +129,11 @@ def get_tokenizer(tokenizer_name: str) -> Any:
121129 model = sp_model , add_bos = True , add_eos = False , reverse = False )
122130 return sp_tokenizer
123131
124- def sample_requests (
132+ def load_sharegpt_dataset (
125133 dataset_path : str ,
126- num_requests : int ,
127134 tokenizer : Any ,
128- max_output_length : int ,
129135 conversation_starter : str ,
130- oversample_multiplier : float = 1.2 ,
136+ max_output_length : Optional [ int ] = None ,
131137) -> List [InputRequest ]:
132138 # Load the dataset.
133139 with open (dataset_path ) as f :
@@ -144,18 +150,6 @@ def sample_requests(
144150 for data in dataset
145151 ]
146152
147- # Create necessary number of requests even if bigger than dataset size
148- sampled_indices = random .sample (range (len (dataset )),
149- min (int (num_requests * oversample_multiplier ), len (dataset )))
150- if num_requests > len (sampled_indices ):
151- print (f"Number of requests { num_requests } is larger than size of dataset { len (dataset )} .\n " ,
152- f"Repeating data to meet number of requests.\n " )
153- sampled_indices = sampled_indices * int (np .ceil (num_requests / len (sampled_indices )))
154-
155- print (f"{ len (sampled_indices )= } " )
156- # some of these will be filtered out, so sample more than we need
157- dataset = [dataset [i ] for i in sampled_indices ]
158-
159153 # Tokenize the prompts and completions.
160154 prompts = [prompt for prompt , _ in dataset ]
161155 prompt_token_ids = tokenizer .tokenize (
@@ -167,27 +161,104 @@ def sample_requests(
167161 ) # adjust this code based on tokenizer method
168162 tokenized_dataset = []
169163 for i in range (len (dataset )):
170- output_len = len (completion_token_ids [i ])
171- tokenized_dataset .append ((prompts [i ], prompt_token_ids [i ], completions [i ], output_len ))
164+ prompt_len = len (prompt_token_ids [i ])
165+ completion_len = len (completion_token_ids [i ])
166+ tokenized_dataset .append (
167+ (prompts [i ], prompt_token_ids [i ], completions [i ], prompt_len , completion_len )
168+ )
172169
173170 # Filter out too long sequences.
174171 filtered_dataset : List [InputRequest ] = []
175172
176- for prompt , prompt_token_ids , output , output_len in tokenized_dataset :
177- prompt_len = len (prompt_token_ids )
178- if prompt_len < 4 or output_len < 4 :
173+ for prompt , prompt_token_ids , completion , prompt_len , completion_len in tokenized_dataset :
174+ if prompt_len < 4 or completion_len < 4 :
179175 # Prune too short sequences.
180176 # This is because TGI causes errors when the input or output length
181177 # is too short.
182178 continue
179+ if prompt_len > 1024 or prompt_len + completion_len > 2048 :
180+ # Prune too long sequences.
181+ continue
182+ request = InputRequest (prompt , prompt_len , completion , max_output_length or completion_len )
183+ filtered_dataset .append (request )
184+
185+ if max_output_length is None :
186+ print ("In InputRequest, pass in actual output_length for each sample" )
187+ else :
188+ print (f"In InputRequest, pass in max_output_length: { max_output_length } for each sample" )
189+
190+ print (f"The dataset contains { len (tokenized_dataset )} samples." )
191+ print (f"The filtered dataset contains { len (filtered_dataset )} samples." )
192+
193+ return filtered_dataset
194+
195+
196+ def load_openorca_dataset (
197+ dataset_path : str ,
198+ tokenizer : Any ,
199+ max_output_length : Optional [int ] = None ,
200+ ) -> List [InputRequest ]:
201+
202+ # Load the dataset.
203+ with open (dataset_path ) as f :
204+ dataset = json .load (f )
205+
206+ # Tokenize the prompts and completions.
207+ prompts = dataset ["prompts" ]
208+ outputs = dataset ["results" ]
209+ n = len (prompts )
210+ prompt_token_ids = tokenizer .tokenize (prompts )
211+ output_token_ids = tokenizer .tokenize (outputs )
212+
213+ tokenized_dataset = []
214+ for i in range (n ):
215+ prompt_len = len (prompt_token_ids [i ])
216+ output_len = len (output_token_ids [i ])
217+ tokenized_dataset .append ((prompts [i ], prompt_token_ids [i ], outputs [i ], prompt_len , output_len ))
218+
219+ # Filter out too long sequences.
220+ filtered_dataset : List [InputRequest ] = []
221+
222+ for prompt , prompt_token_ids , output , prompt_len , output_len in tokenized_dataset :
183223 if prompt_len > 1024 or prompt_len + output_len > 2048 :
184224 # Prune too long sequences.
185225 continue
186- request = InputRequest (prompt , prompt_len , output , max_output_length )
226+ request = InputRequest (prompt , prompt_len , output , max_output_length or output_len )
187227 filtered_dataset .append (request )
188228
229+ if max_output_length is None :
230+ print ("In InputRequest, pass in actual output_length for each sample" )
231+ else :
232+ print (f"In InputRequest, pass in max_output_length: { max_output_length } for each sample" )
233+
234+ print (f"The dataset contains { len (tokenized_dataset )} samples." )
235+ print (f"The filtered dataset contains { len (filtered_dataset )} samples." )
236+
237+ return filtered_dataset
238+
239+
240+ def sample_requests (
241+ dataset : List [InputRequest ],
242+ num_requests : int ,
243+ oversample_multiplier : float = 1.2 ,
244+ ) -> List [InputRequest ]:
245+
246+ # Create necessary number of requests even if bigger than dataset size
247+ sampled_indices = random .sample (
248+ range (len (dataset )), min (int (num_requests * oversample_multiplier ), len (dataset )))
249+
250+ if num_requests > len (sampled_indices ):
251+ print (f"Number of requests { num_requests } is larger than size of dataset { len (dataset )} .\n " ,
252+ f"Repeating data to meet number of requests.\n " )
253+ sampled_indices = sampled_indices * int (np .ceil (num_requests / len (sampled_indices )))
254+
255+ print (f"{ len (sampled_indices )= } " )
256+ # some of these will be filtered out, so sample more than we need
257+ dataset = [dataset [i ] for i in sampled_indices ]
258+
189259 # Sample the requests.
190- sampled_requests = random .sample (filtered_dataset , num_requests )
260+ sampled_requests = random .sample (dataset , num_requests )
261+
191262 return sampled_requests
192263
193264
@@ -396,7 +467,7 @@ def sample_warmup_requests(requests):
396467 256 ,
397468 512 ,
398469 1024 ,]
399-
470+
400471 for start , end in zip (interesting_buckets [:- 1 ], interesting_buckets [1 :]):
401472 for request in requests :
402473 if start < request .prompt_len <= end :
@@ -418,12 +489,26 @@ def main(args: argparse.Namespace):
418489 if tokenizer == "test" or args .dataset == "test" :
419490 input_requests = mock_requests (args .total_mock_requests ) # e.g. [("AB", 2, "AB", 3)]
420491 else :
492+ if args .dataset == "openorca" :
493+ dataset = load_openorca_dataset (
494+ args .dataset_path ,
495+ tokenizer ,
496+ args .max_output_length
497+ )
498+ elif args .dataset == "sharegpt" :
499+ dataset = load_sharegpt_dataset (
500+ args .dataset_path ,
501+ tokenizer ,
502+ args .conversation_starter ,
503+ args .max_output_length
504+ )
505+
506+ # A given args.max_output_length value is the max generation step,
507+ # when the args.max_output_length is default to None, the sample's golden output length
508+ # will be used to decide the generation step
421509 input_requests = sample_requests (
422- args .dataset ,
423- args .num_prompts ,
424- tokenizer ,
425- args .max_output_length ,
426- args .conversation_starter ,
510+ dataset ,
511+ args .num_prompts ,
427512 )
428513
429514 if args .warmup_first :
@@ -486,7 +571,7 @@ def main(args: argparse.Namespace):
486571 if args .save_request_outputs :
487572 file_path = args .request_outputs_file_path
488573 with open (file_path , "w" ) as output_file :
489- json .dump ([output .to_dict () for output in request_outputs ], output_file , indent = 4 )
574+ json .dump ([output .to_dict () for output in request_outputs ], output_file , indent = 4 )
490575
491576
492577if __name__ == "__main__" :
@@ -501,7 +586,10 @@ def main(args: argparse.Namespace):
501586 )
502587 parser .add_argument ("--port" , type = str , default = 9000 )
503588 parser .add_argument (
504- "--dataset" , type = str , default = "test" , help = "Path to the dataset."
589+ "--dataset" , type = str , default = "test" , choices = ["test" , "sharegpt" , "openorca" ], help = "The dataset name."
590+ )
591+ parser .add_argument (
592+ "--dataset-path" , type = str , help = "Path to the dataset."
505593 )
506594 parser .add_argument (
507595 "--model" ,
@@ -558,7 +646,7 @@ def main(args: argparse.Namespace):
558646 parser .add_argument (
559647 "--max-output-length" ,
560648 type = int ,
561- default = 1024 ,
649+ default = None ,
562650 help = "The maximum output length for reference request." ,
563651 )
564652
0 commit comments