|
3 | 3 |
|
4 | 4 | import torch |
5 | 5 | import torch.nn as nn |
| 6 | +from torch.utils.data import DataLoader |
6 | 7 |
|
7 | 8 | from labml import lab, experiment, monit, logger, tracker |
8 | 9 | from labml.configs import option |
9 | 10 | from labml.logger import Text |
10 | | -from labml_helpers.datasets.text import TextDataset, SequentialDataLoader |
| 11 | +from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset |
11 | 12 | from labml_helpers.device import DeviceConfigs |
12 | 13 | from labml_helpers.metrics.accuracy import Accuracy |
13 | 14 | from labml_helpers.module import Module |
@@ -181,34 +182,65 @@ def source_code(c: Configs): |
181 | 182 |
|
182 | 183 |
|
183 | 184 | @option(Configs.train_loader) |
184 | | -def train_loader(c: Configs): |
| 185 | +def sequential_train_loader(c: Configs): |
185 | 186 | return SequentialDataLoader(text=c.text.train, |
186 | 187 | dataset=c.text, |
187 | 188 | batch_size=c.batch_size, |
188 | 189 | seq_len=c.seq_len) |
189 | 190 |
|
190 | 191 |
|
191 | 192 | @option(Configs.valid_loader) |
192 | | -def train_loader(c: Configs): |
| 193 | +def sequential_valid_loader(c: Configs): |
193 | 194 | return SequentialDataLoader(text=c.text.valid, |
194 | 195 | dataset=c.text, |
195 | 196 | batch_size=c.batch_size, |
196 | 197 | seq_len=c.seq_len) |
197 | 198 |
|
198 | 199 |
|
| 200 | +def transpose_batch(batch): |
| 201 | + transposed_data = list(zip(*batch)) |
| 202 | + src = torch.stack(transposed_data[0], 1) |
| 203 | + tgt = torch.stack(transposed_data[1], 1) |
| 204 | + |
| 205 | + return src, tgt |
| 206 | + |
| 207 | + |
| 208 | +@option(Configs.train_loader) |
| 209 | +def shuffled_train_loader(c: Configs): |
| 210 | + return DataLoader(SequentialUnBatchedDataset(text=c.text.train, |
| 211 | + dataset=c.text, |
| 212 | + seq_len=c.seq_len), |
| 213 | + batch_size=c.batch_size, |
| 214 | + collate_fn=transpose_batch, |
| 215 | + shuffle=True) |
| 216 | + |
| 217 | + |
| 218 | +@option(Configs.valid_loader) |
| 219 | +def shuffled_valid_loader(c: Configs): |
| 220 | + return DataLoader(SequentialUnBatchedDataset(text=c.text.valid, |
| 221 | + dataset=c.text, |
| 222 | + seq_len=c.seq_len), |
| 223 | + batch_size=c.batch_size, |
| 224 | + collate_fn=transpose_batch, |
| 225 | + shuffle=True) |
| 226 | + |
| 227 | + |
199 | 228 | def main(): |
200 | 229 | conf = Configs() |
201 | 230 | # Assign one of transformer_mode, lstm_model, or rhn_model |
202 | 231 | experiment.create(name="source_code", |
203 | 232 | comment='lstm model') |
204 | 233 | experiment.configs(conf, { |
205 | | - 'model': 'lstm_model', |
206 | | - 'n_layers': 2, |
207 | | - 'batch_size': 2, |
| 234 | + 'model': 'transformer_model', |
| 235 | + 'n_layers': 6, |
| 236 | + 'batch_size': 12, |
208 | 237 | 'epochs': 32, |
209 | | - 'optimizer.optimizer': 'Adam', |
210 | | - 'optimizer.learning_rate': 2.5e-4, |
211 | | - 'device.cuda_device': 1 |
| 238 | + 'optimizer.optimizer': 'Noam', |
| 239 | + 'optimizer.learning_rate': 1.0, |
| 240 | + 'device.cuda_device': 0, |
| 241 | + 'seq_len': 512, |
| 242 | + 'train_loader': 'shuffled_train_loader', |
| 243 | + 'valid_loader': 'shuffled_valid_loader' |
212 | 244 | }) |
213 | 245 | experiment.add_pytorch_models(model=conf.model) |
214 | 246 | # experiment.load('d5ba7f56d88911eaa6629b54a83956dc') |
|
0 commit comments