Skip to content

Commit a3eca14

Browse files
committed
fix for Flake8-py3 codeformat error
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent 1db8cc1 commit a3eca14

2 files changed

Lines changed: 27 additions & 40 deletions

File tree

monai/engines/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,9 +404,7 @@ class GradientAccumulation:
404404

405405
def __init__(self, accumulation_steps: int = 2) -> None:
406406
if not isinstance(accumulation_steps, int) or accumulation_steps < 1:
407-
raise ValueError(
408-
f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}."
409-
)
407+
raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")
410408
self.accumulation_steps = accumulation_steps
411409

412410
def __repr__(self) -> str:

tests/engines/test_gradient_accumulation.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,30 +24,16 @@
2424

2525
_, has_ignite = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version)
2626

27-
INVALID_ACCUMULATION_STEPS = [
28-
(0,), (-1,), (2.5,), ("2",),
29-
]
27+
INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",)]
3028

3129
SUPPRESSION_CASES = [
3230
# (attr_name, acc, epoch_length, num_iters, expected)
33-
(
34-
"zero_grad", 4, 12, 12,
35-
[True, False, False, False, True, False, False, False, True, False, False, False],
36-
),
37-
(
38-
"step", 4, 12, 12,
39-
[False, False, False, True, False, False, False, True, False, False, False, True],
40-
),
31+
("zero_grad", 4, 12, 12, [True, False, False, False, True, False, False, False, True, False, False, False]),
32+
("step", 4, 12, 12, [False, False, False, True, False, False, False, True, False, False, False, True]),
4133
# epoch_length=11 not divisible by 4 → flush at epoch end
42-
(
43-
"step", 4, 11, 11,
44-
[False, False, False, True, False, False, False, True, False, False, True],
45-
),
34+
("step", 4, 11, 11, [False, False, False, True, False, False, False, True, False, False, True]),
4635
# epoch_length=None (iterable dataset) → no epoch flush
47-
(
48-
"step", 4, None, 10,
49-
[False, False, False, True, False, False, False, True, False, False],
50-
),
36+
("step", 4, None, 10, [False, False, False, True, False, False, False, True, False, False]),
5137
]
5238

5339

@@ -249,10 +235,7 @@ def test_integration_gradient_equivalence(self) -> None:
249235

250236
torch.manual_seed(42)
251237
acc_steps, lr = 4, 0.1
252-
batches = [
253-
{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)}
254-
for _ in range(acc_steps)
255-
]
238+
batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(acc_steps)]
256239

257240
ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
258241

@@ -263,8 +246,12 @@ def test_integration_gradient_equivalence(self) -> None:
263246
ref_opt.step()
264247

265248
trainer = SupervisedTrainer(
266-
device=torch.device("cpu"), max_epochs=1, train_data_loader=batches,
267-
network=test_model, optimizer=test_opt, loss_function=nn.MSELoss(),
249+
device=torch.device("cpu"),
250+
max_epochs=1,
251+
train_data_loader=batches,
252+
network=test_model,
253+
optimizer=test_opt,
254+
loss_function=nn.MSELoss(),
268255
iteration_update=GradientAccumulation(accumulation_steps=acc_steps),
269256
)
270257
trainer.run()
@@ -279,10 +266,7 @@ def test_integration_epoch_boundary_flush(self) -> None:
279266

280267
torch.manual_seed(123)
281268
acc_steps, lr = 3, 0.1
282-
batches = [
283-
{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)}
284-
for _ in range(5)
285-
]
269+
batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(5)]
286270

287271
ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
288272

@@ -294,8 +278,12 @@ def test_integration_epoch_boundary_flush(self) -> None:
294278
ref_opt.step()
295279

296280
trainer = SupervisedTrainer(
297-
device=torch.device("cpu"), max_epochs=1, train_data_loader=batches,
298-
network=test_model, optimizer=test_opt, loss_function=nn.MSELoss(),
281+
device=torch.device("cpu"),
282+
max_epochs=1,
283+
train_data_loader=batches,
284+
network=test_model,
285+
optimizer=test_opt,
286+
loss_function=nn.MSELoss(),
299287
iteration_update=GradientAccumulation(accumulation_steps=acc_steps),
300288
)
301289
trainer.run()
@@ -310,10 +298,7 @@ def test_integration_multi_epoch(self) -> None:
310298

311299
torch.manual_seed(42)
312300
acc_steps, lr, num_epochs = 2, 0.1, 3
313-
batches = [
314-
{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)}
315-
for _ in range(4)
316-
]
301+
batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(4)]
317302

318303
ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
319304

@@ -326,8 +311,12 @@ def test_integration_multi_epoch(self) -> None:
326311
ref_opt.step()
327312

328313
trainer = SupervisedTrainer(
329-
device=torch.device("cpu"), max_epochs=num_epochs, train_data_loader=batches,
330-
network=test_model, optimizer=test_opt, loss_function=nn.MSELoss(),
314+
device=torch.device("cpu"),
315+
max_epochs=num_epochs,
316+
train_data_loader=batches,
317+
network=test_model,
318+
optimizer=test_opt,
319+
loss_function=nn.MSELoss(),
331320
iteration_update=GradientAccumulation(accumulation_steps=acc_steps),
332321
)
333322
trainer.run()

0 commit comments

Comments
 (0)