Skip to content

Commit 876fde7

Browse files
committed
bpe evaluate
1 parent 014c50c commit 876fde7

3 files changed

Lines changed: 53 additions & 54 deletions

File tree

python_autocomplete/bpe.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ def itos(self):
2626
def stoi(self):
2727
return self.bpe.bpe_stoi
2828

29-
def encode(self, data: str):
30-
words = self.tokenizer.tokenize(data)
29+
def encode(self, data: str, *, is_silent: bool = False):
30+
words = self.tokenizer.tokenize(data, is_silent=is_silent)
3131

3232
res = []
33-
for w in monit.iterate('Encode words', words):
33+
for w in monit.iterate('Encode words', words, is_silent=is_silent):
3434
res += self.bpe.encode(w)
3535

3636
return res
@@ -131,7 +131,7 @@ def encode(self, word: str):
131131
if word in self.popular_words:
132132
return self.popular_words[word]
133133

134-
return self.encoder.encode([self.char_stoi[c] for c in word])
134+
return self.encoder.encode([self.char_stoi[c] for c in word if c in self.char_stoi])
135135

136136

137137
class Tokenizer:
@@ -141,7 +141,7 @@ def collect_words(self, data: str):
141141
def get_words(self) -> Tuple[List[str], List[int]]:
142142
raise NotImplementedError
143143

144-
def tokenize(self, data: str) -> List[str]:
144+
def tokenize(self, data: str, *, is_silent: bool = False) -> List[str]:
145145
raise NotImplementedError
146146

147147

@@ -158,12 +158,12 @@ def add_word(self, word):
158158
else:
159159
self.words[word] += 1
160160

161-
def tokenize(self, data: str) -> List[str]:
161+
def tokenize(self, data: str, *, is_silent: bool = False) -> List[str]:
162162
last_idx = 0
163163
is_id = False
164164
res = []
165165

166-
for i, c in monit.enum('Collect words', data):
166+
for i, c in monit.enum('Collect words', data, is_silent=is_silent):
167167
if c in ID_CHARS:
168168
if not is_id:
169169
if last_idx < i:
@@ -217,7 +217,7 @@ def collect_words(self, data):
217217
def get_words(self):
218218
return [self.data], [1]
219219

220-
def tokenize(self, data: str) -> List[str]:
220+
def tokenize(self, data: str, *, is_silent: bool = False) -> List[str]:
221221
return [data]
222222

223223

python_autocomplete/evaluate.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,19 @@
99
from labml import experiment, logger, lab, monit
1010
from labml.logger import Text, Style
1111
from labml.utils.pytorch import get_modules
12+
from labml_helpers.datasets.text import TextDataset
1213
from labml_helpers.module import Module
1314
from python_autocomplete.train import Configs, StateUpdater
1415

1516

1617
class Predictor:
17-
def __init__(self, model: Module, stoi: Dict[str, int], itos: List[str], *,
18+
def __init__(self, model: Module, text: TextDataset, *,
1819
state_updater: StateUpdater,
1920
is_token_by_token: bool):
21+
text.is_silent = True
22+
self.text = text
2023
self.is_token_by_token = is_token_by_token
2124
self.state_updater = state_updater
22-
self.stoi = stoi
23-
self.itos = itos
2425
self.model = model
2526

2627
# For timing
@@ -29,10 +30,8 @@ def __init__(self, model: Module, stoi: Dict[str, int], itos: List[str], *,
2930
self.time_check = 0
3031

3132
def _get_predictions(self, prompt: str, state: Any) -> Tuple[torch.Tensor, Any]:
32-
prompt = prompt[-512:]
33-
data = torch.tensor([[self.stoi[c]] for c in prompt if c in self.stoi],
34-
dtype=torch.long,
35-
device=self.model.device)
33+
data = self.text.text_to_i(prompt)[-512:]
34+
data = data.to(self.model.device).unsqueeze(-1)
3635

3736
# Get predictions
3837
with torch.no_grad():
@@ -58,7 +57,7 @@ def get_probabilities(self, prompt: str, state: Any) -> Tuple[np.ndarray, Any]:
5857
def get_next_token(self, prompt: str, state: Any) -> Tuple[str, Any]:
5958
prediction, state = self.get_predictions(prompt, state)
6059
best = prediction.argmax(-1).squeeze().item()
61-
return self.itos[best], state
60+
return self.text.itos[best], state
6261

6362
def get_start_state(self, prompt: str):
6463
assert prompt
@@ -152,10 +151,10 @@ def anomalies(predictor: Predictor, text: str):
152151
logs = [(f"{line_no: 4d}: ", Text.meta)]
153152
elif c == '\r':
154153
continue
155-
elif c not in predictor.stoi:
154+
elif c not in predictor.text.stoi:
156155
logs.append(c)
157156
else:
158-
next_id = predictor.stoi[c]
157+
next_id = predictor.text.stoi[c]
159158
prob = preds[next_id]
160159
if prob > 0.9:
161160
logs.append((c, [Style.bold, Text.success, Style.underline]))
@@ -220,21 +219,22 @@ def get_predictor():
220219
# And for latest checkpoint
221220
# checkpoint = None
222221

223-
run_uuid = '41dc02106d1611eb9ab213fdf628e807' # bpe
222+
run_uuid = '275e62e66dc711eb9d162f2ddfc33452' # bpe
224223
# run_uuid = 'c45857026a2811eba16c27c69839e51f' # xl
225224
checkpoint = None
226-
# run_uuid, checkpoint = experiment.load_bundle(
227-
# lab.get_path() / 'saved_checkpoint.tar.gz',
228-
# url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz')
225+
run_uuid, checkpoint = experiment.load_bundle(
226+
lab.get_path() / 'saved_checkpoint.tar.gz',
227+
url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz')
229228

230229
conf_dict = experiment.load_configs(run_uuid)
230+
conf_dict['is_load_data'] = False
231231
experiment.configs(conf, conf_dict)
232232
experiment.add_pytorch_models(get_modules(conf))
233233
experiment.load(run_uuid, checkpoint)
234234

235235
experiment.start()
236236
conf.model.eval()
237-
return Predictor(conf.model, conf.stoi, conf.itos,
237+
return Predictor(conf.model, conf.text,
238238
state_updater=conf.state_updater,
239239
is_token_by_token=conf.is_token_by_token)
240240

python_autocomplete/train.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,34 +20,48 @@
2020

2121

2222
class SourceCodeDataset(TextDataset):
23-
def __init__(self, path: PurePath, tokenizer: Callable):
24-
with monit.section("Load data"):
25-
train = self.load(path / 'train.py') # [:100000]
26-
valid = self.load(path / 'valid.py') # [:100000]
23+
def __init__(self, path: PurePath, tokenizer: Callable, dont_load: bool):
24+
if not dont_load:
25+
with monit.section("Load data"):
26+
train = self.load(path / 'train.py') # [:100000]
27+
valid = self.load(path / 'valid.py') # [:100000]
28+
else:
29+
train = ''
30+
valid = ''
2731

28-
from labml.utils.cache import cache_get
32+
from labml.utils.cache import cache_get, cache_set
2933

3034
super().__init__(path, tokenizer, train, valid, '',
3135
n_tokens=cache_get('n_tokens'),
3236
itos=cache_get('itos'),
3337
stoi=cache_get('stoi'))
3438

39+
cache_set(f'n_tokens', self.n_tokens)
40+
cache_set(f'itos', self.itos)
41+
cache_set(f'stoi', self.stoi)
42+
3543

3644
class BPESourceCodeDataset(TextDataset):
3745
tokenizer: BPE
3846

39-
def __init__(self, path: PurePath, bpe: BPE):
40-
with monit.section("Load data"):
41-
train = self.load(path / 'train.py') # [:100_000]
42-
valid = self.load(path / 'valid.py') # [:100_000]
47+
def __init__(self, path: PurePath, bpe: BPE, dont_load: bool):
48+
if not dont_load:
49+
with monit.section("Load data"):
50+
train = self.load(path / 'train.py') # [:100_000]
51+
valid = self.load(path / 'valid.py') # [:100_000]
52+
else:
53+
train = ''
54+
valid = ''
55+
56+
self.is_silent = False
4357

4458
super().__init__(path, bpe, train, valid, '',
4559
n_tokens=bpe.n_tokens,
4660
itos=bpe.itos,
4761
stoi=bpe.stoi)
4862

4963
def text_to_i(self, text: str) -> torch.Tensor:
50-
return torch.tensor(self.tokenizer.encode(text))
64+
return torch.tensor(self.tokenizer.encode(text, is_silent=self.is_silent))
5165

5266

5367
class Configs(TrainValidConfigs):
@@ -80,10 +94,8 @@ class Configs(TrainValidConfigs):
8094
grad_norm_clip: float = 1.0
8195
is_token_by_token: bool = False
8296

83-
itos: List[str]
84-
stoi: Dict[str, int]
85-
8697
cache_name: str = ''
98+
is_load_data: bool = True
8799

88100
def init(self):
89101
tracker.set_queue("loss.*", 20, True)
@@ -129,10 +141,10 @@ def sample(self):
129141
data = data.to(self.device)
130142
output, new_state = self.model(data, state)
131143
output = output.argmax(dim=-1).squeeze(1)
132-
prompt += '' + self.itos[output[-1]]
144+
prompt += '' + self.text.itos[output[-1]]
133145
if self.is_token_by_token:
134146
prompt = prompt[-1:]
135-
log += [('' + self.itos[output[-1]], Text.value)]
147+
log += [('' + self.text.itos[output[-1]], Text.value)]
136148
state = self.state_updater(state, new_state)
137149

138150
logger.log(log)
@@ -177,20 +189,7 @@ def _loss_func(c: Configs):
177189

178190
@option(Configs.n_tokens)
179191
def _n_tokens(c: Configs):
180-
from labml.utils.cache import cache
181-
return cache(f'n_tokens{c.cache_name}', lambda: c.text.n_tokens)
182-
183-
184-
@option(Configs.itos)
185-
def _itos(c: Configs):
186-
from labml.utils.cache import cache
187-
return cache(f'itos{c.cache_name}', lambda: c.text.itos)
188-
189-
190-
@option(Configs.stoi)
191-
def _stoi(c: Configs):
192-
from labml.utils.cache import cache
193-
return cache(f'stoi{c.cache_name}', lambda: c.text.stoi)
192+
return c.text.n_tokens
194193

195194

196195
@option(Configs.model)
@@ -285,7 +284,7 @@ def character():
285284

286285
@option(Configs.text)
287286
def source_code(c: Configs):
288-
return SourceCodeDataset(lab.get_data_path(), c.tokenizer)
287+
return SourceCodeDataset(lab.get_data_path(), c.tokenizer, c.is_load_data)
289288

290289

291290
@option(Configs.text)
@@ -301,7 +300,7 @@ def source_code_bpe(c: Configs):
301300
raise RuntimeError('BPE not cached')
302301

303302
tokenizer = BPE(bpe_en_de, SourceCodeTokenizer())
304-
return BPESourceCodeDataset(lab.get_data_path(), tokenizer)
303+
return BPESourceCodeDataset(lab.get_data_path(), tokenizer, c.is_load_data)
305304

306305

307306
@option(Configs.train_loader)

0 commit comments

Comments
 (0)