1- import string
2- from typing import Set , Optional , Any , Tuple
1+ from heapq import heappush , heappop
2+ from typing import Any , Tuple , List , Optional , NamedTuple
33
4- import numpy as np
54import torch
65import torch .nn
76from torch import nn
109from labml .logger import Text , Style
1110from labml .utils .pytorch import get_modules
1211from labml_helpers .module import Module
13- from python_autocomplete .dataset import Tokenizer
12+ from python_autocomplete .dataset import Tokenizer , ID_CHARS
1413from python_autocomplete .train import Configs , StateUpdater
1514
1615
16+ class PredictionComplete :
17+ def __call__ (self , text , token_str : str ):
18+ raise NotImplementedError
19+
20+
21+ class NextWordPredictionComplete (PredictionComplete ):
22+ def __init__ (self , prompt : str ):
23+ self .is_id = False
24+ if prompt and prompt [- 1 ] in ID_CHARS :
25+ self .is_id = True
26+
27+ def __call__ (self , text , token_str : str ):
28+ prediction = set (token_str )
29+ intersection = prediction .intersection (ID_CHARS )
30+ is_id = intersection and intersection == prediction
31+ is_not_id = intersection != prediction
32+ if is_id and is_not_id :
33+ return True
34+ return is_id == self .is_id
35+
36+
37+ class BeamSearch :
38+ def __init__ (self , beam_size : int , prediction_complete : PredictionComplete ,
39+ max_beam_size : int , rest : str ,
40+ state_updater : 'StateUpdater' ,
41+ probs : Optional [List [float ]],
42+ is_token_by_token : bool ):
43+ self .is_token_by_token = is_token_by_token
44+ self .state_updater = state_updater
45+ self .prediction_complete = prediction_complete
46+ self .max_beam_size = max_beam_size
47+ self .rest = rest
48+
49+ if probs is None :
50+ probs = [1 / beam_size ] * beam_size
51+ assert len (probs ) == beam_size
52+ self .probs = probs
53+
54+ self .result_heap = []
55+ self .text = ['' ] * beam_size
56+ self .beam_heap = []
57+
58+ @staticmethod
59+ def is_substr (original , token_str ):
60+ if not original :
61+ return True
62+
63+ n = min (len (original ), len (token_str ))
64+ return original [:n ] == token_str [:n ]
65+
66+ def add_prediction (self , prob : float , beam_idx : int , token_str : str , state ):
67+ if len (self .result_heap ) == self .max_beam_size :
68+ if self .result_heap [0 ][0 ] > prob :
69+ return
70+ heappop (self .result_heap )
71+
72+ state = self .state_updater .get_from_batch (state , beam_idx )
73+ text = self .text [beam_idx ] + token_str
74+ heappush (self .result_heap , (prob , (text , state )))
75+
76+ def add_beam (self , prob : float , beam_idx : int , token : int ):
77+ if self .result_heap and self .result_heap [0 ][0 ] > prob :
78+ return
79+
80+ if len (self .beam_heap ) == self .max_beam_size :
81+ if self .beam_heap [0 ][0 ] > prob :
82+ return
83+ heappop (self .beam_heap )
84+
85+ heappush (self .beam_heap , (prob , (beam_idx , token )))
86+
87+ def next_batch (self , prompt : torch .Tensor , state : Any , itos : List [str ]):
88+ if not self .beam_heap :
89+ return None , None
90+
91+ new_prompt = []
92+ new_state = []
93+
94+ texts = self .text
95+ self .text = []
96+ self .probs = []
97+
98+ for prob , (b , token ) in self .beam_heap :
99+ token = prompt .new_tensor ([token ])
100+ if self .is_token_by_token :
101+ new_prompt .append (token )
102+ else :
103+ new_prompt .append (torch .cat ((prompt [1 :, b ], token )))
104+ new_state .append (self .state_updater .get_from_batch (state , b ))
105+ self .probs .append (prob )
106+ self .text .append (texts [b ] + itos [token ])
107+
108+ new_prompt = torch .stack (new_prompt , dim = 1 )
109+ new_state = self .state_updater .make_batch (new_state )
110+
111+ self .beam_heap = []
112+
113+ return new_prompt , new_state
114+
115+ def update (self , next_token , itos : List [str ], state ):
116+ self .beam_heap = []
117+
118+ for b , text in enumerate (self .text ):
119+ text = self .text [b ]
120+ if len (text ) >= len (self .rest ):
121+ check_rest = None
122+ else :
123+ check_rest = self .rest [len (text ):]
124+
125+ for token , token_str in enumerate (itos ):
126+ if not self .is_substr (check_rest , token_str ):
127+ continue
128+
129+ if self .prediction_complete (text , token_str ):
130+ self .add_prediction (self .probs [b ] * next_token [b ][token ].item (), b , token_str , state )
131+ self .add_beam (self .probs [b ] * next_token [b ][token ].item (), b , token )
132+
133+
134+ class Prediction (NamedTuple ):
135+ prob : float
136+ text : str
137+ state : Any
138+
139+
17140class Predictor :
18141 def __init__ (self , model : Module , tokenizer : Tokenizer , * ,
19142 state_updater : StateUpdater ,
@@ -28,65 +151,41 @@ def __init__(self, model: Module, tokenizer: Tokenizer, *,
28151 self .time_predict = 0
29152 self .time_check = 0
30153
31- def _get_predictions (self , prompt : str , state : Any ) -> Tuple [torch .Tensor , Any ]:
32- data = torch . tensor ( self . tokenizer . encode ( prompt ),
33- dtype = torch . long ,
34- device = self .model .device )[ - 512 :]. unsqueeze ( - 1 )
154+ def _get_predictions (self , prompt : torch . Tensor , state : Any ) -> Tuple [torch .Tensor , Any ]:
155+ if prompt . shape [ 0 ] == 0 :
156+ return prompt . new_ones ( prompt . shape [ 1 ], len ( self . tokenizer . itos )) / len ( self . tokenizer . itos ), state
157+ prompt = prompt . to ( self .model .device )
35158
36159 # Get predictions
37160 with torch .no_grad ():
38- prediction , new_state = self .model (data , state )
161+ prediction , new_state = self .model (prompt , state )
39162
40163 state = self .state_updater (state , new_state )
164+ prediction = nn .Softmax (- 1 )(prediction [- 1 ])
41165
42166 # Final prediction
43- return prediction [ - 1 , :, :] , state
167+ return prediction , state
44168
45- def get_predictions (self , prompt : str , state : Any ) -> Tuple [np .ndarray , Any ]:
46- prediction , state = self ._get_predictions (prompt , state )
169+ def get_next_word (self , prompt : torch .Tensor , state : Any , rest : str , probs : List [float ],
170+ prediction_complete : PredictionComplete ,
171+ max_beam_size : int ) -> \
172+ List [Prediction ]:
173+ beam = BeamSearch (prompt .shape [1 ], prediction_complete , max_beam_size , rest , self .state_updater ,
174+ probs , self .is_token_by_token )
47175
48- return prediction .detach ().cpu ().numpy (), state
176+ for _ in range (10 ):
177+ next_token , state = self ._get_predictions (prompt , state )
178+ beam .update (next_token , self .tokenizer .itos , state )
179+ prompt , state = beam .next_batch (prompt , state , self .tokenizer .itos )
49180
50- def get_probabilities (self , prompt : str , state : Any ) -> Tuple [np .ndarray , Any ]:
51- # Final prediction
52- prediction , state = self ._get_predictions (prompt , state )
53- prediction = nn .Softmax (- 1 )(prediction )
54-
55- return prediction .detach ().cpu ().numpy (), state
56-
57- def get_next_token (self , prompt : str , state : Any ) -> Tuple [str , Any ]:
58- prediction , state = self .get_predictions (prompt , state )
59- best = prediction .argmax (- 1 ).squeeze ().item ()
60- return self .tokenizer .itos [best ], state
61-
62- def get_start_state (self , prompt : str ):
63- assert prompt
64-
65- if len (prompt ) == 1 :
66- return prompt , None
67- if not self .is_token_by_token :
68- return prompt , None
69-
70- _ , state = self .get_next_token (prompt [:- 1 ], None )
71- return prompt [- 1 ], state
72-
73- def get_next_word (self , prompt : str , token_chars : Optional [Set [str ]], state : Any ) -> Tuple [str , Any ]:
74- result = ''
75- if token_chars is None :
76- token_chars = set (string .ascii_letters + string .digits + ' ' + '\n ' + '\r ' )
77- while True :
78- next_token , state = self .get_next_token (prompt , state )
79- if len (result ) > 2 and next_token not in token_chars or (next_token .strip () == '' and result .strip () != '' ):
80- if not result :
81- result += next_token
82- return result , state
83- result += next_token
84- if len (result ) > 20 :
85- return result , state
86- if self .is_token_by_token :
87- prompt = next_token
88- else :
89- prompt += next_token
181+ if prompt is None :
182+ break
183+
184+ results = [Prediction (r [0 ], r [1 ][0 ], r [1 ][1 ]) for r in beam .result_heap ]
185+ return results
186+
187+ def rstrip (self , prompt : str ) -> Tuple [str , List [int ]]:
188+ return self .tokenizer .rstrip (prompt )
90189
91190
92191def evaluate (predictor : Predictor , text : str ):
@@ -95,12 +194,23 @@ def evaluate(predictor: Predictor, text: str):
95194
96195 correct = 0
97196 i = 0
98- right = False
99197 key_strokes = 0
100198
101199 while i + 1 < len (text ):
102- next_token , state = predictor .get_next_word (text [:i + 1 ], None , None )
103- if next_token == text [i + 1 : i + 1 + len (next_token )]:
200+ prefix = text [:i + 1 ]
201+ stripped , prompt = predictor .rstrip (prefix )
202+ rest = prefix [len (stripped ):]
203+ prediction_complete = NextWordPredictionComplete (stripped )
204+ prompt = torch .tensor (prompt , dtype = torch .long ).unsqueeze (- 1 )
205+
206+ predictions = predictor .get_next_word (prompt , None , rest , [1. ], prediction_complete , 5 )
207+ predictions .sort (key = lambda x : - x [0 ])
208+ if predictions :
209+ next_token = predictions [0 ].text [len (rest ):]
210+ else :
211+ next_token = ''
212+
213+ if next_token and next_token == text [i + 1 : i + 1 + len (next_token )]:
104214 correct += len (next_token )
105215 right = True
106216 else :
@@ -141,7 +251,7 @@ def anomalies(predictor: Predictor, text: str):
141251
142252 while i + 1 < len (text ):
143253 # print(i, self.predictor.prompt)
144- preds , _ = predictor .get_probabilities (text [:i + 1 ], None )
254+ preds , _ = predictor .get_predictions (text [:i + 1 ], None , calc_probs = True )
145255 preds = preds [0 , :]
146256 c = text [i + 1 ]
147257
@@ -207,7 +317,7 @@ def complete(predictor: Predictor, text: str, completion: int):
207317 logger .log (logs )
208318
209319
210- def get_predictor ():
320+ def get_predictor () -> Predictor :
211321 conf = Configs ()
212322 experiment .evaluate ()
213323
0 commit comments