11import string
2- from typing import List , Dict , Set , Optional , Any , Tuple
2+ from typing import Set , Optional , Any , Tuple
33
44import numpy as np
55import torch
99from labml import experiment , logger , lab , monit
1010from labml .logger import Text , Style
1111from labml .utils .pytorch import get_modules
12- from labml_helpers .datasets .text import TextDataset
1312from labml_helpers .module import Module
13+ from python_autocomplete .dataset import Tokenizer
1414from python_autocomplete .train import Configs , StateUpdater
1515
1616
1717class Predictor :
18- def __init__ (self , model : Module , text : TextDataset , * ,
18+ def __init__ (self , model : Module , tokenizer : Tokenizer , * ,
1919 state_updater : StateUpdater ,
2020 is_token_by_token : bool ):
21- text .is_silent = True
22- self .text = text
21+ self .tokenizer = tokenizer
2322 self .is_token_by_token = is_token_by_token
2423 self .state_updater = state_updater
2524 self .model = model
@@ -30,8 +29,9 @@ def __init__(self, model: Module, text: TextDataset, *,
3029 self .time_check = 0
3130
3231 def _get_predictions (self , prompt : str , state : Any ) -> Tuple [torch .Tensor , Any ]:
33- data = self .text .text_to_i (prompt )[- 512 :]
34- data = data .to (self .model .device ).unsqueeze (- 1 )
32+ data = torch .tensor (self .tokenizer .encode (prompt ),
33+ dtype = torch .long ,
34+ device = self .model .device )[- 512 :].unsqueeze (- 1 )
3535
3636 # Get predictions
3737 with torch .no_grad ():
@@ -57,7 +57,7 @@ def get_probabilities(self, prompt: str, state: Any) -> Tuple[np.ndarray, Any]:
5757 def get_next_token (self , prompt : str , state : Any ) -> Tuple [str , Any ]:
5858 prediction , state = self .get_predictions (prompt , state )
5959 best = prediction .argmax (- 1 ).squeeze ().item ()
60- return self .text .itos [best ], state
60+ return self .tokenizer .itos [best ], state
6161
6262 def get_start_state (self , prompt : str ):
6363 assert prompt
@@ -151,10 +151,10 @@ def anomalies(predictor: Predictor, text: str):
151151 logs = [(f"{ line_no : 4d} : " , Text .meta )]
152152 elif c == '\r ' :
153153 continue
154- elif c not in predictor .text .stoi :
154+ elif c not in predictor .tokenizer .stoi :
155155 logs .append (c )
156156 else :
157- next_id = predictor .text .stoi [c ]
157+ next_id = predictor .tokenizer .stoi [c ]
158158 prob = preds [next_id ]
159159 if prob > 0.9 :
160160 logs .append ((c , [Style .bold , Text .success , Style .underline ]))
@@ -219,22 +219,21 @@ def get_predictor():
219219 # And for latest checkpoint
220220 # checkpoint = None
221221
222- run_uuid = '275e62e66dc711eb9d162f2ddfc33452' # bpe
223- # run_uuid = 'c45857026a2811eba16c27c69839e51f' # xl
222+ run_uuid = '109d1b8c6e8611eb80e13584488b68a4' # bpe
224223 checkpoint = None
225- run_uuid , checkpoint = experiment .load_bundle (
226- lab .get_path () / 'saved_checkpoint.tar.gz' ,
227- url = 'https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz' )
224+ # run_uuid, checkpoint = experiment.load_bundle(
225+ # lab.get_path() / 'saved_checkpoint.tar.gz',
226+ # url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz')
228227
229228 conf_dict = experiment .load_configs (run_uuid )
230- conf_dict ['is_load_data' ] = False
229+ conf_dict ['text. is_load_data' ] = False
231230 experiment .configs (conf , conf_dict )
232231 experiment .add_pytorch_models (get_modules (conf ))
233232 experiment .load (run_uuid , checkpoint )
234233
235234 experiment .start ()
236235 conf .model .eval ()
237- return Predictor (conf .model , conf .text ,
236+ return Predictor (conf .model , conf .text . tokenizer ,
238237 state_updater = conf .state_updater ,
239238 is_token_by_token = conf .is_token_by_token )
240239
0 commit comments