Skip to content

Commit b8d502c

Browse files
committed
non sequential dataloader
1 parent eef9f87 commit b8d502c

3 files changed

Lines changed: 48 additions & 13 deletions

File tree

python_autocomplete/create_dataset.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,11 @@ def extract_zip(file_path: Path, overwrite: bool = False):
157157
rm_tree(repo_source)
158158
else:
159159
return repo_source
160-
with zipfile.ZipFile(file_path, 'r') as repo_zip:
161-
repo_zip.extractall(repo_source)
160+
try:
161+
with zipfile.ZipFile(file_path, 'r') as repo_zip:
162+
repo_zip.extractall(repo_source)
163+
except zipfile.BadZipfile as e:
164+
print(file_path, e)
162165

163166
return repo_source
164167

@@ -213,7 +216,7 @@ def progressive(overwrite: bool = False):
213216

214217
def main():
215218
try:
216-
progressive()
219+
batch()
217220
except KeyboardInterrupt:
218221
pass
219222

python_autocomplete/models/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def subsequent_mask(seq_len):
2222
mask = np.triu(np.ones(attn_shape, dtype=np.uint8), k=1)
2323
return (torch.from_numpy(mask) == 0).unsqueeze(-1)
2424

25-
def forward(self, src):
25+
def __call__(self, src):
2626
if self.src_mask is None or self.src_mask.size(0) != len(src):
2727
device = src.device
2828
mask = self.subsequent_mask(len(src)).to(device)

python_autocomplete/train.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44
import torch
55
import torch.nn as nn
6+
from torch.utils.data import DataLoader
67

78
from labml import lab, experiment, monit, logger, tracker
89
from labml.configs import option
910
from labml.logger import Text
10-
from labml_helpers.datasets.text import TextDataset, SequentialDataLoader
11+
from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset
1112
from labml_helpers.device import DeviceConfigs
1213
from labml_helpers.metrics.accuracy import Accuracy
1314
from labml_helpers.module import Module
@@ -181,34 +182,65 @@ def source_code(c: Configs):
181182

182183

183184
@option(Configs.train_loader)
184-
def train_loader(c: Configs):
185+
def sequential_train_loader(c: Configs):
185186
return SequentialDataLoader(text=c.text.train,
186187
dataset=c.text,
187188
batch_size=c.batch_size,
188189
seq_len=c.seq_len)
189190

190191

191192
@option(Configs.valid_loader)
192-
def train_loader(c: Configs):
193+
def sequential_valid_loader(c: Configs):
193194
return SequentialDataLoader(text=c.text.valid,
194195
dataset=c.text,
195196
batch_size=c.batch_size,
196197
seq_len=c.seq_len)
197198

198199

200+
def transpose_batch(batch):
201+
transposed_data = list(zip(*batch))
202+
src = torch.stack(transposed_data[0], 1)
203+
tgt = torch.stack(transposed_data[1], 1)
204+
205+
return src, tgt
206+
207+
208+
@option(Configs.train_loader)
209+
def shuffled_train_loader(c: Configs):
210+
return DataLoader(SequentialUnBatchedDataset(text=c.text.train,
211+
dataset=c.text,
212+
seq_len=c.seq_len),
213+
batch_size=c.batch_size,
214+
collate_fn=transpose_batch,
215+
shuffle=True)
216+
217+
218+
@option(Configs.valid_loader)
219+
def shuffled_valid_loader(c: Configs):
220+
return DataLoader(SequentialUnBatchedDataset(text=c.text.valid,
221+
dataset=c.text,
222+
seq_len=c.seq_len),
223+
batch_size=c.batch_size,
224+
collate_fn=transpose_batch,
225+
shuffle=True)
226+
227+
199228
def main():
200229
conf = Configs()
201230
# Assign one of transformer_mode, lstm_model, or rhn_model
202231
experiment.create(name="source_code",
203232
comment='lstm model')
204233
experiment.configs(conf, {
205-
'model': 'lstm_model',
206-
'n_layers': 2,
207-
'batch_size': 2,
234+
'model': 'transformer_model',
235+
'n_layers': 6,
236+
'batch_size': 12,
208237
'epochs': 32,
209-
'optimizer.optimizer': 'Adam',
210-
'optimizer.learning_rate': 2.5e-4,
211-
'device.cuda_device': 1
238+
'optimizer.optimizer': 'Noam',
239+
'optimizer.learning_rate': 1.0,
240+
'device.cuda_device': 0,
241+
'seq_len': 512,
242+
'train_loader': 'shuffled_train_loader',
243+
'valid_loader': 'shuffled_valid_loader'
212244
})
213245
experiment.add_pytorch_models(model=conf.model)
214246
# experiment.load('d5ba7f56d88911eaa6629b54a83956dc')

0 commit comments

Comments
 (0)