@@ -91,7 +91,8 @@ class InputRequest:
9191@dataclass
9292class RequestFuncOutput :
9393 input_request : InputRequest = None
94- generated_text : str = ""
94+ generated_token_list : list [str ] = None
95+ generated_text : str = None
9596 success : bool = False
9697 latency : float = 0
9798 ttft : float = 0
@@ -124,6 +125,7 @@ def sample_requests(
124125 dataset_path : str ,
125126 num_requests : int ,
126127 tokenizer : Any ,
128+ max_output_length : int ,
127129) -> List [InputRequest ]:
128130 # Load the dataset.
129131 with open (dataset_path ) as f :
@@ -167,7 +169,7 @@ def sample_requests(
167169 if prompt_len > 1024 or prompt_len + output_len > 2048 :
168170 # Prune too long sequences.
169171 continue
170- reqeust = InputRequest (prompt , prompt_len , output , output_len )
172+ reqeust = InputRequest (prompt , prompt_len , output , max_output_length )
171173 filtered_dataset .append (reqeust )
172174
173175 # Sample the requests.
@@ -206,9 +208,9 @@ def calculate_metrics(
206208 for i in range (len (outputs )):
207209 if outputs [i ].success :
208210 output_len = len (
209- tokenizer . tokenize ( outputs [i ].generated_text )
211+ outputs [i ].generated_token_list
210212 if tokenizer != "test"
211- else "ĊŌƟ"
213+ else [ "Ċ" , "Ō" , "Ɵ" ]
212214 )
213215 total_output += output_len
214216 total_input += input_requests [i ].prompt_len
@@ -234,9 +236,10 @@ def calculate_metrics(
234236 return metrics
235237
236238
237- def grpc_sync_request (api_url : str , request : Any ) -> tuple [str , float , float ]:
239+ def grpc_sync_request (api_url : str , request : Any ) -> tuple [list [ str ] , float , float ]:
238240 """Send grpc synchronous request since the current grpc server is sync."""
239- with grpc .insecure_channel (api_url ) as channel :
241+ options = [("grpc.keepalive_timeout_ms" , 10000 )]
242+ with grpc .insecure_channel (api_url , options = options ) as channel :
240243 grpc .channel_ready_future (channel ).result ()
241244 stub = jetstream_pb2_grpc .OrchestratorStub (channel )
242245 print ("Making request" )
@@ -249,8 +252,7 @@ def grpc_sync_request(api_url: str, request: Any) -> tuple[str, float, float]:
249252 ttft = time .perf_counter () - request_start_time
250253 token_list .append (token .response [0 ])
251254 latency = time .perf_counter () - request_start_time
252- generated_text = "" .join (token_list )
253- return generated_text , ttft , latency
255+ return token_list , ttft , latency
254256
255257
256258async def send_request (
@@ -273,12 +275,13 @@ async def send_request(
273275 output = RequestFuncOutput ()
274276 output .input_request = input_request
275277 output .prompt_len = input_request .prompt_len
276- generated_text , ttft , latency = await loop .run_in_executor (
278+ generated_token_list , ttft , latency = await loop .run_in_executor (
277279 None , grpc_sync_request , api_url , request
278280 )
279281 output .ttft = ttft
280282 output .latency = latency
281- output .generated_text = generated_text
283+ output .generated_token_list = generated_token_list
284+ output .generated_text = "" .join (generated_token_list )
282285 output .success = True
283286 if pbar :
284287 pbar .update (1 )
@@ -374,6 +377,24 @@ def mock_requests(total_mock_requests: int):
374377 return data
375378
376379
380+ def sample_warmup_requests (requests ):
381+ interesting_buckets = [
382+ 0 ,
383+ 16 ,
384+ 32 ,
385+ 64 ,
386+ 128 ,
387+ 256 ,
388+ 512 ,
389+ 1024 ,]
390+
391+ for start , end in zip (interesting_buckets [:- 1 ], interesting_buckets [1 :]):
392+ for request in requests :
393+ if start < request .prompt_len <= end :
394+ yield request
395+ break
396+
397+
377398def main (args : argparse .Namespace ):
378399 print (args )
379400 random .seed (args .seed )
@@ -388,7 +409,24 @@ def main(args: argparse.Namespace):
388409 if tokenizer == "test" or args .dataset == "test" :
389410 input_requests = mock_requests (args .total_mock_requests ) # e.g. [("AB", 2, "AB", 3)]
390411 else :
391- input_requests = sample_requests (args .dataset , args .num_prompts , tokenizer )
412+ input_requests = sample_requests (args .dataset , args .num_prompts , tokenizer , args .max_output_length )
413+
414+ if args .warmup_first :
415+ print ('Warm up start:' )
416+ warmup_requests = list (sample_warmup_requests (input_requests )) * 2
417+ benchmark_result , request_outputs = asyncio .run (
418+ benchmark (
419+ api_url = api_url ,
420+ tokenizer = tokenizer ,
421+ input_requests = warmup_requests ,
422+ request_rate = args .request_rate ,
423+ disable_tqdm = args .disable_tqdm ,
424+ session_cache = args .session_cache ,
425+ priority = args .priority ,
426+ threads = args .threads ,
427+ )
428+ )
429+ print ('Warm up done' )
392430
393431 benchmark_result , request_outputs = asyncio .run (
394432 benchmark (
@@ -501,6 +539,14 @@ def main(args: argparse.Namespace):
501539 default = 150 ,
502540 help = "The maximum number of mock requests to send for benchmark testing." ,
503541 )
542+
543+ parser .add_argument (
544+ "--max-output-length" ,
545+ type = int ,
546+ default = 1024 ,
547+ help = "The maximum output length for reference request." ,
548+ )
549+
504550 parser .add_argument ("--seed" , type = int , default = 0 )
505551 parser .add_argument (
506552 "--disable-tqdm" ,
@@ -543,6 +589,14 @@ def main(args: argparse.Namespace):
543589 "File path to store request outputs"
544590 ),
545591 )
592+ parser .add_argument (
593+ "--warmup-first" ,
594+ type = bool ,
595+ default = False ,
596+ help = (
597+ "Whether to send warmup req first"
598+ ),
599+ )
546600
547601 args = parser .parse_args ()
548602 main (args )
0 commit comments