@@ -189,21 +189,33 @@ def complete(predictor: Predictor, text: str, completion: int):
189189 logger .log (logs )
190190
191191
192- def main ():
192+ def get_predictor ():
193193 conf = Configs ()
194194 experiment .evaluate ()
195195
196- # Replace this with your training experiment UUID
197- run_uuid = '39b03a1e454011ebbaff2b26e3148b3d'
196+ # This will download a pretrained model checkpoint and some cached files.
197+ # It will download the archive as `saved_checkpoint.tar.gz` and extract it.
198+ #
199+ # If you have a locally trained model load it directly with
200+ # run_uuid = 'RUN_UUID'
201+ # And for latest checkpoint
202+ # checkpoint = None
203+ run_uuid , checkpoint = experiment .load_bundle (
204+ lab .get_path () / 'saved_checkpoint.tar.gz' ,
205+ url = 'https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz' )
198206
199207 conf_dict = experiment .load_configs (run_uuid )
200208 experiment .configs (conf , conf_dict )
201209 experiment .add_pytorch_models (get_modules (conf ))
202- experiment .load (run_uuid )
210+ experiment .load (run_uuid , checkpoint )
203211
204212 experiment .start ()
205- predictor = Predictor (conf .model , cache ('stoi' , lambda : conf .text .stoi ), cache ('itos' , lambda : conf .text .itos ))
206213 conf .model .eval ()
214+ return Predictor (conf .model , cache ('stoi' , lambda : conf .text .stoi ), cache ('itos' , lambda : conf .text .itos ))
215+
216+
217+ def main ():
218+ predictor = get_predictor ()
207219
208220 with open (str (lab .get_data_path () / 'sample.py' ), 'r' ) as f :
209221 sample = f .read ()
0 commit comments