2424
2525_ , has_ignite = optional_import ("ignite.engine" , IgniteInfo .OPT_IMPORT_VERSION , min_version )
2626
27- INVALID_ACCUMULATION_STEPS = [
28- (0 ,), (- 1 ,), (2.5 ,), ("2" ,),
29- ]
27+ INVALID_ACCUMULATION_STEPS = [(0 ,), (- 1 ,), (2.5 ,), ("2" ,)]
3028
3129SUPPRESSION_CASES = [
3230 # (attr_name, acc, epoch_length, num_iters, expected)
33- (
34- "zero_grad" , 4 , 12 , 12 ,
35- [True , False , False , False , True , False , False , False , True , False , False , False ],
36- ),
37- (
38- "step" , 4 , 12 , 12 ,
39- [False , False , False , True , False , False , False , True , False , False , False , True ],
40- ),
31+ ("zero_grad" , 4 , 12 , 12 , [True , False , False , False , True , False , False , False , True , False , False , False ]),
32+ ("step" , 4 , 12 , 12 , [False , False , False , True , False , False , False , True , False , False , False , True ]),
4133 # epoch_length=11 not divisible by 4 → flush at epoch end
42- (
43- "step" , 4 , 11 , 11 ,
44- [False , False , False , True , False , False , False , True , False , False , True ],
45- ),
34+ ("step" , 4 , 11 , 11 , [False , False , False , True , False , False , False , True , False , False , True ]),
4635 # epoch_length=None (iterable dataset) → no epoch flush
47- (
48- "step" , 4 , None , 10 ,
49- [False , False , False , True , False , False , False , True , False , False ],
50- ),
36+ ("step" , 4 , None , 10 , [False , False , False , True , False , False , False , True , False , False ]),
5137]
5238
5339
@@ -249,10 +235,7 @@ def test_integration_gradient_equivalence(self) -> None:
249235
250236 torch .manual_seed (42 )
251237 acc_steps , lr = 4 , 0.1
252- batches = [
253- {CommonKeys .IMAGE : torch .randn (1 , 4 ), CommonKeys .LABEL : torch .randn (1 , 1 )}
254- for _ in range (acc_steps )
255- ]
238+ batches = [{CommonKeys .IMAGE : torch .randn (1 , 4 ), CommonKeys .LABEL : torch .randn (1 , 1 )} for _ in range (acc_steps )]
256239
257240 ref_model , test_model , ref_opt , test_opt , init_weight = _make_model_pair (lr )
258241
@@ -263,8 +246,12 @@ def test_integration_gradient_equivalence(self) -> None:
263246 ref_opt .step ()
264247
265248 trainer = SupervisedTrainer (
266- device = torch .device ("cpu" ), max_epochs = 1 , train_data_loader = batches ,
267- network = test_model , optimizer = test_opt , loss_function = nn .MSELoss (),
249+ device = torch .device ("cpu" ),
250+ max_epochs = 1 ,
251+ train_data_loader = batches ,
252+ network = test_model ,
253+ optimizer = test_opt ,
254+ loss_function = nn .MSELoss (),
268255 iteration_update = GradientAccumulation (accumulation_steps = acc_steps ),
269256 )
270257 trainer .run ()
@@ -279,10 +266,7 @@ def test_integration_epoch_boundary_flush(self) -> None:
279266
280267 torch .manual_seed (123 )
281268 acc_steps , lr = 3 , 0.1
282- batches = [
283- {CommonKeys .IMAGE : torch .randn (1 , 4 ), CommonKeys .LABEL : torch .randn (1 , 1 )}
284- for _ in range (5 )
285- ]
269+ batches = [{CommonKeys .IMAGE : torch .randn (1 , 4 ), CommonKeys .LABEL : torch .randn (1 , 1 )} for _ in range (5 )]
286270
287271 ref_model , test_model , ref_opt , test_opt , init_weight = _make_model_pair (lr )
288272
@@ -294,8 +278,12 @@ def test_integration_epoch_boundary_flush(self) -> None:
294278 ref_opt .step ()
295279
296280 trainer = SupervisedTrainer (
297- device = torch .device ("cpu" ), max_epochs = 1 , train_data_loader = batches ,
298- network = test_model , optimizer = test_opt , loss_function = nn .MSELoss (),
281+ device = torch .device ("cpu" ),
282+ max_epochs = 1 ,
283+ train_data_loader = batches ,
284+ network = test_model ,
285+ optimizer = test_opt ,
286+ loss_function = nn .MSELoss (),
299287 iteration_update = GradientAccumulation (accumulation_steps = acc_steps ),
300288 )
301289 trainer .run ()
@@ -310,10 +298,7 @@ def test_integration_multi_epoch(self) -> None:
310298
311299 torch .manual_seed (42 )
312300 acc_steps , lr , num_epochs = 2 , 0.1 , 3
313- batches = [
314- {CommonKeys .IMAGE : torch .randn (1 , 4 ), CommonKeys .LABEL : torch .randn (1 , 1 )}
315- for _ in range (4 )
316- ]
301+ batches = [{CommonKeys .IMAGE : torch .randn (1 , 4 ), CommonKeys .LABEL : torch .randn (1 , 1 )} for _ in range (4 )]
317302
318303 ref_model , test_model , ref_opt , test_opt , init_weight = _make_model_pair (lr )
319304
@@ -326,8 +311,12 @@ def test_integration_multi_epoch(self) -> None:
326311 ref_opt .step ()
327312
328313 trainer = SupervisedTrainer (
329- device = torch .device ("cpu" ), max_epochs = num_epochs , train_data_loader = batches ,
330- network = test_model , optimizer = test_opt , loss_function = nn .MSELoss (),
314+ device = torch .device ("cpu" ),
315+ max_epochs = num_epochs ,
316+ train_data_loader = batches ,
317+ network = test_model ,
318+ optimizer = test_opt ,
319+ loss_function = nn .MSELoss (),
331320 iteration_update = GradientAccumulation (accumulation_steps = acc_steps ),
332321 )
333322 trainer .run ()
0 commit comments