@@ -59,15 +59,19 @@ public function generate(
5959 // Whisper has additional options for returning timestamps
6060 $ generationConfig ['return_timestamps ' ] ??= false ;
6161
62+
6263 if ($ generationConfig ['return_timestamps ' ]) {
63- $ logitsProcessor = [new WhisperTimeStampLogitsProcessor ($ generationConfig )];
64+ $ logitsProcessor = new LogitsProcessorList ();
65+ $ logitsProcessor ->push (new WhisperTimeStampLogitsProcessor ($ generationConfig ));
6466 }
6567
68+
69+
6670 if (isset ($ generationConfig ['return_token_timestamps ' ])) {
67- $ generationConfig-> output_attentions = true ;
68- $ generationConfig-> return_dict_in_generate = true ;
71+ $ generationConfig[ ' output_attentions ' ] = true ;
72+ $ generationConfig[ ' return_dict_in_generate ' ] = true ;
6973
70- if ($ generationConfig ['task ' ] === 'translate ' ) {
74+ if ($ generationConfig ['task ' ] ?? '' === 'translate ' ) {
7175 trigger_error ("Token-level timestamps may not be reliable for task 'translate'. " , E_USER_WARNING );
7276 }
7377
@@ -79,13 +83,14 @@ public function generate(
7983 }
8084 }
8185
86+
8287 $ outputs = parent ::generate ($ inputs , $ generationConfig , $ logitsProcessor , $ inputsAttentionMask , $ streamer );
8388
8489 if (isset ($ generationConfig ['return_token_timestamps ' ]) && isset ($ generationConfig ['alignment_heads ' ])) {
8590 $ outputs ['token_timestamps ' ] = $ this ->extractTokenTimestamps (
8691 $ outputs ,
8792 $ generationConfig ['alignment_heads ' ],
88- $ generationConfig ['num_frames ' ]
93+ $ generationConfig ['num_frames ' ] ?? null ,
8994 );
9095 }
9196
@@ -106,9 +111,10 @@ public function generate(
106111 public function extractTokenTimestamps (
107112 array $ generateOutputs ,
108113 array $ alignmentHeads ,
109- ? int $ numFrames = null ,
114+ int | float | null $ numFrames = null ,
110115 float $ timePrecision = 0.02
111116 ): Tensor {
117+ $ numFrames = (int ) $ numFrames ;
112118 if (!isset ($ generateOutputs ['cross_attentions ' ])) {
113119 throw new Exception (
114120 "Model outputs must contain cross attentions to extract timestamps. " .
@@ -125,18 +131,22 @@ public function extractTokenTimestamps(
125131 $ batchedMatrices = array_map (function ($ batch ) use ($ numFrames , $ alignmentHeads , $ medianFilterWidth ) {
126132 // Create a list with `decoder_layers` elements, each a tensor of shape
127133 // (batch size, attention_heads, output length, input length).
134+ /** @var Tensor[] $crossAttentions */
128135 $ crossAttentions = [];
129136 for ($ i = 0 ; $ i < $ this ->config ['decoder_layers ' ]; $ i ++) {
130- $ crossAttentions [] = cat (array_map (fn ($ x ) => $ x [$ i ], $ batch ), 2 );
137+ $ crossAttentions [] = Tensor:: concat (array_map (fn ($ x ) => $ x [$ i ], $ batch ), 2 );
131138 }
132139
133- $ weights = stack (array_map (function ($ alignmentHead ) use ($ crossAttentions , $ numFrames ) {
134- list ( $ l , $ h) = $ alignmentHead ;
140+ $ weights = Tensor:: stack (array_map (function ($ alignmentHead ) use ($ crossAttentions , $ numFrames ) {
141+ [ $ l , $ h] = $ alignmentHead ;
135142 return $ numFrames
136143 ? $ crossAttentions [$ l ]->slice (null , $ h , null , [0 , $ numFrames ])
137144 : $ crossAttentions [$ l ]->slice (null , $ h );
138145 }, $ alignmentHeads ));
139- $ weights = $ weights ->transpose (1 , 0 , 2 , 3 );
146+ dd ($ weights ->shape ());
147+
148+ $ weights = $ weights ->permute ( 1 , 0 , 2 , 3 );
149+
140150
141151 list ($ std , $ calculatedMean ) = std_mean ($ weights , -2 , 0 , true );
142152
0 commit comments