Skip to content

Commit 6c06c80

Browse files
committed
✨ ddp
1 parent 95562f6 commit 6c06c80

1 file changed

Lines changed: 110 additions & 0 deletions

File tree

python_autocomplete/distributed.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import datetime
2+
3+
import torch.distributed
4+
import torch.nn as nn
5+
6+
from labml import experiment, monit, logger, tracker
7+
from labml.configs import option
8+
from labml.logger import Text, inspect
9+
from labml_helpers.train_valid import hook_model_outputs, BatchIndex
10+
from python_autocomplete.train import Configs as Configs_
11+
12+
13+
class Configs(Configs_):
14+
ddp_model: nn.parallel.DistributedDataParallel
15+
16+
def init(self):
17+
tracker.set_queue("loss.*", 20, True)
18+
tracker.set_scalar("accuracy.*", True)
19+
hook_model_outputs(self.mode, self.ddp_model, 'model')
20+
self.state_modules = [self.accuracy_func]
21+
22+
def step(self, batch: any, batch_idx: BatchIndex):
23+
data, target = batch[0].to(self.device), batch[1].to(self.device)
24+
25+
if self.mode.is_train:
26+
tracker.add_global_step(len(data))
27+
28+
with self.mode.update(is_log_activations=batch_idx.is_last):
29+
output, *_ = self.ddp_model(data)
30+
31+
loss = self.loss_func(output, target)
32+
self.accuracy_func(output, target)
33+
self.accuracy_func.track()
34+
tracker.add("loss.", loss)
35+
36+
if self.mode.is_train:
37+
loss.backward()
38+
39+
self.optimizer.step()
40+
if batch_idx.is_last:
41+
tracker.add('model', self.ddp_model)
42+
self.optimizer.zero_grad()
43+
44+
tracker.save()
45+
46+
def sample(self):
47+
prompt = 'def train('
48+
log = [(prompt, Text.subtle)]
49+
for i in monit.iterate('Sample', 25):
50+
data = self.text.text_to_i(prompt).unsqueeze(-1)
51+
data = data.to(self.device)
52+
output, *_ = self.ddp_model(data)
53+
output = output.argmax(dim=-1).squeeze()
54+
prompt += '' + self.text.itos[output[-1]]
55+
log += [('' + self.text.itos[output[-1]], Text.value)]
56+
57+
logger.log(log)
58+
59+
60+
@option(Configs.ddp_model)
61+
def ddp_model(c: Configs):
62+
return nn.parallel.DistributedDataParallel(c.model, device_ids=[c.device])
63+
64+
65+
def main(local_rank, rank, world_size, uuid, init_method: str = 'tcp://localhost:23456'):
66+
with monit.section('Distributed'):
67+
torch.distributed.init_process_group("gloo",
68+
timeout=datetime.timedelta(seconds=30),
69+
init_method=init_method,
70+
rank=rank,
71+
world_size=world_size)
72+
conf = Configs()
73+
experiment.create(uuid=uuid,
74+
name="source_code_ddp",
75+
comment='lstm model')
76+
experiment.distributed(local_rank, world_size)
77+
experiment.configs(conf, {
78+
'model': 'transformer_model',
79+
'n_layers': 6,
80+
'batch_size': 12,
81+
'epochs': 32,
82+
'optimizer.optimizer': 'Noam',
83+
'optimizer.learning_rate': 1.0,
84+
'device.cuda_device': local_rank,
85+
'seq_len': 512,
86+
'train_loader': 'shuffled_train_loader',
87+
'valid_loader': 'shuffled_valid_loader'
88+
})
89+
experiment.add_pytorch_models(model=conf.ddp_model)
90+
with experiment.start():
91+
conf.run()
92+
93+
94+
def _launcher():
95+
import os
96+
world_size = int(os.environ['WORLD_SIZE'])
97+
run_uuid = os.environ['RUN_UUID']
98+
local_rank = int(os.environ['LOCAL_RANK'])
99+
rank = int(os.environ['RANK'])
100+
inspect(world_size=os.environ['WORLD_SIZE'],
101+
run_uuid=os.environ['RUN_UUID'],
102+
local_rank=os.environ['LOCAL_RANK'],
103+
rank=os.environ['RANK'],
104+
master_addr=os.environ['MASTER_ADDR'],
105+
master_port=os.environ['MASTER_PORT'])
106+
main(local_rank, rank, world_size, run_uuid, 'env://')
107+
108+
109+
if __name__ == '__main__':
110+
_launcher()

0 commit comments

Comments
 (0)