Skip to content

Commit 0492112

Browse files
committed
cache stoi/itos
1 parent 32cc6a8 commit 0492112

5 files changed

Lines changed: 51 additions & 48 deletions

File tree

notebooks/evaluate.ipynb

Lines changed: 37 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,20 @@
66
"metadata": {},
77
"outputs": [],
88
"source": [
9-
"from labml import experiment\n",
10-
"from python_autocomplete.train import Configs\n",
11-
"from python_autocomplete.evaluate import evaluate, anomalies, complete, Predictor\n",
9+
"import string\n",
10+
"\n",
1211
"import torch\n",
13-
"import torch.nn\n",
12+
"from torch import nn\n",
13+
"\n",
1414
"from labml import experiment, logger, lab\n",
1515
"from labml_helpers.module import Module\n",
1616
"from labml.logger import Text, Style\n",
1717
"from labml.utils.pytorch import get_modules\n",
18-
"from torch import nn\n",
19-
"import string\n",
18+
"from labml.utils.cache import cache\n",
19+
"from labml_helpers.datasets.text import TextDataset\n",
2020
"\n",
21-
"from labml_helpers.datasets.text import TextDataset"
21+
"from python_autocomplete.train import Configs\n",
22+
"from python_autocomplete.evaluate import evaluate, anomalies, complete, Predictor"
2223
]
2324
},
2425
{
@@ -87,7 +88,8 @@
8788
"metadata": {},
8889
"outputs": [],
8990
"source": [
90-
"conf_dict['device.cuda_device'] = 1"
91+
"conf_dict['device.cuda_device'] = 1\n",
92+
"# conf_dict['device.use_cuda'] = False"
9193
]
9294
},
9395
{
@@ -121,26 +123,19 @@
121123
"data": {
122124
"text/html": [
123125
"<pre style=\"overflow-x: scroll;\">Prepare model...\n",
124-
" Prepare n_tokens...\n",
125-
" Prepare text...\n",
126-
" Prepare tokenizer<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t3.16ms</span>\n",
127-
" Load data<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t162.36ms</span>\n",
128-
" Tokenize<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t3,109.56ms</span>\n",
129-
" Build vocabulary<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t66.50ms</span>\n",
130-
" Prepare text<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t3,350.38ms</span>\n",
131-
" Prepare n_tokens<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t3,354.96ms</span>\n",
132-
" Prepare transformer<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t1.12ms</span>\n",
126+
" Prepare n_tokens<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t1.32ms</span>\n",
127+
" Prepare transformer<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t1.05ms</span>\n",
133128
" Prepare encoder...\n",
134129
" Prepare encoder_layer...\n",
135-
" Prepare encoder_attn<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t9.52ms</span>\n",
136-
" Prepare feed_forward<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t13.47ms</span>\n",
137-
" Prepare encoder_layer<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t28.71ms</span>\n",
138-
" Prepare encoder<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t47.43ms</span>\n",
139-
" Prepare src_embed<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t4.22ms</span>\n",
130+
" Prepare encoder_attn<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t12.84ms</span>\n",
131+
" Prepare feed_forward<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t12.51ms</span>\n",
132+
" Prepare encoder_layer<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t31.24ms</span>\n",
133+
" Prepare encoder<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t47.89ms</span>\n",
134+
" Prepare src_embed<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t5.40ms</span>\n",
140135
" Prepare device...\n",
141-
" Prepare device_info<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t57.05ms</span>\n",
142-
" Prepare device<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t58.74ms</span>\n",
143-
"Prepare model<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t4,815.99ms</span>\n",
136+
" Prepare device_info<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t37.79ms</span>\n",
137+
" Prepare device<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t39.82ms</span>\n",
138+
"Prepare model<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t1,351.06ms</span>\n",
144139
"</pre>"
145140
],
146141
"text/plain": [
@@ -175,10 +170,10 @@
175170
"data": {
176171
"text/html": [
177172
"<pre style=\"overflow-x: scroll;\">Selected <span style=\"color: #60C6C8\">experiment</span> = <strong>source_code</strong> <span style=\"color: #60C6C8\">run</span> = <strong>39b03a1e454011ebbaff2b26e3148b3d</strong> <span style=\"color: #60C6C8\">checkpoint</span> = <strong>351023104</strong>\n",
178-
"Loading checkpoint<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t67.46ms</span>\n",
173+
"Loading checkpoint<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t47.04ms</span>\n",
179174
"\n",
180-
"<strong><span style=\"text-decoration: underline\">Notebook Experiment</span></strong>: <span style=\"color: #208FFB\">0c69f4444ab011ebb618517d0c553d3c</span>\n",
181-
"\t[dirty]: <strong><span style=\"color: #DDB62B\">\"evaluate\"</span></strong>\n",
175+
"<strong><span style=\"text-decoration: underline\">Notebook Experiment</span></strong>: <span style=\"color: #208FFB\">7f418b0e4b1d11eba0e89704bc602f06</span>\n",
176+
"\t[dirty]: <strong><span style=\"color: #DDB62B\">\"rename notebook\"</span></strong>\n",
182177
"\tloaded from: <span style=\"color: #D160C4\">39b03a1e454011ebbaff2b26e3148b3d</span></pre>"
183178
],
184179
"text/plain": [
@@ -191,7 +186,7 @@
191186
{
192187
"data": {
193188
"text/plain": [
194-
"<labml.internal.experiment.watcher.ExperimentWatcher at 0x7fd68e67b370>"
189+
"<labml.internal.experiment.watcher.ExperimentWatcher at 0x7f655c41a400>"
195190
]
196191
},
197192
"execution_count": 9,
@@ -209,7 +204,7 @@
209204
"metadata": {},
210205
"outputs": [],
211206
"source": [
212-
"p = Predictor(conf.model, conf.text.stoi, conf.text.itos)\n",
207+
"p = Predictor(conf.model, cache('stoi', lambda: conf.text.stoi), cache('itos', lambda: conf.text.itos))\n",
213208
"_ = conf.model.eval()"
214209
]
215210
},
@@ -235,16 +230,17 @@
235230
},
236231
{
237232
"cell_type": "code",
238-
"execution_count": 16,
233+
"execution_count": 12,
239234
"metadata": {},
240235
"outputs": [
241236
{
242237
"name": "stdout",
243238
"output_type": "stream",
244239
"text": [
245-
"\"(LSTM\"\n",
246-
"CPU times: user 386 ms, sys: 6.33 ms, total: 392 ms\n",
247-
"Wall time: 41 ms\n"
240+
"\"\n",
241+
" super\"\n",
242+
"CPU times: user 950 ms, sys: 34.7 ms, total: 984 ms\n",
243+
"Wall time: 254 ms\n"
248244
]
249245
}
250246
],
@@ -256,14 +252,14 @@
256252
},
257253
{
258254
"cell_type": "code",
259-
"execution_count": 17,
255+
"execution_count": 13,
260256
"metadata": {},
261257
"outputs": [
262258
{
263259
"name": "stdout",
264260
"output_type": "stream",
265261
"text": [
266-
"\",\"\n"
262+
"\"(LSTM\"\n"
267263
]
268264
}
269265
],
@@ -275,7 +271,7 @@
275271
},
276272
{
277273
"cell_type": "code",
278-
"execution_count": 18,
274+
"execution_count": 14,
279275
"metadata": {},
280276
"outputs": [
281277
{
@@ -299,7 +295,7 @@
299295
},
300296
{
301297
"cell_type": "code",
302-
"execution_count": 19,
298+
"execution_count": 15,
303299
"metadata": {},
304300
"outputs": [
305301
{
@@ -428,7 +424,7 @@
428424
"name": "stdout",
429425
"output_type": "stream",
430426
"text": [
431-
"CPU times: user 2min 18s, sys: 1.17 s, total: 2min 19s\n",
427+
"CPU times: user 1min 59s, sys: 62.9 ms, total: 1min 59s\n",
432428
"Wall time: 1min 23s\n"
433429
]
434430
}
@@ -440,7 +436,7 @@
440436
},
441437
{
442438
"cell_type": "code",
443-
"execution_count": 20,
439+
"execution_count": 16,
444440
"metadata": {},
445441
"outputs": [
446442
{
@@ -569,7 +565,7 @@
569565
},
570566
{
571567
"cell_type": "code",
572-
"execution_count": 21,
568+
"execution_count": 17,
573569
"metadata": {},
574570
"outputs": [
575571
{

python_autocomplete/evaluate.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
import torch
55
import torch.nn
6+
from labml.utils.cache import cache
67
from torch import nn
78

8-
from labml import experiment, logger
9+
from labml import experiment, logger, lab
910
from labml.logger import Text, Style
1011
from labml.utils.pytorch import get_modules
1112
from labml_helpers.module import Module
@@ -191,16 +192,21 @@ def main():
191192
conf = Configs()
192193
experiment.evaluate()
193194

194-
run_uuid = '39b03a1e454011ebbaff2b26e3148b3d'
195195
# Replace this with your training experiment UUID
196+
run_uuid = '39b03a1e454011ebbaff2b26e3148b3d'
197+
196198
conf_dict = experiment.load_configs(run_uuid)
197199
experiment.configs(conf, conf_dict)
198200
experiment.add_pytorch_models(get_modules(conf))
199201
experiment.load(run_uuid)
200202

201203
experiment.start()
202-
predictor = Predictor(conf.model, conf.text.stoi, conf.text.itos)
203-
evaluate(predictor, conf.text.valid[:1000])
204+
predictor = Predictor(conf.model, cache('stoi', lambda: conf.text.stoi), cache('itos', lambda: conf.text.itos))
205+
conf.model.eval()
206+
207+
with open(str(lab.get_data_path() / 'sample.py'), 'r') as f:
208+
sample = f.read()
209+
evaluate(predictor, sample)
204210

205211

206212
if __name__ == '__main__':

python_autocomplete/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ def _loss_func(c: Configs):
133133

134134
@option(Configs.n_tokens)
135135
def _n_tokens(c: Configs):
136-
return c.text.n_tokens
136+
from labml.utils.cache import cache
137+
return cache('n_tokens', lambda: c.text.n_tokens)
137138

138139

139140
@option(Configs.model)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ labml>=0.4.74
22
torch
33
numpy
44
labml-helpers>=0.4.70
5-
labml-nn>=0.4.83
5+
labml-nn>=0.4.86

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
'labml_nn', 'labml_nn.*',
2020
'labml', 'labml.*',
2121
'test', 'test.*')),
22-
install_requires=['labml>=0.4.83',
22+
install_requires=['labml>=0.4.86',
2323
'labml_helpers>=0.4.70',
2424
'labml_nn>=0.4.70'
2525
'torch',

0 commit comments

Comments
 (0)