@@ -60,6 +60,7 @@ def get_config(self):
6060 }
6161
6262 def build (self , inputs_shape ):
63+
6364 # Transformation for linearly projecting the queries, keys, and values.
6465 self .q_transformation = self ._get_weights (
6566 "q_project" , shape = (self .hidden_size , self .hidden_size ), init = tf .initializers .get ('glorot_uniform' )
@@ -75,20 +76,7 @@ def build(self, inputs_shape):
7576 )
7677
7778 def split_heads (self , x ):
78- """Split x into different heads, and transpose the resulting value.
79-
80- The tensor is transposed to insure the inner dimensions hold the correct
81- values during the matrix multiplication.
8279
83- Parameters
84- -----------
85-
86- x: A tensor with shape [batch_size, length, hidden_size]
87-
88- Returns:
89- -----------
90- A tensor with shape [batch_size, num_heads, length, hidden_size/num_heads]
91- """
9280 with tf .name_scope ("split_heads" ):
9381 batch_size = tf .shape (x )[0 ]
9482 length = tf .shape (x )[1 ]
@@ -103,40 +91,15 @@ def split_heads(self, x):
10391 return tf .transpose (x , [0 , 2 , 1 , 3 ])
10492
10593 def combine_heads (self , x ):
106- """Combine tensor that has been split.
107-
108- Args:
109- x: A tensor [batch_size, num_heads, length, hidden_size/num_heads]
11094
111- Returns:
112- -----------
113- A tensor with shape [batch_size, length, hidden_size]
114- """
11595 with tf .name_scope ("combine_heads" ):
11696 batch_size = tf .shape (x )[0 ]
11797 length = tf .shape (x )[2 ]
11898 x = tf .transpose (x , [0 , 2 , 1 , 3 ]) # --> [batch, length, num_heads, depth]
11999 return tf .reshape (x , [batch_size , length , self .hidden_size ])
120100
121101 def forward (self , x , y , mask , cache = None ):
122- """Apply attention mechanism to x and y.
123-
124- Args:
125- x: a tensor with shape [batch_size, length_x, hidden_size]
126- y: a tensor with shape [batch_size, length_y, hidden_size]
127- mask: attention bias that will be added to the result of the dot product.
128- training: boolean, whether in training mode or not.
129- cache: (Used during prediction) dictionary with tensors containing results
130- of previous attentions. The dictionary must have the items:
131- {"k": tensor with shape [batch_size, i, key_channels],
132- "v": tensor with shape [batch_size, i, value_channels]}
133- where i is the current decoded length.
134-
135- Returns:
136- -----------
137- Attention layer output with shape [batch_size, length_x, hidden_size]
138- Attention weights with shape [batch_size, number_of_head, length_x, length_y]
139- """
102+ """Apply attention mechanism to x and y."""
140103 # Linearly project the query (q), key (k) and value (v) using different
141104 # learned projections. This is in preparation of splitting them into
142105 # multiple heads. Multi-head attention uses multiple queries, keys, and
0 commit comments