Skip to content

Commit dacb46a

Browse files
committed
✨ download checkpoint
1 parent a971cfd commit dacb46a

6 files changed

Lines changed: 42 additions & 37 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ labml_helpers
1212
einops
1313
*.egg-info
1414
build
15-
dist
15+
dist
16+
*.tar.gz

notebooks/evaluate.ipynb

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"metadata": {},
2929
"outputs": [],
3030
"source": [
31-
"%%capture\n",
3231
"!pip install labml labml_python_autocomplete"
3332
]
3433
},
@@ -69,7 +68,14 @@
6968
"\n",
7069
"[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=39b03a1e454011ebbaff2b26e3148b3d)\n",
7170
"\n",
72-
"*If you want to try this on Colab you need to run this on the same space where you run the training, because models are saved locally.*"
71+
"If you have a locally trained model load it directly with:\n",
72+
"\n",
73+
"```python\n",
74+
"run_uuid = 'RUN_UUID'\n",
75+
"checkpoint = None # Get latest checkpoint\n",
76+
"```\n",
77+
"\n",
78+
"`load_bundle` will download an archive with a saved checkpoint (pretrained model)."
7379
]
7480
},
7581
{
@@ -78,7 +84,9 @@
7884
"metadata": {},
7985
"outputs": [],
8086
"source": [
81-
"TRAINING_RUN_UUID = '39b03a1e454011ebbaff2b26e3148b3d'"
87+
"run_uuid, checkpoint = experiment.load_bundle(\n",
88+
" lab.get_path() / 'saved_checkpoint.tar.gz',\n",
89+
" url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz')"
8290
]
8391
},
8492
{
@@ -145,7 +153,7 @@
145153
}
146154
],
147155
"source": [
148-
"custom_conf = experiment.load_configs(TRAINING_RUN_UUID)\n",
156+
"custom_conf = experiment.load_configs(run_uuid)\n",
149157
"custom_conf"
150158
]
151159
},
@@ -234,7 +242,7 @@
234242
"metadata": {},
235243
"outputs": [],
236244
"source": [
237-
"experiment.load(TRAINING_RUN_UUID)"
245+
"experiment.load(run_uuid, checkpoint)"
238246
]
239247
},
240248
{
@@ -817,4 +825,4 @@
817825
},
818826
"nbformat": 4,
819827
"nbformat_minor": 4
820-
}
828+
}

python_autocomplete/bundle.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from labml import experiment, lab
2+
3+
if __name__ == '__main__':
4+
experiment.save_bundle(lab.get_path() / 'bundle.tar.gz', '39b03a1e454011ebbaff2b26e3148b3d',
5+
data_files=['cache/itos.json', 'cache/n_tokens.json', 'cache/stoi.json'])

python_autocomplete/evaluate.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,21 +189,33 @@ def complete(predictor: Predictor, text: str, completion: int):
189189
logger.log(logs)
190190

191191

192-
def main():
192+
def get_predictor():
193193
conf = Configs()
194194
experiment.evaluate()
195195

196-
# Replace this with your training experiment UUID
197-
run_uuid = '39b03a1e454011ebbaff2b26e3148b3d'
196+
# This will download a pretrained model checkpoint and some cached files.
197+
# It will download the archive as `saved_checkpoint.tar.gz` and extract it.
198+
#
199+
# If you have a locally trained model load it directly with
200+
# run_uuid = 'RUN_UUID'
201+
# And for latest checkpoint
202+
# checkpoint = None
203+
run_uuid, checkpoint = experiment.load_bundle(
204+
lab.get_path() / 'saved_checkpoint.tar.gz',
205+
url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz')
198206

199207
conf_dict = experiment.load_configs(run_uuid)
200208
experiment.configs(conf, conf_dict)
201209
experiment.add_pytorch_models(get_modules(conf))
202-
experiment.load(run_uuid)
210+
experiment.load(run_uuid, checkpoint)
203211

204212
experiment.start()
205-
predictor = Predictor(conf.model, cache('stoi', lambda: conf.text.stoi), cache('itos', lambda: conf.text.itos))
206213
conf.model.eval()
214+
return Predictor(conf.model, cache('stoi', lambda: conf.text.stoi), cache('itos', lambda: conf.text.itos))
215+
216+
217+
def main():
218+
predictor = get_predictor()
207219

208220
with open(str(lab.get_data_path() / 'sample.py'), 'r') as f:
209221
sample = f.read()

python_autocomplete/serve.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,11 @@
44

55
from flask import Flask, request, jsonify
66

7-
from labml import experiment, monit
8-
from labml.utils.cache import cache
9-
from labml.utils.pytorch import get_modules
10-
from python_autocomplete.evaluate import Predictor
11-
from python_autocomplete.train import Configs
7+
from labml import monit
8+
from python_autocomplete.evaluate import get_predictor
129

1310
TOKEN_CHARS = set(string.ascii_letters + string.digits + ' ' + '\n' + '\r' + '_')
1411

15-
16-
def get_predictor():
17-
conf = Configs()
18-
experiment.evaluate()
19-
20-
# Replace this with your training experiment UUID
21-
run_uuid = '39b03a1e454011ebbaff2b26e3148b3d'
22-
23-
conf_dict = experiment.load_configs(run_uuid)
24-
experiment.configs(conf, conf_dict)
25-
experiment.add_pytorch_models(get_modules(conf))
26-
experiment.load(run_uuid)
27-
28-
experiment.start()
29-
conf.model.eval()
30-
return Predictor(conf.model, cache('stoi', lambda: conf.text.stoi), cache('itos', lambda: conf.text.itos))
31-
32-
3312
app = Flask('python_autocomplete')
3413
predictor = get_predictor()
3514
lock = threading.Lock()

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name='labml_python_autocomplete',
8-
version='0.0.3',
8+
version='0.0.4',
99
author="Varuna Jayasiri",
1010
author_email="vpjayasiri@gmail.com",
1111
description="A simple model that learns to predict Python source code",
@@ -19,7 +19,7 @@
1919
'labml_nn', 'labml_nn.*',
2020
'labml', 'labml.*',
2121
'test', 'test.*')),
22-
install_requires=['labml>=0.4.86',
22+
install_requires=['labml>=0.4.98',
2323
'labml_helpers>=0.4.70',
2424
'labml_nn>=0.4.70'
2525
'torch',

0 commit comments

Comments
 (0)