Skip to content

Commit cad2105

Browse files
committed
fix mypy error
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent a3eca14 commit cad2105

2 files changed

Lines changed: 6 additions & 2 deletions

File tree

monai/engines/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,11 @@ def __call__(self, engine: Any, batchdata: dict[str, Any]) -> dict:
423423
"""
424424
acc = self.accumulation_steps
425425

426+
result: dict
427+
426428
if acc == 1:
427-
return engine._iteration(engine, batchdata)
429+
result = engine._iteration(engine, batchdata)
430+
return result
428431

429432
# engine.state.iteration is 1-indexed and already incremented before __call__
430433
epoch_length = engine.state.epoch_length # None for iterable datasets

tests/engines/test_gradient_accumulation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
import unittest
15+
from typing import Any
1516
from unittest.mock import MagicMock
1617

1718
import torch
@@ -201,7 +202,7 @@ def test_scaler_is_none(self) -> None:
201202

202203
def test_batch_data_passed_correctly(self) -> None:
203204
engine = _make_engine(epoch_length=4, iteration=1)
204-
test_batch = {CommonKeys.IMAGE: torch.randn(1, 10), CommonKeys.LABEL: torch.randn(1, 1)}
205+
test_batch: dict[str, Any] = {CommonKeys.IMAGE: torch.randn(1, 10), CommonKeys.LABEL: torch.randn(1, 1)}
205206

206207
GradientAccumulation(accumulation_steps=2)(engine, test_batch)
207208

0 commit comments

Comments
 (0)