@@ -60,9 +60,8 @@ public function generate(
6060 // Whisper has additional options for returning timestamps
6161 $ generationConfig ['return_timestamps ' ] ??= false ;
6262
63-
6463 if ($ generationConfig ['return_timestamps ' ]) {
65- $ logitsProcessor = new LogitsProcessorList ();
64+ $ logitsProcessor ?? = new LogitsProcessorList ();
6665 $ logitsProcessor ->push (new WhisperTimeStampLogitsProcessor ($ generationConfig ));
6766 }
6867
@@ -83,14 +82,13 @@ public function generate(
8382 }
8483 }
8584
86-
87- $ outputs = parent ::generate ($ inputs , $ generationConfig , $ logitsProcessor , $ inputsAttentionMask , $ streamer );
85+ $ outputs = parent ::generate ($ inputs , $ generationConfig , $ logitsProcessor , streamer: $ streamer );
8886
8987 if (isset ($ generationConfig ['return_token_timestamps ' ]) && isset ($ generationConfig ['alignment_heads ' ])) {
9088 $ outputs ['token_timestamps ' ] = $ this ->extractTokenTimestamps (
9189 $ outputs ,
9290 $ generationConfig ['alignment_heads ' ],
93- ( int ) $ generationConfig ['num_frames ' ] ?? null ,
91+ $ generationConfig ['num_frames ' ] ?? null ,
9492 );
9593 }
9694
@@ -109,10 +107,10 @@ public function generate(
109107 * @throws Exception If the model outputs do not contain cross attentions
110108 */
111109 public function extractTokenTimestamps (
112- array $ generateOutputs ,
113- array $ alignmentHeads ,
110+ array $ generateOutputs ,
111+ array $ alignmentHeads ,
114112 int |null $ numFrames = null ,
115- float $ timePrecision = 0.02
113+ float $ timePrecision = 0.02
116114 ): Tensor
117115 {
118116 if (!isset ($ generateOutputs ['cross_attentions ' ])) {
@@ -128,6 +126,7 @@ public function extractTokenTimestamps(
128126 $ medianFilterWidth = 7 ;
129127 }
130128
129+
131130 $ batchedMatrices = array_map (function ($ batch ) use ($ numFrames , $ alignmentHeads , $ medianFilterWidth ) {
132131 // Create a list with `decoder_layers` elements, each a tensor of shape
133132 // (batch size, attention_heads, output length, input length).
@@ -164,13 +163,18 @@ public function extractTokenTimestamps(
164163 /** @var Tensor $cTensor */
165164 $ cTensor = $ bTensor [$ c ]; // [1500]
166165
167- $ cTensor
168- ->add ($ meanTensor ->multiply (-1 ))
169- ->multiply ($ stdTensor ->reciprocal ())
170- ->copyTo ($ cTensor );
166+ for ($ d = 0 ; $ d < $ cTensor ->count (); ++$ d ) {
167+ $ cTensor [$ d ] = ($ cTensor [$ d ] - $ meanTensor [$ d ]) / $ stdTensor [$ d ];
168+ }
171169
172170 // Apply median filter.
173171 $ this ->medianFilter ($ cTensor , $ medianFilterWidth )->copyTo ($ cTensor );
172+ // $filtered = $this->medianFilter($cTensor, $medianFilterWidth);
173+ // for ($e = 0; $e < $filtered->count(); ++$e) {
174+ // $cTensor[$e] = $filtered[$e];
175+ // }
176+
177+
174178 }
175179 }
176180 }
@@ -181,7 +185,6 @@ public function extractTokenTimestamps(
181185
182186 $ timestampsShape = [count ($ generateOutputs ['sequences ' ]), count ($ generateOutputs ['sequences ' ][0 ])];
183187
184-
185188 $ timestamps = Tensor::zeros ($ timestampsShape , Tensor::float32);
186189
187190 // Perform dynamic time warping on each element of the batch.
@@ -194,14 +197,13 @@ public function extractTokenTimestamps(
194197 $ diffs = array_map (fn ($ i ) => $ textIndices [$ i + 1 ] - $ textIndices [$ i ], range (0 , count ($ textIndices ) - 2 ));
195198 $ jumps = array_map (fn ($ x ) => (bool )$ x , array_merge ([1 ], $ diffs ));
196199
197- dd ($ timeIndices );
198200 $ jumpTimes = [];
199201 for ($ i = 0 ; $ i < count ($ jumps ); ++$ i ) {
200202 if ($ jumps [$ i ]) {
201203 $ jumpTimes [] = $ timeIndices [$ i ] * $ timePrecision ;
202204 }
203205 }
204- dd ( $ jumpTimes );
206+
205207 for ($ i = 1 ; $ i < count ($ jumpTimes ); ++$ i ) {
206208 $ timestamps [$ batchIdx ][$ i ] = $ jumpTimes [$ i ];
207209 }
@@ -210,38 +212,54 @@ public function extractTokenTimestamps(
210212 return $ timestamps ;
211213 }
212214
213- function medianFilter (Tensor $ tensor , int $ windowSize ): Tensor
215+ /**
216+ * Applies a median filter of width `$windowSize` along the last dimension of the input.
217+ *
218+ * The `$input` tensor is assumed to be 3- or 4-dimensional.
219+ * @param Tensor $input
220+ * @param int $windowSize
221+ * @return Tensor
222+ */
223+ function medianFilter (Tensor $ input , int $ windowSize ): Tensor
214224 {
215225 if ($ windowSize % 2 === 0 || $ windowSize <= 0 ) {
216226 throw new InvalidArgumentException ('Window size must be a positive odd number ' );
217227 }
218228
219- $ outputArray = array_fill ( 0 , count ( $ tensor ), 0 );
229+ $ output = Tensor:: fill ( $ input -> shape ( ), 0 , $ input -> dtype () );
220230 $ buffer = array_fill (0 , $ windowSize , 0 );
221231
222232 $ halfWindowSize = (int )floor ($ windowSize / 2 );
223233
224- for ($ i = 0 ; $ i < count ($ tensor ); ++$ i ) {
234+ for ($ i = 0 ; $ i < count ($ input ); ++$ i ) {
225235 $ valuesIndex = 0 ;
226236
227237 for ($ j = -$ halfWindowSize ; $ j <= $ halfWindowSize ; ++$ j ) {
228238 $ index = $ i + $ j ;
229239 if ($ index < 0 ) {
230240 $ index = abs ($ index );
231- } else if ($ index >= count ($ tensor )) {
232- $ index = 2 * (count ($ tensor ) - 1 ) - $ index ;
241+ } else if ($ index >= count ($ input )) {
242+ $ index = 2 * (count ($ input ) - 1 ) - $ index ;
233243 }
234244
235- $ buffer [$ valuesIndex ++] = $ tensor -> buffer () [$ index ];
245+ $ buffer [$ valuesIndex ++] = $ input [$ index ];
236246 }
237247
238248 sort ($ buffer );
239- $ outputArray [$ i ] = $ buffer [$ halfWindowSize ];
249+
250+ $ output ->buffer ()[$ i ] = $ buffer [$ halfWindowSize ];
240251 }
241252
242- return Tensor:: fromArray ( $ outputArray , $ tensor -> dtype ()) ;
253+ return $ output ;
243254 }
244255
256+ /**
257+ * Measures
258+ * similarity between two temporal sequences: the input audio and the output tokens. Used to generate
259+ * token-level timestamps.
260+ * @param Tensor $tensor
261+ * @return array
262+ */
245263 private function dynamicTimeWarping (Tensor $ tensor ): array
246264 {
247265 [$ outputLength , $ inputLength ] = $ tensor ->shape ();
0 commit comments