1+ import json
12import math
3+ import re
24import warnings
35from collections import defaultdict
6+ from typing import List , Optional
47
58import torch
69from torch .utils .data import IterableDataset
@@ -49,7 +52,7 @@ def __iter__(self):
4952 prompts_encoder = []
5053 infill = []
5154 instruction = []
52- for sample in range (self .limit_start , self .limit_start + self .n_tasks ):
55+ for sample in range (self .limit_start , self .limit_start + self .n_tasks ):
5356 prompt_contents = self .task .get_prompt (self .dataset [sample ])
5457 if isinstance (prompt_contents , str ):
5558 # Normal code completion mode
@@ -111,8 +114,6 @@ def __iter__(self):
111114 return_token_type_ids = return_token_type_ids ,
112115 )
113116
114-
115-
116117 if self .n_copies == 1 and self .n_tasks % self .num_devices != 0 :
117118 self .n_copies = 2
118119 warnings .warn (
@@ -127,7 +128,9 @@ def __iter__(self):
127128 "ids_encoder" : outputs_encoder .input_ids [sample ],
128129 "task_id" : sample ,
129130 "input_len" : outputs .attention_mask [sample ].sum (),
130- "input_len_encoder" : outputs_encoder .attention_mask [sample ].sum (),
131+ "input_len_encoder" : outputs_encoder .attention_mask [
132+ sample
133+ ].sum (),
131134 }
132135 else :
133136 yield {
@@ -231,14 +234,20 @@ def complete_code(
231234 instruction_tokens = None ,
232235 postprocess = True ,
233236 is_wrapped = False ,
237+ save_every_k_tasks : int = - 1 ,
238+ intermediate_generations : Optional [List [Optional [List [Optional [str ]]]]] = None ,
239+ intermediate_save_generations_path : Optional [str ] = None ,
234240 ** gen_kwargs ,
235241):
236242 """Generate multiple codes for each task in the dataset using multiple GPUs with accelerate.
237243 dataloader sends all the prompts from the evalution dataset to the model as the following:
238244 [p_0_0, p_0_1, ..., p_0_nc-1, p_1_0, ..., p_nt-1_nc-1] where nc is the number of copies of the prompt,
239245 and nt is the number of tasks. nc is such that num_samples(for each task)= nc * batch_size
240246 """
241-
247+ # keep track of the list of generated codes
248+ # where len(code_gens) = n_tasks and len(code_gens[0]) = number of generated code samples
249+ code_gens : List [List [Optional [str ]]] = [[] for _ in range (n_tasks )]
250+ generations = [] if not intermediate_generations else intermediate_generations
242251 gen_token_dict = defaultdict (list ) # dict of list of generated tokens
243252 for step , batch in tqdm (
244253 enumerate (dataloader ),
@@ -251,12 +260,14 @@ def complete_code(
251260 # Set the start_length after which to check for stopping to be the longest input ignoring padding
252261 max_len = batch ["input_len" ].max ().item ()
253262 if "ids_encoder" in batch :
254- max_len += 1 # Add 1 for decoder_start_token_id
263+ max_len += 1 # Add 1 for decoder_start_token_id
255264 gen_kwargs ["stopping_criteria" ][0 ].start_length = max_len
256265 if hasattr (task , "max_length_multiplier" ) and task .max_length_multiplier :
257266 idx = 1 if task .stop_words else 0
258- gen_kwargs ["stopping_criteria" ][idx ].input_length = batch ["input_len" ].max ().item ()
259-
267+ gen_kwargs ["stopping_criteria" ][idx ].input_length = (
268+ batch ["input_len" ].max ().item ()
269+ )
270+
260271 inputs = batch ["ids" ][:, : batch ["input_len" ]]
261272 if "ids_encoder" in batch :
262273 if is_wrapped :
@@ -306,7 +317,55 @@ def complete_code(
306317 for sample , generated_tokens in zip (generated_tasks , generated_tokens ):
307318 gen_token_dict [sample ].append (generated_tokens )
308319
309- code_gens = [[] for _ in range (n_tasks )]
320+ if save_every_k_tasks >= 1 and (step + 1 ) % save_every_k_tasks == 0 :
321+ if not intermediate_save_generations_path :
322+ raise ValueError (
323+ "intermediate_save_generations_path cannot be empty!"
324+ )
325+
326+ code_gens = update_code_gens (
327+ task ,
328+ tokenizer ,
329+ limit_start ,
330+ prefix ,
331+ instruction_tokens ,
332+ postprocess ,
333+ code_gens ,
334+ gen_token_dict ,
335+ )
336+ with open (intermediate_save_generations_path , "w" ) as fp :
337+ json .dump (generations + code_gens , fp )
338+ print (
339+ f"intermediate generations were saved at { intermediate_save_generations_path } "
340+ )
341+ # reset gen_token_dict - prevent redundant decoding
342+ gen_token_dict = defaultdict (list )
343+
344+ code_gens = update_code_gens (
345+ task ,
346+ tokenizer ,
347+ limit_start ,
348+ prefix ,
349+ instruction_tokens ,
350+ postprocess ,
351+ code_gens ,
352+ gen_token_dict ,
353+ )
354+
355+ generations .extend (code_gens )
356+ return generations
357+
358+
359+ def update_code_gens (
360+ task ,
361+ tokenizer ,
362+ limit_start ,
363+ prefix ,
364+ instruction_tokens ,
365+ postprocess ,
366+ code_gens ,
367+ gen_token_dict ,
368+ ):
310369 for sample , generated_tokens in gen_token_dict .items ():
311370 for s in generated_tokens :
312371 if INFILL_MODE or tokenizer .eos_token in task .stop_words :
@@ -315,7 +374,7 @@ def complete_code(
315374 # Treat eos token as a regular stop word not removing it from the output
316375 # If it's removed it may have the effect of removing it in the middle of a
317376 # longer generation in case a batch size > 1 is used, which will result in
318- # a wrong generation as it won't be used for splitting lateron
377+ # a wrong generation as it won't be used for splitting lateron
319378 gen_code = tokenizer .decode (
320379 s , skip_special_tokens = False , clean_up_tokenization_spaces = False
321380 )
@@ -338,13 +397,9 @@ def complete_code(
338397 "model output is not postprocessed, this might lower evaluation scores"
339398 )
340399 code_gens [sample ].append (gen_code )
341-
342400 return code_gens
343401
344402
345- import re
346-
347-
348403def remove_after_return (code ):
349404 """
350405 Takes as input a code, and removes everything that is after the return.
@@ -361,6 +416,6 @@ def remove_after_return(code):
361416 and start_match < len (code )
362417 and code [start_match ].strip () != ""
363418 ):
364- return code [0 :start_match ]
419+ return code [0 : start_match ]
365420 end_last_match = end_match
366421 return code
0 commit comments