Skip to content

Commit 47153aa

Browse files
committed
bug fixes
1 parent 7094232 commit 47153aa

4 files changed

Lines changed: 22 additions & 18 deletions

File tree

python_autocomplete/bpe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def encode(self, codes: List[int]):
5858

5959
while self.heap:
6060
_, idx, pair = heappop(self.heap)
61+
self.merge(idx, pair)
6162

6263
return [c for c in self.codes if c != -1]
6364

@@ -111,7 +112,7 @@ def set_popular_words(self, popular_words):
111112
def calc(self):
112113
self.bpe_itos = self.calc_bpe_itos()
113114
self.bpe_stoi = {s: i for i, s in enumerate(self.bpe_itos)}
114-
self.pairs = {(p[0], p[1]): c for c, p in enumerate(self.bpe) if isinstance(p, tuple)}
115+
self.pairs = {(p[0], p[1]): c for c, p in enumerate(self.bpe) if not isinstance(p, int)}
115116

116117
self.encoder = _BPEEncoder(self.pairs)
117118

python_autocomplete/evaluate.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.nn
77
from torch import nn
88

9-
from labml import experiment, logger, lab
9+
from labml import experiment, logger, lab, monit
1010
from labml.logger import Text, Style
1111
from labml.utils.pytorch import get_modules
1212
from labml_helpers.module import Module
@@ -55,7 +55,7 @@ def get_probabilities(self, prompt: str, state: Any) -> Tuple[np.ndarray, Any]:
5555

5656
return prediction.detach().cpu().numpy(), state
5757

58-
def get_next_char(self, prompt: str, state: Any) -> Tuple[str, Any]:
58+
def get_next_token(self, prompt: str, state: Any) -> Tuple[str, Any]:
5959
prediction, state = self.get_predictions(prompt, state)
6060
best = prediction.argmax(-1).squeeze().item()
6161
return self.itos[best], state
@@ -68,25 +68,26 @@ def get_start_state(self, prompt: str):
6868
if not self.is_token_by_token:
6969
return prompt, None
7070

71-
_, state = self.get_next_char(prompt[:-1], None)
71+
_, state = self.get_next_token(prompt[:-1], None)
7272
return prompt[-1], state
7373

74-
def get_token(self, prompt: str, token_chars: Optional[Set[str]], state: Any) -> Tuple[str, Any]:
74+
def get_next_word(self, prompt: str, token_chars: Optional[Set[str]], state: Any) -> Tuple[str, Any]:
7575
result = ''
7676
if token_chars is None:
7777
token_chars = set(string.ascii_letters + string.digits + ' ' + '\n' + '\r')
7878
while True:
79-
next_char, state = self.get_next_char(prompt, state)
80-
if len(result) > 2 and next_char not in token_chars or (next_char.strip() == '' and result.strip() != ''):
79+
next_token, state = self.get_next_token(prompt, state)
80+
if len(result) > 2 and next_token not in token_chars or (next_token.strip() == '' and result.strip() != ''):
8181
if not result:
82-
result += next_char
82+
result += next_token
8383
return result, state
84-
result += next_char
84+
result += next_token
8585
if len(result) > 20:
8686
return result, state
87-
prompt += next_char
8887
if self.is_token_by_token:
89-
prompt = prompt[-1:]
88+
prompt = next_token
89+
else:
90+
prompt += next_token
9091

9192

9293
def evaluate(predictor: Predictor, text: str):
@@ -99,7 +100,7 @@ def evaluate(predictor: Predictor, text: str):
99100
key_strokes = 0
100101

101102
while i + 1 < len(text):
102-
next_token, state = predictor.get_token(text[:i + 1], None, None)
103+
next_token, state = predictor.get_next_word(text[:i + 1], None, None)
103104
if next_token == text[i + 1: i + 1 + len(next_token)]:
104105
correct += len(next_token)
105106
right = True
@@ -187,7 +188,7 @@ def complete(predictor: Predictor, text: str, completion: int):
187188
if len(text) > i + 1:
188189
c = text[i + 1]
189190
else:
190-
c, _ = predictor.get_next_char(text[:i + 1], None)
191+
c, _ = predictor.get_next_token(text[:i + 1], None)
191192

192193
if c == '\n':
193194
logger.log(logs)
@@ -219,7 +220,8 @@ def get_predictor():
219220
# And for latest checkpoint
220221
# checkpoint = None
221222

222-
run_uuid = 'c45857026a2811eba16c27c69839e51f'
223+
run_uuid = '41dc02106d1611eb9ab213fdf628e807' # bpe
224+
# run_uuid = 'c45857026a2811eba16c27c69839e51f' # xl
223225
checkpoint = None
224226
# run_uuid, checkpoint = experiment.load_bundle(
225227
# lab.get_path() / 'saved_checkpoint.tar.gz',
@@ -242,7 +244,8 @@ def main():
242244

243245
with open(str(lab.get_data_path() / 'sample.py'), 'r') as f:
244246
sample = f.read()
245-
evaluate(predictor, sample)
247+
with monit.section('Evaluate'):
248+
evaluate(predictor, sample)
246249

247250

248251
if __name__ == '__main__':

python_autocomplete/serve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def autocomplete():
2828
with monit.section('Predict') as s:
2929
acquired = lock.acquire(blocking=False)
3030
if acquired:
31-
res, state = predictor.get_token(prompt, TOKEN_CHARS, None)
31+
res, state = predictor.get_next_word(prompt, TOKEN_CHARS, None)
3232
lock.release()
3333
s.message = f'{json.dumps(prompt[-5:])} -> {json.dumps(res)}'
3434
return jsonify({'success': True, 'prediction': res})

python_autocomplete/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ class BPESourceCodeDataset(TextDataset):
3838

3939
def __init__(self, path: PurePath, bpe: BPE):
4040
with monit.section("Load data"):
41-
train = self.load(path / 'train.py') # [:1000_000]
42-
valid = self.load(path / 'valid.py') # [:1000_000]
41+
train = self.load(path / 'train.py') # [:100_000]
42+
valid = self.load(path / 'valid.py') # [:100_000]
4343

4444
super().__init__(path, bpe, train, valid, '',
4545
n_tokens=bpe.n_tokens,

0 commit comments

Comments
 (0)