@@ -126,12 +126,17 @@ def sample_requests(
126126 num_requests : int ,
127127 tokenizer : Any ,
128128 max_output_length : int ,
129+ conversation_starter : str ,
129130) -> List [InputRequest ]:
130131 # Load the dataset.
131132 with open (dataset_path ) as f :
132133 dataset = json .load (f )
133134 # Filter out the conversations with less than 2 turns.
134135 dataset = [data for data in dataset if len (data ["conversations" ]) >= 2 ]
136+
137+ # Filter based on conversation starter
138+ if conversation_starter != "both" :
139+ dataset = [data for data in dataset if data ["conversations" ][0 ]["from" ] == conversation_starter ]
135140 # Only keep the first two turns of each conversation.
136141 dataset = [
137142 (data ["conversations" ][0 ]["value" ], data ["conversations" ][1 ]["value" ])
@@ -169,8 +174,8 @@ def sample_requests(
169174 if prompt_len > 1024 or prompt_len + output_len > 2048 :
170175 # Prune too long sequences.
171176 continue
172- reqeust = InputRequest (prompt , prompt_len , output , max_output_length )
173- filtered_dataset .append (reqeust )
177+ request = InputRequest (prompt , prompt_len , output , max_output_length )
178+ filtered_dataset .append (request )
174179
175180 # Sample the requests.
176181 sampled_requests = random .sample (filtered_dataset , num_requests )
@@ -409,7 +414,13 @@ def main(args: argparse.Namespace):
409414 if tokenizer == "test" or args .dataset == "test" :
410415 input_requests = mock_requests (args .total_mock_requests ) # e.g. [("AB", 2, "AB", 3)]
411416 else :
412- input_requests = sample_requests (args .dataset , args .num_prompts , tokenizer , args .max_output_length )
417+ input_requests = sample_requests (
418+ args .dataset ,
419+ args .num_prompts ,
420+ tokenizer ,
421+ args .max_output_length ,
422+ args .conversation_starter ,
423+ )
413424
414425 if args .warmup_first :
415426 print ('Warm up start:' )
@@ -597,6 +608,15 @@ def main(args: argparse.Namespace):
597608 "Whether to send warmup req first"
598609 ),
599610 )
611+ parser .add_argument (
612+ "--conversation-starter" ,
613+ type = str ,
614+ default = "human" ,
615+ choices = ["human" , "gpt" , "both" ],
616+ help = (
617+ "What entity should be the one starting the conversations."
618+ ),
619+ )
600620
601621 args = parser .parse_args ()
602622 main (args )
0 commit comments