Skip to content

Commit 0e4b0e2

Browse files
More improvements for returning timestamps in ASR
1 parent 0f9610a commit 0e4b0e2

4 files changed

Lines changed: 77 additions & 6 deletions

File tree

examples/pipelines/asr.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@
2323
//$audioUrl = __DIR__ . '/../sounds/dataset1.wav';
2424

2525
$streamer = StdOutStreamer::make();
26-
$output = $transcriber($audioUrl, maxNewTokens: 256, chunkLengthSecs: 30, strideLengthSecs: 6);
26+
$output = $transcriber($audioUrl, maxNewTokens: 256, chunkLengthSecs: 30, returnTimestamps: true);
2727

2828
dd($output, timeUsage(), memoryUsage());

src/Models/Pretrained/WhisperForConditionalGeneration.php

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ public function extractTokenTimestamps(
189189
// NOTE: Since we run only one batch at a time, we can squeeze to get the same dimensions
190190
// as the python implementation
191191
$matrix = $batchedMatrices[$batchIdx]->multiply(-1)->squeeze(0);
192-
list($textIndices, $timeIndices) = dynamicTimeWarping($matrix);
192+
list($textIndices, $timeIndices) = $this->dynamicTimeWarping($matrix);
193193

194194
$diffs = array_map(fn($i) => $textIndices[$i + 1] - $textIndices[$i], range(0, count($textIndices) - 2));
195195
$jumps = array_map(fn($x) => (bool)$x, array_merge([1], $diffs));
@@ -216,7 +216,7 @@ function medianFilter(Tensor $tensor, int $windowSize): Tensor
216216
$outputArray = array_fill(0, count($tensor), 0);
217217
$buffer = array_fill(0, $windowSize, 0);
218218

219-
$halfWindowSize = (int) floor($windowSize / 2);
219+
$halfWindowSize = (int)floor($windowSize / 2);
220220

221221
for ($i = 0; $i < count($tensor); ++$i) {
222222
$valuesIndex = 0;
@@ -239,4 +239,75 @@ function medianFilter(Tensor $tensor, int $windowSize): Tensor
239239
return Tensor::fromArray($outputArray, $tensor->dtype());
240240
}
241241

242+
private function dynamicTimeWarping(Tensor $tensor): array
243+
{
244+
[$rows, $cols] = $tensor->shape();
245+
246+
$outputShape = [$rows + 1, $cols + 1];
247+
248+
$cost = Tensor::fill($outputShape, -INF, Tensor::float32);
249+
$traceback = Tensor::fill($outputShape, -1, Tensor::int32);
250+
251+
$cost[0][0] = 0;
252+
253+
for ($i = 1; $i < $rows + 1; ++$i) {
254+
for ($j = 1; $j < $cols + 1; ++$j) {
255+
$c0 = $cost[$i - 1][$j - 1];
256+
$c1 = $cost[$i - 1][$j];
257+
$c2 = $cost[$i][$j - 1];
258+
259+
if ($c0 <= $c1 && $c0 <= $c2) {
260+
$c = $c0;
261+
$t = 0;
262+
} else if ($c1 <= $c0 && $c1 <= $c2) {
263+
$c = $c1;
264+
$t = 1;
265+
} else {
266+
$c = $c2;
267+
$t = 2;
268+
}
269+
270+
$cost[$i][$j] = $tensor[$i - 1][$j - 1] + $c;
271+
$traceback[$i][$j] = $t;
272+
}
273+
}
274+
275+
// Traceback
276+
$i = $rows;
277+
$j = $cols;
278+
279+
for ($k = 0; $k < $outputShape[1]; ++$k) {
280+
$traceback[0][$k] = 2;
281+
}
282+
283+
for ($k = 0; $k < $outputShape[0]; ++$k) {
284+
$traceback[$k][0] = 1;
285+
}
286+
287+
$textIndices = [];
288+
$timeIndices = [];
289+
290+
while ($i > 0 || $j > 0) {
291+
$textIndices[] = $i - 1;
292+
$timeIndices[] = $j - 1;
293+
294+
$t = $traceback[$i][$j];
295+
296+
if ($t === 0) {
297+
$i--;
298+
$j--;
299+
} else if ($t === 1) {
300+
$i--;
301+
} else {
302+
$j--;
303+
}
304+
}
305+
306+
307+
$textIndices = array_reverse($textIndices);
308+
$timeIndices = array_reverse($timeIndices);
309+
310+
return [$textIndices, $timeIndices];
311+
}
312+
242313
}

src/Pipelines/AutomaticSpeechRecognitionPipeline.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ private function __invokeWhisper(array|string $inputs, ...$args): array|Tensor|I
104104
$forceFullSequences = $args['forceFullSequences'] ?? false;
105105
$strideLengthSecs = $args['strideLengthSecs'] ?? null;
106106

107-
if ($returnTimestamps == 'word') {
107+
if ($returnTimestamps === 'word') {
108108
$args['return_token_timestamps'] = true;
109109
}
110110

@@ -209,7 +209,7 @@ private function __invokeWhisper(array|string $inputs, ...$args): array|Tensor|I
209209
// TODO: Right now we only get top beam
210210
if ($returnTimestamps === 'word') {
211211
$chunk['tokens'] = $data['sequences'][0];
212-
$chunk['token_timestamps'] = array_map(fn($x) => round($x, 2), $data['token_timestamps'][0]);
212+
$chunk['token_timestamps'] = $data['token_timestamps'][0]->round(2);
213213
} else {
214214
$chunk['tokens'] = $data[0];
215215
}

src/Tensor/Tensor.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1122,7 +1122,7 @@ public function permute(...$axes): static
11221122
{
11231123
$permuted = self::mo()->transpose($this, $axes);
11241124

1125-
return Tensor::fromArray($permuted);
1125+
return new static($permuted->buffer(), $permuted->dtype(), $permuted->shape(), $permuted->offset());
11261126
}
11271127

11281128
/**

0 commit comments

Comments
 (0)