@@ -11,54 +11,6 @@ def forward(self, input):
1111 return input .view (input .size (0 ), - 1 )
1212
1313
14-
15-
16- class AttLSTM (nn .Module ):
17- def __init__ (self , mask_activation = "softmax" , ** kwargs ):
18- """
19- Attentional LSTM
20- input_size: input dimension
21- hidden_size: hidden dimension, also the output dimention of LSTM
22- other kwargs of LSTM, most of the following is pilferage from nn.LSTM doc:
23- input_size: mentioned above, only have to specify once
24- hidden_size: mentioned above, only have to specify once
25- num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
26- would mean stacking two LSTMs together to form a `stacked LSTM`,
27- with the second LSTM taking in outputs of the first LSTM and
28- computing the final results. Default: 1
29- bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
30- Default: ``True``
31- batch_first: If ``True``, then the input and output tensors are provided
32- as (batch, seq, feature). Default: ``False``
33- dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
34- LSTM layer except the last layer, with dropout probability equal to
35- :attr:`dropout`. Default: 0
36- bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False``
37- """
38- super (AttLSTM , self ).__init__ ()
39- self .input_size = kwargs ["input_size" ]
40- self .hidden_size = kwargs ["hidden_size" ]
41- self .mask_maker = nn .Linear (self .hidden_size , 1 )
42- self .lstm = nn .LSTM (** kwargs )
43- if mask_activation == "softmax" :
44- self .mask_act = nn .Softmax (dim = 1 )
45- elif mask_activation == "sigmoid" :
46- self .mask_act = nn .Sigmoid ()
47- elif mask_activation == "relu" :
48- self .mask_act = nn .ReLU ()
49- elif mask_activation == "passon" :
50- self .mask_act = passon ()
51- else :
52- print ("Activation type:%s not found, should be one of the following:\n softmax\n sigmoid\n relu" % (
53- mask_activation ))
54-
55- def forward (self , x ):
56- mask = self .mask_act (self .mask_maker (x ).squeeze (- 1 )).unsqueeze (1 ) # mask shape (bs,1,seq_leng)
57- output , (h_n , c_n ) = self .lstm (x )
58- output = mask .bmm (output ).squeeze (1 ) # output shape (bs, hidden_size)
59- return output , (h_n , c_n ), mask .squeeze (1 )
60-
61-
6214class passon (nn .Module ):
6315 def __init__ (self ):
6416 """
0 commit comments