11import string
2- from typing import List , Dict , Set , Optional
2+ from typing import List , Dict , Set , Optional , Any , Tuple
33
4+ import numpy as np
45import torch
56import torch .nn
6- from labml .utils .cache import cache
77from torch import nn
88
99from labml import experiment , logger , lab
1010from labml .logger import Text , Style
1111from labml .utils .pytorch import get_modules
1212from labml_helpers .module import Module
13- from python_autocomplete .train import Configs
13+ from python_autocomplete .train import Configs , StateUpdater
1414
1515
1616class Predictor :
17- def __init__ (self , model : Module , stoi : Dict [str , int ], itos : List [str ]):
17+ def __init__ (self , model : Module , stoi : Dict [str , int ], itos : List [str ], * ,
18+ state_updater : StateUpdater ,
19+ is_token_by_token : bool ):
20+ self .is_token_by_token = is_token_by_token
21+ self .state_updater = state_updater
1822 self .stoi = stoi
1923 self .itos = itos
2024 self .model = model
2125
22- # Initial state
23- self ._state = None
24-
2526 # For timing
2627 self .time_add = 0
2728 self .time_predict = 0
2829 self .time_check = 0
2930
30- def _get_predictions (self , prompt : str ) -> torch .Tensor :
31+ def _get_predictions (self , prompt : str , state : Any ) -> Tuple [ torch .Tensor , Any ] :
3132 prompt = prompt [- 512 :]
3233 data = torch .tensor ([[self .stoi [c ]] for c in prompt if c in self .stoi ],
3334 dtype = torch .long ,
3435 device = self .model .device )
3536
3637 # Get predictions
3738 with torch .no_grad ():
38- prediction , * _ = self .model (data )
39+ prediction , new_state = self .model (data , state )
40+
41+ state = self .state_updater (state , new_state )
3942
4043 # Final prediction
41- return prediction [- 1 , :, :]
44+ return prediction [- 1 , :, :], state
4245
43- def get_predictions (self , prompt : str ) -> torch . Tensor :
44- prediction = self ._get_predictions (prompt )
46+ def get_predictions (self , prompt : str , state : Any ) -> Tuple [ np . ndarray , Any ] :
47+ prediction , state = self ._get_predictions (prompt , state )
4548
46- return prediction .detach ().cpu ().numpy ()
49+ return prediction .detach ().cpu ().numpy (), state
4750
48- def get_probabilities (self , prompt : str ) -> torch . Tensor :
51+ def get_probabilities (self , prompt : str , state : Any ) -> Tuple [ np . ndarray , Any ] :
4952 # Final prediction
50- prediction = nn .Softmax (- 1 )(self ._get_predictions (prompt ))
53+ prediction , state = self ._get_predictions (prompt , state )
54+ prediction = nn .Softmax (- 1 )(prediction )
5155
52- return prediction .detach ().cpu ().numpy ()
56+ return prediction .detach ().cpu ().numpy (), state
5357
54- def get_next_char (self , prompt : str ) -> str :
55- prediction = self .get_predictions (prompt )
58+ def get_next_char (self , prompt : str , state : Any ) -> Tuple [ str , Any ] :
59+ prediction , state = self .get_predictions (prompt , state )
5660 best = prediction .argmax (- 1 ).squeeze ().item ()
57- return self .itos [best ]
61+ return self .itos [best ], state
5862
59- def get_token (self , prompt : str , token_chars : Optional [Set [str ]] = None ) -> str :
63+ def get_token (self , prompt : str , token_chars : Optional [Set [str ]], state : Any ) -> Tuple [ str , Any ] :
6064 result = ''
6165 if token_chars is None :
6266 token_chars = set (string .ascii_letters + string .digits + ' ' + '\n ' + '\r ' )
6367 while True :
64- next_char = self .get_next_char (prompt )
68+ next_char , state = self .get_next_char (prompt , state )
6569 if len (result ) > 2 and next_char not in token_chars or (next_char .strip () == '' and result .strip () != '' ):
6670 if not result :
6771 result += next_char
68- return result
72+ return result , state
6973 result += next_char
7074 if len (result ) > 20 :
71- return result
75+ return result , state
7276 prompt += next_char
77+ if self .is_token_by_token :
78+ prompt = prompt [- 1 :]
7379
7480
7581def evaluate (predictor : Predictor , text : str ):
@@ -82,7 +88,7 @@ def evaluate(predictor: Predictor, text: str):
8288 key_strokes = 0
8389
8490 while i + 1 < len (text ):
85- next_token = predictor .get_token (text [:i + 1 ])
91+ next_token , state = predictor .get_token (text [:i + 1 ], None , None )
8692 if next_token == text [i + 1 : i + 1 + len (next_token )]:
8793 correct += len (next_token )
8894 right = True
@@ -124,7 +130,8 @@ def anomalies(predictor: Predictor, text: str):
124130
125131 while i + 1 < len (text ):
126132 # print(i, self.predictor.prompt)
127- preds = predictor .get_probabilities (text [:i + 1 ])[0 , :]
133+ preds , _ = predictor .get_probabilities (text [:i + 1 ], None )
134+ preds = preds [0 , :]
128135 c = text [i + 1 ]
129136
130137 if c == '\n ' :
@@ -169,7 +176,7 @@ def complete(predictor: Predictor, text: str, completion: int):
169176 if len (text ) > i + 1 :
170177 c = text [i + 1 ]
171178 else :
172- c = predictor .get_next_char (text [:i + 1 ])
179+ c , _ = predictor .get_next_char (text [:i + 1 ], None )
173180
174181 if c == '\n ' :
175182 logger .log (logs )
@@ -200,9 +207,12 @@ def get_predictor():
200207 # run_uuid = 'RUN_UUID'
201208 # And for latest checkpoint
202209 # checkpoint = None
203- run_uuid , checkpoint = experiment .load_bundle (
204- lab .get_path () / 'saved_checkpoint.tar.gz' ,
205- url = 'https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz' )
210+
211+ run_uuid = 'c45857026a2811eba16c27c69839e51f'
212+ checkpoint = None
213+ # run_uuid, checkpoint = experiment.load_bundle(
214+ # lab.get_path() / 'saved_checkpoint.tar.gz',
215+ # url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz')
206216
207217 conf_dict = experiment .load_configs (run_uuid )
208218 experiment .configs (conf , conf_dict )
@@ -211,7 +221,9 @@ def get_predictor():
211221
212222 experiment .start ()
213223 conf .model .eval ()
214- return Predictor (conf .model , conf .stoi , conf .itos )
224+ return Predictor (conf .model , conf .stoi , conf .itos ,
225+ state_updater = conf .state_updater ,
226+ is_token_by_token = conf .is_token_by_token )
215227
216228
217229def main ():
0 commit comments