@@ -81,14 +81,33 @@ class BenchmarkMetrics:
8181 p99_tpot_ms : float
8282
8383
84+ @dataclass
85+ class InputRequest :
86+ prompt : str = ""
87+ prompt_len : int = 0
88+ output : str = ""
89+ output_len : int = 0
90+
8491@dataclass
8592class RequestFuncOutput :
93+ input_request : InputRequest = None
8694 generated_text : str = ""
8795 success : bool = False
8896 latency : float = 0
8997 ttft : float = 0
9098 prompt_len : int = 0
9199
100+ # Flatten the structure and return only the necessary results
101+ def to_dict (self ):
102+ return {
103+ "prompt" : self .input_request .prompt ,
104+ "original_output" : self .input_request .output ,
105+ "generated_text" : self .generated_text ,
106+ "success" : self .success ,
107+ "latency" : self .latency ,
108+ "prompt_len" : self .prompt_len
109+ }
110+
92111
93112def get_tokenizer (tokenizer_name : str ) -> Any :
94113 """Return a tokenizer or a tokenizer placholder."""
@@ -105,7 +124,7 @@ def sample_requests(
105124 dataset_path : str ,
106125 num_requests : int ,
107126 tokenizer : Any ,
108- ) -> List [Tuple [ str , int , int ] ]:
127+ ) -> List [InputRequest ]:
109128 # Load the dataset.
110129 with open (dataset_path ) as f :
111130 dataset = json .load (f )
@@ -133,11 +152,12 @@ def sample_requests(
133152 tokenized_dataset = []
134153 for i in range (len (dataset )):
135154 output_len = len (completion_token_ids [i ])
136- tokenized_dataset .append ((prompts [i ], prompt_token_ids [i ], output_len ))
155+ tokenized_dataset .append ((prompts [i ], prompt_token_ids [i ], completions [ i ], output_len ))
137156
138157 # Filter out too long sequences.
139- filtered_dataset : List [Tuple [str , int , int ]] = []
140- for prompt , prompt_token_ids , output_len in tokenized_dataset :
158+ filtered_dataset : List [InputRequest ] = []
159+
160+ for prompt , prompt_token_ids , output , output_len in tokenized_dataset :
141161 prompt_len = len (prompt_token_ids )
142162 if prompt_len < 4 or output_len < 4 :
143163 # Prune too short sequences.
@@ -147,17 +167,18 @@ def sample_requests(
147167 if prompt_len > 1024 or prompt_len + output_len > 2048 :
148168 # Prune too long sequences.
149169 continue
150- filtered_dataset .append ((prompt , prompt_len , output_len ))
170+ reqeust = InputRequest (prompt , prompt_len , output , output_len )
171+ filtered_dataset .append (reqeust )
151172
152173 # Sample the requests.
153174 sampled_requests = random .sample (filtered_dataset , num_requests )
154175 return sampled_requests
155176
156177
157178async def get_request (
158- input_requests : List [Tuple [ str , int , int ] ],
179+ input_requests : List [InputRequest ],
159180 request_rate : float ,
160- ) -> AsyncGenerator [Tuple [ str , int , int ] , None ]:
181+ ) -> AsyncGenerator [InputRequest , None ]:
161182 input_requests = iter (input_requests )
162183 for request in input_requests :
163184 yield request
@@ -172,7 +193,7 @@ async def get_request(
172193
173194
174195def calculate_metrics (
175- input_requests : List [Tuple [ str , int , int ] ],
196+ input_requests : List [InputRequest ],
176197 outputs : List [RequestFuncOutput ],
177198 dur_s : float ,
178199 tokenizer : Any ,
@@ -190,7 +211,7 @@ def calculate_metrics(
190211 else "ĊŌƟ"
191212 )
192213 total_output += output_len
193- total_input += input_requests [i ][ 1 ]
214+ total_input += input_requests [i ]. prompt_len
194215 per_token_latencies .append (outputs [i ].latency / output_len )
195216 ttfts .append (outputs [i ].ttft )
196217 completed += 1
@@ -234,25 +255,24 @@ def grpc_sync_request(api_url: str, request: Any) -> tuple[str, float, float]:
234255
235256async def send_request (
236257 api_url : str ,
237- prompt : str ,
238- prompt_len : int ,
258+ input_request : InputRequest ,
239259 pbar : tqdm ,
240260 session_cache : str ,
241261 priority : int ,
242- max_tokens : int ,
243262 threads : int ,
244263) -> RequestFuncOutput :
245264 """Send the request to JetStream server."""
246265 loop = asyncio .get_running_loop ()
247266 loop .set_default_executor (ThreadPoolExecutor (max_workers = threads ))
248267 request = jetstream_pb2 .DecodeRequest (
249268 session_cache = session_cache ,
250- additional_text = prompt ,
269+ additional_text = input_request . prompt ,
251270 priority = priority ,
252- max_tokens = max_tokens ,
271+ max_tokens = input_request . output_len ,
253272 )
254273 output = RequestFuncOutput ()
255- output .prompt_len = prompt_len
274+ output .input_request = input_request
275+ output .prompt_len = input_request .prompt_len
256276 generated_text , ttft , latency = await loop .run_in_executor (
257277 None , grpc_sync_request , api_url , request
258278 )
@@ -268,7 +288,7 @@ async def send_request(
268288async def benchmark (
269289 api_url : str ,
270290 tokenizer : Any ,
271- input_requests : List [Tuple [ str , int , int ] ],
291+ input_requests : List [InputRequest ],
272292 request_rate : float ,
273293 disable_tqdm : bool ,
274294 session_cache : str ,
@@ -283,17 +303,14 @@ async def benchmark(
283303 benchmark_start_time = time .perf_counter ()
284304 tasks = []
285305 async for request in get_request (input_requests , request_rate ):
286- prompt , prompt_len , output_len = request
287306 tasks .append (
288307 asyncio .create_task (
289308 send_request (
290309 api_url = api_url ,
291- prompt = prompt ,
292- prompt_len = prompt_len ,
310+ input_request = request ,
293311 pbar = pbar ,
294312 session_cache = session_cache ,
295313 priority = priority ,
296- max_tokens = output_len ,
297314 threads = threads ,
298315 )
299316 )
@@ -341,17 +358,19 @@ async def benchmark(
341358 "median_tpot_ms" : metrics .median_tpot_ms ,
342359 "p99_tpot_ms" : metrics .p99_tpot_ms ,
343360 }
344- return result
361+ return result , outputs
345362
346363
347364def mock_requests (total_mock_requests : int ):
348365 """Generates a list of mock requests containing mock data."""
349366 data = []
350367 for _ in range (total_mock_requests ):
351- name = f"Item { random .randint (1 , 1000 )} "
352- price = random .randint (10 , 100 )
353- quantity = random .randint (1 , 10 )
354- data .append ((name , price , quantity ))
368+ reqeust = InputRequest ()
369+ reqeust .prompt = f"Prompt { random .randint (1 , 1000 )} "
370+ reqeust .prompt_len = random .randint (10 , 100 )
371+ reqeust .out = f"Output { random .randint (1 , 1000 )} "
372+ reqeust .output_len = random .randint (1 , 10 )
373+ data .append (reqeust )
355374 return data
356375
357376
@@ -367,11 +386,11 @@ def main(args: argparse.Namespace):
367386
368387 tokenizer = get_tokenizer (tokenizer_id )
369388 if tokenizer == "test" or args .dataset == "test" :
370- input_requests = mock_requests (args .total_mock_requests ) # e.g. [("AB", 2, 3)]
389+ input_requests = mock_requests (args .total_mock_requests ) # e.g. [("AB", 2, "AB", 3)]
371390 else :
372391 input_requests = sample_requests (args .dataset , args .num_prompts , tokenizer )
373392
374- benchmark_result = asyncio .run (
393+ benchmark_result , request_outputs = asyncio .run (
375394 benchmark (
376395 api_url = api_url ,
377396 tokenizer = tokenizer ,
@@ -411,6 +430,11 @@ def main(args: argparse.Namespace):
411430 with open (file_name , "w" ) as outfile :
412431 json .dump (result_json , outfile )
413432
433+ if args .save_request_outputs :
434+ file_path = args .request_outputs_file_path
435+ with open (file_path , "w" ) as output_file :
436+ json .dump ([output .to_dict () for output in request_outputs ], output_file , indent = 4 )
437+
414438
415439if __name__ == "__main__" :
416440 parser = argparse .ArgumentParser (
@@ -506,6 +530,19 @@ def main(args: argparse.Namespace):
506530 " not implemented, use default empty str)"
507531 ),
508532 )
533+ parser .add_argument (
534+ "--save-request-outputs" ,
535+ action = "store_true" ,
536+ help = "Specify to store request outputs into a json file" ,
537+ )
538+ parser .add_argument (
539+ "--request-outputs-file-path" ,
540+ type = str ,
541+ default = "/tmp/request-outputs.json" ,
542+ help = (
543+ "File path to store request outputs"
544+ ),
545+ )
509546
510547 args = parser .parse_args ()
511548 main (args )
0 commit comments