1111from torch import nn
1212from torch .nn import init
1313import torch .utils .data as Data
14- import torch .nn .functional as F
1514import matplotlib .pyplot as plt
1615import numpy as np
1716
2423EPOCH = 12
2524LR = 0.03
2625N_HIDDEN = 8
27- ACTIVATION = F .tanh
26+ ACTIVATION = torch .tanh
2827B_INIT = - 0.2 # use a bad bias constant initializer
2928
3029# training data
4847plt .scatter (train_x .numpy (), train_y .numpy (), c = '#FF9359' , s = 50 , alpha = 0.2 , label = 'train' )
4948plt .legend (loc = 'upper left' )
5049
50+
5151class Net (nn .Module ):
5252 def __init__ (self , batch_normalization = False ):
5353 super (Net , self ).__init__ ()
@@ -89,20 +89,20 @@ def forward(self, x):
8989
9090nets = [Net (batch_normalization = False ), Net (batch_normalization = True )]
9191
92- print (* nets ) # print net architecture
92+ # print(*nets) # print net architecture
9393
9494opts = [torch .optim .Adam (net .parameters (), lr = LR ) for net in nets ]
9595
9696loss_func = torch .nn .MSELoss ()
9797
98- f , axs = plt .subplots (4 , N_HIDDEN + 1 , figsize = (10 , 5 ))
99- plt .ion () # something about plotting
100- plt .show ()
98+
10199def plot_histogram (l_in , l_in_bn , pre_ac , pre_ac_bn ):
102- for i , (ax_pa , ax_pa_bn , ax , ax_bn ) in enumerate (zip (axs [0 , :], axs [1 , :], axs [2 , :], axs [3 , :])):
100+ for i , (ax_pa , ax_pa_bn , ax , ax_bn ) in enumerate (zip (axs [0 , :], axs [1 , :], axs [2 , :], axs [3 , :])):
103101 [a .clear () for a in [ax_pa , ax_pa_bn , ax , ax_bn ]]
104- if i == 0 : p_range = (- 7 , 10 );the_range = (- 7 , 10 )
105- else :p_range = (- 4 , 4 );the_range = (- 1 , 1 )
102+ if i == 0 :
103+ p_range = (- 7 , 10 );the_range = (- 7 , 10 )
104+ else :
105+ p_range = (- 4 , 4 );the_range = (- 1 , 1 )
106106 ax_pa .set_title ('L' + str (i ))
107107 ax_pa .hist (pre_ac [i ].data .numpy ().ravel (), bins = 10 , range = p_range , color = '#FF9359' , alpha = 0.5 );ax_pa_bn .hist (pre_ac_bn [i ].data .numpy ().ravel (), bins = 10 , range = p_range , color = '#74BCFF' , alpha = 0.5 )
108108 ax .hist (l_in [i ].data .numpy ().ravel (), bins = 10 , range = the_range , color = '#FF9359' );ax_bn .hist (l_in_bn [i ].data .numpy ().ravel (), bins = 10 , range = the_range , color = '#74BCFF' )
@@ -111,44 +111,50 @@ def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn):
111111 axs [0 , 0 ].set_ylabel ('PreAct' );axs [1 , 0 ].set_ylabel ('BN PreAct' );axs [2 , 0 ].set_ylabel ('Act' );axs [3 , 0 ].set_ylabel ('BN Act' )
112112 plt .pause (0.01 )
113113
114- # training
115- losses = [[], []] # recode loss for two networks
116- for epoch in range (EPOCH ):
117- print ('Epoch: ' , epoch )
118- layer_inputs , pre_acts = [], []
119- for net , l in zip (nets , losses ):
120- net .eval () # set eval mode to fix moving_mean and moving_var
121- pred , layer_input , pre_act = net (test_x )
122- l .append (loss_func (pred , test_y ).data [0 ])
123- layer_inputs .append (layer_input )
124- pre_acts .append (pre_act )
125- net .train () # free moving_mean and moving_var
126- plot_histogram (* layer_inputs , * pre_acts ) # plot histogram
127-
128- for step , (b_x , b_y ) in enumerate (train_loader ):
129- for net , opt in zip (nets , opts ): # train for each network
130- pred , _ , _ = net (b_x )
131- loss = loss_func (pred , b_y )
132- opt .zero_grad ()
133- loss .backward ()
134- opt .step () # it will also learns the parameters in Batch Normalization
135-
136-
137- plt .ioff ()
138-
139- # plot training loss
140- plt .figure (2 )
141- plt .plot (losses [0 ], c = '#FF9359' , lw = 3 , label = 'Original' )
142- plt .plot (losses [1 ], c = '#74BCFF' , lw = 3 , label = 'Batch Normalization' )
143- plt .xlabel ('step' );plt .ylabel ('test loss' );plt .ylim ((0 , 2000 ));plt .legend (loc = 'best' )
144-
145- # evaluation
146- # set net to eval mode to freeze the parameters in batch normalization layers
147- [net .eval () for net in nets ] # set eval mode to fix moving_mean and moving_var
148- preds = [net (test_x )[0 ] for net in nets ]
149- plt .figure (3 )
150- plt .plot (test_x .data .numpy (), preds [0 ].data .numpy (), c = '#FF9359' , lw = 4 , label = 'Original' )
151- plt .plot (test_x .data .numpy (), preds [1 ].data .numpy (), c = '#74BCFF' , lw = 4 , label = 'Batch Normalization' )
152- plt .scatter (test_x .data .numpy (), test_y .data .numpy (), c = 'r' , s = 50 , alpha = 0.2 , label = 'train' )
153- plt .legend (loc = 'best' )
154- plt .show ()
114+
115+ if __name__ == "__main__" :
116+ f , axs = plt .subplots (4 , N_HIDDEN + 1 , figsize = (10 , 5 ))
117+ plt .ion () # something about plotting
118+ plt .show ()
119+
120+ # training
121+ losses = [[], []] # recode loss for two networks
122+
123+ for epoch in range (EPOCH ):
124+ print ('Epoch: ' , epoch )
125+ layer_inputs , pre_acts = [], []
126+ for net , l in zip (nets , losses ):
127+ net .eval () # set eval mode to fix moving_mean and moving_var
128+ pred , layer_input , pre_act = net (test_x )
129+ l .append (loss_func (pred , test_y ).data .item ())
130+ layer_inputs .append (layer_input )
131+ pre_acts .append (pre_act )
132+ net .train () # free moving_mean and moving_var
133+ plot_histogram (* layer_inputs , * pre_acts ) # plot histogram
134+
135+ for step , (b_x , b_y ) in enumerate (train_loader ):
136+ for net , opt in zip (nets , opts ): # train for each network
137+ pred , _ , _ = net (b_x )
138+ loss = loss_func (pred , b_y )
139+ opt .zero_grad ()
140+ loss .backward ()
141+ opt .step () # it will also learns the parameters in Batch Normalization
142+
143+ plt .ioff ()
144+
145+ # plot training loss
146+ plt .figure (2 )
147+ plt .plot (losses [0 ], c = '#FF9359' , lw = 3 , label = 'Original' )
148+ plt .plot (losses [1 ], c = '#74BCFF' , lw = 3 , label = 'Batch Normalization' )
149+ plt .xlabel ('step' );plt .ylabel ('test loss' );plt .ylim ((0 , 2000 ));plt .legend (loc = 'best' )
150+
151+ # evaluation
152+ # set net to eval mode to freeze the parameters in batch normalization layers
153+ [net .eval () for net in nets ] # set eval mode to fix moving_mean and moving_var
154+ preds = [net (test_x )[0 ] for net in nets ]
155+ plt .figure (3 )
156+ plt .plot (test_x .data .numpy (), preds [0 ].data .numpy (), c = '#FF9359' , lw = 4 , label = 'Original' )
157+ plt .plot (test_x .data .numpy (), preds [1 ].data .numpy (), c = '#74BCFF' , lw = 4 , label = 'Batch Normalization' )
158+ plt .scatter (test_x .data .numpy (), test_y .data .numpy (), c = 'r' , s = 50 , alpha = 0.2 , label = 'train' )
159+ plt .legend (loc = 'best' )
160+ plt .show ()
0 commit comments