Skip to content

Commit 8d9f667

Browse files
authored
Merge pull request #143 from iCSawyer/main
Add --save_references_path into args
2 parents be2a44c + a5fc279 commit 8d9f667

3 files changed

Lines changed: 11 additions & 4 deletions

File tree

bigcode_eval/evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ def evaluate(self, task_name):
8484
f"generations were saved at {self.args.save_generations_path}"
8585
)
8686
if self.args.save_references:
87-
with open("references.json", "w") as fp:
87+
with open(self.args.save_references_path, "w") as fp:
8888
json.dump(references, fp)
89-
print("references were saved at references.json")
89+
print(f"references were saved at {self.args.save_references_path}")
9090

9191
# make sure tokenizer plays nice with multiprocessing
9292
os.environ["TOKENIZERS_PARALLELISM"] = "false"

main.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,12 @@ def parse_args():
169169
action="store_true",
170170
help="Whether to save reference solutions/tests",
171171
)
172+
parser.add_argument(
173+
"--save_references_path",
174+
type=str,
175+
default="references.json",
176+
help="Path for saving the references solutions/tests",
177+
)
172178
parser.add_argument(
173179
"--prompt",
174180
type=str,
@@ -335,9 +341,9 @@ def main():
335341
json.dump(generations, fp)
336342
print(f"generations were saved at {args.save_generations_path}")
337343
if args.save_references:
338-
with open("references.json", "w") as fp:
344+
with open(args.save_references_path, "w") as fp:
339345
json.dump(references, fp)
340-
print("references were saved")
346+
print(f"references were saved at {args.save_references_path}")
341347
else:
342348
results[task] = evaluator.evaluate(task)
343349

tests/test_generation_evaluation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def update_args(args):
3131
args.save_generations = False
3232
args.save_generations_path = ""
3333
args.save_references = False
34+
args.save_references_path = ""
3435
args.metric_output_path = TMPDIR
3536
args.load_generations_path = None
3637
args.generation_only = False

0 commit comments

Comments
 (0)