Skip to content

Commit 14f7ff1

Browse files
committed
beam search
1 parent c20788c commit 14f7ff1

5 files changed

Lines changed: 237 additions & 85 deletions

File tree

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import string
2+
from typing import Dict, List, Tuple
3+
4+
ID_CHARS = set(string.ascii_letters + string.digits + '_')
5+
6+
7+
class Tokenizer:
8+
n_tokens: int
9+
itos: List[str]
10+
stoi: Dict[str, int]
11+
is_trained: int
12+
13+
def encode(self, data: str, *, is_silent: bool = True):
14+
raise NotImplementedError
15+
16+
def train(self, data: str):
17+
pass
18+
19+
def rstrip(self, data: str) -> Tuple[str, List[int]]:
20+
return data, self.encode(data)
Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
1-
import string
21
from functools import lru_cache
32
from heapq import heappush, heappop
43
from typing import List, Tuple
54

65
from labml import lab, monit
76
from labml.utils.cache import cache_set
7+
from python_autocomplete.dataset import Tokenizer, ID_CHARS
88

9-
ID_CHARS = set(string.ascii_letters + string.digits + '_')
109

11-
12-
class BPE:
13-
def __init__(self, bpe_en_de: 'BPEEnDe', tokenizer):
10+
class BPE(Tokenizer):
11+
def __init__(self, bpe_en_de: 'BPEEnDe', word_tokenizer):
1412
self.bpe = bpe_en_de
15-
self.tokenizer = tokenizer
13+
self.word_tokenizer = word_tokenizer
1614
self.is_trained = True
1715

1816
@property
@@ -28,7 +26,7 @@ def stoi(self):
2826
return self.bpe.bpe_stoi
2927

3028
def encode(self, data: str, *, is_silent: bool = True):
31-
words = self.tokenizer.tokenize(data, is_silent=is_silent)
29+
words = self.word_tokenizer.tokenize(data, is_silent=is_silent)
3230

3331
res = []
3432
for w in monit.iterate('Encode words', words, is_silent=is_silent):
@@ -40,6 +38,15 @@ def __call__(self, data: str):
4038
encoded = self.encode(data)
4139
return [self.itos[c] for c in encoded]
4240

41+
def rstrip(self, data: str):
42+
words = self.word_tokenizer.tokenize(data, is_silent=True)
43+
words = words[:-1]
44+
res = []
45+
for w in words:
46+
res += self.bpe.encode(w)
47+
48+
return ''.join(words), res
49+
4350

4451
class _BPEEncoder:
4552
def __init__(self, pairs):
Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,13 @@
11
from pathlib import PurePath
2-
from typing import Dict, List
32

43
import torch
54
from torch.utils.data import Dataset, DataLoader
65

76
from labml import lab, monit
87
from labml.configs import option, BaseConfigs
98
from labml_helpers.datasets.text import TextDataset
10-
from python_autocomplete.bpe import BPE, SourceCodeTokenizer
11-
12-
13-
class Tokenizer:
14-
n_tokens: int
15-
itos: List[str]
16-
stoi: Dict[str, int]
17-
is_trained: int
18-
19-
def encode(self, data: str, *, is_silent: bool = True):
20-
raise NotImplementedError
21-
22-
def train(self, data: str):
23-
pass
9+
from python_autocomplete.dataset import Tokenizer
10+
from python_autocomplete.dataset.bpe import BPE, SourceCodeTokenizer
2411

2512

2613
class CharacterTokenizer(Tokenizer):
@@ -36,7 +23,7 @@ def __init__(self, retrain: bool):
3623
self.is_trained = not retrain
3724

3825
def encode(self, data: str, *, is_silent: bool = True):
39-
return torch.tensor([self.stoi[c] for c in data if c in self.stoi], dtype=torch.long)
26+
return [self.stoi[c] for c in data if c in self.stoi]
4027

4128
def train(self, data: str):
4229
with monit.section("Build vocabulary"):
@@ -105,7 +92,7 @@ def _dataset(c: SourceCodeDataConfigs):
10592
@option(SourceCodeDataConfigs.tokenizer, 'bpe')
10693
def _bpe_tokenizer():
10794
from labml.utils.cache import cache_get
108-
from python_autocomplete.bpe import BPEEnDe
95+
from python_autocomplete.dataset.bpe import BPEEnDe
10996
bpe_cache = cache_get('bpe')
11097

11198
if bpe_cache:

python_autocomplete/evaluate.py

Lines changed: 168 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
import string
2-
from typing import Set, Optional, Any, Tuple
1+
from heapq import heappush, heappop
2+
from typing import Any, Tuple, List, Optional, NamedTuple
33

4-
import numpy as np
54
import torch
65
import torch.nn
76
from torch import nn
@@ -10,10 +9,134 @@
109
from labml.logger import Text, Style
1110
from labml.utils.pytorch import get_modules
1211
from labml_helpers.module import Module
13-
from python_autocomplete.dataset import Tokenizer
12+
from python_autocomplete.dataset import Tokenizer, ID_CHARS
1413
from python_autocomplete.train import Configs, StateUpdater
1514

1615

16+
class PredictionComplete:
17+
def __call__(self, text, token_str: str):
18+
raise NotImplementedError
19+
20+
21+
class NextWordPredictionComplete(PredictionComplete):
22+
def __init__(self, prompt: str):
23+
self.is_id = False
24+
if prompt and prompt[-1] in ID_CHARS:
25+
self.is_id = True
26+
27+
def __call__(self, text, token_str: str):
28+
prediction = set(token_str)
29+
intersection = prediction.intersection(ID_CHARS)
30+
is_id = intersection and intersection == prediction
31+
is_not_id = intersection != prediction
32+
if is_id and is_not_id:
33+
return True
34+
return is_id == self.is_id
35+
36+
37+
class BeamSearch:
38+
def __init__(self, beam_size: int, prediction_complete: PredictionComplete,
39+
max_beam_size: int, rest: str,
40+
state_updater: 'StateUpdater',
41+
probs: Optional[List[float]],
42+
is_token_by_token: bool):
43+
self.is_token_by_token = is_token_by_token
44+
self.state_updater = state_updater
45+
self.prediction_complete = prediction_complete
46+
self.max_beam_size = max_beam_size
47+
self.rest = rest
48+
49+
if probs is None:
50+
probs = [1 / beam_size] * beam_size
51+
assert len(probs) == beam_size
52+
self.probs = probs
53+
54+
self.result_heap = []
55+
self.text = [''] * beam_size
56+
self.beam_heap = []
57+
58+
@staticmethod
59+
def is_substr(original, token_str):
60+
if not original:
61+
return True
62+
63+
n = min(len(original), len(token_str))
64+
return original[:n] == token_str[:n]
65+
66+
def add_prediction(self, prob: float, beam_idx: int, token_str: str, state):
67+
if len(self.result_heap) == self.max_beam_size:
68+
if self.result_heap[0][0] > prob:
69+
return
70+
heappop(self.result_heap)
71+
72+
state = self.state_updater.get_from_batch(state, beam_idx)
73+
text = self.text[beam_idx] + token_str
74+
heappush(self.result_heap, (prob, (text, state)))
75+
76+
def add_beam(self, prob: float, beam_idx: int, token: int):
77+
if self.result_heap and self.result_heap[0][0] > prob:
78+
return
79+
80+
if len(self.beam_heap) == self.max_beam_size:
81+
if self.beam_heap[0][0] > prob:
82+
return
83+
heappop(self.beam_heap)
84+
85+
heappush(self.beam_heap, (prob, (beam_idx, token)))
86+
87+
def next_batch(self, prompt: torch.Tensor, state: Any, itos: List[str]):
88+
if not self.beam_heap:
89+
return None, None
90+
91+
new_prompt = []
92+
new_state = []
93+
94+
texts = self.text
95+
self.text = []
96+
self.probs = []
97+
98+
for prob, (b, token) in self.beam_heap:
99+
token = prompt.new_tensor([token])
100+
if self.is_token_by_token:
101+
new_prompt.append(token)
102+
else:
103+
new_prompt.append(torch.cat((prompt[1:, b], token)))
104+
new_state.append(self.state_updater.get_from_batch(state, b))
105+
self.probs.append(prob)
106+
self.text.append(texts[b] + itos[token])
107+
108+
new_prompt = torch.stack(new_prompt, dim=1)
109+
new_state = self.state_updater.make_batch(new_state)
110+
111+
self.beam_heap = []
112+
113+
return new_prompt, new_state
114+
115+
def update(self, next_token, itos: List[str], state):
116+
self.beam_heap = []
117+
118+
for b, text in enumerate(self.text):
119+
text = self.text[b]
120+
if len(text) >= len(self.rest):
121+
check_rest = None
122+
else:
123+
check_rest = self.rest[len(text):]
124+
125+
for token, token_str in enumerate(itos):
126+
if not self.is_substr(check_rest, token_str):
127+
continue
128+
129+
if self.prediction_complete(text, token_str):
130+
self.add_prediction(self.probs[b] * next_token[b][token].item(), b, token_str, state)
131+
self.add_beam(self.probs[b] * next_token[b][token].item(), b, token)
132+
133+
134+
class Prediction(NamedTuple):
135+
prob: float
136+
text: str
137+
state: Any
138+
139+
17140
class Predictor:
18141
def __init__(self, model: Module, tokenizer: Tokenizer, *,
19142
state_updater: StateUpdater,
@@ -28,65 +151,41 @@ def __init__(self, model: Module, tokenizer: Tokenizer, *,
28151
self.time_predict = 0
29152
self.time_check = 0
30153

31-
def _get_predictions(self, prompt: str, state: Any) -> Tuple[torch.Tensor, Any]:
32-
data = torch.tensor(self.tokenizer.encode(prompt),
33-
dtype=torch.long,
34-
device=self.model.device)[-512:].unsqueeze(-1)
154+
def _get_predictions(self, prompt: torch.Tensor, state: Any) -> Tuple[torch.Tensor, Any]:
155+
if prompt.shape[0] == 0:
156+
return prompt.new_ones(prompt.shape[1], len(self.tokenizer.itos)) / len(self.tokenizer.itos), state
157+
prompt = prompt.to(self.model.device)
35158

36159
# Get predictions
37160
with torch.no_grad():
38-
prediction, new_state = self.model(data, state)
161+
prediction, new_state = self.model(prompt, state)
39162

40163
state = self.state_updater(state, new_state)
164+
prediction = nn.Softmax(-1)(prediction[-1])
41165

42166
# Final prediction
43-
return prediction[-1, :, :], state
167+
return prediction, state
44168

45-
def get_predictions(self, prompt: str, state: Any) -> Tuple[np.ndarray, Any]:
46-
prediction, state = self._get_predictions(prompt, state)
169+
def get_next_word(self, prompt: torch.Tensor, state: Any, rest: str, probs: List[float],
170+
prediction_complete: PredictionComplete,
171+
max_beam_size: int) -> \
172+
List[Prediction]:
173+
beam = BeamSearch(prompt.shape[1], prediction_complete, max_beam_size, rest, self.state_updater,
174+
probs, self.is_token_by_token)
47175

48-
return prediction.detach().cpu().numpy(), state
176+
for _ in range(10):
177+
next_token, state = self._get_predictions(prompt, state)
178+
beam.update(next_token, self.tokenizer.itos, state)
179+
prompt, state = beam.next_batch(prompt, state, self.tokenizer.itos)
49180

50-
def get_probabilities(self, prompt: str, state: Any) -> Tuple[np.ndarray, Any]:
51-
# Final prediction
52-
prediction, state = self._get_predictions(prompt, state)
53-
prediction = nn.Softmax(-1)(prediction)
54-
55-
return prediction.detach().cpu().numpy(), state
56-
57-
def get_next_token(self, prompt: str, state: Any) -> Tuple[str, Any]:
58-
prediction, state = self.get_predictions(prompt, state)
59-
best = prediction.argmax(-1).squeeze().item()
60-
return self.tokenizer.itos[best], state
61-
62-
def get_start_state(self, prompt: str):
63-
assert prompt
64-
65-
if len(prompt) == 1:
66-
return prompt, None
67-
if not self.is_token_by_token:
68-
return prompt, None
69-
70-
_, state = self.get_next_token(prompt[:-1], None)
71-
return prompt[-1], state
72-
73-
def get_next_word(self, prompt: str, token_chars: Optional[Set[str]], state: Any) -> Tuple[str, Any]:
74-
result = ''
75-
if token_chars is None:
76-
token_chars = set(string.ascii_letters + string.digits + ' ' + '\n' + '\r')
77-
while True:
78-
next_token, state = self.get_next_token(prompt, state)
79-
if len(result) > 2 and next_token not in token_chars or (next_token.strip() == '' and result.strip() != ''):
80-
if not result:
81-
result += next_token
82-
return result, state
83-
result += next_token
84-
if len(result) > 20:
85-
return result, state
86-
if self.is_token_by_token:
87-
prompt = next_token
88-
else:
89-
prompt += next_token
181+
if prompt is None:
182+
break
183+
184+
results = [Prediction(r[0], r[1][0], r[1][1]) for r in beam.result_heap]
185+
return results
186+
187+
def rstrip(self, prompt: str) -> Tuple[str, List[int]]:
188+
return self.tokenizer.rstrip(prompt)
90189

91190

92191
def evaluate(predictor: Predictor, text: str):
@@ -95,12 +194,23 @@ def evaluate(predictor: Predictor, text: str):
95194

96195
correct = 0
97196
i = 0
98-
right = False
99197
key_strokes = 0
100198

101199
while i + 1 < len(text):
102-
next_token, state = predictor.get_next_word(text[:i + 1], None, None)
103-
if next_token == text[i + 1: i + 1 + len(next_token)]:
200+
prefix = text[:i + 1]
201+
stripped, prompt = predictor.rstrip(prefix)
202+
rest = prefix[len(stripped):]
203+
prediction_complete = NextWordPredictionComplete(stripped)
204+
prompt = torch.tensor(prompt, dtype=torch.long).unsqueeze(-1)
205+
206+
predictions = predictor.get_next_word(prompt, None, rest, [1.], prediction_complete, 5)
207+
predictions.sort(key=lambda x: -x[0])
208+
if predictions:
209+
next_token = predictions[0].text[len(rest):]
210+
else:
211+
next_token = ''
212+
213+
if next_token and next_token == text[i + 1: i + 1 + len(next_token)]:
104214
correct += len(next_token)
105215
right = True
106216
else:
@@ -141,7 +251,7 @@ def anomalies(predictor: Predictor, text: str):
141251

142252
while i + 1 < len(text):
143253
# print(i, self.predictor.prompt)
144-
preds, _ = predictor.get_probabilities(text[:i + 1], None)
254+
preds, _ = predictor.get_predictions(text[:i + 1], None, calc_probs=True)
145255
preds = preds[0, :]
146256
c = text[i + 1]
147257

@@ -207,7 +317,7 @@ def complete(predictor: Predictor, text: str, completion: int):
207317
logger.log(logs)
208318

209319

210-
def get_predictor():
320+
def get_predictor() -> Predictor:
211321
conf = Configs()
212322
experiment.evaluate()
213323

0 commit comments

Comments
 (0)