@@ -237,7 +237,8 @@ def calculate_metrics(
237237
238238def grpc_sync_request (api_url : str , request : Any ) -> tuple [list [str ], float , float ]:
239239 """Send grpc synchronous request since the current grpc server is sync."""
240- with grpc .insecure_channel (api_url ) as channel :
240+ options = [("grpc.keepalive_timeout_ms" , 10000 )]
241+ with grpc .insecure_channel (api_url , options = options ) as channel :
241242 grpc .channel_ready_future (channel ).result ()
242243 stub = jetstream_pb2_grpc .OrchestratorStub (channel )
243244 print ("Making request" )
@@ -374,6 +375,24 @@ def mock_requests(total_mock_requests: int):
374375 return data
375376
376377
378+ def sample_warmup_requests (requests ):
379+ interesting_buckets = [
380+ 0 ,
381+ 16 ,
382+ 32 ,
383+ 64 ,
384+ 128 ,
385+ 256 ,
386+ 512 ,
387+ 1024 ,]
388+
389+ for start , end in zip (interesting_buckets [:- 1 ], interesting_buckets [1 :]):
390+ for request in requests :
391+ if start < request .prompt_len <= end :
392+ yield request
393+ break
394+
395+
377396def main (args : argparse .Namespace ):
378397 print (args )
379398 random .seed (args .seed )
@@ -390,6 +409,23 @@ def main(args: argparse.Namespace):
390409 else :
391410 input_requests = sample_requests (args .dataset , args .num_prompts , tokenizer , args .max_output_length )
392411
412+ if args .warmup_first :
413+ print ('Warm up start:' )
414+ warmup_requests = list (sample_warmup_requests (input_requests )) * 2
415+ benchmark_result , request_outputs = asyncio .run (
416+ benchmark (
417+ api_url = api_url ,
418+ tokenizer = tokenizer ,
419+ input_requests = warmup_requests ,
420+ request_rate = args .request_rate ,
421+ disable_tqdm = args .disable_tqdm ,
422+ session_cache = args .session_cache ,
423+ priority = args .priority ,
424+ threads = args .threads ,
425+ )
426+ )
427+ print ('Warm up done' )
428+
393429 benchmark_result , request_outputs = asyncio .run (
394430 benchmark (
395431 api_url = api_url ,
@@ -551,6 +587,14 @@ def main(args: argparse.Namespace):
551587 "File path to store request outputs"
552588 ),
553589 )
590+ parser .add_argument (
591+ "--warmup-first" ,
592+ type = bool ,
593+ default = False ,
594+ help = (
595+ "Whether to send warmup req first"
596+ ),
597+ )
554598
555599 args = parser .parse_args ()
556600 main (args )
0 commit comments