5454"""
5555
5656
57- import tensorflow as tf
58- import tensorflow_text as tftxt
59-
6057import argparse
6158import asyncio
62-
6359from dataclasses import dataclass
6460from datetime import datetime
6561import json
6662import random
6763import time
6864from typing import Any , AsyncGenerator , List , Optional
65+
6966import grpc
7067from jetstream .core .proto import jetstream_pb2
7168from jetstream .core .proto import jetstream_pb2_grpc
7269import numpy as np
70+ import tensorflow as tf
71+ import tensorflow_text as tftxt
7372from tqdm .asyncio import tqdm
7473
7574
@@ -96,6 +95,7 @@ class InputRequest:
9695 output : str = ""
9796 output_len : int = 0
9897
98+
9999@dataclass
100100class RequestFuncOutput :
101101 input_request : InputRequest = None
@@ -109,12 +109,12 @@ class RequestFuncOutput:
109109 # Flatten the structure and return only the necessary results
110110 def to_dict (self ):
111111 return {
112- "prompt" : self .input_request .prompt ,
113- "original_output" : self .input_request .output ,
114- "generated_text" : self .generated_text ,
115- "success" : self .success ,
116- "latency" : self .latency ,
117- "prompt_len" : self .prompt_len
112+ "prompt" : self .input_request .prompt ,
113+ "original_output" : self .input_request .output ,
114+ "generated_text" : self .generated_text ,
115+ "success" : self .success ,
116+ "latency" : self .latency ,
117+ "prompt_len" : self .prompt_len ,
118118 }
119119
120120
@@ -123,12 +123,14 @@ def get_tokenizer(tokenizer_name: str) -> Any:
123123 if tokenizer_name == "test" :
124124 return "test"
125125 else :
126- with tf .io .gfile .GFile (tokenizer_name , 'rb' ) as model_fp :
126+ with tf .io .gfile .GFile (tokenizer_name , "rb" ) as model_fp :
127127 sp_model = model_fp .read ()
128128 sp_tokenizer = tftxt .SentencepieceTokenizer (
129- model = sp_model , add_bos = True , add_eos = False , reverse = False )
129+ model = sp_model , add_bos = True , add_eos = False , reverse = False
130+ )
130131 return sp_tokenizer
131132
133+
132134def load_sharegpt_dataset (
133135 dataset_path : str ,
134136 conversation_starter : str ,
@@ -141,7 +143,11 @@ def load_sharegpt_dataset(
141143
142144 # Filter based on conversation starter
143145 if conversation_starter != "both" :
144- dataset = [data for data in dataset if data ["conversations" ][0 ]["from" ] == conversation_starter ]
146+ dataset = [
147+ data
148+ for data in dataset
149+ if data ["conversations" ][0 ]["from" ] == conversation_starter
150+ ]
145151 # Only keep the first two turns of each conversation.
146152 dataset = [
147153 (data ["conversations" ][0 ]["value" ], data ["conversations" ][1 ]["value" ])
@@ -151,9 +157,7 @@ def load_sharegpt_dataset(
151157 return dataset
152158
153159
154- def load_openorca_dataset (
155- dataset_path : str
156- ) -> List [tuple [str ]]:
160+ def load_openorca_dataset (dataset_path : str ) -> List [tuple [str ]]:
157161 # Load the dataset.
158162 with open (dataset_path ) as f :
159163 dataset = json .load (f )
@@ -187,23 +191,31 @@ def tokenize_dataset(
187191 prompt_len = len (prompt_token_ids [i ])
188192 output_len = len (outputs_token_ids [i ])
189193 tokenized_dataset .append (
190- (prompts [i ], prompt_token_ids [i ], outputs [i ], prompt_len , output_len )
194+ (prompts [i ], prompt_token_ids [i ], outputs [i ], prompt_len , output_len )
191195 )
192196 return tokenized_dataset
193197
194198
195199def filter_dataset (
196- tokenized_dataset : List [tuple [Any ]],
197- max_output_length : Optional [int ] = None
200+ tokenized_dataset : List [tuple [Any ]], max_output_length : Optional [int ] = None
198201) -> List [InputRequest ]:
199202 if max_output_length is None :
200203 print ("In InputRequest, pass in actual output_length for each sample" )
201204 else :
202- print (f"In InputRequest, pass in max_output_length: { max_output_length } for each sample" )
205+ print (
206+ f"In InputRequest, pass in max_output_length: { max_output_length } for"
207+ " each sample"
208+ )
203209
204210 # Filter out too long sequences.
205211 filtered_dataset : List [InputRequest ] = []
206- for prompt , prompt_token_ids , output , prompt_len , output_len in tokenized_dataset :
212+ for (
213+ prompt ,
214+ prompt_token_ids ,
215+ output ,
216+ prompt_len ,
217+ output_len ,
218+ ) in tokenized_dataset :
207219 if prompt_len < 4 or output_len < 4 :
208220 # Prune too short sequences.
209221 # This is because TGI causes errors when the input or output length
@@ -212,7 +224,9 @@ def filter_dataset(
212224 if prompt_len > 1024 or prompt_len + output_len > 2048 :
213225 # Prune too long sequences.
214226 continue
215- request = InputRequest (prompt , prompt_len , output , max_output_length or output_len )
227+ request = InputRequest (
228+ prompt , prompt_len , output , max_output_length or output_len
229+ )
216230 filtered_dataset .append (request )
217231
218232 print (f"The dataset contains { len (tokenized_dataset )} samples." )
@@ -226,20 +240,26 @@ def sample_requests(
226240 tokenizer : Any ,
227241 num_requests : int ,
228242 max_output_length : Optional [int ] = None ,
229- oversample_multiplier : float = 1.2 ,
230- ) -> List [InputRequest ]:
243+ oversample_multiplier : float = 1.2 ,
244+ ) -> List [InputRequest ]:
231245
232246 # Original dataset size
233247 n = len (dataset )
234248
235249 # Create necessary number of requests even if bigger than dataset size
236250 sampled_indices = random .sample (
237- range (n ), min (int (num_requests * oversample_multiplier ), n ))
251+ range (n ), min (int (num_requests * oversample_multiplier ), n )
252+ )
238253
239254 if num_requests > len (sampled_indices ):
240- print (f"Number of requests { num_requests } is larger than size of dataset { n } .\n " ,
241- f"Repeating data to meet number of requests.\n " )
242- sampled_indices = sampled_indices * int (np .ceil (num_requests / len (sampled_indices )))
255+ print (
256+ f"Number of requests { num_requests } is larger than size of dataset"
257+ f" { n } .\n " ,
258+ f"Repeating data to meet number of requests.\n " ,
259+ )
260+ sampled_indices = sampled_indices * int (
261+ np .ceil (num_requests / len (sampled_indices ))
262+ )
243263
244264 print (f"{ len (sampled_indices )= } " )
245265 # some of these will be filtered out, so sample more than we need
@@ -315,7 +335,9 @@ def calculate_metrics(
315335 return metrics
316336
317337
318- async def grpc_async_request (api_url : str , request : Any ) -> tuple [list [str ], float , float ]:
338+ async def grpc_async_request (
339+ api_url : str , request : Any
340+ ) -> tuple [list [str ], float , float ]:
319341 """Send grpc synchronous request since the current grpc server is sync."""
320342 options = [("grpc.keepalive_timeout_ms" , 10000 )]
321343 async with grpc .aio .insecure_channel (api_url , options = options ) as channel :
@@ -351,7 +373,9 @@ async def send_request(
351373 output = RequestFuncOutput ()
352374 output .input_request = input_request
353375 output .prompt_len = input_request .prompt_len
354- generated_token_list , ttft , latency = await grpc_async_request (api_url , request )
376+ generated_token_list , ttft , latency = await grpc_async_request (
377+ api_url , request
378+ )
355379 output .ttft = ttft
356380 output .latency = latency
357381 output .generated_token_list = generated_token_list
@@ -453,14 +477,15 @@ def mock_requests(total_mock_requests: int):
453477
454478def sample_warmup_requests (requests ):
455479 interesting_buckets = [
456- 0 ,
457- 16 ,
458- 32 ,
459- 64 ,
460- 128 ,
461- 256 ,
462- 512 ,
463- 1024 ,]
480+ 0 ,
481+ 16 ,
482+ 32 ,
483+ 64 ,
484+ 128 ,
485+ 256 ,
486+ 512 ,
487+ 1024 ,
488+ ]
464489
465490 for start , end in zip (interesting_buckets [:- 1 ], interesting_buckets [1 :]):
466491 for request in requests :
@@ -481,28 +506,30 @@ def main(args: argparse.Namespace):
481506
482507 tokenizer = get_tokenizer (tokenizer_id )
483508 if tokenizer == "test" or args .dataset == "test" :
484- input_requests = mock_requests (args .total_mock_requests ) # e.g. [("AB", 2, "AB", 3)]
509+ input_requests = mock_requests (
510+ args .total_mock_requests
511+ ) # e.g. [("AB", 2, "AB", 3)]
485512 else :
486513 if args .dataset == "openorca" :
487514 dataset = load_openorca_dataset (args .dataset_path )
488515 elif args .dataset == "sharegpt" :
489516 dataset = load_sharegpt_dataset (
490- args .dataset_path ,
491- args .conversation_starter ,
517+ args .dataset_path ,
518+ args .conversation_starter ,
492519 )
493520
494521 # A given args.max_output_length value is the max generation step,
495522 # when the args.max_output_length is default to None, the sample's golden output length
496523 # will be used to decide the generation step
497524 input_requests = sample_requests (
498- dataset = dataset ,
499- tokenizer = tokenizer ,
500- num_requests = args .num_prompts ,
501- max_output_length = args .max_output_length
525+ dataset = dataset ,
526+ tokenizer = tokenizer ,
527+ num_requests = args .num_prompts ,
528+ max_output_length = args .max_output_length ,
502529 )
503530
504531 if args .warmup_first :
505- print (' Warm up start:' )
532+ print (" Warm up start:" )
506533 warmup_requests = list (sample_warmup_requests (input_requests )) * 2
507534 benchmark_result , request_outputs = asyncio .run (
508535 benchmark (
@@ -516,7 +543,7 @@ def main(args: argparse.Namespace):
516543 threads = args .threads ,
517544 )
518545 )
519- print (' Warm up done' )
546+ print (" Warm up done" )
520547
521548 benchmark_result , request_outputs = asyncio .run (
522549 benchmark (
@@ -561,7 +588,11 @@ def main(args: argparse.Namespace):
561588 if args .save_request_outputs :
562589 file_path = args .request_outputs_file_path
563590 with open (file_path , "w" ) as output_file :
564- json .dump ([output .to_dict () for output in request_outputs ], output_file , indent = 4 )
591+ json .dump (
592+ [output .to_dict () for output in request_outputs ],
593+ output_file ,
594+ indent = 4 ,
595+ )
565596
566597
567598if __name__ == "__main__" :
@@ -576,11 +607,13 @@ def main(args: argparse.Namespace):
576607 )
577608 parser .add_argument ("--port" , type = str , default = 9000 )
578609 parser .add_argument (
579- "--dataset" , type = str , default = "test" , choices = ["test" , "sharegpt" , "openorca" ], help = "The dataset name."
580- )
581- parser .add_argument (
582- "--dataset-path" , type = str , help = "Path to the dataset."
610+ "--dataset" ,
611+ type = str ,
612+ default = "test" ,
613+ choices = ["test" , "sharegpt" , "openorca" ],
614+ help = "The dataset name." ,
583615 )
616+ parser .add_argument ("--dataset-path" , type = str , help = "Path to the dataset." )
584617 parser .add_argument (
585618 "--model" ,
586619 type = str ,
@@ -637,7 +670,16 @@ def main(args: argparse.Namespace):
637670 "--max-output-length" ,
638671 type = int ,
639672 default = None ,
640- help = "The maximum output length for reference request." ,
673+ help = (
674+ "The maximum output length for reference request. It would be passed"
675+ " to `max_tokens` parameter of the JetStream's DecodeRequest proto,"
676+ " and used in JetStream to control the output/decode length of a"
677+ " sequence. It would not be used in the engine. We should always set"
678+ " max_tokens <= (max_target_length - max_prefill_predict_length)."
679+ " max_target_length is the maximum length of a sequence;"
680+ " max_prefill_predict_length is the maximum length of the"
681+ " input/prefill of a sequence."
682+ ),
641683 )
642684
643685 parser .add_argument ("--seed" , type = int , default = 0 )
@@ -678,26 +720,20 @@ def main(args: argparse.Namespace):
678720 "--request-outputs-file-path" ,
679721 type = str ,
680722 default = "/tmp/request-outputs.json" ,
681- help = (
682- "File path to store request outputs"
683- ),
723+ help = "File path to store request outputs" ,
684724 )
685725 parser .add_argument (
686726 "--warmup-first" ,
687727 type = bool ,
688728 default = False ,
689- help = (
690- "Whether to send warmup req first"
691- ),
729+ help = "Whether to send warmup req first" ,
692730 )
693731 parser .add_argument (
694732 "--conversation-starter" ,
695733 type = str ,
696734 default = "human" ,
697735 choices = ["human" , "gpt" , "both" ],
698- help = (
699- "What entity should be the one starting the conversations."
700- ),
736+ help = "What entity should be the one starting the conversations." ,
701737 )
702738
703739 args = parser .parse_args ()
0 commit comments