File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -127,6 +127,7 @@ def sample_requests(
127127 tokenizer : Any ,
128128 max_output_length : int ,
129129 conversation_starter : str ,
130+ oversample_multiplier : float = 1.2 ,
130131) -> List [InputRequest ]:
131132 # Load the dataset.
132133 with open (dataset_path ) as f :
@@ -143,8 +144,16 @@ def sample_requests(
143144 for data in dataset
144145 ]
145146
147+ # Create necessary number of requests even if bigger than dataset size
148+ sampled_indices = random .sample (range (len (dataset )),
149+ min (int (num_requests * oversample_multiplier ), len (dataset )))
150+ if num_requests > len (sampled_indices ):
151+ print (f"Number of requests { num_requests } is larger than size of dataset { len (dataset )} .\n " ,
152+ f"Repeating data to meet number of requests.\n " )
153+ sampled_indices = sampled_indices * int (np .ceil (num_requests / len (sampled_indices )))
154+
155+ print (f"{ len (sampled_indices )= } " )
146156 # some of these will be filtered out, so sample more than we need
147- sampled_indices = random .sample (range (len (dataset )), int (num_requests * 1.2 ))
148157 dataset = [dataset [i ] for i in sampled_indices ]
149158
150159 # Tokenize the prompts and completions.
You can’t perform that action at this time.
0 commit comments