|
67 | 67 | from tunix.rl import rl_cluster as rl_cluster_lib |
68 | 68 | from tunix.rl.rollout import base_rollout |
69 | 69 | from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner |
| 70 | +from tunix.rl.agentic.agentic_grpo_learner import GrpoConfig as AgenticGrpoConfig, GrpoLearner as AgenticGrpoLearner |
70 | 71 | from tunix.sft import metrics_logger, profiler |
71 | 72 |
|
72 | 73 | # for vLLM we can skip JAX precompilation with this flag, it makes startup faster |
@@ -386,6 +387,16 @@ def _filter_long_prompts(x): |
386 | 387 | return len(tokens) <= trainer_config.max_prefill_predict_length |
387 | 388 |
|
388 | 389 | train_dataset = train_dataset.filter(_filter_long_prompts) |
| 390 | + |
| 391 | + # AgenticGRPOLearner uses a built in chat parser that expects raw prompts |
| 392 | + if getattr(trainer_config.rl, "use_agentic_rollout", False): |
| 393 | + |
| 394 | + def _use_raw_prompt(x): |
| 395 | + x["prompts"] = x["question"] |
| 396 | + return x |
| 397 | + |
| 398 | + train_dataset = train_dataset.map(_use_raw_prompt) |
| 399 | + |
389 | 400 | dataset_size = int(trainer_config.num_batches * trainer_config.batch_size * trainer_config.train_fraction) |
390 | 401 | train_dataset = train_dataset[:dataset_size] |
391 | 402 | train_dataset = train_dataset.repeat(trainer_config.num_epoch) |
@@ -525,27 +536,24 @@ def create_rl_components( |
525 | 536 | rollout_vllm_max_num_batched_tokens=trainer_config.max_num_batched_tokens, |
526 | 537 | rollout_vllm_max_num_seqs=trainer_config.max_num_seqs, |
527 | 538 | rollout_vllm_async_scheduling=trainer_config.async_scheduling, |
| 539 | + rollout_vllm_server_mode=trainer_config.rl.use_agentic_rollout, |
528 | 540 | rollout_vllm_kwargs={ |
529 | 541 | "hf_overrides": trainer_config.vllm_hf_overrides, |
530 | 542 | "enable_expert_parallel": sampler_config.rollout_expert_parallelism > 1, |
| 543 | + "enable_prefix_caching": True, # Enable prefix caching to speed up generation for long prompts |
531 | 544 | }, |
532 | 545 | rollout_vllm_sampling_kwargs={ |
533 | 546 | "stop": trainer_config.stop_strings, |
534 | 547 | "detokenize": trainer_config.stop_strings is not None, |
535 | 548 | "include_stop_str_in_output": trainer_config.stop_strings is not None, |
536 | 549 | }, |
| 550 | + # AgenticGRPOLearner requires log-probabilities from the rollout engine |
| 551 | + # to support off-policy filtering and multi-iteration training. |
| 552 | + **({"return_logprobs": True} if trainer_config.rl.use_agentic_rollout else {}), |
537 | 553 | **get_rollout_kwargs_for_parallelism(sampler_config, len(sampler_devices)), |
538 | 554 | ), |
539 | 555 | ) |
540 | 556 |
|
541 | | - grpo_config = GrpoConfig( |
542 | | - num_generations=trainer_config.rl.num_generations, |
543 | | - num_iterations=trainer_config.rl.num_iterations, |
544 | | - beta=trainer_config.rl.grpo_beta, |
545 | | - epsilon=trainer_config.rl.grpo_epsilon, |
546 | | - loss_algo=trainer_config.rl.loss_algo, |
547 | | - ) |
548 | | - |
549 | 557 | # Create RL cluster |
550 | 558 | max_logging.log("Creating RL cluster...") |
551 | 559 | rl_cluster_kwargs = {} |
@@ -578,19 +586,57 @@ def _reward_fn(**kwargs): |
578 | 586 |
|
579 | 587 | return _reward_fn |
580 | 588 |
|
| 589 | + reward_fns = [ # type: ignore |
| 590 | + make_reward_fn(utils_rl.match_format_exactly), |
| 591 | + make_reward_fn(utils_rl.match_format_approximately), |
| 592 | + # TODO(atwigg): comment out to simplify reward and overlap with check_numbers |
| 593 | + make_reward_fn(utils_rl.check_answer), |
| 594 | + make_reward_fn(utils_rl.check_numbers), |
| 595 | + ] |
| 596 | + |
581 | 597 | # Create RL trainer |
582 | 598 | max_logging.log("Setting up RL trainer...") |
583 | | - rl_trainer = GrpoLearner( |
584 | | - rl_cluster=rl_cluster, |
585 | | - reward_fns=[ # type: ignore |
586 | | - make_reward_fn(utils_rl.match_format_exactly), |
587 | | - make_reward_fn(utils_rl.match_format_approximately), |
588 | | - # TODO(atwigg): comment out to simplify reward and overlap with check_numbers |
589 | | - make_reward_fn(utils_rl.check_answer), |
590 | | - make_reward_fn(utils_rl.check_numbers), |
591 | | - ], |
592 | | - algo_config=grpo_config, |
593 | | - ) |
| 599 | + if trainer_config.rl.use_agentic_rollout: |
| 600 | + max_logging.log("Using AgenticGRPOLearner with async online rollouts.") |
| 601 | + grpo_config = AgenticGrpoConfig( |
| 602 | + num_generations=trainer_config.rl.num_generations, |
| 603 | + num_iterations=trainer_config.rl.num_iterations, |
| 604 | + beta=trainer_config.rl.grpo_beta, |
| 605 | + epsilon=trainer_config.rl.grpo_epsilon, |
| 606 | + loss_algo=trainer_config.rl.loss_algo, |
| 607 | + max_response_length=trainer_config.max_target_length - trainer_config.max_prefill_predict_length, |
| 608 | + max_concurrency=trainer_config.rl.max_concurrency, |
| 609 | + off_policy_steps=trainer_config.rl.off_policy_steps, |
| 610 | + system_prompt=trainer_config.rl.system_prompt, |
| 611 | + degenerate_group_masking=trainer_config.rl.degenerate_group_masking, |
| 612 | + epsilon_high=trainer_config.rl.epsilon_high, |
| 613 | + ) |
| 614 | + # Instantiate the custom MaxText chat parser |
| 615 | + template_config = load_template_from_file(trainer_config.chat_template_path) |
| 616 | + chat_parser = utils_rl.MaxTextChatParser( |
| 617 | + model_tokenizer=model_tokenizer, template_config=template_config, tmvp_config=trainer_config |
| 618 | + ) |
| 619 | + rl_trainer = AgenticGrpoLearner( |
| 620 | + rl_cluster=rl_cluster, |
| 621 | + reward_fns=reward_fns, |
| 622 | + algo_config=grpo_config, |
| 623 | + chat_parser=chat_parser, |
| 624 | + metric_fns=[utils_rl.get_correctness_metrics], |
| 625 | + ) |
| 626 | + else: |
| 627 | + max_logging.log("Using standard GRPOLearner with offline rollouts.") |
| 628 | + grpo_config = GrpoConfig( |
| 629 | + num_generations=trainer_config.rl.num_generations, |
| 630 | + num_iterations=trainer_config.rl.num_iterations, |
| 631 | + beta=trainer_config.rl.grpo_beta, |
| 632 | + epsilon=trainer_config.rl.grpo_epsilon, |
| 633 | + loss_algo=trainer_config.rl.loss_algo, |
| 634 | + ) |
| 635 | + rl_trainer = GrpoLearner( |
| 636 | + rl_cluster=rl_cluster, |
| 637 | + reward_fns=reward_fns, |
| 638 | + algo_config=grpo_config, |
| 639 | + ) |
594 | 640 |
|
595 | 641 | return rl_cluster, rl_trainer, optimizer |
596 | 642 |
|
|
0 commit comments