Skip to content

Commit 0f9610a

Browse files
Return old tensor slice for dimensions > 3
1 parent eb342ec commit 0f9610a

3 files changed

Lines changed: 162 additions & 30 deletions

File tree

examples/pipelines/asr.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
$audioUrl = __DIR__ . '/../sounds/gettysburg.wav';
2121
$audioUrl = __DIR__ . '/../sounds/kyrian-speaking-30.wav';
2222
$audioUrl = __DIR__ . '/../sounds/kyrian-speaking.wav';
23-
$audioUrl = __DIR__ . '/../sounds/dataset1.wav';
23+
//$audioUrl = __DIR__ . '/../sounds/dataset1.wav';
2424

2525
$streamer = StdOutStreamer::make();
26-
$output = $transcriber($audioUrl, maxNewTokens: 256, returnTimestamps: 'word');
26+
$output = $transcriber($audioUrl, maxNewTokens: 256, chunkLengthSecs: 30, strideLengthSecs: 6);
2727

2828
dd($output, timeUsage(), memoryUsage());

src/Models/Pretrained/WhisperForConditionalGeneration.php

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
use Codewithkyrian\Transformers\Utils\AutoConfig;
1515
use Codewithkyrian\Transformers\Utils\GenerationConfig;
1616
use Exception;
17+
use InvalidArgumentException;
1718

1819
class WhisperForConditionalGeneration extends WhisperPretrainedModel
1920
{
@@ -66,7 +67,6 @@ public function generate(
6667
}
6768

6869

69-
7070
if (isset($generationConfig['return_token_timestamps'])) {
7171
$generationConfig['output_attentions'] = true;
7272
$generationConfig['return_dict_in_generate'] = true;
@@ -109,12 +109,13 @@ public function generate(
109109
* @throws Exception If the model outputs do not contain cross attentions
110110
*/
111111
public function extractTokenTimestamps(
112-
array $generateOutputs,
113-
array $alignmentHeads,
112+
array $generateOutputs,
113+
array $alignmentHeads,
114114
int|float|null $numFrames = null,
115-
float $timePrecision = 0.02
116-
): Tensor {
117-
$numFrames = (int) $numFrames;
115+
float $timePrecision = 0.02
116+
): Tensor
117+
{
118+
$numFrames = (int)$numFrames;
118119
if (!isset($generateOutputs['cross_attentions'])) {
119120
throw new Exception(
120121
"Model outputs must contain cross attentions to extract timestamps. " .
@@ -128,7 +129,7 @@ public function extractTokenTimestamps(
128129
$medianFilterWidth = 7;
129130
}
130131

131-
$batchedMatrices = array_map(function($batch) use ($numFrames, $alignmentHeads, $medianFilterWidth) {
132+
$batchedMatrices = array_map(function ($batch) use ($numFrames, $alignmentHeads, $medianFilterWidth) {
132133
// Create a list with `decoder_layers` elements, each a tensor of shape
133134
// (batch size, attention_heads, output length, input length).
134135
/** @var Tensor[] $crossAttentions */
@@ -137,61 +138,61 @@ public function extractTokenTimestamps(
137138
$crossAttentions[] = Tensor::concat(array_map(fn($x) => $x[$i], $batch), 2);
138139
}
139140

140-
$weights = Tensor::stack(array_map(function($alignmentHead) use ($crossAttentions, $numFrames) {
141+
$weights = Tensor::stack(array_map(function ($alignmentHead) use ($crossAttentions, $numFrames) {
141142
[$l, $h] = $alignmentHead;
142143
return $numFrames
143-
? $crossAttentions[$l]->slice(null, $h, null, [0, $numFrames])
144-
: $crossAttentions[$l]->slice(null, $h);
144+
? $crossAttentions[$l]->slice(null, $h, null, [0, $numFrames])->squeeze(1)
145+
: $crossAttentions[$l]->slice(null, $h)->squeeze(1); // experimental
145146
}, $alignmentHeads));
146-
dd($weights->shape());
147-
148-
$weights = $weights->permute( 1, 0, 2, 3);
149147

148+
$weights = $weights->permute(1, 0, 2, 3);
150149

151-
list($std, $calculatedMean) = std_mean($weights, -2, 0, true);
150+
[$std, $calculatedMean] = $weights->stdMean(-2, 0, true);
152151

153152
// Normalize and smoothen the weights.
154-
$smoothedWeights = $weights->clone(); // [1, 8, seqLength, 1500]
153+
$smoothedWeights = clone $weights; // [1, 8, seqLength, 1500]
155154

156-
for ($a = 0; $a < $smoothedWeights->dims[0]; ++$a) {
155+
for ($a = 0; $a < $smoothedWeights->shape()[0]; ++$a) {
157156
$aTensor = $smoothedWeights[$a]; // [8, seqLength, 1500]
158157

159-
for ($b = 0; $b < $aTensor->dims[0]; ++$b) {
158+
for ($b = 0; $b < $aTensor->shape()[0]; ++$b) {
160159
$bTensor = $aTensor[$b]; // [seqLength, 1500]
161160

162161
$stdTensor = $std[$a][$b][0]; // [1500]
163162
$meanTensor = $calculatedMean[$a][$b][0]; // [1500]
164163

165-
for ($c = 0; $c < $bTensor->dims[0]; ++$c) {
164+
for ($c = 0; $c < $bTensor->shape()[0]; ++$c) {
165+
/** @var Tensor $cTensor */
166166
$cTensor = $bTensor[$c]; // [1500]
167-
for ($d = 0; $d < count($cTensor->data); ++$d) {
168-
$cTensor->data[$d] = ($cTensor->data[$d] - $meanTensor->data[$d]) / $stdTensor->data[$d];
169-
}
167+
// for ($d = 0; $d < count($cTensor->buffer()); ++$d) {
168+
// $cTensor->buffer()[$d] = ($cTensor->buffer()[$d] - $meanTensor->buffer()[$d]) / $stdTensor->buffer()[$d];
169+
// }
170+
$cTensor = $cTensor->add($meanTensor->multiply(-1))->multiply($stdTensor->reciprocal());
170171

171172
// Apply median filter.
172-
$cTensor->data = medianFilter($cTensor->data, $medianFilterWidth);
173+
$cTensor = $this->medianFilter($cTensor, $medianFilterWidth);
173174
}
174175
}
175176
}
176177

177178
// Average the different cross-attention heads.
178-
$matrix = mean($smoothedWeights, 1);
179-
return $matrix;
179+
return $smoothedWeights->mean(1);
180180
}, $generateOutputs['cross_attentions']);
181181

182182
$timestampsShape = [count($generateOutputs['sequences']), count($generateOutputs['sequences'][0])];
183183

184+
184185
$timestamps = new Tensor(null, Tensor::float32, $timestampsShape);
185186

186187
// Perform dynamic time warping on each element of the batch.
187188
for ($batchIdx = 0; $batchIdx < $timestampsShape[0]; ++$batchIdx) {
188189
// NOTE: Since we run only one batch at a time, we can squeeze to get the same dimensions
189190
// as the python implementation
190-
$matrix = $batchedMatrices[$batchIdx]->neg()->squeeze_(0);
191+
$matrix = $batchedMatrices[$batchIdx]->multiply(-1)->squeeze(0);
191192
list($textIndices, $timeIndices) = dynamicTimeWarping($matrix);
192193

193194
$diffs = array_map(fn($i) => $textIndices[$i + 1] - $textIndices[$i], range(0, count($textIndices) - 2));
194-
$jumps = array_map(fn($x) => (bool) $x, array_merge([1], $diffs));
195+
$jumps = array_map(fn($x) => (bool)$x, array_merge([1], $diffs));
195196

196197
$jumpTimes = [];
197198
for ($i = 0; $i < count($jumps); ++$i) {
@@ -206,4 +207,36 @@ public function extractTokenTimestamps(
206207
return $timestamps;
207208
}
208209

210+
function medianFilter(Tensor $tensor, int $windowSize): Tensor
211+
{
212+
if ($windowSize % 2 === 0 || $windowSize <= 0) {
213+
throw new InvalidArgumentException('Window size must be a positive odd number');
214+
}
215+
216+
$outputArray = array_fill(0, count($tensor), 0);
217+
$buffer = array_fill(0, $windowSize, 0);
218+
219+
$halfWindowSize = (int) floor($windowSize / 2);
220+
221+
for ($i = 0; $i < count($tensor); ++$i) {
222+
$valuesIndex = 0;
223+
224+
for ($j = -$halfWindowSize; $j <= $halfWindowSize; ++$j) {
225+
$index = $i + $j;
226+
if ($index < 0) {
227+
$index = abs($index);
228+
} else if ($index >= count($tensor)) {
229+
$index = 2 * (count($tensor) - 1) - $index;
230+
}
231+
232+
$buffer[$valuesIndex++] = $tensor->buffer()[$index];
233+
}
234+
235+
sort($buffer);
236+
$outputArray[$i] = $buffer[$halfWindowSize];
237+
}
238+
239+
return Tensor::fromArray($outputArray, $tensor->dtype());
240+
}
241+
209242
}

src/Tensor/Tensor.php

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,15 @@ public function transpose(): self
727727
return new static($ndArray->buffer(), $ndArray->dtype(), $ndArray->shape(), $ndArray->offset());
728728
}
729729

730+
public function reciprocal(): self
731+
{
732+
$mo = self::mo();
733+
734+
$ndArray = $mo->la()->reciprocal($this);
735+
736+
return new static($ndArray->buffer(), $ndArray->dtype(), $ndArray->shape(), $ndArray->offset());
737+
}
738+
730739
/**
731740
* Performs `L_p` normalization of inputs over specified dimension.
732741
*
@@ -907,6 +916,74 @@ public function mean(?int $axis = null, bool $keepShape = false): static|float|i
907916
return $mean;
908917
}
909918

919+
/**
920+
* Calculates the standard deviation and mean over the dimensions specified by dim. dim can be a
921+
* single dimension or `null` to reduce over all dimensions.
922+
*
923+
* @param int|null $axis The dimension to reduce. If `null`, reduces over all dimensions.
924+
* @param int $correction The type of normalization. Default is 0.
925+
* @param bool $keepShape Whether to keep the reduced dimension or not.
926+
*
927+
* @return array The standard deviation and mean of the tensor.
928+
*/
929+
public function stdMean(?int $axis = null, int $correction = 1, bool $keepShape = false): array
930+
{
931+
$mo = self::mo();
932+
933+
if ($axis === null) {
934+
$mean = $mo->mean($this);
935+
$std = sqrt($mo->sum($mo->la()->pow($mo->la()->sub($this, $mean), 2)) / ($this->size() - $correction));
936+
937+
return [
938+
Tensor::fromArray([$mean], $this->dtype()),
939+
Tensor::fromArray([$std], $this->dtype())
940+
];
941+
}
942+
943+
$axis = $this->safeIndex($axis, $this->ndim());
944+
945+
$mean = $mo->mean($this, $axis);
946+
947+
$resultShape = $this->shape();
948+
$resultShape[$axis] = 1;
949+
950+
$result = $this->zeros([count($this->buffer) / $this->shape()[$axis]], $this->dtype());
951+
952+
for ($i = 0; $i < count($this->buffer); ++$i) {
953+
$resultIndex = 0;
954+
$num = $i;
955+
$resultMultiplier = 1;
956+
957+
for ($j = $this->ndim() - 1; $j >= 0; --$j) {
958+
$size = $this->shape()[$j];
959+
960+
if ($j !== $axis) {
961+
$index = $num % $size;
962+
$resultIndex += $index * $resultMultiplier;
963+
$resultMultiplier *= $resultShape[$j];
964+
}
965+
966+
$num = floor($num / $size);
967+
}
968+
969+
$result->buffer[$resultIndex] += pow($this->buffer[$i] - $mean->buffer()[$resultIndex], 2);
970+
}
971+
972+
for ($i = 0; $i < count($result->buffer); ++$i) {
973+
$result->buffer[$i] = sqrt($result->buffer[$i] / ($this->shape()[$axis] - $correction));
974+
}
975+
976+
if (!$keepShape) {
977+
array_splice($resultShape, $axis, 1);
978+
}
979+
980+
return [
981+
new static($result->buffer(), $result->dtype(), $resultShape, $result->offset()),
982+
new static($mean->buffer(), $mean->dtype(), $resultShape, $mean->offset()),
983+
];
984+
}
985+
986+
910987
/**
911988
* Perform mean pooling of the tensor followed by a normalization step.
912989
*
@@ -977,7 +1054,7 @@ public function slice(...$slices): Tensor
9771054
$slice = $this->safeIndex($slice, $this->shape()[$sliceIndex], $sliceIndex);
9781055

9791056
$start[] = $slice;
980-
$size[] = 1;
1057+
$size[] = 1;
9811058

9821059
} elseif (is_array($slice) && count($slice) === 2) {
9831060
// An array of length 2 means take a range of elements
@@ -993,7 +1070,29 @@ public function slice(...$slices): Tensor
9931070
}
9941071
}
9951072

996-
return $this->sliceWithBounds($start, $size);
1073+
if (count($size) <= 3) {
1074+
return $this->sliceWithBounds($start, $size);
1075+
}
1076+
1077+
// The sliceWithBounds method only supports up to 3 dimensions,
1078+
// so we need to slice manually for higher dimensions
1079+
$newShape = $size;
1080+
$newBufferSize = array_product($size);
1081+
1082+
$buffer = self::newBuffer($newBufferSize, $this->dtype());
1083+
$stride = $this->stride();
1084+
1085+
for ($i = 0; $i < $newBufferSize; ++$i) {
1086+
$originalIndex = 0;
1087+
for ($j = count($newShape) - 1, $num = $i; $j >= 0; --$j) {
1088+
$size = $newShape[$j];
1089+
$originalIndex += (($num % $size) + $start[$j]) * $stride[$j];
1090+
$num = floor($num / $size);
1091+
}
1092+
$buffer[$i] = $this->buffer[$originalIndex];
1093+
}
1094+
1095+
return new Tensor($buffer, $this->dtype(), $newShape, $this->offset());
9971096
}
9981097

9991098
/**

0 commit comments

Comments
 (0)