@@ -300,16 +300,14 @@ def run_samples(samples, n_workers=4):
300300 return results
301301
302302
303- def estimate_at_k (num_samples , num_correct , k ):
304- """Estimates beyond@k of each problem and returns them in an array."""
305-
306- def cf (n , k ):
307- return math .gamma (n + 1 ) / (math .gamma (k + 1 ) * (math .gamma (n - k + 1 )))
303+ def estimate_pass_at_k (num_samples , num_correct , k ):
304+ """Estimates pass@k of each problem and returns them in an array."""
308305
309306 def estimator (n : int , c : int , k : int ) -> float :
307+ """Calculates 1 - comb(n - c, k) / comb(n, k)."""
310308 if n - c < k :
311309 return 1.0
312- return 1 - cf ( n - c , k ) / cf ( n , k )
310+ return 1.0 - np . prod ( 1.0 - k / np . arange ( n - c + 1 , n + 1 ) )
313311
314312 if isinstance (num_samples , int ):
315313 num_samples_it = itertools .repeat (num_samples , len (num_correct ))
@@ -319,6 +317,17 @@ def estimator(n: int, c: int, k: int) -> float:
319317
320318 return np .array ([estimator (int (n ), int (c ), k ) for n , c in zip (num_samples_it , num_correct )])
321319
320+ def estimate_beyond_at_k (runtimes , k ):
321+ """Estimates pass@k of each problem and returns them in an array."""
322+
323+ def estimator (runtimes : list , k : int ) -> float :
324+ """Calculates 1 - comb(n - c, k) / comb(n, k)."""
325+ print (runtimes )
326+ print ("============" )
327+ return sum (runtimes [:k ])/ len (runtimes )
328+
329+ return np .array ([estimator (r , k ) for r in runtimes ])
330+
322331def compute_beyond_eval (generations_list , reference_list , timeout = 30 ):
323332 sandbox = Sandbox ()
324333
@@ -353,7 +362,8 @@ def compute_beyond_eval(generations_list, reference_list, timeout=30):
353362 max_runtime = max (runtimes )
354363
355364 # Evaluate generated solutions
356- t_c , p_c , b_c = 0 , 0 , 0
365+ t_c , p_c = 0 , 0
366+ b_l = list ()
357367 difficulty = instance ['difficulty' ]
358368
359369 for index , solution in enumerate (generations ):
@@ -373,29 +383,30 @@ def compute_beyond_eval(generations_list, reference_list, timeout=30):
373383 # Calculate Beyond
374384 if result ['result' ] == "passed" :
375385 runtime = result ['runtime' ]
376- runtime = min (runtime , max_runtime )
377- runtime = max (runtime , min_runtime )
378- b_c += (max_runtime - runtime ) / (max_runtime - min_runtime )
379386 p_c += 1
380387 else :
381388 runtime = float ('inf' )
389+
390+ runtime = min (runtime , max_runtime )
391+ runtime = max (runtime , min_runtime )
392+ b_l += [(max_runtime - runtime ) / (max_runtime - min_runtime )]
382393
383394 scores [difficulty ]['total_c' ] += [t_c ]
384395 scores [difficulty ]['correct_c' ] += [p_c ]
385- scores [difficulty ]['beyond_c' ] += [b_c ]
396+ scores [difficulty ]['beyond_c' ] += [b_l ]
386397
387398 scores ['Average' ]['total_c' ] += [t_c ]
388399 scores ['Average' ]['correct_c' ] += [p_c ]
389- scores ['Average' ]['beyond_c' ] += [b_c ]
400+ scores ['Average' ]['beyond_c' ] += [b_l ]
390401
391402 results = dict ()
392403 for difficulty in ['Easy' , "Medium" , "Hard" , "Average" ]:
393404 total = np .array (scores [difficulty ]['total_c' ])
394405 correct = np .array (scores [difficulty ]['correct_c' ])
395- beyond = np . array ( scores [difficulty ]['beyond_c' ])
406+ beyond = scores [difficulty ]['beyond_c' ]
396407
397- pass_at_k = {f"{ difficulty } _pass@{ k } " : estimate_at_k (total , correct , k ).mean () for k in [1 ,3 ,5 ] if (total >= k ).all ()}
398- beyond_at_k = {f"{ difficulty } _beyond@{ k } " : estimate_at_k ( total , beyond , k ).mean () for k in [1 ,3 ,5 ] if (total >= k ).all ()}
408+ pass_at_k = {f"{ difficulty } _pass@{ k } " : estimate_pass_at_k (total , correct , k ).mean () for k in [1 ,3 ,5 ] if (total >= k ).all ()}
409+ beyond_at_k = {f"{ difficulty } _beyond@{ k } " : estimate_beyond_at_k ( beyond , k ).mean () for k in [1 ,3 ,5 ] if (total >= k ).all ()}
399410
400411 results .update (pass_at_k )
401412 results .update (beyond_at_k )
0 commit comments