Skip to content

Commit 78bcd16

Browse files
Fix bugs in ForceTokensLogitsProcessor and WhisperTokenizer to correct errors for non-english languages in ASR
1 parent cbfd758 commit 78bcd16

13 files changed

Lines changed: 158 additions & 87 deletions

File tree

examples/pipelines/asr.php

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,26 @@
1212
ini_set('memory_limit', '-1');
1313

1414
//$transcriber = pipeline('automatic-speech-recognition', 'Xenova/whisper-tiny.en');
15-
$transcriber = pipeline('automatic-speech-recognition', 'Xenova/whisper-base');
15+
$transcriber = pipeline('automatic-speech-recognition', 'Xenova/whisper-tiny');
16+
//$transcriber = pipeline('automatic-speech-recognition', 'Xenova/whisper-base');
1617
//$transcriber = pipeline('automatic-speech-recognition', 'Xenova/wav2vec2-large-xlsr-53-english');
1718

1819
$audioUrl = __DIR__ . '/../sounds/kyrian-dev.wav';
1920
//$audioUrl = __DIR__ . '/../sounds/jfk.wav';
2021
//$audioUrl = __DIR__ . '/../sounds/preamble.wav';
2122
//$audioUrl = __DIR__ . '/../sounds/taunt.wav';
2223
//$audioUrl = __DIR__ . '/../sounds/gettysburg.wav';
23-
//$audioUrl = __DIR__ . '/../sounds/kyrian-speaking-30.wav';
2424
//$audioUrl = __DIR__ . '/../sounds/kyrian-speaking.wav';
25-
//$audioUrl = __DIR__ . '/../sounds/kyrian-speaking2.wav';
26-
//$audioUrl = __DIR__ . '/../sounds/dataset1.wav';
25+
//$audioUrl = __DIR__ . '/../sounds/ted_60.wav';
26+
$audioUrl = __DIR__ . '/../sounds/french-audio.wav';
2727

2828
$streamer = StdOutStreamer::make();
29-
$output = $transcriber($audioUrl, maxNewTokens: 256, chunkLengthSecs: 20, returnTimestamps: 'word');
29+
$output = $transcriber($audioUrl,
30+
maxNewTokens: 256,
31+
chunkLengthSecs: 24,
32+
task: 'translate'
33+
// returnTimestamps: true,
34+
// streamer: $streamer
35+
);
3036

3137
dd($output, timeUsage(), memoryUsage());

examples/sounds/french-audio.mp3

97.6 KB
Binary file not shown.

examples/sounds/french-audio.wav

538 KB
Binary file not shown.

examples/sounds/kyrian-dev.wav

616 KB
Binary file not shown.
4.97 MB
Binary file not shown.

examples/sounds/ted_60.wav

11 MB
Binary file not shown.

src/Generation/LogitsProcessors/ForceTokensLogitsProcessor.php

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
namespace Codewithkyrian\Transformers\Generation\LogitsProcessors;
77

88
use Codewithkyrian\Transformers\Tensor\Tensor;
9+
use function Codewithkyrian\Transformers\Utils\timeUsage;
910

1011
class ForceTokensLogitsProcessor extends LogitsProcessor
1112
{
@@ -15,9 +16,7 @@ class ForceTokensLogitsProcessor extends LogitsProcessor
1516

1617
public function __construct(array $forcedDecoderIds)
1718
{
18-
foreach ($forcedDecoderIds[0] as $inputLength => $forcedId) {
19-
$this->forceTokenMap[$inputLength] = $forcedId;
20-
}
19+
$this->forceTokenMap = array_column($forcedDecoderIds, 1, 0);
2120
}
2221

2322
/**

src/Generation/LogitsProcessors/WhisperTimeStampLogitsProcessor.php

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
use Codewithkyrian\Transformers\Tensor\Tensor;
99
use Codewithkyrian\Transformers\Utils\GenerationConfig;
10+
use function Codewithkyrian\Transformers\Utils\timeUsage;
1011

1112
class WhisperTimeStampLogitsProcessor extends LogitsProcessor
1213
{
@@ -40,7 +41,7 @@ class WhisperTimeStampLogitsProcessor extends LogitsProcessor
4041
*/
4142
public function __construct(GenerationConfig $generateConfig)
4243
{
43-
$this->eosTokenId = $generateConfig->eos_token_id;
44+
$this->eosTokenId = $generateConfig['eos_token_id'];
4445
$this->noTimestampsTokenId = $generateConfig['no_timestamps_token_id'];
4546
$this->timestampBegin = $this->noTimestampsTokenId + 1;
4647

@@ -79,11 +80,11 @@ public function __invoke(array $inputIds, Tensor $logits): Tensor
7980
if ($lastWasTimestamp) {
8081
if ($penultimateWasTimestamp) { // has to be non-timestamp
8182
for ($i = $this->timestampBegin; $i < $logits->size(); $i++) {
82-
$logitsData[$i] = -INF;
83+
$logits->buffer()[$i] = -INF;
8384
}
8485
} else { // cannot be normal text tokens
8586
for ($i = 0; $i < $this->eosTokenId; $i++) {
86-
$logitsData[$i] = -INF;
87+
$logits->buffer()[$i] = -INF;
8788
}
8889
}
8990
}
@@ -92,19 +93,19 @@ public function __invoke(array $inputIds, Tensor $logits): Tensor
9293
if (count($inputIds) === $this->beginIndex && $this->maxInitialTimestampIndex !== null) {
9394
$lastAllowed = $this->timestampBegin + $this->maxInitialTimestampIndex;
9495
for ($i = $lastAllowed + 1; $i < $logits->size(); $i++) {
95-
$logitsData[$i] = -INF;
96+
$logits->buffer()[$i] = -INF;
9697
}
9798
}
9899

99100
// if sum of probability over timestamps is above any other token, sample timestamp
100101
$logProbs = $logits->softmax()->log();
101-
$a = $logProbs->sliceWithBounds([0, $this->timestampBegin], [1, $logProbs->size() - $this->timestampBegin]);
102-
$timestampLogProb = log($a->exp()->sum());
102+
$timestampProbs = $logProbs->sliceWithBounds([0, $this->timestampBegin], [1, $logProbs->size() - $this->timestampBegin]);
103+
$timestampLogProb = log($timestampProbs->exp()->sum());
103104
$maxTextTokenLogProb = $logProbs->sliceWithBounds([0, 0], [1, $this->timestampBegin])->max();
104105

105106
if ($timestampLogProb > $maxTextTokenLogProb) {
106107
for ($i = 0; $i < $this->timestampBegin; $i++) {
107-
$logitsData[$i] = -INF;
108+
$logits->buffer()[$i] = -INF;
108109
}
109110
}
110111

src/Generation/Streamers/TextStreamer.php

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,11 @@ public function put(mixed $value): void
148148

149149
$tokensToDecode = array_slice($value[0]['output_token_ids'], $this->lastDecodedCheckpointForToken);
150150

151+
if (empty($tokensToDecode))
152+
{
153+
return;
154+
}
155+
151156
$decodedText = $this->tokenizer->decode($tokensToDecode, skipSpecialTokens: true);
152157

153158
// Check for punctuation marks indicating the end of a word or sentence

src/Models/Pretrained/WhisperForConditionalGeneration.php

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,8 @@ public function generate(
6060
// Whisper has additional options for returning timestamps
6161
$generationConfig['return_timestamps'] ??= false;
6262

63-
6463
if ($generationConfig['return_timestamps']) {
65-
$logitsProcessor = new LogitsProcessorList();
64+
$logitsProcessor ??= new LogitsProcessorList();
6665
$logitsProcessor->push(new WhisperTimeStampLogitsProcessor($generationConfig));
6766
}
6867

@@ -83,14 +82,13 @@ public function generate(
8382
}
8483
}
8584

86-
87-
$outputs = parent::generate($inputs, $generationConfig, $logitsProcessor, $inputsAttentionMask, $streamer);
85+
$outputs = parent::generate($inputs, $generationConfig, $logitsProcessor, streamer: $streamer);
8886

8987
if (isset($generationConfig['return_token_timestamps']) && isset($generationConfig['alignment_heads'])) {
9088
$outputs['token_timestamps'] = $this->extractTokenTimestamps(
9189
$outputs,
9290
$generationConfig['alignment_heads'],
93-
(int)$generationConfig['num_frames'] ?? null,
91+
$generationConfig['num_frames'] ?? null,
9492
);
9593
}
9694

@@ -109,10 +107,10 @@ public function generate(
109107
* @throws Exception If the model outputs do not contain cross attentions
110108
*/
111109
public function extractTokenTimestamps(
112-
array $generateOutputs,
113-
array $alignmentHeads,
110+
array $generateOutputs,
111+
array $alignmentHeads,
114112
int|null $numFrames = null,
115-
float $timePrecision = 0.02
113+
float $timePrecision = 0.02
116114
): Tensor
117115
{
118116
if (!isset($generateOutputs['cross_attentions'])) {
@@ -128,6 +126,7 @@ public function extractTokenTimestamps(
128126
$medianFilterWidth = 7;
129127
}
130128

129+
131130
$batchedMatrices = array_map(function ($batch) use ($numFrames, $alignmentHeads, $medianFilterWidth) {
132131
// Create a list with `decoder_layers` elements, each a tensor of shape
133132
// (batch size, attention_heads, output length, input length).
@@ -164,13 +163,18 @@ public function extractTokenTimestamps(
164163
/** @var Tensor $cTensor */
165164
$cTensor = $bTensor[$c]; // [1500]
166165

167-
$cTensor
168-
->add($meanTensor->multiply(-1))
169-
->multiply($stdTensor->reciprocal())
170-
->copyTo($cTensor);
166+
for ($d = 0; $d < $cTensor->count(); ++$d) {
167+
$cTensor[$d] = ($cTensor[$d] - $meanTensor[$d]) / $stdTensor[$d];
168+
}
171169

172170
// Apply median filter.
173171
$this->medianFilter($cTensor, $medianFilterWidth)->copyTo($cTensor);
172+
// $filtered = $this->medianFilter($cTensor, $medianFilterWidth);
173+
// for ($e = 0; $e < $filtered->count(); ++$e) {
174+
// $cTensor[$e] = $filtered[$e];
175+
// }
176+
177+
174178
}
175179
}
176180
}
@@ -181,7 +185,6 @@ public function extractTokenTimestamps(
181185

182186
$timestampsShape = [count($generateOutputs['sequences']), count($generateOutputs['sequences'][0])];
183187

184-
185188
$timestamps = Tensor::zeros($timestampsShape, Tensor::float32);
186189

187190
// Perform dynamic time warping on each element of the batch.
@@ -194,14 +197,13 @@ public function extractTokenTimestamps(
194197
$diffs = array_map(fn($i) => $textIndices[$i + 1] - $textIndices[$i], range(0, count($textIndices) - 2));
195198
$jumps = array_map(fn($x) => (bool)$x, array_merge([1], $diffs));
196199

197-
dd($timeIndices);
198200
$jumpTimes = [];
199201
for ($i = 0; $i < count($jumps); ++$i) {
200202
if ($jumps[$i]) {
201203
$jumpTimes[] = $timeIndices[$i] * $timePrecision;
202204
}
203205
}
204-
dd($jumpTimes);
206+
205207
for ($i = 1; $i < count($jumpTimes); ++$i) {
206208
$timestamps[$batchIdx][$i] = $jumpTimes[$i];
207209
}
@@ -210,38 +212,54 @@ public function extractTokenTimestamps(
210212
return $timestamps;
211213
}
212214

213-
function medianFilter(Tensor $tensor, int $windowSize): Tensor
215+
/**
216+
* Applies a median filter of width `$windowSize` along the last dimension of the input.
217+
*
218+
* The `$input` tensor is assumed to be 3- or 4-dimensional.
219+
* @param Tensor $input
220+
* @param int $windowSize
221+
* @return Tensor
222+
*/
223+
function medianFilter(Tensor $input, int $windowSize): Tensor
214224
{
215225
if ($windowSize % 2 === 0 || $windowSize <= 0) {
216226
throw new InvalidArgumentException('Window size must be a positive odd number');
217227
}
218228

219-
$outputArray = array_fill(0, count($tensor), 0);
229+
$output = Tensor::fill($input->shape(), 0, $input->dtype());
220230
$buffer = array_fill(0, $windowSize, 0);
221231

222232
$halfWindowSize = (int)floor($windowSize / 2);
223233

224-
for ($i = 0; $i < count($tensor); ++$i) {
234+
for ($i = 0; $i < count($input); ++$i) {
225235
$valuesIndex = 0;
226236

227237
for ($j = -$halfWindowSize; $j <= $halfWindowSize; ++$j) {
228238
$index = $i + $j;
229239
if ($index < 0) {
230240
$index = abs($index);
231-
} else if ($index >= count($tensor)) {
232-
$index = 2 * (count($tensor) - 1) - $index;
241+
} else if ($index >= count($input)) {
242+
$index = 2 * (count($input) - 1) - $index;
233243
}
234244

235-
$buffer[$valuesIndex++] = $tensor->buffer()[$index];
245+
$buffer[$valuesIndex++] = $input[$index];
236246
}
237247

238248
sort($buffer);
239-
$outputArray[$i] = $buffer[$halfWindowSize];
249+
250+
$output->buffer()[$i] = $buffer[$halfWindowSize];
240251
}
241252

242-
return Tensor::fromArray($outputArray, $tensor->dtype());
253+
return $output;
243254
}
244255

256+
/**
257+
* Measures
258+
* similarity between two temporal sequences: the input audio and the output tokens. Used to generate
259+
* token-level timestamps.
260+
* @param Tensor $tensor
261+
* @return array
262+
*/
245263
private function dynamicTimeWarping(Tensor $tensor): array
246264
{
247265
[$outputLength, $inputLength] = $tensor->shape();

0 commit comments

Comments
 (0)