Skip to content

Commit 68396f5

Browse files
committed
fix bug for windows
1 parent 46c516f commit 68396f5

1 file changed

Lines changed: 36 additions & 35 deletions

File tree

tutorial-contents/306_optimizer.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -43,43 +43,44 @@ def forward(self, x):
4343
x = self.predict(x) # linear output
4444
return x
4545

46-
# different nets
47-
net_SGD = Net()
48-
net_Momentum = Net()
49-
net_RMSprop = Net()
50-
net_Adam = Net()
51-
nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]
46+
if __name__ == '__main__':
47+
# different nets
48+
net_SGD = Net()
49+
net_Momentum = Net()
50+
net_RMSprop = Net()
51+
net_Adam = Net()
52+
nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]
5253

53-
# different optimizers
54-
opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)
55-
opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
56-
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)
57-
opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
58-
optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]
54+
# different optimizers
55+
opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)
56+
opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
57+
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)
58+
opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
59+
optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]
5960

60-
loss_func = torch.nn.MSELoss()
61-
losses_his = [[], [], [], []] # record loss
61+
loss_func = torch.nn.MSELoss()
62+
losses_his = [[], [], [], []] # record loss
6263

63-
# training
64-
for epoch in range(EPOCH):
65-
print('Epoch: ', epoch)
66-
for step, (batch_x, batch_y) in enumerate(loader): # for each training step
67-
b_x = Variable(batch_x)
68-
b_y = Variable(batch_y)
64+
# training
65+
for epoch in range(EPOCH):
66+
print('Epoch: ', epoch)
67+
for step, (batch_x, batch_y) in enumerate(loader): # for each training step
68+
b_x = Variable(batch_x)
69+
b_y = Variable(batch_y)
6970

70-
for net, opt, l_his in zip(nets, optimizers, losses_his):
71-
output = net(b_x) # get output for every net
72-
loss = loss_func(output, b_y) # compute loss for every net
73-
opt.zero_grad() # clear gradients for next train
74-
loss.backward() # backpropagation, compute gradients
75-
opt.step() # apply gradients
76-
l_his.append(loss.data[0]) # loss recoder
71+
for net, opt, l_his in zip(nets, optimizers, losses_his):
72+
output = net(b_x) # get output for every net
73+
loss = loss_func(output, b_y) # compute loss for every net
74+
opt.zero_grad() # clear gradients for next train
75+
loss.backward() # backpropagation, compute gradients
76+
opt.step() # apply gradients
77+
l_his.append(loss.data[0]) # loss recoder
7778

78-
labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']
79-
for i, l_his in enumerate(losses_his):
80-
plt.plot(l_his, label=labels[i])
81-
plt.legend(loc='best')
82-
plt.xlabel('Steps')
83-
plt.ylabel('Loss')
84-
plt.ylim((0, 0.2))
85-
plt.show()
79+
labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']
80+
for i, l_his in enumerate(losses_his):
81+
plt.plot(l_his, label=labels[i])
82+
plt.legend(loc='best')
83+
plt.xlabel('Steps')
84+
plt.ylabel('Loss')
85+
plt.ylim((0, 0.2))
86+
plt.show()

0 commit comments

Comments
 (0)