@@ -131,10 +131,8 @@ def get_tokenizer(tokenizer_name: str) -> Any:
131131
132132def load_sharegpt_dataset (
133133 dataset_path : str ,
134- tokenizer : Any ,
135134 conversation_starter : str ,
136- max_output_length : Optional [int ] = None ,
137- ) -> List [InputRequest ]:
135+ ) -> List [tuple [str ]]:
138136 # Load the dataset.
139137 with open (dataset_path ) as f :
140138 dataset = json .load (f )
@@ -150,116 +148,112 @@ def load_sharegpt_dataset(
150148 for data in dataset
151149 ]
152150
151+ return dataset
152+
153+
154+ def load_openorca_dataset (
155+ dataset_path : str
156+ ) -> List [tuple [str ]]:
157+ # Load the dataset.
158+ with open (dataset_path ) as f :
159+ dataset = json .load (f )
160+
153161 # Tokenize the prompts and completions.
162+ prompts = dataset ["prompts" ]
163+ outputs = dataset ["results" ]
164+
165+ return [(prompt , output ) for prompt , output in zip (prompts , outputs )]
166+
167+
168+ def tokenize_dataset (
169+ dataset : List [tuple [str ]],
170+ tokenizer : Any ,
171+ ) -> List [tuple [Any ]]:
172+
173+ n = len (dataset )
174+
154175 prompts = [prompt for prompt , _ in dataset ]
176+ outputs = [output for _ , output in dataset ]
177+
155178 prompt_token_ids = tokenizer .tokenize (
156179 prompts
157180 ) # adjust this code based on tokenizer method
158- completions = [completion for _ , completion in dataset ]
159- completion_token_ids = tokenizer .tokenize (
160- completions
181+ outputs_token_ids = tokenizer .tokenize (
182+ outputs
161183 ) # adjust this code based on tokenizer method
184+
162185 tokenized_dataset = []
163- for i in range (len ( dataset ) ):
186+ for i in range (n ):
164187 prompt_len = len (prompt_token_ids [i ])
165- completion_len = len (completion_token_ids [i ])
188+ output_len = len (outputs_token_ids [i ])
166189 tokenized_dataset .append (
167- (prompts [i ], prompt_token_ids [i ], completions [i ], prompt_len , completion_len )
190+ (prompts [i ], prompt_token_ids [i ], outputs [i ], prompt_len , output_len )
168191 )
192+ return tokenized_dataset
169193
170- # Filter out too long sequences.
171- filtered_dataset : List [InputRequest ] = []
172-
173- for prompt , prompt_token_ids , completion , prompt_len , completion_len in tokenized_dataset :
174- if prompt_len < 4 or completion_len < 4 :
175- # Prune too short sequences.
176- # This is because TGI causes errors when the input or output length
177- # is too short.
178- continue
179- if prompt_len > 1024 or prompt_len + completion_len > 2048 :
180- # Prune too long sequences.
181- continue
182- request = InputRequest (prompt , prompt_len , completion , max_output_length or completion_len )
183- filtered_dataset .append (request )
184194
195+ def filter_dataset (
196+ tokenized_dataset : List [tuple [Any ]],
197+ max_output_length : Optional [int ] = None
198+ ) -> List [InputRequest ]:
185199 if max_output_length is None :
186200 print ("In InputRequest, pass in actual output_length for each sample" )
187201 else :
188202 print (f"In InputRequest, pass in max_output_length: { max_output_length } for each sample" )
189203
190- print (f"The dataset contains { len (tokenized_dataset )} samples." )
191- print (f"The filtered dataset contains { len (filtered_dataset )} samples." )
192-
193- return filtered_dataset
194-
195-
196- def load_openorca_dataset (
197- dataset_path : str ,
198- tokenizer : Any ,
199- max_output_length : Optional [int ] = None ,
200- ) -> List [InputRequest ]:
201-
202- # Load the dataset.
203- with open (dataset_path ) as f :
204- dataset = json .load (f )
205-
206- # Tokenize the prompts and completions.
207- prompts = dataset ["prompts" ]
208- outputs = dataset ["results" ]
209- n = len (prompts )
210- prompt_token_ids = tokenizer .tokenize (prompts )
211- output_token_ids = tokenizer .tokenize (outputs )
212-
213- tokenized_dataset = []
214- for i in range (n ):
215- prompt_len = len (prompt_token_ids [i ])
216- output_len = len (output_token_ids [i ])
217- tokenized_dataset .append ((prompts [i ], prompt_token_ids [i ], outputs [i ], prompt_len , output_len ))
218-
219204 # Filter out too long sequences.
220205 filtered_dataset : List [InputRequest ] = []
221-
222206 for prompt , prompt_token_ids , output , prompt_len , output_len in tokenized_dataset :
207+ if prompt_len < 4 or output_len < 4 :
208+ # Prune too short sequences.
209+ # This is because TGI causes errors when the input or output length
210+ # is too short.
211+ continue
223212 if prompt_len > 1024 or prompt_len + output_len > 2048 :
224213 # Prune too long sequences.
225214 continue
226215 request = InputRequest (prompt , prompt_len , output , max_output_length or output_len )
227216 filtered_dataset .append (request )
228217
229- if max_output_length is None :
230- print ("In InputRequest, pass in actual output_length for each sample" )
231- else :
232- print (f"In InputRequest, pass in max_output_length: { max_output_length } for each sample" )
233-
234218 print (f"The dataset contains { len (tokenized_dataset )} samples." )
235219 print (f"The filtered dataset contains { len (filtered_dataset )} samples." )
236220
237221 return filtered_dataset
238222
239223
240224def sample_requests (
241- dataset : List [InputRequest ],
225+ dataset : List [tuple [str ]],
226+ tokenizer : Any ,
242227 num_requests : int ,
228+ max_output_length : Optional [int ] = None ,
243229 oversample_multiplier : float = 1.2 ,
244230 ) -> List [InputRequest ]:
245231
232+ # Original dataset size
233+ n = len (dataset )
234+
246235 # Create necessary number of requests even if bigger than dataset size
247236 sampled_indices = random .sample (
248- range (len ( dataset )) , min (int (num_requests * oversample_multiplier ), len ( dataset ) ))
237+ range (n ) , min (int (num_requests * oversample_multiplier ), n ))
249238
250239 if num_requests > len (sampled_indices ):
251- print (f"Number of requests { num_requests } is larger than size of dataset { len ( dataset ) } .\n " ,
240+ print (f"Number of requests { num_requests } is larger than size of dataset { n } .\n " ,
252241 f"Repeating data to meet number of requests.\n " )
253242 sampled_indices = sampled_indices * int (np .ceil (num_requests / len (sampled_indices )))
254243
255244 print (f"{ len (sampled_indices )= } " )
256245 # some of these will be filtered out, so sample more than we need
257246 dataset = [dataset [i ] for i in sampled_indices ]
258247
248+ tokenized_dataset = tokenize_dataset (dataset , tokenizer )
249+
250+ input_requests = filter_dataset (tokenized_dataset , max_output_length )
251+
259252 # Sample the requests.
260- sampled_requests = random .sample (dataset , num_requests )
253+ if len (input_requests ) > num_requests :
254+ input_requests = random .sample (input_requests , num_requests )
261255
262- return sampled_requests
256+ return input_requests
263257
264258
265259async def get_request (
@@ -490,25 +484,21 @@ def main(args: argparse.Namespace):
490484 input_requests = mock_requests (args .total_mock_requests ) # e.g. [("AB", 2, "AB", 3)]
491485 else :
492486 if args .dataset == "openorca" :
493- dataset = load_openorca_dataset (
494- args .dataset_path ,
495- tokenizer ,
496- args .max_output_length
497- )
487+ dataset = load_openorca_dataset (args .dataset_path )
498488 elif args .dataset == "sharegpt" :
499489 dataset = load_sharegpt_dataset (
500490 args .dataset_path ,
501- tokenizer ,
502491 args .conversation_starter ,
503- args .max_output_length
504492 )
505493
506494 # A given args.max_output_length value is the max generation step,
507495 # when the args.max_output_length is default to None, the sample's golden output length
508496 # will be used to decide the generation step
509497 input_requests = sample_requests (
510- dataset ,
511- args .num_prompts ,
498+ dataset = dataset ,
499+ tokenizer = tokenizer ,
500+ num_requests = args .num_prompts ,
501+ max_output_length = args .max_output_length
512502 )
513503
514504 if args .warmup_first :
0 commit comments