2525 from typing_extensions import override
2626
2727import jax
28+ import time
2829import jax .numpy as jnp
2930
3031from flax import nnx
@@ -54,6 +55,7 @@ def __init__(self, config, mesh, learning_rate_schedule, goodput_recorder):
5455 self .metadata = {}
5556 self .train_metadata = defaultdict (float )
5657 self .eval_metadata = defaultdict (float )
58+ self .step_start_time = 0.0
5759
5860 @override
5961 def on_train_start (self , train_ctx : peft_trainer .PeftTrainer ):
@@ -79,6 +81,7 @@ def on_train_start(self, train_ctx: peft_trainer.PeftTrainer):
7981 )
8082
8183 self .metadata ["first_train_step" ] = train_ctx .train_steps
84+ self .step_start_time = time .perf_counter ()
8285
8386 @override
8487 def on_train_end (self , train_ctx : peft_trainer .PeftTrainer ): # pylint: disable=unused-argument
@@ -109,22 +112,26 @@ def on_train_step_end(
109112 train_ctx : peft_trainer .PeftTrainer ,
110113 train_step : int ,
111114 train_loss : float ,
112- step_time : float ,
115+ step_time : float = 0.0 , # No longer provided. See https://github.com/google/tunix/pull/1289.
113116 ):
114117 """Called at the end of training step.
115118 This hook is called by Tunix after the step counter has been incremented for logging purposes.
116119 Therefore, using `train_step - 1` to refer to the state of the previous step counter.
117120 However, we will use the current `train_step` value to record metrics in this hook to be
118121 consistent with Tunix's metric logging convention.
119122 """
120-
121123 assert train_step - 1 in self .train_metadata , (
122124 "SFTTrainingHooks.on_train_step_start() must be called before" " SFTTrainingHooks.on_train_step_end()"
123125 )
124126
125127 if self .metadata ["first_train_step" ] == train_step - 1 :
126128 max_utils .print_mem_stats ("After params initialized" )
127129
130+ # Use our own timing since Tunix passes 0.0
131+ current_time = time .perf_counter ()
132+ step_time = current_time - self .step_start_time
133+ self .step_start_time = current_time
134+
128135 metrics = {
129136 "scalar" : {
130137 "learning/loss" : train_loss ,
0 commit comments