Skip to content

Commit 04d5c13

Browse files
authored
Use batch_per_epoch to slice the train data (#11)
1 parent a1a9caf commit 04d5c13

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

script/run.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import math
44
import pprint
5+
from itertools import islice
56

67
import torch
78
import torch_geometric as pyg
@@ -59,7 +60,7 @@ def train_and_validate(cfg, model, train_data, valid_data, device, logger, filte
5960

6061
losses = []
6162
sampler.set_epoch(epoch)
62-
for batch in train_loader:
63+
for batch in islice(train_loader, batch_per_epoch):
6364
batch = tasks.negative_sampling(train_data, batch, cfg.task.num_negative,
6465
strict=cfg.task.strict_negative)
6566
pred = parallel_model(train_data, batch)

0 commit comments

Comments
 (0)