Skip to content

Commit 7094232

Browse files
committed
bpe training
1 parent a241fe0 commit 7094232

2 files changed

Lines changed: 163 additions & 14 deletions

File tree

python_autocomplete/bpe.py

Lines changed: 118 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,133 @@
33
from typing import List, Tuple
44

55
from labml import lab, monit
6+
from labml.utils.cache import cache_set
67

78
ID_CHARS = set(string.ascii_letters + string.digits + '_')
89

910

1011
class BPE:
12+
def __init__(self, bpe_en_de: 'BPEEnDe', tokenizer):
13+
self.bpe = bpe_en_de
14+
self.tokenizer = tokenizer
15+
16+
@property
17+
def n_tokens(self):
18+
return len(self.bpe.bpe)
19+
20+
@property
21+
def itos(self):
22+
return self.bpe.bpe_itos
23+
24+
@property
25+
def stoi(self):
26+
return self.bpe.bpe_stoi
27+
28+
def encode(self, data: str):
29+
words = self.tokenizer.tokenize(data)
30+
31+
res = []
32+
for w in monit.iterate('Encode words', words):
33+
res += self.bpe.encode(w)
34+
35+
return res
36+
37+
def __call__(self, data: str):
38+
encoded = self.encode(data)
39+
return [self.itos[c] for c in encoded]
40+
41+
42+
class _BPEEncoder:
43+
def __init__(self, pairs):
44+
self.pairs = pairs
45+
self.codes = []
46+
self.next_idx = []
47+
self.prev_idx = []
48+
self.heap = []
49+
50+
def encode(self, codes: List[int]):
51+
self.codes = codes
52+
self.next_idx = BPELearner.default_next_pointers(len(codes))
53+
self.prev_idx = BPELearner.default_prev_pointers(len(codes))
54+
self.heap = []
55+
56+
for i in range(len(self.codes) - 1):
57+
self.add_pair((self.codes[i], self.codes[i + 1]), i)
58+
59+
while self.heap:
60+
_, idx, pair = heappop(self.heap)
61+
62+
return [c for c in self.codes if c != -1]
63+
64+
def merge(self, p2, pair):
65+
p3 = self.next_idx[p2]
66+
67+
if p3 == -1 or pair[0] != self.codes[p2] or pair[1] != self.codes[p3]:
68+
return
69+
70+
self.codes[p2] = self.pairs[pair]
71+
self.codes[p3] = -1
72+
p1 = self.prev_idx[p2]
73+
p4 = self.next_idx[p3]
74+
75+
if p1 != -1:
76+
self.add_pair((self.codes[p1], self.codes[p2]), p1)
77+
self.next_idx[p2] = p4
78+
if p4 != -1:
79+
self.prev_idx[p4] = p2
80+
self.add_pair((self.codes[p2], self.codes[p4]), p2)
81+
82+
def add_pair(self, pair, idx):
83+
if pair not in self.pairs:
84+
return
85+
86+
heappush(self.heap, (self.pairs[pair], idx, pair))
87+
88+
89+
class BPEEnDe:
1190
def __init__(self):
1291
self.char_itos = []
1392
self.char_stoi = {}
14-
self.bpe_itos = []
1593
self.bpe = []
16-
self.common = {}
94+
self.popular_words = {}
95+
96+
self.bpe_itos = []
97+
self.bpe_stoi = {}
98+
self.pairs = {}
99+
self.encoder = None
17100

101+
def load(self, char_itos, char_stoi, bpe):
102+
self.char_itos = char_itos
103+
self.char_stoi = char_stoi
104+
self.bpe = bpe
105+
106+
self.calc()
107+
108+
def set_popular_words(self, popular_words):
109+
self.popular_words = popular_words
110+
111+
def calc(self):
18112
self.bpe_itos = self.calc_bpe_itos()
113+
self.bpe_stoi = {s: i for i, s in enumerate(self.bpe_itos)}
114+
self.pairs = {(p[0], p[1]): c for c, p in enumerate(self.bpe) if isinstance(p, tuple)}
115+
116+
self.encoder = _BPEEncoder(self.pairs)
19117

20118
def to_char_stoi(self, w: str):
21119
return [self.char_stoi[c] for c in w]
22120

23121
def calc_bpe_itos(self):
24122
itos = list(self.char_itos)
25-
itos += [itos[p1] + itos[p2] for p1, p2 in self.bpe[len(self.char_itos):]]
123+
for p1, p2 in self.bpe[len(self.char_itos):]:
124+
itos.append(itos[p1] + itos[p2])
26125
return itos
27126

127+
def encode(self, word: str):
128+
if word in self.popular_words:
129+
return self.popular_words[word]
130+
131+
return self.encoder.encode([self.char_stoi[c] for c in word])
132+
28133

29134
class Tokenizer:
30135
def collect_words(self, data: str):
@@ -284,7 +389,7 @@ def main():
284389
path = lab.get_data_path() / 'train.py'
285390

286391
with open(str(path), 'r') as f:
287-
data = f.read()[:100_000]
392+
data = f.read()
288393

289394
tokenizer = SourceCodeTokenizer()
290395
tokenizer.collect_words(data)
@@ -295,6 +400,15 @@ def main():
295400
print(bpe.bpe_itos()[len(bpe.char_itos):])
296401
print(len(data), bpe.get_length())
297402

403+
cache_set('bpe', {
404+
'char_itos': bpe.char_itos,
405+
'char_stoi': bpe.char_stoi,
406+
'bpe': bpe.bpe
407+
})
408+
409+
bpe_en_de = BPEEnDe()
410+
bpe_en_de.load(bpe.char_itos, bpe.char_stoi, bpe.bpe)
411+
298412

299413
if __name__ == '__main__':
300414
main()

python_autocomplete/train.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
1717
from labml_nn.optimizers.configs import OptimizerConfigs
1818
from labml_nn.transformers import TransformerConfigs
19+
from python_autocomplete.bpe import BPE, SourceCodeTokenizer
1920

2021

2122
class SourceCodeDataset(TextDataset):
@@ -32,6 +33,23 @@ def __init__(self, path: PurePath, tokenizer: Callable):
3233
stoi=cache_get('stoi'))
3334

3435

36+
class BPESourceCodeDataset(TextDataset):
37+
tokenizer: BPE
38+
39+
def __init__(self, path: PurePath, bpe: BPE):
40+
with monit.section("Load data"):
41+
train = self.load(path / 'train.py') # [:1000_000]
42+
valid = self.load(path / 'valid.py') # [:1000_000]
43+
44+
super().__init__(path, bpe, train, valid, '',
45+
n_tokens=bpe.n_tokens,
46+
itos=bpe.itos,
47+
stoi=bpe.stoi)
48+
49+
def text_to_i(self, text: str) -> torch.Tensor:
50+
return torch.tensor(self.tokenizer.encode(text))
51+
52+
3553
class Configs(TrainValidConfigs):
3654
optimizer: torch.optim.Adam
3755
device: torch.device = DeviceConfigs()
@@ -268,6 +286,22 @@ def source_code(c: Configs):
268286
return SourceCodeDataset(lab.get_data_path(), c.tokenizer)
269287

270288

289+
@option(Configs.text)
290+
def source_code_bpe(c: Configs):
291+
from labml.utils.cache import cache_get
292+
from python_autocomplete.bpe import BPEEnDe
293+
bpe_cache = cache_get('bpe')
294+
295+
if bpe_cache:
296+
bpe_en_de = BPEEnDe()
297+
bpe_en_de.load(**bpe_cache)
298+
else:
299+
raise RuntimeError('BPE not cached')
300+
301+
tokenizer = BPE(bpe_en_de, SourceCodeTokenizer())
302+
return BPESourceCodeDataset(lab.get_data_path(), tokenizer)
303+
304+
271305
@option(Configs.train_loader)
272306
def sequential_train_loader(c: Configs):
273307
return SequentialDataLoader(text=c.text.train,
@@ -316,26 +350,27 @@ def main():
316350
conf = Configs()
317351
# Assign one of transformer_mode, lstm_model, or rhn_model
318352
experiment.create(name="source_code",
319-
comment='transformer xl model')
353+
comment='bpe')
320354
experiment.configs(conf, {
321-
# 'model': 'transformer_model',
322-
'model': 'transformer_xl_model',
355+
# 'text': 'source_code',
356+
'text': 'source_code_bpe',
357+
'model': 'transformer_model',
358+
# 'model': 'transformer_xl_model',
323359
'n_layers': 6,
324360
'batch_size': 12,
325361
'epochs': 32,
326362
'optimizer.optimizer': 'Noam',
327363
'optimizer.learning_rate': 1.0,
328364
'device.cuda_device': 0,
329365
'seq_len': 512,
330-
'is_token_by_token': True,
331-
# 'train_loader': 'shuffled_train_loader',
332-
# 'valid_loader': 'shuffled_valid_loader',
333-
'train_loader': 'sequential_train_loader',
334-
'valid_loader': 'sequential_valid_loader',
366+
'is_token_by_token': False,
367+
'state_updater': 'simple',
368+
'train_loader': 'shuffled_train_loader',
369+
'valid_loader': 'shuffled_valid_loader',
370+
# 'train_loader': 'sequential_train_loader',
371+
# 'valid_loader': 'sequential_valid_loader',
335372
})
336373
experiment.add_pytorch_models(model=conf.model)
337-
# experiment.load('70df7f86450911eb887b25e3927208f3')
338-
experiment.load('c45857026a2811eba16c27c69839e51f')
339374
with experiment.start():
340375
conf.run()
341376

0 commit comments

Comments
 (0)