@@ -131,6 +131,12 @@ class SupervisedTrainer(Trainer):
131131 `torch.Tensor` before forward pass, then converted back afterward with copied meta information.
132132 compile_kwargs: dict of the args for `torch.compile()` API, for more details:
133133 https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.
134+ accumulation_steps: number of mini-batches over which to accumulate gradients before
135+ calling ``optimizer.step()``, effectively simulating a larger batch size on
136+ memory-constrained hardware. Must be a positive integer. Default: 1 (no accumulation).
137+ When ``epoch_length`` is known and not divisible by ``accumulation_steps``, a flush
138+ (optimizer step) is performed at the end of each epoch so no gradients are silently
139+ discarded. The loss stored in ``engine.state.output`` is always the **unscaled** value.
134140 """
135141
136142 def __init__ (
@@ -160,7 +166,10 @@ def __init__(
160166 amp_kwargs : dict | None = None ,
161167 compile : bool = False ,
162168 compile_kwargs : dict | None = None ,
169+ accumulation_steps : int = 1 ,
163170 ) -> None :
171+ if accumulation_steps < 1 :
172+ raise ValueError (f"`accumulation_steps` must be a positive integer, got { accumulation_steps !r} ." )
164173 super ().__init__ (
165174 device = device ,
166175 max_epochs = max_epochs ,
@@ -190,6 +199,7 @@ def __init__(
190199 self .loss_function = loss_function
191200 self .inferer = SimpleInferer () if inferer is None else inferer
192201 self .optim_set_to_none = optim_set_to_none
202+ self .accumulation_steps = accumulation_steps
193203
194204 def _iteration (self , engine : SupervisedTrainer , batchdata : dict [str , torch .Tensor ]) -> dict :
195205 """
@@ -245,21 +255,42 @@ def _compute_pred_loss():
245255 engine .state .output [Keys .LOSS ] = engine .loss_function (engine .state .output [Keys .PRED ], targets ).mean ()
246256 engine .fire_event (IterationEvents .LOSS_COMPLETED )
247257
258+ # Determine gradient accumulation state
259+ acc = engine .accumulation_steps
260+ if acc > 1 :
261+ epoch_length = engine .state .epoch_length
262+ if epoch_length is not None :
263+ local_iter = (engine .state .iteration - 1 ) % epoch_length # 0-indexed within epoch
264+ should_zero_grad = local_iter % acc == 0
265+ should_step = (local_iter + 1 ) % acc == 0 or (local_iter + 1 ) == epoch_length
266+ else :
267+ local_iter = engine .state .iteration - 1 # 0-indexed global
268+ should_zero_grad = local_iter % acc == 0
269+ should_step = (local_iter + 1 ) % acc == 0
270+ else :
271+ should_zero_grad = True
272+ should_step = True
273+
248274 engine .network .train ()
249- engine .optimizer .zero_grad (set_to_none = engine .optim_set_to_none )
275+ if should_zero_grad :
276+ engine .optimizer .zero_grad (set_to_none = engine .optim_set_to_none )
250277
251278 if engine .amp and engine .scaler is not None :
252279 with torch .autocast ("cuda" , ** engine .amp_kwargs ):
253280 _compute_pred_loss ()
254- engine .scaler .scale (engine .state .output [Keys .LOSS ]).backward ()
281+ loss = engine .state .output [Keys .LOSS ]
282+ engine .scaler .scale (loss / acc if acc > 1 else loss ).backward ()
255283 engine .fire_event (IterationEvents .BACKWARD_COMPLETED )
256- engine .scaler .step (engine .optimizer )
257- engine .scaler .update ()
284+ if should_step :
285+ engine .scaler .step (engine .optimizer )
286+ engine .scaler .update ()
258287 else :
259288 _compute_pred_loss ()
260- engine .state .output [Keys .LOSS ].backward ()
289+ loss = engine .state .output [Keys .LOSS ]
290+ (loss / acc if acc > 1 else loss ).backward ()
261291 engine .fire_event (IterationEvents .BACKWARD_COMPLETED )
262- engine .optimizer .step ()
292+ if should_step :
293+ engine .optimizer .step ()
263294 # copy back meta info
264295 if self .compile :
265296 if inputs_meta is not None :
0 commit comments