Skip to content

Commit 342aed8

Browse files
committed
feat: support selective eval task
1 parent 8cdcdfe commit 342aed8

1 file changed

Lines changed: 15 additions & 4 deletions

File tree

bigcodebench/evaluate.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def evaluate(
119119
samples: Optional[str] = None,
120120
no_execute: bool = False,
121121
local_execute: bool = False,
122+
selective_evaluate: str = "",
122123
remote_execute_api: str = "https://bigcode-bigcodebench-evaluator.hf.space/",
123124
pass_k: str = "1,5,10",
124125
save_pass_rate: bool = True,
@@ -168,6 +169,7 @@ def evaluate(
168169
calibrated=calibrated,
169170
check_gt_only=check_gt_only,
170171
no_gt=no_gt,
172+
selective_evaluate=selective_evaluate,
171173
api_name="/predict"
172174
)
173175
break
@@ -193,6 +195,14 @@ def evaluate(
193195
samples = "__dummy__.jsonl"
194196

195197
problems = get_bigcodebench(subset=subset)
198+
199+
# Add selective evaluation logic
200+
if selective_evaluate:
201+
selected_ids = set(selective_evaluate.split(","))
202+
problems = {k: v for k, v in problems.items() if k in selected_ids}
203+
if not problems:
204+
raise ValueError(f"None of the provided task IDs {selected_ids} were found in the dataset")
205+
196206
dataset_hash = get_bigcodebench_hash(subset=subset)
197207

198208
if not no_gt:
@@ -240,10 +250,9 @@ def evaluate(
240250
task_id = sample["task_id"]
241251

242252
if task_id not in problems:
243-
warn(
244-
f"Task {task_id} is found in the samples but not found in the dataset"
245-
)
253+
# Skip if task is not in problems (either not in dataset or filtered out by selective_evaluate)
246254
continue
255+
247256
solution = (
248257
sample["solution"]
249258
if "solution" in sample
@@ -267,8 +276,10 @@ def evaluate(
267276
completion_id[task_id] += 1
268277
n_samples += 1
269278

279+
# Modify the assertion to account for selective evaluation
270280
assert n_samples == len(remainings), "Missing problems in unfinished"
271-
assert len(completion_id) == len(problems), "Missing problems in samples"
281+
# Only check against problems that weren't filtered out
282+
assert len(completion_id) == len(problems), f"Missing problems in samples. Expected {len(problems)} problems, got {len(completion_id)}"
272283

273284
def stucking_checker():
274285
while remainings:

0 commit comments

Comments
 (0)