Skip to content

Commit b954cc2

Browse files
committed
🐛 trainer
1 parent c1cee50 commit b954cc2

6 files changed

Lines changed: 79 additions & 62 deletions

File tree

.labml.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
check_repo_dirty: False
2-
web_api: https://api.lab-ml.com/api/v1/track?labml_token=team-samples

python_autocomplete/create_dataset.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,20 @@
1313
PRINTABLE = set(string.printable)
1414

1515

16-
class _PythonFile(NamedTuple):
16+
class PythonFile(NamedTuple):
1717
relative_path: str
1818
project: str
1919
path: Path
2020

2121

22-
class _GetPythonFiles:
22+
class GetPythonFiles:
2323
"""
2424
Get list of python files and their paths inside `data/source` folder
2525
"""
2626

2727
def __init__(self):
2828
self.source_path = Path(lab.get_data_path() / 'source')
29-
self.files: List[_PythonFile] = []
29+
self.files: List[PythonFile] = []
3030
self.get_python_files(self.source_path)
3131

3232
logger.inspect([f.path for f in self.files])
@@ -36,28 +36,21 @@ def add_file(self, path: Path):
3636
Add a file to the list of tiles
3737
"""
3838
project = path.relative_to(self.source_path).parents
39-
project = project[len(project) - 2]
40-
relative_path = path.relative_to(self.source_path / project)
39+
relative_path = path.relative_to(self.source_path / project[len(project) - 3])
4140

42-
self.files.append(_PythonFile(relative_path=str(relative_path),
43-
project=str(project),
44-
path=path))
41+
self.files.append(PythonFile(relative_path=str(relative_path),
42+
project=str(project[len(project) - 2]),
43+
path=path))
4544

4645
def get_python_files(self, path: Path):
4746
"""
4847
Recursively collect files
4948
"""
5049
for p in path.iterdir():
51-
if p.is_symlink():
52-
p.unlink()
53-
continue
5450
if p.is_dir():
5551
self.get_python_files(p)
5652
else:
57-
if p.suffix == '.py':
58-
self.add_file(p)
59-
else:
60-
p.unlink()
53+
self.add_file(p)
6154

6255

6356
def _read_file(path: Path) -> str:
@@ -72,15 +65,15 @@ def _read_file(path: Path) -> str:
7265
return content
7366

7467

75-
def _load_code(path: PurePath, source_files: List[_PythonFile]):
68+
def _load_code(path: PurePath, source_files: List[PythonFile]):
7669
with open(str(path), 'w') as f:
7770
for i, source in monit.enum(f"Write {path.name}", source_files):
7871
f.write(f"# PROJECT: {source.project} FILE: {str(source.relative_path)}\n")
7972
f.write(_read_file(source.path) + "\n")
8073

8174

8275
def main():
83-
source_files = _GetPythonFiles().files
76+
source_files = GetPythonFiles().files
8477

8578
np.random.shuffle(source_files)
8679

python_autocomplete/evaluate.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def eval(self):
9494
continue
9595
else:
9696
if next_char == self.text[i + 1]:
97-
logs.append((self.text[i + 1], Style.underline))
97+
logs.append((self.text[i + 1], [Text.success, Style.underline]))
9898
else:
9999
logs.append((self.text[i + 1], Text.subtle))
100100

@@ -107,17 +107,17 @@ def eval(self):
107107

108108
def main():
109109
conf = Configs()
110-
experiment.create(name="source_code_eval",
111-
comment='lstm model')
110+
experiment.evaluate()
112111

113112
# Replace this with your training experiment UUID
114-
conf_dict = experiment.load_configs('6f10a292e77211ea89d69979079dc3d6')
115-
experiment.configs(conf, conf_dict, 'run')
113+
conf_dict = experiment.load_configs('8d16abcc3f6211ebb0be67ed81588441')
114+
experiment.configs(conf, conf_dict)
116115
experiment.add_pytorch_models(get_modules(conf))
117-
experiment.load('6f10a292e77211ea89d69979079dc3d6')
116+
experiment.load('8d16abcc3f6211ebb0be67ed81588441')
118117

119118
experiment.start()
120-
evaluator = Evaluator(conf.model, conf.text, conf.text.valid, False)
119+
from python_autocomplete.models.transformer import TransformerModel
120+
evaluator = Evaluator(conf.model, conf.text, conf.text.valid, not isinstance(conf.model, TransformerModel))
121121
evaluator.eval()
122122

123123

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,29 @@
11
import zipfile
22
from pathlib import Path
33

4+
from labml.internal.util import rm_tree
5+
46
from labml import lab, monit
57

68

7-
def main():
9+
def extract_zips(overwrite: bool = False):
810
download = Path(lab.get_data_path() / 'download')
911
source = Path(lab.get_data_path() / 'source')
1012

13+
if not source.exists():
14+
source.mkdir(parents=True)
15+
1116
for repo in download.iterdir():
1217
with monit.section(f"Extract {repo.stem}"):
1318
repo_source = source / repo.stem
1419
if repo_source.exists():
15-
continue
20+
if overwrite:
21+
rm_tree(repo_source)
22+
else:
23+
continue
1624
with zipfile.ZipFile(repo, 'r') as repo_zip:
1725
repo_zip.extractall(repo_source)
1826

19-
if not source.exists():
20-
source.mkdir(parents=True)
21-
2227

2328
if __name__ == '__main__':
24-
main()
29+
extract_zips()

python_autocomplete/train.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33

44
import torch
55
import torch.nn as nn
6-
from labml import lab, experiment, monit, logger
6+
7+
from labml import lab, experiment, monit, logger, tracker
78
from labml.configs import option
89
from labml.logger import Text
910
from labml.utils.pytorch import get_modules
10-
1111
from labml_helpers.datasets.text import TextDataset, SequentialDataLoader
1212
from labml_helpers.device import DeviceConfigs
13+
from labml_helpers.metrics.accuracy import Accuracy
1314
from labml_helpers.module import Module
14-
from labml_helpers.optimizer import OptimizerConfigs
15-
from labml_helpers.train_valid import TrainValidConfigs
15+
from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
16+
from labml_nn.optimizers.configs import OptimizerConfigs
1617
from labml_nn.transformers import TransformerConfigs
1718

1819

@@ -26,7 +27,9 @@ def __init__(self, path: PurePath, tokenizer: Callable):
2627

2728

2829
class Configs(TrainValidConfigs):
29-
device = DeviceConfigs()
30+
optimizer: torch.optim.Adam
31+
device: torch.device = DeviceConfigs()
32+
3033
model: Module
3134
text: TextDataset
3235
batch_size: int = 16
@@ -44,32 +47,50 @@ class Configs(TrainValidConfigs):
4447

4548
transformer: TransformerConfigs
4649

47-
def run(self):
48-
for _ in self.training_loop:
49-
prompt = 'def train('
50-
log = [(prompt, Text.subtle)]
51-
for i in monit.iterate('Sample', 25):
52-
data = self.text.text_to_i(prompt).unsqueeze(-1)
53-
data = data.to(self.device)
54-
output, *_ = self.model(data)
55-
output = output.argmax(dim=-1).squeeze()
56-
prompt += '' + self.text.itos[output[-1]]
57-
log += [('' + self.text.itos[output[-1]], Text.value)]
50+
accuracy_func = Accuracy()
51+
loss_func: 'CrossEntropyLoss'
52+
53+
def init(self):
54+
tracker.set_queue("loss.*", 20, True)
55+
tracker.set_scalar("accuracy.*", True)
56+
hook_model_outputs(self.mode, self.model, 'model')
57+
self.state_modules = [self.accuracy_func]
58+
59+
def step(self, batch: any, batch_idx: BatchIndex):
60+
data, target = batch[0].to(self.device), batch[1].to(self.device)
61+
62+
if self.mode.is_train:
63+
tracker.add_global_step(len(data))
64+
65+
with self.mode.update(is_log_activations=batch_idx.is_last):
66+
output, *_ = self.model(data)
5867

59-
logger.log(log)
68+
loss = self.loss_func(output, target)
69+
self.accuracy_func(output, target)
70+
tracker.add("loss.", loss)
6071

61-
self.run_step()
72+
if self.mode.is_train:
73+
loss.backward()
6274

75+
self.optimizer.step()
76+
if batch_idx.is_last:
77+
tracker.add('model', self.model)
78+
self.optimizer.zero_grad()
6379

64-
class SimpleAccuracyFunc(Module):
65-
def __call__(self, output: torch.Tensor, target: torch.Tensor) -> int:
66-
pred = output.argmax(dim=-1)
67-
return pred.eq(target).sum().item() / target.shape[1]
80+
tracker.save()
6881

82+
def sample(self):
83+
prompt = 'def train('
84+
log = [(prompt, Text.subtle)]
85+
for i in monit.iterate('Sample', 25):
86+
data = self.text.text_to_i(prompt).unsqueeze(-1)
87+
data = data.to(self.device)
88+
output, *_ = self.model(data)
89+
output = output.argmax(dim=-1).squeeze()
90+
prompt += '' + self.text.itos[output[-1]]
91+
log += [('' + self.text.itos[output[-1]], Text.value)]
6992

70-
@option(Configs.accuracy_func)
71-
def simple_accuracy():
72-
return SimpleAccuracyFunc()
93+
logger.log(log)
7394

7495

7596
@option(Configs.transformer)
@@ -126,7 +147,7 @@ def lstm_model(c: Configs):
126147

127148
@option(Configs.model)
128149
def rhn_model(c: Configs):
129-
from python_autocomplete.models import RhnModel
150+
from python_autocomplete.models.highway import RhnModel
130151
m = RhnModel(n_tokens=c.n_tokens,
131152
embedding_size=c.d_model,
132153
hidden_size=c.rnn_size,
@@ -137,7 +158,7 @@ def rhn_model(c: Configs):
137158

138159
@option(Configs.model)
139160
def transformer_model(c: Configs):
140-
from python_autocomplete.models import TransformerModel
161+
from python_autocomplete.models.transformer import TransformerModel
141162
m = TransformerModel(n_tokens=c.n_tokens,
142163
d_model=c.d_model,
143164
encoder=c.transformer.encoder,
@@ -189,7 +210,7 @@ def main():
189210
'optimizer.optimizer': 'Adam',
190211
'optimizer.learning_rate': 2.5e-4,
191212
'device.cuda_device': 1
192-
}, 'run')
213+
})
193214
experiment.add_pytorch_models(get_modules(conf))
194215
# experiment.load('d5ba7f56d88911eaa6629b54a83956dc')
195216
with experiment.start():

readme.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ This repo trains deep learning models on source code.
1212
4. Run `python_autocomplete/extract_downloads.py` to extract the downloaded zip files to `data/source`.
1313
You can directly copy any python code to `data/source` to train on them.
1414
5. Run `python_autocomplete/remove_non_source_files.py` to all files except `.py` files.
15-
6. Run `create_dataset.py` to collect all python files.
15+
6. Run `python_autocomplete/create_dataset.py` to collect all python files.
1616
The collected code will be written to `data/train.py` and, `data/eval.py`.
17-
7. Run `train.py` to train the model.
17+
7. Run `python_autocomplete/train.py` to train the model.
1818
*Try changing hyper-parameters like model dimensions and number of layers*.
1919
8. Run `evaluate.py` to evaluate the model.
2020
9. Enjoy!

0 commit comments

Comments
 (0)