1414use Codewithkyrian \Transformers \Utils \AutoConfig ;
1515use Codewithkyrian \Transformers \Utils \GenerationConfig ;
1616use Exception ;
17+ use InvalidArgumentException ;
1718
1819class WhisperForConditionalGeneration extends WhisperPretrainedModel
1920{
@@ -66,7 +67,6 @@ public function generate(
6667 }
6768
6869
69-
7070 if (isset ($ generationConfig ['return_token_timestamps ' ])) {
7171 $ generationConfig ['output_attentions ' ] = true ;
7272 $ generationConfig ['return_dict_in_generate ' ] = true ;
@@ -109,12 +109,13 @@ public function generate(
109109 * @throws Exception If the model outputs do not contain cross attentions
110110 */
111111 public function extractTokenTimestamps (
112- array $ generateOutputs ,
113- array $ alignmentHeads ,
112+ array $ generateOutputs ,
113+ array $ alignmentHeads ,
114114 int |float |null $ numFrames = null ,
115- float $ timePrecision = 0.02
116- ): Tensor {
117- $ numFrames = (int ) $ numFrames ;
115+ float $ timePrecision = 0.02
116+ ): Tensor
117+ {
118+ $ numFrames = (int )$ numFrames ;
118119 if (!isset ($ generateOutputs ['cross_attentions ' ])) {
119120 throw new Exception (
120121 "Model outputs must contain cross attentions to extract timestamps. " .
@@ -128,7 +129,7 @@ public function extractTokenTimestamps(
128129 $ medianFilterWidth = 7 ;
129130 }
130131
131- $ batchedMatrices = array_map (function ($ batch ) use ($ numFrames , $ alignmentHeads , $ medianFilterWidth ) {
132+ $ batchedMatrices = array_map (function ($ batch ) use ($ numFrames , $ alignmentHeads , $ medianFilterWidth ) {
132133 // Create a list with `decoder_layers` elements, each a tensor of shape
133134 // (batch size, attention_heads, output length, input length).
134135 /** @var Tensor[] $crossAttentions */
@@ -137,61 +138,61 @@ public function extractTokenTimestamps(
137138 $ crossAttentions [] = Tensor::concat (array_map (fn ($ x ) => $ x [$ i ], $ batch ), 2 );
138139 }
139140
140- $ weights = Tensor::stack (array_map (function ($ alignmentHead ) use ($ crossAttentions , $ numFrames ) {
141+ $ weights = Tensor::stack (array_map (function ($ alignmentHead ) use ($ crossAttentions , $ numFrames ) {
141142 [$ l , $ h ] = $ alignmentHead ;
142143 return $ numFrames
143- ? $ crossAttentions [$ l ]->slice (null , $ h , null , [0 , $ numFrames ])
144- : $ crossAttentions [$ l ]->slice (null , $ h );
144+ ? $ crossAttentions [$ l ]->slice (null , $ h , null , [0 , $ numFrames ])-> squeeze ( 1 )
145+ : $ crossAttentions [$ l ]->slice (null , $ h )-> squeeze ( 1 ); // experimental
145146 }, $ alignmentHeads ));
146- dd ($ weights ->shape ());
147-
148- $ weights = $ weights ->permute ( 1 , 0 , 2 , 3 );
149147
148+ $ weights = $ weights ->permute (1 , 0 , 2 , 3 );
150149
151- list ( $ std , $ calculatedMean) = std_mean ( $ weights, -2 , 0 , true );
150+ [ $ std , $ calculatedMean] = $ weights-> stdMean ( -2 , 0 , true );
152151
153152 // Normalize and smoothen the weights.
154- $ smoothedWeights = $ weights-> clone () ; // [1, 8, seqLength, 1500]
153+ $ smoothedWeights = clone $ weights ; // [1, 8, seqLength, 1500]
155154
156- for ($ a = 0 ; $ a < $ smoothedWeights ->dims [0 ]; ++$ a ) {
155+ for ($ a = 0 ; $ a < $ smoothedWeights ->shape () [0 ]; ++$ a ) {
157156 $ aTensor = $ smoothedWeights [$ a ]; // [8, seqLength, 1500]
158157
159- for ($ b = 0 ; $ b < $ aTensor ->dims [0 ]; ++$ b ) {
158+ for ($ b = 0 ; $ b < $ aTensor ->shape () [0 ]; ++$ b ) {
160159 $ bTensor = $ aTensor [$ b ]; // [seqLength, 1500]
161160
162161 $ stdTensor = $ std [$ a ][$ b ][0 ]; // [1500]
163162 $ meanTensor = $ calculatedMean [$ a ][$ b ][0 ]; // [1500]
164163
165- for ($ c = 0 ; $ c < $ bTensor ->dims [0 ]; ++$ c ) {
164+ for ($ c = 0 ; $ c < $ bTensor ->shape ()[0 ]; ++$ c ) {
165+ /** @var Tensor $cTensor */
166166 $ cTensor = $ bTensor [$ c ]; // [1500]
167- for ($ d = 0 ; $ d < count ($ cTensor ->data ); ++$ d ) {
168- $ cTensor ->data [$ d ] = ($ cTensor ->data [$ d ] - $ meanTensor ->data [$ d ]) / $ stdTensor ->data [$ d ];
169- }
167+ // for ($d = 0; $d < count($cTensor->buffer()); ++$d) {
168+ // $cTensor->buffer()[$d] = ($cTensor->buffer()[$d] - $meanTensor->buffer()[$d]) / $stdTensor->buffer()[$d];
169+ // }
170+ $ cTensor = $ cTensor ->add ($ meanTensor ->multiply (-1 ))->multiply ($ stdTensor ->reciprocal ());
170171
171172 // Apply median filter.
172- $ cTensor-> data = medianFilter ($ cTensor-> data , $ medianFilterWidth );
173+ $ cTensor = $ this -> medianFilter ($ cTensor , $ medianFilterWidth );
173174 }
174175 }
175176 }
176177
177178 // Average the different cross-attention heads.
178- $ matrix = mean ($ smoothedWeights , 1 );
179- return $ matrix ;
179+ return $ smoothedWeights ->mean (1 );
180180 }, $ generateOutputs ['cross_attentions ' ]);
181181
182182 $ timestampsShape = [count ($ generateOutputs ['sequences ' ]), count ($ generateOutputs ['sequences ' ][0 ])];
183183
184+
184185 $ timestamps = new Tensor (null , Tensor::float32, $ timestampsShape );
185186
186187 // Perform dynamic time warping on each element of the batch.
187188 for ($ batchIdx = 0 ; $ batchIdx < $ timestampsShape [0 ]; ++$ batchIdx ) {
188189 // NOTE: Since we run only one batch at a time, we can squeeze to get the same dimensions
189190 // as the python implementation
190- $ matrix = $ batchedMatrices [$ batchIdx ]->neg ( )->squeeze_ (0 );
191+ $ matrix = $ batchedMatrices [$ batchIdx ]->multiply (- 1 )->squeeze (0 );
191192 list ($ textIndices , $ timeIndices ) = dynamicTimeWarping ($ matrix );
192193
193194 $ diffs = array_map (fn ($ i ) => $ textIndices [$ i + 1 ] - $ textIndices [$ i ], range (0 , count ($ textIndices ) - 2 ));
194- $ jumps = array_map (fn ($ x ) => (bool ) $ x , array_merge ([1 ], $ diffs ));
195+ $ jumps = array_map (fn ($ x ) => (bool )$ x , array_merge ([1 ], $ diffs ));
195196
196197 $ jumpTimes = [];
197198 for ($ i = 0 ; $ i < count ($ jumps ); ++$ i ) {
@@ -206,4 +207,36 @@ public function extractTokenTimestamps(
206207 return $ timestamps ;
207208 }
208209
210+ function medianFilter (Tensor $ tensor , int $ windowSize ): Tensor
211+ {
212+ if ($ windowSize % 2 === 0 || $ windowSize <= 0 ) {
213+ throw new InvalidArgumentException ('Window size must be a positive odd number ' );
214+ }
215+
216+ $ outputArray = array_fill (0 , count ($ tensor ), 0 );
217+ $ buffer = array_fill (0 , $ windowSize , 0 );
218+
219+ $ halfWindowSize = (int ) floor ($ windowSize / 2 );
220+
221+ for ($ i = 0 ; $ i < count ($ tensor ); ++$ i ) {
222+ $ valuesIndex = 0 ;
223+
224+ for ($ j = -$ halfWindowSize ; $ j <= $ halfWindowSize ; ++$ j ) {
225+ $ index = $ i + $ j ;
226+ if ($ index < 0 ) {
227+ $ index = abs ($ index );
228+ } else if ($ index >= count ($ tensor )) {
229+ $ index = 2 * (count ($ tensor ) - 1 ) - $ index ;
230+ }
231+
232+ $ buffer [$ valuesIndex ++] = $ tensor ->buffer ()[$ index ];
233+ }
234+
235+ sort ($ buffer );
236+ $ outputArray [$ i ] = $ buffer [$ halfWindowSize ];
237+ }
238+
239+ return Tensor::fromArray ($ outputArray , $ tensor ->dtype ());
240+ }
241+
209242}
0 commit comments