Skip to content

Commit 879d494

Browse files
committed
fix sampling when training
1 parent 14f7ff1 commit 879d494

2 files changed

Lines changed: 32 additions & 19 deletions

File tree

python_autocomplete/dataset/dataset.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pathlib import PurePath
2+
from typing import Optional
23

34
import torch
45
from torch.utils.data import Dataset, DataLoader
@@ -48,8 +49,8 @@ def load(path: PurePath):
4849
def get_train_valid(path: PurePath, is_load_data: bool):
4950
if is_load_data:
5051
with monit.section("Load data"):
51-
train = TextDataset.load(path / 'train.py')[:1000_000]
52-
valid = TextDataset.load(path / 'valid.py')[:1000_000]
52+
train = TextDataset.load(path / 'train.py')
53+
valid = TextDataset.load(path / 'valid.py')
5354
else:
5455
train = ''
5556
valid = ''
@@ -67,6 +68,7 @@ def __repr__(self):
6768

6869
class SourceCodeDataConfigs(BaseConfigs):
6970
dataset: SourceCodeDataset
71+
truncate_data: int = 0
7072
is_load_data: bool = True
7173
tokenizer: Tokenizer
7274
retrain_tokenizer: bool = True
@@ -84,6 +86,8 @@ def text_to_i(self, text: str, *, is_silent: bool = True) -> torch.Tensor:
8486
@option(SourceCodeDataConfigs.dataset, 'default')
8587
def _dataset(c: SourceCodeDataConfigs):
8688
train, valid = SourceCodeDataset.get_train_valid(lab.get_data_path(), c.is_load_data)
89+
if c.truncate_data:
90+
train, valid = train[:c.truncate_data], valid[:c.truncate_data]
8791
if not c.tokenizer.is_trained:
8892
c.tokenizer.train(train + valid)
8993
return SourceCodeDataset(c.tokenizer, train, valid)
@@ -120,7 +124,7 @@ def __init__(self, *,
120124
self.seq_len = seq_len
121125
self.data = data
122126
self.n_samples = (self.data.shape[0] - 1) // self.seq_len
123-
if drop_last:
127+
if not drop_last:
124128
self.n_batches = (self.n_samples + batch_size - 1) // batch_size
125129
else:
126130
self.n_batches = self.n_samples // batch_size
@@ -152,17 +156,21 @@ def transpose_batch(batch):
152156
def _train_loader(c: SourceCodeDataConfigs):
153157
return DataLoader(TokenDataset(data=c.text_to_i(c.dataset.train, is_silent=False),
154158
batch_size=c.batch_size,
155-
seq_len=c.seq_len),
159+
seq_len=c.seq_len,
160+
drop_last=True),
156161
batch_size=c.batch_size,
157162
collate_fn=transpose_batch,
158-
shuffle=c.is_shuffle)
163+
shuffle=c.is_shuffle,
164+
drop_last=True)
159165

160166

161167
@option(SourceCodeDataConfigs.valid_loader)
162168
def _valid_loader(c: SourceCodeDataConfigs):
163169
return DataLoader(TokenDataset(data=c.text_to_i(c.dataset.valid, is_silent=False),
164170
batch_size=c.batch_size,
165-
seq_len=c.seq_len),
171+
seq_len=c.seq_len,
172+
drop_last=True),
166173
batch_size=c.batch_size,
167174
collate_fn=transpose_batch,
168-
shuffle=c.is_shuffle)
175+
shuffle=c.is_shuffle,
176+
drop_last=True)

python_autocomplete/train.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def step(self, batch: any, batch_idx: BatchIndex):
5151
data, target = batch[0].to(self.device), batch[1].to(self.device)
5252

5353
if self.mode.is_train:
54-
tracker.add_global_step(len(data))
54+
tracker.add_global_step(target.shape[0] * target.shape[1])
5555

5656
with self.mode.update(is_log_activations=batch_idx.is_last):
5757
state = self.state.get()
@@ -87,7 +87,9 @@ def sample(self):
8787
output = output.argmax(dim=-1).squeeze(1)
8888
prompt += '' + self.text.tokenizer.itos[output[-1]]
8989
if self.is_token_by_token:
90-
prompt = prompt[-1:]
90+
prompt = self.text.tokenizer.itos[output[-1]]
91+
else:
92+
prompt += '' + self.text.tokenizer.itos[output[-1]]
9193
log += [('' + self.text.tokenizer.itos[output[-1]], Text.value)]
9294
state = self.state_updater(state, new_state)
9395

@@ -260,22 +262,25 @@ def main():
260262
experiment.create(name="source_code",
261263
comment='bpe')
262264
experiment.configs(conf, {
263-
'model': 'transformer_model',
264-
# 'model': 'transformer_xl_model',
265+
# 'model': 'transformer_model',
266+
'model': 'transformer_xl_model',
265267
'n_layers': 6,
266268
'epochs': 32,
267-
'optimizer.optimizer': 'Noam',
268-
'optimizer.learning_rate': 1.0,
269+
'optimizer.optimizer': 'AdamW',
270+
'optimizer.learning_rate': 1.25e-4,
269271
'device.cuda_device': 0,
270-
'is_token_by_token': False,
271-
'state_updater': 'simple',
272+
273+
'is_token_by_token': True,
274+
'state_updater': 'transformer_memory',
275+
'mem_len': 256,
272276

273277
'text.is_shuffle': False,
274-
'text.tokenizer': 'char',
278+
'text.tokenizer': 'bpe',
275279
'text.batch_size': 12,
276-
'text.seq_len': 512,
277-
278-
'inner_iterations': 10,
280+
'text.seq_len': 256,
281+
#
282+
# 'inner_iterations': 10,
283+
# 'text.truncate_data': 100_000,
279284
})
280285
experiment.add_pytorch_models(model=conf.model)
281286
with experiment.start():

0 commit comments

Comments
 (0)