Skip to content

Commit 199eeec

Browse files
authored
Merge pull request #166 from bigcode-project/max/save-intermediate-gen
QOL changes for generations
2 parents 8d9f667 + 6b18f1e commit 199eeec

6 files changed

Lines changed: 179 additions & 44 deletions

File tree

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
run: |
2727
python -m pip install --upgrade pip
2828
pip install flake8 pytest
29-
pip install transformers==4.21.1 accelerate==0.13.2 datasets==2.6.1 evaluate==0.2.2 pyext==0.7 mosestokenizer==1.0.0 "fsspec<2023.10.0"
29+
pip install transformers==4.21.1 accelerate==0.13.2 datasets==2.14.6 evaluate==0.2.2 pyext==0.7 mosestokenizer==1.0.0 "fsspec<2023.10.0"
3030
#- name: Lint with flake8
3131
# run: |
3232
# flake8 .

bigcode_eval/evaluator.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import os
44
import warnings
55

6+
from typing import List
7+
8+
69
from bigcode_eval import tasks
710
from bigcode_eval.generation import parallel_generations
811

@@ -24,7 +27,6 @@
2427
################################################################################\
2528
"""
2629

27-
2830
class Evaluator:
2931
def __init__(self, accelerator, model, tokenizer, args):
3032
self.accelerator = accelerator
@@ -38,11 +40,16 @@ def __init__(self, accelerator, model, tokenizer, args):
3840
# code evaluation permission
3941
self.allow_code_execution = args.allow_code_execution
4042

41-
def generate_text(self, task_name):
43+
def generate_text(self, task_name, intermediate_generations=None):
4244
task = tasks.get_task(task_name, self.args)
4345
dataset = task.get_dataset()
4446
# if args.limit is None, use all samples
45-
n_tasks = self.args.limit if self.args.limit else len(dataset)
47+
# if args.limit is used, make sure args.limit_start + args.limit <= len(dataset)
48+
n_tasks = min(self.args.limit, len(dataset) - self.args.limit_start) if self.args.limit else len(dataset)
49+
# when args.limit is None
50+
# adjust n_tasks by args.limit_start to prevent out of bounds issues
51+
if not self.args.limit:
52+
n_tasks -= self.args.limit_start
4653
references = [task.get_reference(dataset[i]) for i in range(self.args.limit_start, self.args.limit_start+n_tasks)]
4754

4855
if self.args.check_references:
@@ -52,6 +59,13 @@ def generate_text(self, task_name):
5259
solutions = [[ref] for ref in references]
5360
return solutions, references
5461

62+
curr_generations = [] # list[list[str | None] | None]
63+
if intermediate_generations:
64+
curr_generations = [gen for gen in intermediate_generations if gen]
65+
n_tasks -= len(curr_generations)
66+
intermediate_save_generations_path = f"{os.path.splitext(self.args.save_generations_path)[0]}_{task_name}_intermediate.json"
67+
curr_sample_idx = len(curr_generations)
68+
5569
generations = parallel_generations(
5670
task,
5771
dataset,
@@ -60,33 +74,30 @@ def generate_text(self, task_name):
6074
self.tokenizer,
6175
n_tasks=n_tasks,
6276
args=self.args,
77+
curr_sample_idx=curr_sample_idx, # curr_sample_idx will added to limit_start to fix indexing
78+
save_every_k_tasks=self.args.save_every_k_tasks,
79+
intermediate_generations=curr_generations,
80+
intermediate_save_generations_path=intermediate_save_generations_path,
6381
)
82+
6483
if len(generations[0]) > self.args.n_samples:
6584
generations = [l[: self.args.n_samples] for l in generations]
6685
warnings.warn(
6786
f"Number of tasks wasn't proportional to number of devices, we removed extra predictions to only keep nsamples={self.args.n_samples}"
6887
)
6988
return generations, references
7089

71-
def evaluate(self, task_name):
90+
def evaluate(self, task_name, intermediate_generations=None):
7291
task = tasks.get_task(task_name, self.args)
7392
if task.requires_execution and not self.allow_code_execution:
7493
raise ValueError(_WARNING)
7594

76-
generations, references = self.generate_text(task_name)
95+
generations, references = self.generate_text(task_name, intermediate_generations=intermediate_generations)
7796

7897
if self.accelerator.is_main_process:
7998
if not self.args.load_generations_path:
80-
if self.args.save_generations:
81-
with open(self.args.save_generations_path, "w") as fp:
82-
json.dump(generations, fp)
83-
print(
84-
f"generations were saved at {self.args.save_generations_path}"
85-
)
86-
if self.args.save_references:
87-
with open(self.args.save_references_path, "w") as fp:
88-
json.dump(references, fp)
89-
print(f"references were saved at {self.args.save_references_path}")
99+
save_generations_path = f"{os.path.splitext(self.args.save_generations_path)[0]}_{task_name}.json"
100+
self.save_json_files(generations, references, save_generations_path, f"references_{task_name}.json")
90101

91102
# make sure tokenizer plays nice with multiprocessing
92103
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -95,3 +106,19 @@ def evaluate(self, task_name):
95106
print("Evaluating generations...")
96107
results = task.process_results(generations, references)
97108
return results
109+
110+
def save_json_files(
111+
self,
112+
generations: List[str],
113+
references: List[str],
114+
save_generations_path: str,
115+
save_references_path: str,
116+
) -> None:
117+
if self.args.save_generations:
118+
with open(save_generations_path, "w") as fp:
119+
json.dump(generations, fp)
120+
print(f"generations were saved at {save_generations_path}")
121+
if self.args.save_references:
122+
with open(save_references_path, "w") as fp:
123+
json.dump(references, fp)
124+
print(f"references were saved at {save_references_path}")

bigcode_eval/generation.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import json
22
from math import ceil
33

4+
from typing import List, Optional
5+
46
from accelerate.utils import set_seed
57
from torch.utils.data.dataloader import DataLoader
68
from transformers import StoppingCriteria, StoppingCriteriaList
@@ -37,7 +39,19 @@ def __call__(self, input_ids, scores, **kwargs):
3739
return input_ids.shape[1] > int(self.input_length * self.multiplier)
3840

3941

40-
def parallel_generations(task, dataset, accelerator, model, tokenizer, n_tasks, args):
42+
def parallel_generations(
43+
task,
44+
dataset,
45+
accelerator,
46+
model,
47+
tokenizer,
48+
n_tasks,
49+
args,
50+
curr_sample_idx: int = 0,
51+
save_every_k_tasks: int = -1,
52+
intermediate_generations: Optional[List[Optional[List[Optional[str]]]]] = None,
53+
intermediate_save_generations_path: Optional[str] = None,
54+
):
4155
if args.load_generations_path:
4256
# load generated code
4357
with open(args.load_generations_path) as fp:
@@ -100,7 +114,7 @@ def parallel_generations(task, dataset, accelerator, model, tokenizer, n_tasks,
100114
tokenizer,
101115
num_devices=accelerator.state.num_processes,
102116
max_length=args.max_length_generation,
103-
limit_start=args.limit_start,
117+
limit_start=args.limit_start + curr_sample_idx,
104118
n_tasks=n_tasks,
105119
n_copies=n_copies,
106120
prefix=args.prefix,
@@ -131,12 +145,15 @@ def parallel_generations(task, dataset, accelerator, model, tokenizer, n_tasks,
131145
tokenizer,
132146
ds_loader,
133147
n_tasks=n_tasks,
134-
limit_start=args.limit_start,
148+
limit_start=args.limit_start + curr_sample_idx,
135149
batch_size=args.batch_size,
136150
prefix=args.prefix,
137151
instruction_tokens=instruction_tokens,
138152
postprocess=args.postprocess,
139153
is_wrapped=is_loaded_in_8bit or is_loaded_in_4bit,
154+
save_every_k_tasks=save_every_k_tasks,
155+
intermediate_generations=intermediate_generations,
156+
intermediate_save_generations_path=intermediate_save_generations_path,
140157
**gen_kwargs,
141158
)
142159
return generations

bigcode_eval/utils.py

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import json
12
import math
3+
import re
24
import warnings
35
from collections import defaultdict
6+
from typing import List, Optional
47

58
import torch
69
from torch.utils.data import IterableDataset
@@ -49,7 +52,7 @@ def __iter__(self):
4952
prompts_encoder = []
5053
infill = []
5154
instruction = []
52-
for sample in range(self.limit_start, self.limit_start+self.n_tasks):
55+
for sample in range(self.limit_start, self.limit_start + self.n_tasks):
5356
prompt_contents = self.task.get_prompt(self.dataset[sample])
5457
if isinstance(prompt_contents, str):
5558
# Normal code completion mode
@@ -111,8 +114,6 @@ def __iter__(self):
111114
return_token_type_ids=return_token_type_ids,
112115
)
113116

114-
115-
116117
if self.n_copies == 1 and self.n_tasks % self.num_devices != 0:
117118
self.n_copies = 2
118119
warnings.warn(
@@ -127,7 +128,9 @@ def __iter__(self):
127128
"ids_encoder": outputs_encoder.input_ids[sample],
128129
"task_id": sample,
129130
"input_len": outputs.attention_mask[sample].sum(),
130-
"input_len_encoder": outputs_encoder.attention_mask[sample].sum(),
131+
"input_len_encoder": outputs_encoder.attention_mask[
132+
sample
133+
].sum(),
131134
}
132135
else:
133136
yield {
@@ -231,14 +234,20 @@ def complete_code(
231234
instruction_tokens=None,
232235
postprocess=True,
233236
is_wrapped=False,
237+
save_every_k_tasks: int = -1,
238+
intermediate_generations: Optional[List[Optional[List[Optional[str]]]]] = None,
239+
intermediate_save_generations_path: Optional[str] = None,
234240
**gen_kwargs,
235241
):
236242
"""Generate multiple codes for each task in the dataset using multiple GPUs with accelerate.
237243
dataloader sends all the prompts from the evalution dataset to the model as the following:
238244
[p_0_0, p_0_1, ..., p_0_nc-1, p_1_0, ..., p_nt-1_nc-1] where nc is the number of copies of the prompt,
239245
and nt is the number of tasks. nc is such that num_samples(for each task)= nc * batch_size
240246
"""
241-
247+
# keep track of the list of generated codes
248+
# where len(code_gens) = n_tasks and len(code_gens[0]) = number of generated code samples
249+
code_gens: List[List[Optional[str]]] = [[] for _ in range(n_tasks)]
250+
generations = [] if not intermediate_generations else intermediate_generations
242251
gen_token_dict = defaultdict(list) # dict of list of generated tokens
243252
for step, batch in tqdm(
244253
enumerate(dataloader),
@@ -251,12 +260,14 @@ def complete_code(
251260
# Set the start_length after which to check for stopping to be the longest input ignoring padding
252261
max_len = batch["input_len"].max().item()
253262
if "ids_encoder" in batch:
254-
max_len += 1 # Add 1 for decoder_start_token_id
263+
max_len += 1 # Add 1 for decoder_start_token_id
255264
gen_kwargs["stopping_criteria"][0].start_length = max_len
256265
if hasattr(task, "max_length_multiplier") and task.max_length_multiplier:
257266
idx = 1 if task.stop_words else 0
258-
gen_kwargs["stopping_criteria"][idx].input_length = batch["input_len"].max().item()
259-
267+
gen_kwargs["stopping_criteria"][idx].input_length = (
268+
batch["input_len"].max().item()
269+
)
270+
260271
inputs = batch["ids"][:, : batch["input_len"]]
261272
if "ids_encoder" in batch:
262273
if is_wrapped:
@@ -306,7 +317,55 @@ def complete_code(
306317
for sample, generated_tokens in zip(generated_tasks, generated_tokens):
307318
gen_token_dict[sample].append(generated_tokens)
308319

309-
code_gens = [[] for _ in range(n_tasks)]
320+
if save_every_k_tasks >= 1 and (step + 1) % save_every_k_tasks == 0:
321+
if not intermediate_save_generations_path:
322+
raise ValueError(
323+
"intermediate_save_generations_path cannot be empty!"
324+
)
325+
326+
code_gens = update_code_gens(
327+
task,
328+
tokenizer,
329+
limit_start,
330+
prefix,
331+
instruction_tokens,
332+
postprocess,
333+
code_gens,
334+
gen_token_dict,
335+
)
336+
with open(intermediate_save_generations_path, "w") as fp:
337+
json.dump(generations + code_gens, fp)
338+
print(
339+
f"intermediate generations were saved at {intermediate_save_generations_path}"
340+
)
341+
# reset gen_token_dict - prevent redundant decoding
342+
gen_token_dict = defaultdict(list)
343+
344+
code_gens = update_code_gens(
345+
task,
346+
tokenizer,
347+
limit_start,
348+
prefix,
349+
instruction_tokens,
350+
postprocess,
351+
code_gens,
352+
gen_token_dict,
353+
)
354+
355+
generations.extend(code_gens)
356+
return generations
357+
358+
359+
def update_code_gens(
360+
task,
361+
tokenizer,
362+
limit_start,
363+
prefix,
364+
instruction_tokens,
365+
postprocess,
366+
code_gens,
367+
gen_token_dict,
368+
):
310369
for sample, generated_tokens in gen_token_dict.items():
311370
for s in generated_tokens:
312371
if INFILL_MODE or tokenizer.eos_token in task.stop_words:
@@ -315,7 +374,7 @@ def complete_code(
315374
# Treat eos token as a regular stop word not removing it from the output
316375
# If it's removed it may have the effect of removing it in the middle of a
317376
# longer generation in case a batch size > 1 is used, which will result in
318-
# a wrong generation as it won't be used for splitting lateron
377+
# a wrong generation as it won't be used for splitting lateron
319378
gen_code = tokenizer.decode(
320379
s, skip_special_tokens=False, clean_up_tokenization_spaces=False
321380
)
@@ -338,13 +397,9 @@ def complete_code(
338397
"model output is not postprocessed, this might lower evaluation scores"
339398
)
340399
code_gens[sample].append(gen_code)
341-
342400
return code_gens
343401

344402

345-
import re
346-
347-
348403
def remove_after_return(code):
349404
"""
350405
Takes as input a code, and removes everything that is after the return.
@@ -361,6 +416,6 @@ def remove_after_return(code):
361416
and start_match < len(code)
362417
and code[start_match].strip() != ""
363418
):
364-
return code[0:start_match]
419+
return code[0: start_match]
365420
end_last_match = end_match
366421
return code

0 commit comments

Comments
 (0)