|
15 | 15 | # pylint: disable=bare-except, consider-using-generator, chained-comparison, broad-exception-caught |
16 | 16 | """RL Utils Module.""" |
17 | 17 | import re |
| 18 | +import uuid |
| 19 | +from etils import epath |
18 | 20 | import optax |
19 | | -from maxtext.utils import max_logging |
20 | 21 | import numpy as np |
21 | 22 |
|
22 | 23 |
|
@@ -433,13 +434,6 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs): |
433 | 434 | extracted_responses.append(fallback_matches[-1].strip() if fallback_matches else None) |
434 | 435 |
|
435 | 436 | scores = [] |
436 | | - if tmvp_config.debug.rl: |
437 | | - max_logging.log("START ============================") |
438 | | - max_logging.log(f"Question: {question[0]}") |
439 | | - max_logging.log(f"Answer: {answer[0]}") |
440 | | - max_logging.log(f"Response: {completions[0]}") |
441 | | - max_logging.log(f"Extracted: {extracted_responses[0]}") |
442 | | - max_logging.log("END ==============================") |
443 | 437 |
|
444 | 438 | for guess, true_answer in zip(extracted_responses, answer): |
445 | 439 | if guess is None: |
@@ -469,6 +463,20 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs): |
469 | 463 | scores.append(tmvp_config.reward_exact_answer if guess_val == true_val else 0.0) |
470 | 464 | except: |
471 | 465 | scores.append(0) |
| 466 | + if tmvp_config.debug.rl: |
| 467 | + debug_log_path = epath.Path(tmvp_config.base_output_directory) / tmvp_config.run_name / "debug_rl_logs" |
| 468 | + debug_log_path.mkdir(parents=True, exist_ok=True) |
| 469 | + log_file = debug_log_path / f"check_numbers_{uuid.uuid4().hex}.txt" |
| 470 | + log_content = ( |
| 471 | + "START ============================\n" |
| 472 | + f"Question: {question[0]}\n" |
| 473 | + f"Answer: {answer[0]}\n" |
| 474 | + f"Response: {completions[0]}\n" |
| 475 | + f"Extracted: {extracted_responses[0]}\n" |
| 476 | + f"Reward Score: {scores[0]}\n" |
| 477 | + "END ==============================\n" |
| 478 | + ) |
| 479 | + log_file.write_text(log_content) |
472 | 480 |
|
473 | 481 | return scores |
474 | 482 |
|
|
0 commit comments