@@ -478,32 +478,40 @@ def extract_hash_answer(text: str) -> str | None:
478478
479479def get_optimizer (tmvp_config , max_train_steps ):
480480 """Function to obtain an optax optimizer, currently we use adamw."""
481- optimizer = optax .adamw (
482- learning_rate = optax .schedules .warmup_cosine_decay_schedule (
483- init_value = 0.0 ,
484- peak_value = tmvp_config .learning_rate ,
485- # Linearly increase learning rate from 0. to learning_rate in the first
486- # warmup_steps_fraction training steps, and then gradually decrease the
487- # learning rate to 0 using cosine scheduler.
488- warmup_steps = int (tmvp_config .warmup_steps_fraction * max_train_steps ),
489- decay_steps = max_train_steps ,
490- end_value = 0.0 ,
491- ),
492- b1 = tmvp_config .adam_b1 ,
493- b2 = tmvp_config .adam_b2 ,
494- weight_decay = tmvp_config .adam_weight_decay ,
481+ schedule = optax .schedules .warmup_cosine_decay_schedule (
482+ init_value = 0.0 ,
483+ peak_value = tmvp_config .learning_rate ,
484+ # Linearly increase learning rate from 0. to learning_rate in the first
485+ # warmup_steps_fraction training steps, and then gradually decrease the
486+ # learning rate to 0 using cosine scheduler.
487+ warmup_steps = int (tmvp_config .warmup_steps_fraction * max_train_steps ),
488+ decay_steps = max_train_steps ,
489+ end_value = 0.0 ,
495490 )
496491
497492 # TODO: @mazumdera: try optimizer offloading with adamw
498493 # Add gradient clipping if specified
499494 # Grad clipping to prevent large gradients. We find this
500495 # important to keep KL divergence in check.
501- if tmvp_config .gradient_clipping_threshold > 0 :
502- optimizer = optax .chain (
503- optax .clip_by_global_norm (max_norm = tmvp_config .gradient_clipping_threshold ),
504- optimizer ,
496+ def make_optimizer (learning_rate ):
497+ transforms = []
498+ if tmvp_config .gradient_clipping_threshold > 0 :
499+ transforms .append (optax .clip_by_global_norm (max_norm = tmvp_config .gradient_clipping_threshold ))
500+ transforms .append (
501+ optax .adamw (
502+ learning_rate = learning_rate ,
503+ b1 = tmvp_config .adam_b1 ,
504+ b2 = tmvp_config .adam_b2 ,
505+ weight_decay = tmvp_config .adam_weight_decay ,
506+ )
505507 )
506- return optimizer
508+ return optax .chain (* transforms )
509+
510+ # Wrap the entire optimizer (including gradient clipping) with
511+ # inject_hyperparams so opt_state.hyperparams['learning_rate'] is at the
512+ # top level of the state tree. This is required for tunix's peft_trainer to
513+ # automatically read and log the per-step learning rate.
514+ return optax .inject_hyperparams (make_optimizer )(learning_rate = schedule )
507515
508516
509517def process_data (dataset_name , model_tokenizer , template_config , tmvp_config , x ):
0 commit comments