Skip to content

Commit e0e4e1e

Browse files
committed
evaluate xl
1 parent 688bc88 commit e0e4e1e

3 files changed

Lines changed: 44 additions & 31 deletions

File tree

python_autocomplete/evaluate.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,81 @@
11
import string
2-
from typing import List, Dict, Set, Optional
2+
from typing import List, Dict, Set, Optional, Any, Tuple
33

4+
import numpy as np
45
import torch
56
import torch.nn
6-
from labml.utils.cache import cache
77
from torch import nn
88

99
from labml import experiment, logger, lab
1010
from labml.logger import Text, Style
1111
from labml.utils.pytorch import get_modules
1212
from labml_helpers.module import Module
13-
from python_autocomplete.train import Configs
13+
from python_autocomplete.train import Configs, StateUpdater
1414

1515

1616
class Predictor:
17-
def __init__(self, model: Module, stoi: Dict[str, int], itos: List[str]):
17+
def __init__(self, model: Module, stoi: Dict[str, int], itos: List[str], *,
18+
state_updater: StateUpdater,
19+
is_token_by_token: bool):
20+
self.is_token_by_token = is_token_by_token
21+
self.state_updater = state_updater
1822
self.stoi = stoi
1923
self.itos = itos
2024
self.model = model
2125

22-
# Initial state
23-
self._state = None
24-
2526
# For timing
2627
self.time_add = 0
2728
self.time_predict = 0
2829
self.time_check = 0
2930

30-
def _get_predictions(self, prompt: str) -> torch.Tensor:
31+
def _get_predictions(self, prompt: str, state: Any) -> Tuple[torch.Tensor, Any]:
3132
prompt = prompt[-512:]
3233
data = torch.tensor([[self.stoi[c]] for c in prompt if c in self.stoi],
3334
dtype=torch.long,
3435
device=self.model.device)
3536

3637
# Get predictions
3738
with torch.no_grad():
38-
prediction, *_ = self.model(data)
39+
prediction, new_state = self.model(data, state)
40+
41+
state = self.state_updater(state, new_state)
3942

4043
# Final prediction
41-
return prediction[-1, :, :]
44+
return prediction[-1, :, :], state
4245

43-
def get_predictions(self, prompt: str) -> torch.Tensor:
44-
prediction = self._get_predictions(prompt)
46+
def get_predictions(self, prompt: str, state: Any) -> Tuple[np.ndarray, Any]:
47+
prediction, state = self._get_predictions(prompt, state)
4548

46-
return prediction.detach().cpu().numpy()
49+
return prediction.detach().cpu().numpy(), state
4750

48-
def get_probabilities(self, prompt: str) -> torch.Tensor:
51+
def get_probabilities(self, prompt: str, state: Any) -> Tuple[np.ndarray, Any]:
4952
# Final prediction
50-
prediction = nn.Softmax(-1)(self._get_predictions(prompt))
53+
prediction, state = self._get_predictions(prompt, state)
54+
prediction = nn.Softmax(-1)(prediction)
5155

52-
return prediction.detach().cpu().numpy()
56+
return prediction.detach().cpu().numpy(), state
5357

54-
def get_next_char(self, prompt: str) -> str:
55-
prediction = self.get_predictions(prompt)
58+
def get_next_char(self, prompt: str, state: Any) -> Tuple[str, Any]:
59+
prediction, state = self.get_predictions(prompt, state)
5660
best = prediction.argmax(-1).squeeze().item()
57-
return self.itos[best]
61+
return self.itos[best], state
5862

59-
def get_token(self, prompt: str, token_chars: Optional[Set[str]] = None) -> str:
63+
def get_token(self, prompt: str, token_chars: Optional[Set[str]], state: Any) -> Tuple[str, Any]:
6064
result = ''
6165
if token_chars is None:
6266
token_chars = set(string.ascii_letters + string.digits + ' ' + '\n' + '\r')
6367
while True:
64-
next_char = self.get_next_char(prompt)
68+
next_char, state = self.get_next_char(prompt, state)
6569
if len(result) > 2 and next_char not in token_chars or (next_char.strip() == '' and result.strip() != ''):
6670
if not result:
6771
result += next_char
68-
return result
72+
return result, state
6973
result += next_char
7074
if len(result) > 20:
71-
return result
75+
return result, state
7276
prompt += next_char
77+
if self.is_token_by_token:
78+
prompt = prompt[-1:]
7379

7480

7581
def evaluate(predictor: Predictor, text: str):
@@ -82,7 +88,7 @@ def evaluate(predictor: Predictor, text: str):
8288
key_strokes = 0
8389

8490
while i + 1 < len(text):
85-
next_token = predictor.get_token(text[:i + 1])
91+
next_token, state = predictor.get_token(text[:i + 1], None, None)
8692
if next_token == text[i + 1: i + 1 + len(next_token)]:
8793
correct += len(next_token)
8894
right = True
@@ -124,7 +130,8 @@ def anomalies(predictor: Predictor, text: str):
124130

125131
while i + 1 < len(text):
126132
# print(i, self.predictor.prompt)
127-
preds = predictor.get_probabilities(text[:i + 1])[0, :]
133+
preds, _ = predictor.get_probabilities(text[:i + 1], None)
134+
preds = preds[0, :]
128135
c = text[i + 1]
129136

130137
if c == '\n':
@@ -169,7 +176,7 @@ def complete(predictor: Predictor, text: str, completion: int):
169176
if len(text) > i + 1:
170177
c = text[i + 1]
171178
else:
172-
c = predictor.get_next_char(text[:i + 1])
179+
c, _ = predictor.get_next_char(text[:i + 1], None)
173180

174181
if c == '\n':
175182
logger.log(logs)
@@ -200,9 +207,12 @@ def get_predictor():
200207
# run_uuid = 'RUN_UUID'
201208
# And for latest checkpoint
202209
# checkpoint = None
203-
run_uuid, checkpoint = experiment.load_bundle(
204-
lab.get_path() / 'saved_checkpoint.tar.gz',
205-
url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz')
210+
211+
run_uuid = 'c45857026a2811eba16c27c69839e51f'
212+
checkpoint = None
213+
# run_uuid, checkpoint = experiment.load_bundle(
214+
# lab.get_path() / 'saved_checkpoint.tar.gz',
215+
# url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz')
206216

207217
conf_dict = experiment.load_configs(run_uuid)
208218
experiment.configs(conf, conf_dict)
@@ -211,7 +221,9 @@ def get_predictor():
211221

212222
experiment.start()
213223
conf.model.eval()
214-
return Predictor(conf.model, conf.stoi, conf.itos)
224+
return Predictor(conf.model, conf.stoi, conf.itos,
225+
state_updater=conf.state_updater,
226+
is_token_by_token=conf.is_token_by_token)
215227

216228

217229
def main():

python_autocomplete/serve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def autocomplete():
2828
with monit.section('Predict') as s:
2929
acquired = lock.acquire(blocking=False)
3030
if acquired:
31-
res = predictor.get_token(prompt, token_chars=TOKEN_CHARS)
31+
res, state = predictor.get_token(prompt, TOKEN_CHARS, None)
3232
lock.release()
3333
s.message = f'{json.dumps(prompt[-5:])} -> {json.dumps(res)}'
3434
return jsonify({'success': True, 'prediction': res})

python_autocomplete/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def main():
335335
})
336336
experiment.add_pytorch_models(model=conf.model)
337337
# experiment.load('70df7f86450911eb887b25e3927208f3')
338+
experiment.load('c45857026a2811eba16c27c69839e51f')
338339
with experiment.start():
339340
conf.run()
340341

0 commit comments

Comments
 (0)