Skip to content

Commit 688bc88

Browse files
committed
transformer xl
1 parent e714bf3 commit 688bc88

8 files changed

Lines changed: 174 additions & 43 deletions

File tree

python_autocomplete/distributed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def init(self):
1717
tracker.set_queue("loss.*", 20, True)
1818
tracker.set_scalar("accuracy.*", True)
1919
hook_model_outputs(self.mode, self.ddp_model, 'model')
20-
self.state_modules = [self.accuracy_func]
20+
self.state_modules = [self.accuracy]
2121

2222
def step(self, batch: any, batch_idx: BatchIndex):
2323
data, target = batch[0].to(self.device), batch[1].to(self.device)
@@ -29,8 +29,8 @@ def step(self, batch: any, batch_idx: BatchIndex):
2929
output, *_ = self.ddp_model(data)
3030

3131
loss = self.loss_func(output, target)
32-
self.accuracy_func(output, target)
33-
self.accuracy_func.track()
32+
self.accuracy(output, target)
33+
self.accuracy.track()
3434
tracker.add("loss.", loss)
3535

3636
if self.mode.is_train:

python_autocomplete/evaluate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def get_predictor():
211211

212212
experiment.start()
213213
conf.model.eval()
214-
return Predictor(conf.model, cache('stoi', lambda: conf.text.stoi), cache('itos', lambda: conf.text.itos))
214+
return Predictor(conf.model, conf.stoi, conf.itos)
215215

216216

217217
def main():
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
from typing import Any
12

3+
import torch
24

5+
from labml_helpers.module import Module
36

7+
8+
class AutoregressiveModel(Module):
9+
def __init__(self):
10+
super().__init__()
11+
12+
def __call__(self, src: torch.Tensor, state: Any):
13+
pass

python_autocomplete/models/highway.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from torch import nn
22

3-
from labml_helpers.module import Module
43
from labml_nn.recurrent_highway_networks import RHN
4+
from python_autocomplete.models import AutoregressiveModel
55

66

7-
class RhnModel(Module):
7+
class RhnModel(AutoregressiveModel):
88
def __init__(self, *,
99
n_tokens: int,
1010
embedding_size: int,
@@ -20,10 +20,10 @@ def __init__(self, *,
2020
depth=depth)
2121
self.fc = nn.Linear(hidden_size, n_tokens)
2222

23-
def __call__(self, x, s0=None):
23+
def __call__(self, x, state=None):
2424
# shape of x is [seq, batch, feat]
2525
x = self.embedding(x)
26-
out, s = self.rhn(x, s0)
26+
out, s = self.rhn(x, state)
2727
logits = self.fc(out)
2828

2929
return logits, s

python_autocomplete/models/lstm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
from typing import Optional, Tuple
2+
3+
import torch
14
from torch import nn
25

3-
from labml_helpers.module import Module
46
from labml_nn.lstm import LSTM
7+
from python_autocomplete.models import AutoregressiveModel
58

69

7-
class LstmModel(Module):
10+
class LstmModel(AutoregressiveModel):
811
def __init__(self, *,
912
n_tokens: int,
1013
embedding_size: int,
@@ -18,10 +21,9 @@ def __init__(self, *,
1821
n_layers=n_layers)
1922
self.fc = nn.Linear(hidden_size, n_tokens)
2023

21-
def __call__(self, x, h0=None, c0=None):
24+
def __call__(self, x: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
2225
# shape of x is [seq, batch, feat]
2326
x = self.embedding(x)
24-
state = (h0, c0) if h0 is not None else None
2527
out, (hn, cn) = self.lstm(x, state)
2628
logits = self.fc(out)
2729

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
import numpy as np
1+
from typing import Any
2+
23
import torch
34
from torch import nn
45

5-
from labml import monit
66
from labml_helpers.module import Module
77
from labml_nn.transformers import Encoder
8+
from labml_nn.transformers.utils import subsequent_mask
9+
from python_autocomplete.models import AutoregressiveModel
810

911

10-
class TransformerModel(Module):
12+
class TransformerModel(AutoregressiveModel):
1113
def __init__(self, n_tokens, d_model, encoder: Encoder, src_embed: Module):
1214
super().__init__()
1315
self.src_mask = None
@@ -16,20 +18,12 @@ def __init__(self, n_tokens, d_model, encoder: Encoder, src_embed: Module):
1618
self.d_model = d_model
1719
self.fc = nn.Linear(d_model, n_tokens)
1820

19-
@staticmethod
20-
def subsequent_mask(seq_len):
21-
attn_shape = (seq_len, seq_len)
22-
mask = np.triu(np.ones(attn_shape, dtype=np.uint8), k=1)
23-
return (torch.from_numpy(mask) == 0).unsqueeze(-1)
24-
25-
def __call__(self, src):
21+
def __call__(self, src: torch.Tensor, _: Any = None):
2622
if self.src_mask is None or self.src_mask.size(0) != len(src):
27-
device = src.device
28-
mask = self.subsequent_mask(len(src)).to(device)
29-
self.src_mask = mask
23+
self.src_mask = subsequent_mask(len(src)).to(src.device)
3024

3125
src = self.src_embed(src)
3226
# with monit.section("transformer"):
3327
output = self.encoder(src, self.src_mask)
3428
output = self.fc(output)
35-
return output,
29+
return output, None

python_autocomplete/models/xl.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from typing import List, Optional
2+
3+
import torch
4+
from torch import nn
5+
6+
from labml_nn.transformers.xl import TransformerXL
7+
from python_autocomplete.models import AutoregressiveModel
8+
9+
10+
class TransformerXLModel(AutoregressiveModel):
11+
def __init__(self, n_vocab: int, d_model: int, transformer: TransformerXL):
12+
super().__init__()
13+
self.src_embed = nn.Embedding(n_vocab, d_model)
14+
self.transformer = transformer
15+
self.generator = nn.Linear(d_model, n_vocab)
16+
self.mask_x = None
17+
self.mask_mem = None
18+
19+
def __call__(self, x: torch.Tensor, mem: Optional[List[torch.Tensor]]):
20+
m_len = len(mem[0]) if mem else 0
21+
if self.mask_x is None or self.mask_x.shape[0] < len(x):
22+
from labml_nn.transformers.utils import subsequent_mask
23+
self.mask_x = subsequent_mask(len(x)).to(x.device)
24+
if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
25+
self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)
26+
27+
if m_len:
28+
mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)
29+
else:
30+
mask = self.mask_x[:len(x), :len(x)]
31+
32+
x = self.src_embed(x)
33+
res, mem = self.transformer(x, mem, mask)
34+
res = self.generator(res)
35+
36+
return res, mem

python_autocomplete/train.py

Lines changed: 106 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import PurePath
2-
from typing import Callable
2+
from typing import Callable, List, Dict
33

44
import torch
55
import torch.nn as nn
@@ -11,6 +11,7 @@
1111
from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset
1212
from labml_helpers.device import DeviceConfigs
1313
from labml_helpers.metrics.accuracy import Accuracy
14+
from labml_helpers.metrics.simple_state import SimpleStateModule
1415
from labml_helpers.module import Module
1516
from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
1617
from labml_nn.optimizers.configs import OptimizerConfigs
@@ -20,10 +21,15 @@
2021
class SourceCodeDataset(TextDataset):
2122
def __init__(self, path: PurePath, tokenizer: Callable):
2223
with monit.section("Load data"):
23-
train = self.load(path / 'train.py')
24-
valid = self.load(path / 'valid.py')
24+
train = self.load(path / 'train.py') # [:100000]
25+
valid = self.load(path / 'valid.py') # [:100000]
2526

26-
super().__init__(path, tokenizer, train, valid, '')
27+
from labml.utils.cache import cache_get
28+
29+
super().__init__(path, tokenizer, train, valid, '',
30+
n_tokens=cache_get('n_tokens'),
31+
itos=cache_get('itos'),
32+
stoi=cache_get('stoi'))
2733

2834

2935
class Configs(TrainValidConfigs):
@@ -47,14 +53,23 @@ class Configs(TrainValidConfigs):
4753

4854
transformer: TransformerConfigs
4955

50-
accuracy_func = Accuracy()
56+
accuracy = Accuracy()
5157
loss_func: 'CrossEntropyLoss'
5258

59+
state_updater: 'StateUpdater'
60+
state = SimpleStateModule()
61+
mem_len: int = 512
62+
grad_norm_clip: float = 1.0
63+
is_token_by_token: bool = False
64+
65+
itos: List[str]
66+
stoi: Dict[str, int]
67+
5368
def init(self):
5469
tracker.set_queue("loss.*", 20, True)
5570
tracker.set_scalar("accuracy.*", True)
5671
hook_model_outputs(self.mode, self.model, 'model')
57-
self.state_modules = [self.accuracy_func]
72+
self.state_modules = [self.accuracy, self.state]
5873

5974
def step(self, batch: any, batch_idx: BatchIndex):
6075
data, target = batch[0].to(self.device), batch[1].to(self.device)
@@ -63,16 +78,21 @@ def step(self, batch: any, batch_idx: BatchIndex):
6378
tracker.add_global_step(len(data))
6479

6580
with self.mode.update(is_log_activations=batch_idx.is_last):
66-
output, *_ = self.model(data)
81+
state = self.state.get()
82+
output, new_state = self.model(data, state)
83+
state = self.state_updater(state, new_state)
84+
self.state.set(state)
6785

6886
loss = self.loss_func(output, target)
69-
self.accuracy_func(output, target)
70-
self.accuracy_func.track()
7187
tracker.add("loss.", loss)
7288

89+
self.accuracy(output, target)
90+
self.accuracy.track()
91+
7392
if self.mode.is_train:
7493
loss.backward()
7594

95+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
7696
self.optimizer.step()
7797
if batch_idx.is_last:
7898
tracker.add('model', self.model)
@@ -83,13 +103,17 @@ def step(self, batch: any, batch_idx: BatchIndex):
83103
def sample(self):
84104
prompt = 'def train('
85105
log = [(prompt, Text.subtle)]
106+
state = None
86107
for i in monit.iterate('Sample', 25):
87108
data = self.text.text_to_i(prompt).unsqueeze(-1)
88109
data = data.to(self.device)
89-
output, *_ = self.model(data)
90-
output = output.argmax(dim=-1).squeeze()
91-
prompt += '' + self.text.itos[output[-1]]
92-
log += [('' + self.text.itos[output[-1]], Text.value)]
110+
output, new_state = self.model(data, state)
111+
output = output.argmax(dim=-1).squeeze(1)
112+
prompt += '' + self.itos[output[-1]]
113+
if self.is_token_by_token:
114+
prompt = prompt[-1:]
115+
log += [('' + self.itos[output[-1]], Text.value)]
116+
state = self.state_updater(state, new_state)
93117

94118
logger.log(log)
95119

@@ -137,6 +161,18 @@ def _n_tokens(c: Configs):
137161
return cache('n_tokens', lambda: c.text.n_tokens)
138162

139163

164+
@option(Configs.itos)
165+
def _itos(c: Configs):
166+
from labml.utils.cache import cache
167+
return cache('itos', lambda: c.text.itos)
168+
169+
170+
@option(Configs.stoi)
171+
def _stoi(c: Configs):
172+
from labml.utils.cache import cache
173+
return cache('stoi', lambda: c.text.stoi)
174+
175+
140176
@option(Configs.model)
141177
def lstm_model(c: Configs):
142178
from python_autocomplete.models.lstm import LstmModel
@@ -169,6 +205,55 @@ def transformer_model(c: Configs):
169205
return m.to(c.device)
170206

171207

208+
@option(Configs.model)
209+
def transformer_xl_model(c: Configs):
210+
from labml_nn.transformers.xl import RelativeMultiHeadAttention
211+
from labml_nn.transformers.feed_forward import FeedForward
212+
from labml_nn.transformers.xl import TransformerXL
213+
from labml_nn.transformers.xl import TransformerXLLayer
214+
from python_autocomplete.models.xl import TransformerXLModel
215+
m = TransformerXLModel(c.n_tokens, c.d_model, TransformerXL(
216+
TransformerXLLayer(d_model=c.d_model,
217+
self_attn=RelativeMultiHeadAttention(c.transformer.n_heads, c.d_model, c.dropout),
218+
feed_forward=FeedForward(c.d_model, c.transformer.ffn.d_ff, c.dropout),
219+
dropout_prob=c.dropout), c.n_layers))
220+
return m.to(c.device)
221+
222+
223+
class StateUpdater:
224+
def __call__(self, old_state, new_state):
225+
return new_state
226+
227+
228+
class MemoryUpdater(StateUpdater):
229+
def __init__(self, mem_len: int):
230+
self.mem_len = mem_len
231+
232+
def __call__(self, old_mem, new_mem):
233+
if self.mem_len == 0:
234+
return []
235+
236+
if old_mem:
237+
mem = [torch.cat((m, x), dim=0) for m, x in zip(old_mem, new_mem)]
238+
else:
239+
mem = new_mem
240+
241+
if len(mem[0]) > self.mem_len:
242+
mem = [m[-self.mem_len:] for m in mem]
243+
244+
return mem
245+
246+
247+
@option(Configs.state_updater)
248+
def simple():
249+
return StateUpdater()
250+
251+
252+
@option(Configs.state_updater)
253+
def transformer_memory(c: Configs):
254+
return MemoryUpdater(c.mem_len)
255+
256+
172257
def character_tokenizer(x: str):
173258
return list(x)
174259

@@ -231,18 +316,22 @@ def main():
231316
conf = Configs()
232317
# Assign one of transformer_mode, lstm_model, or rhn_model
233318
experiment.create(name="source_code",
234-
comment='lstm model')
319+
comment='transformer xl model')
235320
experiment.configs(conf, {
236-
'model': 'transformer_model',
321+
# 'model': 'transformer_model',
322+
'model': 'transformer_xl_model',
237323
'n_layers': 6,
238324
'batch_size': 12,
239325
'epochs': 32,
240326
'optimizer.optimizer': 'Noam',
241327
'optimizer.learning_rate': 1.0,
242328
'device.cuda_device': 0,
243329
'seq_len': 512,
244-
'train_loader': 'shuffled_train_loader',
245-
'valid_loader': 'shuffled_valid_loader'
330+
'is_token_by_token': True,
331+
# 'train_loader': 'shuffled_train_loader',
332+
# 'valid_loader': 'shuffled_valid_loader',
333+
'train_loader': 'sequential_train_loader',
334+
'valid_loader': 'sequential_valid_loader',
246335
})
247336
experiment.add_pytorch_models(model=conf.model)
248337
# experiment.load('70df7f86450911eb887b25e3927208f3')

0 commit comments

Comments
 (0)