Skip to content

Commit 58e2c7e

Browse files
Merge pull request #3581 from AI-Hypercomputer:igorts/sft-timing-fix
PiperOrigin-RevId: 897856338
2 parents 46dbaf5 + 4440993 commit 58e2c7e

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

  • src/maxtext/trainers/post_train/sft

src/maxtext/trainers/post_train/sft/hooks.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from typing_extensions import override
2626

2727
import jax
28+
import time
2829
import jax.numpy as jnp
2930

3031
from flax import nnx
@@ -54,6 +55,7 @@ def __init__(self, config, mesh, learning_rate_schedule, goodput_recorder):
5455
self.metadata = {}
5556
self.train_metadata = defaultdict(float)
5657
self.eval_metadata = defaultdict(float)
58+
self.step_start_time = 0.0
5759

5860
@override
5961
def on_train_start(self, train_ctx: peft_trainer.PeftTrainer):
@@ -79,6 +81,7 @@ def on_train_start(self, train_ctx: peft_trainer.PeftTrainer):
7981
)
8082

8183
self.metadata["first_train_step"] = train_ctx.train_steps
84+
self.step_start_time = time.perf_counter()
8285

8386
@override
8487
def on_train_end(self, train_ctx: peft_trainer.PeftTrainer): # pylint: disable=unused-argument
@@ -109,22 +112,26 @@ def on_train_step_end(
109112
train_ctx: peft_trainer.PeftTrainer,
110113
train_step: int,
111114
train_loss: float,
112-
step_time: float,
115+
step_time: float = 0.0, # No longer provided. See https://github.com/google/tunix/pull/1289.
113116
):
114117
"""Called at the end of training step.
115118
This hook is called by Tunix after the step counter has been incremented for logging purposes.
116119
Therefore, using `train_step - 1` to refer to the state of the previous step counter.
117120
However, we will use the current `train_step` value to record metrics in this hook to be
118121
consistent with Tunix's metric logging convention.
119122
"""
120-
121123
assert train_step - 1 in self.train_metadata, (
122124
"SFTTrainingHooks.on_train_step_start() must be called before" " SFTTrainingHooks.on_train_step_end()"
123125
)
124126

125127
if self.metadata["first_train_step"] == train_step - 1:
126128
max_utils.print_mem_stats("After params initialized")
127129

130+
# Use our own timing since Tunix passes 0.0
131+
current_time = time.perf_counter()
132+
step_time = current_time - self.step_start_time
133+
self.step_start_time = current_time
134+
128135
metrics = {
129136
"scalar": {
130137
"learning/loss": train_loss,

0 commit comments

Comments
 (0)