Skip to content

Commit eb342ec

Browse files
Refactor Tensor slicing and unsqueeze for speed, improve generation config serialization, new Tensor methods - sum and maximum, improve whisper
1 parent 68436d2 commit eb342ec

12 files changed

Lines changed: 125 additions & 113 deletions

File tree

examples/pipelines/asr.php

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818
$audioUrl = __DIR__ . '/../sounds/preamble.wav';
1919
$audioUrl = __DIR__ . '/../sounds/taunt.wav';
2020
$audioUrl = __DIR__ . '/../sounds/gettysburg.wav';
21+
$audioUrl = __DIR__ . '/../sounds/kyrian-speaking-30.wav';
22+
$audioUrl = __DIR__ . '/../sounds/kyrian-speaking.wav';
23+
$audioUrl = __DIR__ . '/../sounds/dataset1.wav';
2124

2225
$streamer = StdOutStreamer::make();
23-
$output = $transcriber($audioUrl, maxNewTokens: 256, streamer: $streamer);
26+
$output = $transcriber($audioUrl, maxNewTokens: 256, returnTimestamps: 'word');
2427

25-
dd( timeUsage(), memoryUsage());
28+
dd($output, timeUsage(), memoryUsage());

src/FeatureExtractors/WhisperFeatureExtractor.php

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,13 @@ public function __invoke(Tensor $waveform): array
4242
'If using a pipeline to extract transcript from a long audio clip,' .
4343
'remember to specify `chunkLengthSecs` and/or `strideLengthSecs` in the pipeline options.', E_USER_WARNING);
4444

45-
$waveform = $waveform->slice(0, $this->config['n_samples']);
46-
} else {
47-
$padding = $this->config['n_samples'] - $waveform->size();
48-
// create a new Tensor with the same data type as the input waveform
49-
$padding = Tensor::zeros([$padding], dtype: $waveform->dtype());
45+
$waveform = $waveform->sliceWithBounds([0], [$this->config['n_samples']]);
46+
} else if ($waveform->size() < $this->config['n_samples']) {
47+
$padLength = $this->config['n_samples'] - $waveform->size();
48+
$padding = Tensor::zeros([$padLength], dtype: $waveform->dtype());
5049
$waveform = Tensor::concat([$waveform, $padding]);
5150
}
5251

53-
timeUsage();
5452
$features = Audio::spectrogram(
5553
$waveform,
5654
$this->window,
@@ -59,13 +57,15 @@ public function __invoke(Tensor $waveform): array
5957
power: 2.0,
6058
melFilters: $this->config['mel_filters'],
6159
logMel: 'log10',
62-
6360
maxNumFrames: $this->config['nb_max_frames'],
6461
);
6562

6663
$maxValue = $features->max();
6764

68-
$features->u(fn($x) => (max($x, $maxValue - 8.0) + 4.0) / 4.0);
65+
$features = $features
66+
->maximum($maxValue - 8.0)
67+
->add(4.0)
68+
->multiply(1.0 / 4.0);
6969

7070
return [
7171
'input_features' => $features->unsqueeze(0)

src/Generation/LogitsProcessors/LogitsProcessorList.php

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,6 @@ public function extend(traversable $items): void
4343
*/
4444
public function __invoke(array $inputIds, Tensor &$batchedLogits): void
4545
{
46-
// foreach ($batchedLogits as &$logits) {
47-
// foreach ($this->processors as $processor) {
48-
// $processor($inputIds, $logits); // Apply processors in-place
49-
// }
50-
// }
5146
for ($i = 0; $i < count($batchedLogits); $i++) {
5247
foreach ($this->processors as $processor) {
5348
$processor($inputIds, $batchedLogits[$i]); // Apply processors in-place
@@ -62,7 +57,6 @@ public function __invoke(array $inputIds, Tensor &$batchedLogits): void
6257
*/
6358
public function getIterator(): Traversable
6459
{
65-
// return new \ArrayIterator($this->processors);
6660
yield from $this->processors;
6761
}
6862
}

src/Generation/LogitsProcessors/WhisperTimeStampLogitsProcessor.php

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,12 @@ public function __construct(GenerationConfig $generateConfig)
4545
$this->timestampBegin = $this->noTimestampsTokenId + 1;
4646

4747
$this->beginIndex = count($generateConfig['forced_decoder_ids'] ?? []) + 2;
48-
if (end($generateConfig['forced_decoder_ids'])[1] === $this->noTimestampsTokenId) {
48+
49+
$forcedDecoderIds = $generateConfig['forced_decoder_ids'] ?? [];
50+
if (count($forcedDecoderIds) > 0 && end($forcedDecoderIds)[1] === $this->noTimestampsTokenId) {
4951
$this->beginIndex -= 1;
5052
}
53+
5154
$this->maxInitialTimestampIndex = $generateConfig['max_initial_timestamp_index'] ?? null;
5255
}
5356

@@ -94,10 +97,10 @@ public function __invoke(array $inputIds, Tensor $logits): Tensor
9497
}
9598

9699
// if sum of probability over timestamps is above any other token, sample timestamp
97-
// $logProbs = log_softmax($logitsData);
98100
$logProbs = $logits->softmax()->log();
99-
$timestampLogProb = log(array_sum(array_map('exp', array_slice($logProbs, $this->timestampBegin))));
100-
$maxTextTokenLogProb = max(array_slice($logProbs, 0, $this->timestampBegin));
101+
$a = $logProbs->sliceWithBounds([0, $this->timestampBegin], [1, $logProbs->size() - $this->timestampBegin]);
102+
$timestampLogProb = log($a->exp()->sum());
103+
$maxTextTokenLogProb = $logProbs->sliceWithBounds([0, 0], [1, $this->timestampBegin])->max();
101104

102105
if ($timestampLogProb > $maxTextTokenLogProb) {
103106
for ($i = 0; $i < $this->timestampBegin; $i++) {

src/Generation/Samplers/Sampler.php

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,31 +46,24 @@ abstract public function sample(Tensor $logits, int $index);
4646
*/
4747
public function getLogits(Tensor $logits, int $index): Tensor
4848
{
49-
$vocabSize = $logits->shape()[$logits->ndim() - 1];
49+
// $vocabSize = $logits->shape()[$logits->ndim() - 1];
5050

51-
// $logs = $logits->buffer()->toArray();
51+
// $start = array_fill(0, $logits->ndim(), 0);
52+
// $size = array_fill(0, $logits->ndim(), 1);
5253
//
53-
// if ($index === -1) {
54-
// $logs = array_slice($logs, -$vocabSize);
55-
// } else {
56-
// $startIndex = $index * $vocabSize;
57-
// $logs = array_slice($logs, $startIndex, $startIndex + $vocabSize);
58-
// }
59-
60-
$start = array_fill(0, $logits->ndim(), 0);
61-
$size = array_fill(0, $logits->ndim(), 1);
62-
63-
array_splice($start, -2, replacement: [$index, 0]);
64-
array_splice($size, -2, replacement: [1, $vocabSize]);
54+
// array_splice($start, -2, replacement: [$index, 0]);
55+
// array_splice($size, -2, replacement: [1, $vocabSize]);
56+
//
57+
// $logs = $logits->sliceWithBounds($start, $size);
6558

66-
$logs = $logits->newSlice($start, $size);
59+
$logits = $logits->slice($index);
6760

6861
if ($this->generationConfig->temperature > 0) {
69-
$logs = $logs->multiply(1 / $this->generationConfig->temperature);
62+
$logits = $logits->multiply(1 / $this->generationConfig->temperature);
7063
}
7164

7265
// Remove all dimensions of 1, leaving a flat 1D array of vocab_size
73-
return $logs->squeeze();
66+
return $logits->squeeze();
7467
}
7568

7669
/**

src/Models/Pretrained/PretrainedModel.php

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,6 @@ protected function getGenerationConfig(?GenerationConfig $generationConfig): Gen
770770
$genConfigArray = array_merge($genConfigArray, $this->generationConfig->toArray());
771771
}
772772

773-
774773
// Finally, use any generation config specified by the user
775774
// when calling `generate`
776775
if ($generationConfig !== null) {

src/Models/Pretrained/WhisperForConditionalGeneration.php

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,19 @@ public function generate(
5959
// Whisper has additional options for returning timestamps
6060
$generationConfig['return_timestamps'] ??= false;
6161

62+
6263
if ($generationConfig['return_timestamps']) {
63-
$logitsProcessor = [new WhisperTimeStampLogitsProcessor($generationConfig)];
64+
$logitsProcessor = new LogitsProcessorList();
65+
$logitsProcessor->push(new WhisperTimeStampLogitsProcessor($generationConfig));
6466
}
6567

68+
69+
6670
if (isset($generationConfig['return_token_timestamps'])) {
67-
$generationConfig->output_attentions = true;
68-
$generationConfig->return_dict_in_generate = true;
71+
$generationConfig['output_attentions'] = true;
72+
$generationConfig['return_dict_in_generate'] = true;
6973

70-
if ($generationConfig['task'] === 'translate') {
74+
if ($generationConfig['task'] ?? '' === 'translate') {
7175
trigger_error("Token-level timestamps may not be reliable for task 'translate'.", E_USER_WARNING);
7276
}
7377

@@ -79,13 +83,14 @@ public function generate(
7983
}
8084
}
8185

86+
8287
$outputs = parent::generate($inputs, $generationConfig, $logitsProcessor, $inputsAttentionMask, $streamer);
8388

8489
if (isset($generationConfig['return_token_timestamps']) && isset($generationConfig['alignment_heads'])) {
8590
$outputs['token_timestamps'] = $this->extractTokenTimestamps(
8691
$outputs,
8792
$generationConfig['alignment_heads'],
88-
$generationConfig['num_frames']
93+
$generationConfig['num_frames'] ?? null,
8994
);
9095
}
9196

@@ -106,9 +111,10 @@ public function generate(
106111
public function extractTokenTimestamps(
107112
array $generateOutputs,
108113
array $alignmentHeads,
109-
?int $numFrames = null,
114+
int|float|null $numFrames = null,
110115
float $timePrecision = 0.02
111116
): Tensor {
117+
$numFrames = (int) $numFrames;
112118
if (!isset($generateOutputs['cross_attentions'])) {
113119
throw new Exception(
114120
"Model outputs must contain cross attentions to extract timestamps. " .
@@ -125,18 +131,22 @@ public function extractTokenTimestamps(
125131
$batchedMatrices = array_map(function($batch) use ($numFrames, $alignmentHeads, $medianFilterWidth) {
126132
// Create a list with `decoder_layers` elements, each a tensor of shape
127133
// (batch size, attention_heads, output length, input length).
134+
/** @var Tensor[] $crossAttentions */
128135
$crossAttentions = [];
129136
for ($i = 0; $i < $this->config['decoder_layers']; $i++) {
130-
$crossAttentions[] = cat(array_map(fn($x) => $x[$i], $batch), 2);
137+
$crossAttentions[] = Tensor::concat(array_map(fn($x) => $x[$i], $batch), 2);
131138
}
132139

133-
$weights = stack(array_map(function($alignmentHead) use ($crossAttentions, $numFrames) {
134-
list($l, $h) = $alignmentHead;
140+
$weights = Tensor::stack(array_map(function($alignmentHead) use ($crossAttentions, $numFrames) {
141+
[$l, $h] = $alignmentHead;
135142
return $numFrames
136143
? $crossAttentions[$l]->slice(null, $h, null, [0, $numFrames])
137144
: $crossAttentions[$l]->slice(null, $h);
138145
}, $alignmentHeads));
139-
$weights = $weights->transpose(1, 0, 2, 3);
146+
dd($weights->shape());
147+
148+
$weights = $weights->permute( 1, 0, 2, 3);
149+
140150

141151
list($std, $calculatedMean) = std_mean($weights, -2, 0, true);
142152

src/Pipelines/AutomaticSpeechRecognitionPipeline.php

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ private function __invokeWhisper(array|string $inputs, ...$args): array|Tensor|I
105105
$strideLengthSecs = $args['strideLengthSecs'] ?? null;
106106

107107
if ($returnTimestamps == 'word') {
108-
$args['returnTimestamps'] = true;
108+
$args['return_token_timestamps'] = true;
109109
}
110110

111111
$language = array_pop_key($args, 'language');
@@ -150,7 +150,6 @@ private function __invokeWhisper(array|string $inputs, ...$args): array|Tensor|I
150150

151151
$chunks = [];
152152

153-
154153
if ($chunkLengthSecs > 0) {
155154

156155
if ($strideLengthSecs === null) {
@@ -164,12 +163,18 @@ private function __invokeWhisper(array|string $inputs, ...$args): array|Tensor|I
164163
$jump = $window - 2 * $stride;
165164
$offset = 0;
166165

166+
167167
while ($offset < $audioTensor->size()) {
168-
$subAudio = $audioTensor->slice($offset, $offset + $window);
168+
169+
if ($offset + $window > $audioTensor->size()) {
170+
$window = $audioTensor->size() - $offset;
171+
}
172+
173+
$subAudio = $audioTensor->sliceWithBounds([$offset], [$window]);
169174
$feature = ($this->processor)($subAudio);
170175

171176
$isFirstChunk = $offset === 0;
172-
$isLastChunk = $offset + $window >= $audioTensor->size();
177+
$isLastChunk = $offset + $jump >= $audioTensor->size();
173178

174179
$chunks[] = [
175180
'stride' => [
@@ -194,7 +199,6 @@ private function __invokeWhisper(array|string $inputs, ...$args): array|Tensor|I
194199

195200
}
196201

197-
198202
// Generate for each set of input features
199203
foreach ($chunks as &$chunk) {
200204
$generationConfig['num_frames'] = floor($chunk['stride'][0] / $hopLength);
@@ -203,7 +207,7 @@ private function __invokeWhisper(array|string $inputs, ...$args): array|Tensor|I
203207
$data = $this->model->generate($chunk['input_features'], generationConfig: $generationConfig, streamer: $streamer);
204208

205209
// TODO: Right now we only get top beam
206-
if ($returnTimestamps == 'word') {
210+
if ($returnTimestamps === 'word') {
207211
$chunk['tokens'] = $data['sequences'][0];
208212
$chunk['token_timestamps'] = array_map(fn($x) => round($x, 2), $data['token_timestamps'][0]);
209213
} else {

src/PretrainedTokenizers/PretrainedTokenizer.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ public function batchDecode(array $batch, bool $skipSpecialTokens = false, ?bool
539539
public function decode(array $tokenIds, bool $skipSpecialTokens = false, ?bool $cleanUpTokenizationSpaces = null): string
540540
{
541541
if (empty($tokenIds) || !is_int($tokenIds[0])) {
542+
dd($tokenIds);
542543
throw new Exception("token_ids must be a non-empty array of integers.");
543544
}
544545

src/PretrainedTokenizers/WhisperTokenizer.php

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ public function decodeASR(
157157
bool $forceFullSequences = true
158158
): array
159159
{
160-
// Set force_full_sequences=false if you want streaming
160+
// Set forceFullSequences=false if you want streaming
161161
// TODO add support for `returnLanguage`
162162

163163
// Internal method meant to only be used by ASR pipeline.
@@ -323,6 +323,7 @@ public function decodeASR(
323323
}
324324
}
325325

326+
// dump($this->decode($currentTokens), empty($previousTokens) ? '': $this->decode($previousTokens[0]));
326327
if (isset($output['stride'])) {
327328
[$chunkLen, $strideLeft, $strideRight] = $output['stride'];
328329
$timeOffset += $chunkLen - $strideRight;
@@ -418,9 +419,11 @@ private function findLongestCommonSequence(array $sequences, array $tokenTimesta
418419
$rightSequence = $sequences[$i];
419420
$max = 0.0;
420421
$maxIndices = [$leftLength, $leftLength, 0, 0];
422+
// dd($this->decode($leftSequence), $this->decode($rightSequence));
421423

422424
$rightLength = count($rightSequence);
423425
for ($j = 1; $j < $leftLength + $rightLength; ++$j) {
426+
// epsilon to favor long perfect matches
424427
$eps = $j / 10000.0;
425428
$leftStart = max(0, $leftLength - $j);
426429
$leftStop = min($leftLength, $leftLength + $rightLength - $j);
@@ -430,10 +433,13 @@ private function findLongestCommonSequence(array $sequences, array $tokenTimesta
430433
$right = array_slice($rightSequence, $rightStart, $rightStop - $rightStart);
431434

432435
if (count($left) !== count($right)) {
433-
throw new Exception("There is a bug within whisper `decode_asr` function, please report it. Dropping to prevent bad inference.");
436+
throw new Exception("There is a bug within whisper `decodeASR` function, please report it. Dropping to prevent bad inference.");
434437
}
435438

436-
$matches = count(array_filter(array_map(fn($elem, $idx) => $elem === $right[$idx], $left, array_keys($left))));
439+
$matches = count(array_filter(
440+
array_map(fn($elem, $idx) => $elem === $right[$idx], $left, array_keys($left))
441+
)
442+
);
437443

438444
$matching = $matches / $j + $eps;
439445
if ($matches > 1 && $matching > $max) {
@@ -443,8 +449,8 @@ private function findLongestCommonSequence(array $sequences, array $tokenTimesta
443449
}
444450

445451
[$leftStart, $leftStop, $rightStart, $rightStop] = $maxIndices;
446-
$leftMid = intval(($leftStop + $leftStart) / 2);
447-
$rightMid = intval(($rightStop + $rightStart) / 2);
452+
$leftMid = (int)floor(($leftStop + $leftStart) / 2);
453+
$rightMid = (int)floor(($rightStop + $rightStart) / 2);
448454
$totalSequence = array_merge($totalSequence, array_slice($leftSequence, 0, $leftMid));
449455
$leftSequence = array_slice($rightSequence, $rightMid);
450456
$leftLength = count($leftSequence);

0 commit comments

Comments
 (0)