Skip to content

Commit 1f04ad1

Browse files
Merge pull request #3540 from AI-Hypercomputer:nicogrande/async-rollouts
PiperOrigin-RevId: 894247267
2 parents d5cbf3c + daac9e0 commit 1f04ad1

4 files changed

Lines changed: 182 additions & 39 deletions

File tree

src/maxtext/configs/post_train/rl.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,21 @@ rl:
5454
grpo_epsilon: 0.2
5555
loss_algo: 'grpo' # grpo or gspo-token
5656

57+
# ====== Agentic Rollout ======
58+
# If True, uses the async AgenticGRPOLearner, which overlaps rollout generation
59+
# with training for faster throughput via online vLLM inference.
60+
use_agentic_rollout: False
61+
# Max concurrent rollout requests when using agentic rollout.
62+
max_concurrency: 256
63+
# Number of off-policy steps tolerated before requiring a policy update.
64+
off_policy_steps: 0
65+
# System prompt injected into the agent at rollout time.
66+
system_prompt: ''
67+
# If True, mask degenerate groups (all-zero advantages) from contributing to the loss.
68+
degenerate_group_masking: True
69+
# Upper-bound clipping epsilon for GRPO loss; defaults to grpo_epsilon when null.
70+
epsilon_high: null
71+
5772

5873
# ====== Models ======
5974
# for MaxText

src/maxtext/configs/types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1709,6 +1709,20 @@ class RL(BaseModel):
17091709
grpo_beta: float = Field(0.08, description="Coefficient for the KL divergence penalty (β).")
17101710
grpo_epsilon: float = Field(0.2, description="Epsilon value for clipping in the GRPO loss.")
17111711
loss_algo: Literal["grpo", "gspo-token"] = Field("grpo", description="Loss algorithm, i.e., 'grpo' or 'gspo-token'.")
1712+
use_agentic_rollout: bool = Field(
1713+
False, description="If True, uses the asynchronous AgenticGRPOLearner for online vLLM rollouts."
1714+
)
1715+
max_concurrency: int = Field(256, description="Maximum number of concurrent rollout requests (agentic rollout only).")
1716+
off_policy_steps: int = Field(
1717+
0, description="Number of off-policy steps tolerated before requiring a policy update (agentic only)."
1718+
)
1719+
system_prompt: str = Field("", description="System prompt injected into the agent at rollout time (agentic only).")
1720+
degenerate_group_masking: bool = Field(
1721+
True, description="Mask degenerate groups (all-zero advantages) from contributing to loss (agentic only)."
1722+
)
1723+
epsilon_high: Optional[float] = Field(
1724+
None, description="Upper-bound clipping epsilon for GRPO loss. Defaults to epsilon when None (agentic only)."
1725+
)
17121726

17131727

17141728
class RLDataset(BaseModel):

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from tunix.rl import rl_cluster as rl_cluster_lib
6868
from tunix.rl.rollout import base_rollout
6969
from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner
70+
from tunix.rl.agentic.agentic_grpo_learner import GrpoConfig as AgenticGrpoConfig, GrpoLearner as AgenticGrpoLearner
7071
from tunix.sft import metrics_logger, profiler
7172

7273
# for vLLM we can skip JAX precompilation with this flag, it makes startup faster
@@ -386,6 +387,16 @@ def _filter_long_prompts(x):
386387
return len(tokens) <= trainer_config.max_prefill_predict_length
387388

388389
train_dataset = train_dataset.filter(_filter_long_prompts)
390+
391+
# AgenticGRPOLearner uses a built in chat parser that expects raw prompts
392+
if getattr(trainer_config.rl, "use_agentic_rollout", False):
393+
394+
def _use_raw_prompt(x):
395+
x["prompts"] = x["question"]
396+
return x
397+
398+
train_dataset = train_dataset.map(_use_raw_prompt)
399+
389400
dataset_size = int(trainer_config.num_batches * trainer_config.batch_size * trainer_config.train_fraction)
390401
train_dataset = train_dataset[:dataset_size]
391402
train_dataset = train_dataset.repeat(trainer_config.num_epoch)
@@ -525,27 +536,24 @@ def create_rl_components(
525536
rollout_vllm_max_num_batched_tokens=trainer_config.max_num_batched_tokens,
526537
rollout_vllm_max_num_seqs=trainer_config.max_num_seqs,
527538
rollout_vllm_async_scheduling=trainer_config.async_scheduling,
539+
rollout_vllm_server_mode=trainer_config.rl.use_agentic_rollout,
528540
rollout_vllm_kwargs={
529541
"hf_overrides": trainer_config.vllm_hf_overrides,
530542
"enable_expert_parallel": sampler_config.rollout_expert_parallelism > 1,
543+
"enable_prefix_caching": True, # Enable prefix caching to speed up generation for long prompts
531544
},
532545
rollout_vllm_sampling_kwargs={
533546
"stop": trainer_config.stop_strings,
534547
"detokenize": trainer_config.stop_strings is not None,
535548
"include_stop_str_in_output": trainer_config.stop_strings is not None,
536549
},
550+
# AgenticGRPOLearner requires log-probabilities from the rollout engine
551+
# to support off-policy filtering and multi-iteration training.
552+
**({"return_logprobs": True} if trainer_config.rl.use_agentic_rollout else {}),
537553
**get_rollout_kwargs_for_parallelism(sampler_config, len(sampler_devices)),
538554
),
539555
)
540556

541-
grpo_config = GrpoConfig(
542-
num_generations=trainer_config.rl.num_generations,
543-
num_iterations=trainer_config.rl.num_iterations,
544-
beta=trainer_config.rl.grpo_beta,
545-
epsilon=trainer_config.rl.grpo_epsilon,
546-
loss_algo=trainer_config.rl.loss_algo,
547-
)
548-
549557
# Create RL cluster
550558
max_logging.log("Creating RL cluster...")
551559
rl_cluster_kwargs = {}
@@ -578,19 +586,57 @@ def _reward_fn(**kwargs):
578586

579587
return _reward_fn
580588

589+
reward_fns = [ # type: ignore
590+
make_reward_fn(utils_rl.match_format_exactly),
591+
make_reward_fn(utils_rl.match_format_approximately),
592+
# TODO(atwigg): comment out to simplify reward and overlap with check_numbers
593+
make_reward_fn(utils_rl.check_answer),
594+
make_reward_fn(utils_rl.check_numbers),
595+
]
596+
581597
# Create RL trainer
582598
max_logging.log("Setting up RL trainer...")
583-
rl_trainer = GrpoLearner(
584-
rl_cluster=rl_cluster,
585-
reward_fns=[ # type: ignore
586-
make_reward_fn(utils_rl.match_format_exactly),
587-
make_reward_fn(utils_rl.match_format_approximately),
588-
# TODO(atwigg): comment out to simplify reward and overlap with check_numbers
589-
make_reward_fn(utils_rl.check_answer),
590-
make_reward_fn(utils_rl.check_numbers),
591-
],
592-
algo_config=grpo_config,
593-
)
599+
if trainer_config.rl.use_agentic_rollout:
600+
max_logging.log("Using AgenticGRPOLearner with async online rollouts.")
601+
grpo_config = AgenticGrpoConfig(
602+
num_generations=trainer_config.rl.num_generations,
603+
num_iterations=trainer_config.rl.num_iterations,
604+
beta=trainer_config.rl.grpo_beta,
605+
epsilon=trainer_config.rl.grpo_epsilon,
606+
loss_algo=trainer_config.rl.loss_algo,
607+
max_response_length=trainer_config.max_target_length - trainer_config.max_prefill_predict_length,
608+
max_concurrency=trainer_config.rl.max_concurrency,
609+
off_policy_steps=trainer_config.rl.off_policy_steps,
610+
system_prompt=trainer_config.rl.system_prompt,
611+
degenerate_group_masking=trainer_config.rl.degenerate_group_masking,
612+
epsilon_high=trainer_config.rl.epsilon_high,
613+
)
614+
# Instantiate the custom MaxText chat parser
615+
template_config = load_template_from_file(trainer_config.chat_template_path)
616+
chat_parser = utils_rl.MaxTextChatParser(
617+
model_tokenizer=model_tokenizer, template_config=template_config, tmvp_config=trainer_config
618+
)
619+
rl_trainer = AgenticGrpoLearner(
620+
rl_cluster=rl_cluster,
621+
reward_fns=reward_fns,
622+
algo_config=grpo_config,
623+
chat_parser=chat_parser,
624+
metric_fns=[utils_rl.get_correctness_metrics],
625+
)
626+
else:
627+
max_logging.log("Using standard GRPOLearner with offline rollouts.")
628+
grpo_config = GrpoConfig(
629+
num_generations=trainer_config.rl.num_generations,
630+
num_iterations=trainer_config.rl.num_iterations,
631+
beta=trainer_config.rl.grpo_beta,
632+
epsilon=trainer_config.rl.grpo_epsilon,
633+
loss_algo=trainer_config.rl.loss_algo,
634+
)
635+
rl_trainer = GrpoLearner(
636+
rl_cluster=rl_cluster,
637+
reward_fns=reward_fns,
638+
algo_config=grpo_config,
639+
)
594640

595641
return rl_cluster, rl_trainer, optimizer
596642

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 88 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@
1717
import re
1818
import optax
1919
from maxtext.utils import max_logging
20+
import numpy as np
2021

2122

2223
from math_verify.errors import TimeoutException
2324
from math_verify.metric import math_metric
2425
from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig
2526
from math_verify import parse
2627

28+
from tunix.rl.agentic.parser.chat_template_parser import parser as agentic_chat_template_parser
29+
30+
2731
# initialize math_verify_func once
2832
math_verify_func = math_metric(
2933
gold_extraction_target=(LatexExtractionConfig(),),
@@ -514,6 +518,23 @@ def make_optimizer(learning_rate):
514518
return optax.inject_hyperparams(make_optimizer)(learning_rate=schedule)
515519

516520

521+
def format_maxtext_messages(messages: list[dict[str, str]], template_config: dict, tmvp_config) -> list[dict[str, str]]:
522+
"""Helper to inject MaxText's system prompt into the input user messages."""
523+
formatted_messages = []
524+
for msg in messages:
525+
formatted_content = template_config["TEMPLATE"].format(
526+
system_prompt=template_config["SYSTEM_PROMPT"].format(
527+
reasoning_start_token=tmvp_config.reasoning_start_token,
528+
reasoning_end_token=tmvp_config.reasoning_end_token,
529+
solution_start_token=tmvp_config.solution_start_token,
530+
solution_end_token=tmvp_config.solution_end_token,
531+
),
532+
question=msg,
533+
)
534+
formatted_messages.append({"role": "user", "content": formatted_content})
535+
return formatted_messages
536+
537+
517538
def process_data(dataset_name, model_tokenizer, template_config, tmvp_config, x):
518539
"""Function to process input dataset"""
519540

@@ -552,28 +573,75 @@ def _to_str(val):
552573
if dataset_name == "gsm8k":
553574
answer = extract_hash_answer(answer)
554575

576+
messages = [question]
577+
formatted_messages = format_maxtext_messages(messages, template_config, tmvp_config)
578+
579+
prompts = model_tokenizer.apply_chat_template(
580+
formatted_messages,
581+
tokenize=False,
582+
add_generation_prompt=True,
583+
)
584+
555585
return {
556-
# passed to model forward pass
557-
"prompts": model_tokenizer.apply_chat_template(
558-
[
559-
{
560-
"role": "user",
561-
"content": template_config["TEMPLATE"].format(
562-
system_prompt=template_config["SYSTEM_PROMPT"].format(
563-
reasoning_start_token=tmvp_config.reasoning_start_token,
564-
reasoning_end_token=tmvp_config.reasoning_end_token,
565-
solution_start_token=tmvp_config.solution_start_token,
566-
solution_end_token=tmvp_config.solution_end_token,
567-
),
568-
question=question,
569-
),
570-
},
571-
],
572-
tokenize=False,
573-
add_generation_prompt=True,
574-
),
575-
# passed to reward functions
586+
# pre-formatted prompts for evaluation
587+
"prompts": prompts,
588+
# raw question for AgenticGRPOLearner to bypass formatting
576589
"question": question,
577590
# passed to reward functions
578591
"answer": answer,
579592
}
593+
594+
595+
def get_correctness_metrics(prompts, completions, rewards, advantages, **kwargs):
596+
"""Compute correctness statistics metrics based on rewards."""
597+
del prompts, completions, advantages, kwargs
598+
solve_all = (rewards > 0.1).all()
599+
solve_none = (rewards == 0).all()
600+
solve_partial = (~solve_all) and (~solve_none)
601+
solve_ratio = (rewards > 0.1).mean()
602+
return {
603+
"rewards/solve_all": (
604+
1 if solve_all else 0,
605+
np.mean,
606+
),
607+
"rewards/solve_none": (
608+
1 if solve_none else 0,
609+
np.mean,
610+
),
611+
"rewards/solve_partial": (
612+
1 if solve_partial else 0,
613+
np.mean,
614+
),
615+
"rewards/solve_ratio": (
616+
solve_ratio,
617+
np.mean,
618+
),
619+
}
620+
621+
622+
class MaxTextChatParser(agentic_chat_template_parser.DefaultChatTemplateParser):
623+
"""
624+
Custom Chat Parser for MaxText that intercepts message lists dynamically
625+
during agentic rollouts and injects the necessary system templates and
626+
special tokens using the shared helper.
627+
"""
628+
629+
def __init__(self, model_tokenizer, template_config, tmvp_config):
630+
super().__init__(model_tokenizer)
631+
self.template_config = template_config
632+
self.tmvp_config = tmvp_config
633+
634+
def parse(
635+
self,
636+
messages: list[dict[str, str]],
637+
add_generation_prompt: bool = False,
638+
is_first_msg: bool = False,
639+
) -> str:
640+
"""Overrides the default parse method to apply MaxText-specific formatting to the messages."""
641+
# Apply MaxText specific formatting to the messages
642+
formatted_messages = format_maxtext_messages(messages, self.template_config, self.tmvp_config)
643+
644+
# Delegate to Tunix default parser to apply the tokenizer's chat template
645+
return super().parse(
646+
messages=formatted_messages, add_generation_prompt=add_generation_prompt, is_first_msg=is_first_msg
647+
)

0 commit comments

Comments
 (0)