6565import json
6666import random
6767import time
68- from typing import Any , AsyncGenerator , List , Optional
68+ from typing import Any , AsyncGenerator , Optional
6969
7070import grpc
7171from jetstream .core .proto import jetstream_pb2
7272from jetstream .core .proto import jetstream_pb2_grpc
73+ from jetstream .engine .token_utils import load_vocab
7374import numpy as np
74- import tensorflow as tf
75- import tensorflow_text as tftxt
76- from tqdm .asyncio import tqdm
75+ from tqdm .asyncio import tqdm # pytype: disable=pyi-error
7776from eval_accuracy import eval_accuracy
7877
7978
@@ -106,9 +105,9 @@ class InputRequest:
106105
107106@dataclass
108107class RequestFuncOutput :
109- input_request : InputRequest = None
110- generated_token_list : list [str ] = None
111- generated_text : str = None
108+ input_request : Optional [ InputRequest ] = None
109+ generated_token_list : list [str ] = []
110+ generated_text : str = ""
112111 success : bool = False
113112 latency : float = 0
114113 ttft : float = 0
@@ -132,18 +131,16 @@ def get_tokenizer(tokenizer_name: str) -> Any:
132131 if tokenizer_name == "test" :
133132 return "test"
134133 else :
135- with tf .io .gfile .GFile (tokenizer_name , "rb" ) as model_fp :
136- sp_model = model_fp .read ()
137- sp_tokenizer = tftxt .SentencepieceTokenizer (
138- model = sp_model , add_bos = True , add_eos = False , reverse = False
139- )
140- return sp_tokenizer
134+ # Use JetStream tokenizer util. It's using the sentencepiece wrapper in
135+ # seqio library.
136+ vocab = load_vocab (tokenizer_name )
137+ return vocab .tokenizer
141138
142139
143140def load_sharegpt_dataset (
144141 dataset_path : str ,
145142 conversation_starter : str ,
146- ) -> List [tuple [str ]]:
143+ ) -> list [tuple [Any , Any ]]:
147144 # Load the dataset.
148145 with open (dataset_path , "r" , encoding = "utf-8" ) as f :
149146 dataset = json .load (f )
@@ -166,7 +163,7 @@ def load_sharegpt_dataset(
166163 return dataset
167164
168165
169- def load_openorca_dataset (dataset_path : str ) -> List [tuple [str ]]:
166+ def load_openorca_dataset (dataset_path : str ) -> list [tuple [Any , Any ]]:
170167 # Load the dataset.
171168 with open (dataset_path , "r" , encoding = "utf-8" ) as f :
172169 dataset = json .load (f )
@@ -179,9 +176,9 @@ def load_openorca_dataset(dataset_path: str) -> List[tuple[str]]:
179176
180177
181178def tokenize_dataset (
182- dataset : List [tuple [str ]],
179+ dataset : list [tuple [Any , Any , Any ]],
183180 tokenizer : Any ,
184- ) -> List [tuple [Any ]]:
181+ ) -> list [tuple [str , Any , str , int , int , int ]]:
185182
186183 n = len (dataset )
187184
@@ -194,10 +191,10 @@ def tokenize_dataset(
194191 outputs .append (output )
195192 indices .append (idx )
196193
197- prompt_token_ids = tokenizer .tokenize (
194+ prompt_token_ids = tokenizer .encode (
198195 prompts
199196 ) # adjust this code based on tokenizer method
200- outputs_token_ids = tokenizer .tokenize (
197+ outputs_token_ids = tokenizer .encode (
201198 outputs
202199 ) # adjust this code based on tokenizer method
203200
@@ -218,8 +215,9 @@ def tokenize_dataset(
218215
219216
220217def filter_dataset (
221- tokenized_dataset : List [tuple [Any ]], max_output_length : Optional [int ] = None
222- ) -> List [InputRequest ]:
218+ tokenized_dataset : list [tuple [str , Any , str , int , int , int ]],
219+ max_output_length : Optional [int ] = None ,
220+ ) -> list [InputRequest ]:
223221 if max_output_length is None :
224222 print ("In InputRequest, pass in actual output_length for each sample" )
225223 else :
@@ -229,7 +227,7 @@ def filter_dataset(
229227 )
230228
231229 # Filter out too long sequences.
232- filtered_dataset : List [InputRequest ] = []
230+ filtered_dataset : list [InputRequest ] = []
233231 for (
234232 prompt ,
235233 _ ,
@@ -258,12 +256,12 @@ def filter_dataset(
258256
259257
260258def sample_requests (
261- dataset : List [tuple [str ]],
259+ dataset : list [tuple [Any , Any ]],
262260 tokenizer : Any ,
263261 num_requests : int ,
264262 max_output_length : Optional [int ] = None ,
265263 oversample_multiplier : float = 1.2 ,
266- ) -> List [InputRequest ]:
264+ ) -> list [InputRequest ]:
267265
268266 # Original dataset size
269267 n = len (dataset )
@@ -304,7 +302,7 @@ def sample_requests(
304302
305303
306304async def get_request (
307- input_requests : List [InputRequest ],
305+ input_requests : list [InputRequest ],
308306 request_rate : float ,
309307) -> AsyncGenerator [InputRequest , None ]:
310308 input_requests = iter (input_requests )
@@ -321,8 +319,8 @@ async def get_request(
321319
322320
323321def calculate_metrics (
324- input_requests : List [InputRequest ],
325- outputs : List [RequestFuncOutput ],
322+ input_requests : list [InputRequest ],
323+ outputs : list [RequestFuncOutput ],
326324 dur_s : float ,
327325 tokenizer : Any ,
328326) -> BenchmarkMetrics :
@@ -374,16 +372,17 @@ async def grpc_async_request(
374372 token_list = []
375373 request_start_time = time .perf_counter ()
376374 response = stub .Decode (request )
377- async for token in response :
375+ async for sample_list in response :
378376 if ttft == 0 :
379377 ttft = time .perf_counter () - request_start_time
380- token_list .append ( token .response [0 ])
378+ token_list .extend ( sample_list .response [0 ]. token_ids )
381379 latency = time .perf_counter () - request_start_time
382380 return token_list , ttft , latency
383381
384382
385383async def send_request (
386384 api_url : str ,
385+ tokenizer : Any ,
387386 input_request : InputRequest ,
388387 pbar : tqdm ,
389388 session_cache : str ,
@@ -405,7 +404,8 @@ async def send_request(
405404 output .ttft = ttft
406405 output .latency = latency
407406 output .generated_token_list = generated_token_list
408- output .generated_text = "" .join (generated_token_list )
407+ # generated_token_list is a list of token ids, decode it to generated_text.
408+ output .generated_text = tokenizer .decode (generated_token_list )
409409 output .success = True
410410 if pbar :
411411 pbar .update (1 )
@@ -415,7 +415,7 @@ async def send_request(
415415async def benchmark (
416416 api_url : str ,
417417 tokenizer : Any ,
418- input_requests : List [InputRequest ],
418+ input_requests : list [InputRequest ],
419419 request_rate : float ,
420420 disable_tqdm : bool ,
421421 session_cache : str ,
@@ -433,6 +433,7 @@ async def benchmark(
433433 asyncio .create_task (
434434 send_request (
435435 api_url = api_url ,
436+ tokenizer = tokenizer ,
436437 input_request = request ,
437438 pbar = pbar ,
438439 session_cache = session_cache ,
@@ -442,7 +443,7 @@ async def benchmark(
442443 )
443444 outputs = await asyncio .gather (* tasks )
444445
445- if not disable_tqdm :
446+ if not disable_tqdm and pbar :
446447 pbar .close ()
447448
448449 benchmark_duration = time .perf_counter () - benchmark_start_time
0 commit comments