Skip to content

Commit c20788c

Browse files
committed
bpe dataset evaluate
1 parent 0de5569 commit c20788c

3 files changed

Lines changed: 20 additions & 21 deletions

File tree

python_autocomplete/bpe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def itos(self):
2727
def stoi(self):
2828
return self.bpe.bpe_stoi
2929

30-
def encode(self, data: str, *, is_silent: bool = False):
30+
def encode(self, data: str, *, is_silent: bool = True):
3131
words = self.tokenizer.tokenize(data, is_silent=is_silent)
3232

3333
res = []

python_autocomplete/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class Tokenizer:
1616
stoi: Dict[str, int]
1717
is_trained: int
1818

19-
def encode(self, data: str, *, is_silent: bool = False):
19+
def encode(self, data: str, *, is_silent: bool = True):
2020
raise NotImplementedError
2121

2222
def train(self, data: str):
@@ -35,7 +35,7 @@ def __init__(self, retrain: bool):
3535
else:
3636
self.is_trained = not retrain
3737

38-
def encode(self, data: str, *, is_silent: bool = False):
38+
def encode(self, data: str, *, is_silent: bool = True):
3939
return torch.tensor([self.stoi[c] for c in data if c in self.stoi], dtype=torch.long)
4040

4141
def train(self, data: str):
@@ -91,7 +91,7 @@ class SourceCodeDataConfigs(BaseConfigs):
9191
seq_len: int
9292

9393
def text_to_i(self, text: str, *, is_silent: bool = True) -> torch.Tensor:
94-
return torch.tensor(self.tokenizer.encode(text, is_silent=is_silent))
94+
return torch.tensor(self.tokenizer.encode(text, is_silent=is_silent), dtype=torch.long)
9595

9696

9797
@option(SourceCodeDataConfigs.dataset, 'default')

python_autocomplete/evaluate.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import string
2-
from typing import List, Dict, Set, Optional, Any, Tuple
2+
from typing import Set, Optional, Any, Tuple
33

44
import numpy as np
55
import torch
@@ -9,17 +9,16 @@
99
from labml import experiment, logger, lab, monit
1010
from labml.logger import Text, Style
1111
from labml.utils.pytorch import get_modules
12-
from labml_helpers.datasets.text import TextDataset
1312
from labml_helpers.module import Module
13+
from python_autocomplete.dataset import Tokenizer
1414
from python_autocomplete.train import Configs, StateUpdater
1515

1616

1717
class Predictor:
18-
def __init__(self, model: Module, text: TextDataset, *,
18+
def __init__(self, model: Module, tokenizer: Tokenizer, *,
1919
state_updater: StateUpdater,
2020
is_token_by_token: bool):
21-
text.is_silent = True
22-
self.text = text
21+
self.tokenizer = tokenizer
2322
self.is_token_by_token = is_token_by_token
2423
self.state_updater = state_updater
2524
self.model = model
@@ -30,8 +29,9 @@ def __init__(self, model: Module, text: TextDataset, *,
3029
self.time_check = 0
3130

3231
def _get_predictions(self, prompt: str, state: Any) -> Tuple[torch.Tensor, Any]:
33-
data = self.text.text_to_i(prompt)[-512:]
34-
data = data.to(self.model.device).unsqueeze(-1)
32+
data = torch.tensor(self.tokenizer.encode(prompt),
33+
dtype=torch.long,
34+
device=self.model.device)[-512:].unsqueeze(-1)
3535

3636
# Get predictions
3737
with torch.no_grad():
@@ -57,7 +57,7 @@ def get_probabilities(self, prompt: str, state: Any) -> Tuple[np.ndarray, Any]:
5757
def get_next_token(self, prompt: str, state: Any) -> Tuple[str, Any]:
5858
prediction, state = self.get_predictions(prompt, state)
5959
best = prediction.argmax(-1).squeeze().item()
60-
return self.text.itos[best], state
60+
return self.tokenizer.itos[best], state
6161

6262
def get_start_state(self, prompt: str):
6363
assert prompt
@@ -151,10 +151,10 @@ def anomalies(predictor: Predictor, text: str):
151151
logs = [(f"{line_no: 4d}: ", Text.meta)]
152152
elif c == '\r':
153153
continue
154-
elif c not in predictor.text.stoi:
154+
elif c not in predictor.tokenizer.stoi:
155155
logs.append(c)
156156
else:
157-
next_id = predictor.text.stoi[c]
157+
next_id = predictor.tokenizer.stoi[c]
158158
prob = preds[next_id]
159159
if prob > 0.9:
160160
logs.append((c, [Style.bold, Text.success, Style.underline]))
@@ -219,22 +219,21 @@ def get_predictor():
219219
# And for latest checkpoint
220220
# checkpoint = None
221221

222-
run_uuid = '275e62e66dc711eb9d162f2ddfc33452' # bpe
223-
# run_uuid = 'c45857026a2811eba16c27c69839e51f' # xl
222+
run_uuid = '109d1b8c6e8611eb80e13584488b68a4' # bpe
224223
checkpoint = None
225-
run_uuid, checkpoint = experiment.load_bundle(
226-
lab.get_path() / 'saved_checkpoint.tar.gz',
227-
url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz')
224+
# run_uuid, checkpoint = experiment.load_bundle(
225+
# lab.get_path() / 'saved_checkpoint.tar.gz',
226+
# url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz')
228227

229228
conf_dict = experiment.load_configs(run_uuid)
230-
conf_dict['is_load_data'] = False
229+
conf_dict['text.is_load_data'] = False
231230
experiment.configs(conf, conf_dict)
232231
experiment.add_pytorch_models(get_modules(conf))
233232
experiment.load(run_uuid, checkpoint)
234233

235234
experiment.start()
236235
conf.model.eval()
237-
return Predictor(conf.model, conf.text,
236+
return Predictor(conf.model, conf.text.tokenizer,
238237
state_updater=conf.state_updater,
239238
is_token_by_token=conf.is_token_by_token)
240239

0 commit comments

Comments
 (0)