Skip to content

Commit a1a9caf

Browse files
authored
Adding batch per epoch option to the finetune script (#10)
* Adding batch per epoch option to the finetune script Allow to specify the batch per epoch in the finetune script. * Remove copy paste error :)
1 parent 33c6e6b commit a1a9caf

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

script/run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,12 @@ def test(cfg, model, test_data, device, logger, filtered_data=None, return_metri
289289
val_filtered_data = val_filtered_data.to(device)
290290
test_filtered_data = test_filtered_data.to(device)
291291

292-
train_and_validate(cfg, model, train_data, valid_data, filtered_data=val_filtered_data, device=device, logger=logger)
292+
train_and_validate(cfg, model, train_data, valid_data, filtered_data=val_filtered_data, device=device, batch_per_epoch=cfg.train.batch_per_epoch, logger=logger)
293293
if util.get_rank() == 0:
294294
logger.warning(separator)
295295
logger.warning("Evaluate on valid")
296296
test(cfg, model, valid_data, filtered_data=val_filtered_data, device=device, logger=logger)
297297
if util.get_rank() == 0:
298298
logger.warning(separator)
299299
logger.warning("Evaluate on test")
300-
test(cfg, model, test_data, filtered_data=test_filtered_data, device=device, logger=logger)
300+
test(cfg, model, test_data, filtered_data=test_filtered_data, device=device, logger=logger)

0 commit comments

Comments
 (0)