|
6 | 6 | "metadata": {}, |
7 | 7 | "outputs": [], |
8 | 8 | "source": [ |
9 | | - "from labml import experiment\n", |
10 | | - "from python_autocomplete.train import Configs\n", |
11 | | - "from python_autocomplete.evaluate import evaluate, anomalies, complete, Predictor\n", |
| 9 | + "import string\n", |
| 10 | + "\n", |
12 | 11 | "import torch\n", |
13 | | - "import torch.nn\n", |
| 12 | + "from torch import nn\n", |
| 13 | + "\n", |
14 | 14 | "from labml import experiment, logger, lab\n", |
15 | 15 | "from labml_helpers.module import Module\n", |
16 | 16 | "from labml.logger import Text, Style\n", |
17 | 17 | "from labml.utils.pytorch import get_modules\n", |
18 | | - "from torch import nn\n", |
19 | | - "import string\n", |
| 18 | + "from labml.utils.cache import cache\n", |
| 19 | + "from labml_helpers.datasets.text import TextDataset\n", |
20 | 20 | "\n", |
21 | | - "from labml_helpers.datasets.text import TextDataset" |
| 21 | + "from python_autocomplete.train import Configs\n", |
| 22 | + "from python_autocomplete.evaluate import evaluate, anomalies, complete, Predictor" |
22 | 23 | ] |
23 | 24 | }, |
24 | 25 | { |
|
87 | 88 | "metadata": {}, |
88 | 89 | "outputs": [], |
89 | 90 | "source": [ |
90 | | - "conf_dict['device.cuda_device'] = 1" |
| 91 | + "conf_dict['device.cuda_device'] = 1\n", |
| 92 | + "# conf_dict['device.use_cuda'] = False" |
91 | 93 | ] |
92 | 94 | }, |
93 | 95 | { |
|
121 | 123 | "data": { |
122 | 124 | "text/html": [ |
123 | 125 | "<pre style=\"overflow-x: scroll;\">Prepare model...\n", |
124 | | - " Prepare n_tokens...\n", |
125 | | - " Prepare text...\n", |
126 | | - " Prepare tokenizer<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t3.16ms</span>\n", |
127 | | - " Load data<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t162.36ms</span>\n", |
128 | | - " Tokenize<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t3,109.56ms</span>\n", |
129 | | - " Build vocabulary<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t66.50ms</span>\n", |
130 | | - " Prepare text<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t3,350.38ms</span>\n", |
131 | | - " Prepare n_tokens<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t3,354.96ms</span>\n", |
132 | | - " Prepare transformer<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t1.12ms</span>\n", |
| 126 | + " Prepare n_tokens<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t1.32ms</span>\n", |
| 127 | + " Prepare transformer<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t1.05ms</span>\n", |
133 | 128 | " Prepare encoder...\n", |
134 | 129 | " Prepare encoder_layer...\n", |
135 | | - " Prepare encoder_attn<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t9.52ms</span>\n", |
136 | | - " Prepare feed_forward<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t13.47ms</span>\n", |
137 | | - " Prepare encoder_layer<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t28.71ms</span>\n", |
138 | | - " Prepare encoder<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t47.43ms</span>\n", |
139 | | - " Prepare src_embed<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t4.22ms</span>\n", |
| 130 | + " Prepare encoder_attn<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t12.84ms</span>\n", |
| 131 | + " Prepare feed_forward<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t12.51ms</span>\n", |
| 132 | + " Prepare encoder_layer<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t31.24ms</span>\n", |
| 133 | + " Prepare encoder<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t47.89ms</span>\n", |
| 134 | + " Prepare src_embed<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t5.40ms</span>\n", |
140 | 135 | " Prepare device...\n", |
141 | | - " Prepare device_info<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t57.05ms</span>\n", |
142 | | - " Prepare device<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t58.74ms</span>\n", |
143 | | - "Prepare model<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t4,815.99ms</span>\n", |
| 136 | + " Prepare device_info<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t37.79ms</span>\n", |
| 137 | + " Prepare device<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t39.82ms</span>\n", |
| 138 | + "Prepare model<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t1,351.06ms</span>\n", |
144 | 139 | "</pre>" |
145 | 140 | ], |
146 | 141 | "text/plain": [ |
|
175 | 170 | "data": { |
176 | 171 | "text/html": [ |
177 | 172 | "<pre style=\"overflow-x: scroll;\">Selected <span style=\"color: #60C6C8\">experiment</span> = <strong>source_code</strong> <span style=\"color: #60C6C8\">run</span> = <strong>39b03a1e454011ebbaff2b26e3148b3d</strong> <span style=\"color: #60C6C8\">checkpoint</span> = <strong>351023104</strong>\n", |
178 | | - "Loading checkpoint<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t67.46ms</span>\n", |
| 173 | + "Loading checkpoint<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t47.04ms</span>\n", |
179 | 174 | "\n", |
180 | | - "<strong><span style=\"text-decoration: underline\">Notebook Experiment</span></strong>: <span style=\"color: #208FFB\">0c69f4444ab011ebb618517d0c553d3c</span>\n", |
181 | | - "\t[dirty]: <strong><span style=\"color: #DDB62B\">\"evaluate\"</span></strong>\n", |
| 175 | + "<strong><span style=\"text-decoration: underline\">Notebook Experiment</span></strong>: <span style=\"color: #208FFB\">7f418b0e4b1d11eba0e89704bc602f06</span>\n", |
| 176 | + "\t[dirty]: <strong><span style=\"color: #DDB62B\">\"rename notebook\"</span></strong>\n", |
182 | 177 | "\tloaded from: <span style=\"color: #D160C4\">39b03a1e454011ebbaff2b26e3148b3d</span></pre>" |
183 | 178 | ], |
184 | 179 | "text/plain": [ |
|
191 | 186 | { |
192 | 187 | "data": { |
193 | 188 | "text/plain": [ |
194 | | - "<labml.internal.experiment.watcher.ExperimentWatcher at 0x7fd68e67b370>" |
| 189 | + "<labml.internal.experiment.watcher.ExperimentWatcher at 0x7f655c41a400>" |
195 | 190 | ] |
196 | 191 | }, |
197 | 192 | "execution_count": 9, |
|
209 | 204 | "metadata": {}, |
210 | 205 | "outputs": [], |
211 | 206 | "source": [ |
212 | | - "p = Predictor(conf.model, conf.text.stoi, conf.text.itos)\n", |
| 207 | + "p = Predictor(conf.model, cache('stoi', lambda: conf.text.stoi), cache('itos', lambda: conf.text.itos))\n", |
213 | 208 | "_ = conf.model.eval()" |
214 | 209 | ] |
215 | 210 | }, |
|
235 | 230 | }, |
236 | 231 | { |
237 | 232 | "cell_type": "code", |
238 | | - "execution_count": 16, |
| 233 | + "execution_count": 12, |
239 | 234 | "metadata": {}, |
240 | 235 | "outputs": [ |
241 | 236 | { |
242 | 237 | "name": "stdout", |
243 | 238 | "output_type": "stream", |
244 | 239 | "text": [ |
245 | | - "\"(LSTM\"\n", |
246 | | - "CPU times: user 386 ms, sys: 6.33 ms, total: 392 ms\n", |
247 | | - "Wall time: 41 ms\n" |
| 240 | + "\"\n", |
| 241 | + " super\"\n", |
| 242 | + "CPU times: user 950 ms, sys: 34.7 ms, total: 984 ms\n", |
| 243 | + "Wall time: 254 ms\n" |
248 | 244 | ] |
249 | 245 | } |
250 | 246 | ], |
|
256 | 252 | }, |
257 | 253 | { |
258 | 254 | "cell_type": "code", |
259 | | - "execution_count": 17, |
| 255 | + "execution_count": 13, |
260 | 256 | "metadata": {}, |
261 | 257 | "outputs": [ |
262 | 258 | { |
263 | 259 | "name": "stdout", |
264 | 260 | "output_type": "stream", |
265 | 261 | "text": [ |
266 | | - "\",\"\n" |
| 262 | + "\"(LSTM\"\n" |
267 | 263 | ] |
268 | 264 | } |
269 | 265 | ], |
|
275 | 271 | }, |
276 | 272 | { |
277 | 273 | "cell_type": "code", |
278 | | - "execution_count": 18, |
| 274 | + "execution_count": 14, |
279 | 275 | "metadata": {}, |
280 | 276 | "outputs": [ |
281 | 277 | { |
|
299 | 295 | }, |
300 | 296 | { |
301 | 297 | "cell_type": "code", |
302 | | - "execution_count": 19, |
| 298 | + "execution_count": 15, |
303 | 299 | "metadata": {}, |
304 | 300 | "outputs": [ |
305 | 301 | { |
|
428 | 424 | "name": "stdout", |
429 | 425 | "output_type": "stream", |
430 | 426 | "text": [ |
431 | | - "CPU times: user 2min 18s, sys: 1.17 s, total: 2min 19s\n", |
| 427 | + "CPU times: user 1min 59s, sys: 62.9 ms, total: 1min 59s\n", |
432 | 428 | "Wall time: 1min 23s\n" |
433 | 429 | ] |
434 | 430 | } |
|
440 | 436 | }, |
441 | 437 | { |
442 | 438 | "cell_type": "code", |
443 | | - "execution_count": 20, |
| 439 | + "execution_count": 16, |
444 | 440 | "metadata": {}, |
445 | 441 | "outputs": [ |
446 | 442 | { |
|
569 | 565 | }, |
570 | 566 | { |
571 | 567 | "cell_type": "code", |
572 | | - "execution_count": 21, |
| 568 | + "execution_count": 17, |
573 | 569 | "metadata": {}, |
574 | 570 | "outputs": [ |
575 | 571 | { |
|
0 commit comments