Skip to content

Commit 44fc6d0

Browse files
Merge pull request #3580 from AI-Hypercomputer:anisha-microbatch
PiperOrigin-RevId: 895488431
2 parents ae52e46 + 6893345 commit 44fc6d0

4 files changed

Lines changed: 26 additions & 13 deletions

File tree

src/maxtext/configs/post_train/rl.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ batch_size: 1
111111
num_batches: 4
112112
# A batch can be split into multiple micro batches for memory management
113113
# and/or async sampling and training.
114-
micro_batch_size: -1
114+
train_micro_batch_size: -1
115+
rollout_micro_batch_size: -1
115116
# Keep `num_test_batches` low so that evaluation runs quickly. It can be
116117
# increased to a max. of 330 (if batch size is 4).
117118
num_test_batches: 5 # 200

src/maxtext/configs/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1733,7 +1733,8 @@ class RLDataset(BaseModel):
17331733
num_test_batches: int = Field(5, description="Number of batches for RL evaluation.")
17341734
test_batch_start_index: int = Field(0, description="Start index for the test dataset")
17351735
train_fraction: float = Field(1.0, description="Fraction of the dataset to be used for training.")
1736-
micro_batch_size: int = Field(-1, description="Micro batch size for rollout and training.")
1736+
train_micro_batch_size: int = Field(-1, description="Micro batch size for training.")
1737+
rollout_micro_batch_size: int = Field(-1, description="Micro batch size for rollout.")
17371738

17381739

17391740
class RLEvaluation(BaseModel):

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,10 @@ def create_rl_components(
461461
checkpoint_dir = None
462462

463463
# Set up micro batching
464-
micro_batch_size = None if trainer_config.micro_batch_size == -1 else trainer_config.micro_batch_size
464+
train_micro_batch_size = None if trainer_config.train_micro_batch_size == -1 else trainer_config.train_micro_batch_size
465+
rollout_micro_batch_size = (
466+
None if trainer_config.rollout_micro_batch_size == -1 else trainer_config.rollout_micro_batch_size
467+
)
465468

466469
# Setup metrics logging
467470
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
@@ -511,8 +514,8 @@ def create_rl_components(
511514
eval_every_n_steps=trainer_config.eval_interval,
512515
max_steps=max_train_steps,
513516
mini_batch_size=trainer_config.batch_size,
514-
train_micro_batch_size=micro_batch_size,
515-
rollout_micro_batch_size=micro_batch_size,
517+
train_micro_batch_size=train_micro_batch_size,
518+
rollout_micro_batch_size=rollout_micro_batch_size,
516519
metrics_logging_options=metrics_logging_options,
517520
profiler_options=profiler_options,
518521
checkpoint_root_directory=checkpoint_dir,

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
# pylint: disable=bare-except, consider-using-generator, chained-comparison, broad-exception-caught
1616
"""RL Utils Module."""
1717
import re
18+
import uuid
19+
from etils import epath
1820
import optax
19-
from maxtext.utils import max_logging
2021
import numpy as np
2122

2223

@@ -433,13 +434,6 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
433434
extracted_responses.append(fallback_matches[-1].strip() if fallback_matches else None)
434435

435436
scores = []
436-
if tmvp_config.debug.rl:
437-
max_logging.log("START ============================")
438-
max_logging.log(f"Question: {question[0]}")
439-
max_logging.log(f"Answer: {answer[0]}")
440-
max_logging.log(f"Response: {completions[0]}")
441-
max_logging.log(f"Extracted: {extracted_responses[0]}")
442-
max_logging.log("END ==============================")
443437

444438
for guess, true_answer in zip(extracted_responses, answer):
445439
if guess is None:
@@ -469,6 +463,20 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
469463
scores.append(tmvp_config.reward_exact_answer if guess_val == true_val else 0.0)
470464
except:
471465
scores.append(0)
466+
if tmvp_config.debug.rl:
467+
debug_log_path = epath.Path(tmvp_config.base_output_directory) / tmvp_config.run_name / "debug_rl_logs"
468+
debug_log_path.mkdir(parents=True, exist_ok=True)
469+
log_file = debug_log_path / f"check_numbers_{uuid.uuid4().hex}.txt"
470+
log_content = (
471+
"START ============================\n"
472+
f"Question: {question[0]}\n"
473+
f"Answer: {answer[0]}\n"
474+
f"Response: {completions[0]}\n"
475+
f"Extracted: {extracted_responses[0]}\n"
476+
f"Reward Score: {scores[0]}\n"
477+
"END ==============================\n"
478+
)
479+
log_file.write_text(log_content)
472480

473481
return scores
474482

0 commit comments

Comments
 (0)