Skip to content

Commit 0de5569

Browse files
committed
bpe dataset
1 parent 876fde7 commit 0de5569

3 files changed

Lines changed: 201 additions & 142 deletions

File tree

python_autocomplete/bpe.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class BPE:
1313
def __init__(self, bpe_en_de: 'BPEEnDe', tokenizer):
1414
self.bpe = bpe_en_de
1515
self.tokenizer = tokenizer
16+
self.is_trained = True
1617

1718
@property
1819
def n_tokens(self):
@@ -134,7 +135,7 @@ def encode(self, word: str):
134135
return self.encoder.encode([self.char_stoi[c] for c in word if c in self.char_stoi])
135136

136137

137-
class Tokenizer:
138+
class WordTokenizer:
138139
def collect_words(self, data: str):
139140
raise NotImplementedError
140141

@@ -145,7 +146,7 @@ def tokenize(self, data: str, *, is_silent: bool = False) -> List[str]:
145146
raise NotImplementedError
146147

147148

148-
class SourceCodeTokenizer(Tokenizer):
149+
class SourceCodeTokenizer(WordTokenizer):
149150
def __init__(self):
150151
self.words = {}
151152

@@ -207,7 +208,7 @@ def get_words(self):
207208
return [w for _, w in words_list], [f for f, _ in words_list]
208209

209210

210-
class NoTokenizer(Tokenizer):
211+
class NoTokenizer(WordTokenizer):
211212
def __init__(self):
212213
self.data = ''
213214

python_autocomplete/dataset.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
from pathlib import PurePath
2+
from typing import Dict, List
3+
4+
import torch
5+
from torch.utils.data import Dataset, DataLoader
6+
7+
from labml import lab, monit
8+
from labml.configs import option, BaseConfigs
9+
from labml_helpers.datasets.text import TextDataset
10+
from python_autocomplete.bpe import BPE, SourceCodeTokenizer
11+
12+
13+
class Tokenizer:
14+
n_tokens: int
15+
itos: List[str]
16+
stoi: Dict[str, int]
17+
is_trained: int
18+
19+
def encode(self, data: str, *, is_silent: bool = False):
20+
raise NotImplementedError
21+
22+
def train(self, data: str):
23+
pass
24+
25+
26+
class CharacterTokenizer(Tokenizer):
27+
def __init__(self, retrain: bool):
28+
from labml.utils.cache import cache_get
29+
30+
self.n_tokens = cache_get('n_tokens')
31+
self.itos = cache_get('itos')
32+
self.stoi = cache_get('stoi')
33+
if retrain is None:
34+
self.is_trained = self.n_tokens and self.itos and self.stoi
35+
else:
36+
self.is_trained = not retrain
37+
38+
def encode(self, data: str, *, is_silent: bool = False):
39+
return torch.tensor([self.stoi[c] for c in data if c in self.stoi], dtype=torch.long)
40+
41+
def train(self, data: str):
42+
with monit.section("Build vocabulary"):
43+
self.itos = list(sorted(list(set(data))))
44+
self.n_tokens = len(self.itos)
45+
self.stoi = {c: i for i, c in enumerate(self.itos)}
46+
47+
from labml.utils.cache import cache_set
48+
49+
cache_set(f'n_tokens', self.n_tokens)
50+
cache_set(f'itos', self.itos)
51+
cache_set(f'stoi', self.stoi)
52+
53+
54+
class SourceCodeDataset:
55+
@staticmethod
56+
def load(path: PurePath):
57+
with open(str(path), 'r') as f:
58+
return f.read()
59+
60+
@staticmethod
61+
def get_train_valid(path: PurePath, is_load_data: bool):
62+
if is_load_data:
63+
with monit.section("Load data"):
64+
train = TextDataset.load(path / 'train.py')[:1000_000]
65+
valid = TextDataset.load(path / 'valid.py')[:1000_000]
66+
else:
67+
train = ''
68+
valid = ''
69+
70+
return train, valid
71+
72+
def __init__(self, tokenizer: Tokenizer, train, valid):
73+
self.train = train
74+
self.valid = valid
75+
self.tokenizer = tokenizer
76+
77+
def __repr__(self):
78+
return f'{len(self.train) / 1_000_000 :,.2f}M, {len(self.valid) / 1_000_000 :,.2f}'
79+
80+
81+
class SourceCodeDataConfigs(BaseConfigs):
82+
dataset: SourceCodeDataset
83+
is_load_data: bool = True
84+
tokenizer: Tokenizer
85+
retrain_tokenizer: bool = True
86+
87+
train_loader: DataLoader
88+
valid_loader: DataLoader
89+
is_shuffle: bool = True
90+
batch_size: int
91+
seq_len: int
92+
93+
def text_to_i(self, text: str, *, is_silent: bool = True) -> torch.Tensor:
94+
return torch.tensor(self.tokenizer.encode(text, is_silent=is_silent))
95+
96+
97+
@option(SourceCodeDataConfigs.dataset, 'default')
98+
def _dataset(c: SourceCodeDataConfigs):
99+
train, valid = SourceCodeDataset.get_train_valid(lab.get_data_path(), c.is_load_data)
100+
if not c.tokenizer.is_trained:
101+
c.tokenizer.train(train + valid)
102+
return SourceCodeDataset(c.tokenizer, train, valid)
103+
104+
105+
@option(SourceCodeDataConfigs.tokenizer, 'bpe')
106+
def _bpe_tokenizer():
107+
from labml.utils.cache import cache_get
108+
from python_autocomplete.bpe import BPEEnDe
109+
bpe_cache = cache_get('bpe')
110+
111+
if bpe_cache:
112+
bpe_en_de = BPEEnDe()
113+
bpe_en_de.load(**bpe_cache)
114+
else:
115+
raise RuntimeError('BPE not cached')
116+
117+
return BPE(bpe_en_de, SourceCodeTokenizer())
118+
119+
120+
@option(SourceCodeDataConfigs.tokenizer, 'char')
121+
def _char_tokenizer(c: SourceCodeDataConfigs):
122+
return CharacterTokenizer(c.retrain_tokenizer)
123+
124+
125+
# Data loaders
126+
class TokenDataset(Dataset):
127+
def __init__(self, *,
128+
data: torch.Tensor,
129+
batch_size: int,
130+
seq_len: int,
131+
drop_last: bool = False):
132+
self.batch_size = batch_size
133+
self.seq_len = seq_len
134+
self.data = data
135+
self.n_samples = (self.data.shape[0] - 1) // self.seq_len
136+
if drop_last:
137+
self.n_batches = (self.n_samples + batch_size - 1) // batch_size
138+
else:
139+
self.n_batches = self.n_samples // batch_size
140+
141+
def __len__(self):
142+
return (self.data.shape[0] - 1) // self.seq_len
143+
144+
def __getitem__(self, idx):
145+
batch = idx // self.batch_size
146+
batch_idx = idx % self.batch_size
147+
idx = batch_idx * self.n_batches + batch
148+
start = idx * self.seq_len
149+
assert start + self.seq_len + 1 <= self.data.shape[0]
150+
end = start + self.seq_len
151+
data = self.data[start: end]
152+
target = self.data[start + 1: end + 1]
153+
return data, target
154+
155+
156+
def transpose_batch(batch):
157+
transposed_data = list(zip(*batch))
158+
src = torch.stack(transposed_data[0], 1)
159+
tgt = torch.stack(transposed_data[1], 1)
160+
161+
return src, tgt
162+
163+
164+
@option(SourceCodeDataConfigs.train_loader)
165+
def _train_loader(c: SourceCodeDataConfigs):
166+
return DataLoader(TokenDataset(data=c.text_to_i(c.dataset.train, is_silent=False),
167+
batch_size=c.batch_size,
168+
seq_len=c.seq_len),
169+
batch_size=c.batch_size,
170+
collate_fn=transpose_batch,
171+
shuffle=c.is_shuffle)
172+
173+
174+
@option(SourceCodeDataConfigs.valid_loader)
175+
def _valid_loader(c: SourceCodeDataConfigs):
176+
return DataLoader(TokenDataset(data=c.text_to_i(c.dataset.valid, is_silent=False),
177+
batch_size=c.batch_size,
178+
seq_len=c.seq_len),
179+
batch_size=c.batch_size,
180+
collate_fn=transpose_batch,
181+
shuffle=c.is_shuffle)

0 commit comments

Comments
 (0)