Skip to content

Commit 4e81fbb

Browse files
committed
Log optimizer step duration via callback
Add OptimizerTimerCallback to basics/base_task.py to measure GPU optimizer step time using torch.cuda.Event and torch.cuda.synchronize. The callback records start/end events around optimizer steps (after epoch 0) and logs the elapsed milliseconds as "stats/optimizer_step_duration_ms" via pl_module.log (on_step, shown in prog_bar). The callback is registered in the Trainer callbacks so durations appear in TensorBoard/console. Note: a local timer_callback variable is instantiated but the callbacks list also constructs a new OptimizerTimerCallback (minor redundancy). Update base_task.py
1 parent f0a1c19 commit 4e81fbb

1 file changed

Lines changed: 33 additions & 0 deletions

File tree

basics/base_task.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torchmetrics import Metric, MeanMetric
1616
import lightning.pytorch as pl
1717
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_only
18+
from lightning.pytorch.callbacks import Callback
1819

1920
from basics.base_module import CategorizedModule
2021
from utils.hparams import hparams
@@ -32,6 +33,37 @@
3233
format=log_format, datefmt='%m/%d %I:%M:%S %p')
3334

3435

36+
class OptimizerTimerCallback(Callback):
37+
def __init__(self):
38+
super().__init__()
39+
# 使用 CUDA Event 确保获取的是 GPU 真实执行时间,而非 CPU 发射时间
40+
self.start_event = torch.cuda.Event(enable_timing=True)
41+
self.end_event = torch.cuda.Event(enable_timing=True)
42+
43+
def on_before_optimizer_step(self, trainer, pl_module, optimizer):
44+
# 只在第一个 Epoch 之后开始计时
45+
if trainer.current_epoch > 0:
46+
self.start_event.record()
47+
48+
def on_after_optimizer_step(self, trainer, pl_module, optimizer):
49+
if trainer.current_epoch > 0:
50+
self.end_event.record()
51+
torch.cuda.synchronize() # 等待 GPU 完成该 Step 的所有计算
52+
53+
# 计算耗时(毫秒)
54+
epoch_time_ms = self.start_event.elapsed_time(self.end_event)
55+
56+
# 记录到 TensorBoard
57+
# pl_module.log 会自动寻找当前配置的 Logger (如 TensorBoardLogger)
58+
pl_module.log(
59+
"stats/optimizer_step_duration_ms",
60+
epoch_time_ms,
61+
on_step=True,
62+
on_epoch=False,
63+
prog_bar=True
64+
)
65+
66+
3567
class BaseTask(pl.LightningModule):
3668
"""
3769
Base class for training tasks.
@@ -423,6 +455,7 @@ def start(cls):
423455
),
424456
# LearningRateMonitor(logging_interval='step'),
425457
DsTQDMProgressBar(),
458+
OptimizerTimerCallback(),
426459
],
427460
logger=DsTensorBoardLogger(
428461
save_dir=str(work_dir),

0 commit comments

Comments
 (0)