Skip to content

Commit 650df9a

Browse files
Fix bugs in ForceTokensLogitsProcessor and WhisperTokenizer to correct errors for non-english languages in ASR
1 parent 20add85 commit 650df9a

2 files changed

Lines changed: 56 additions & 0 deletions

File tree

src/Models/Pretrained/WhisperForConditionalGeneration.php

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ public function extractTokenTimestamps(
170170

171171
// Apply median filter.
172172
$this->medianFilter($cTensor, $medianFilterWidth)->copyTo($cTensor);
173+
// $filtered = $this->medianFilter($cTensor, $medianFilterWidth);
174+
// for ($e = 0; $e < $filtered->count(); ++$e) {
175+
// $cTensor[$e] = $filtered[$e];
176+
// }
177+
178+
173179
}
174180
}
175181
}

src/Tensor/Tensor.php

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,56 @@ public function slice(...$slices): Tensor
10961096
}
10971097

10981098
return new Tensor($buffer, $this->dtype(), $newShape, $this->offset());
1099+
1100+
$newTensorShape = [];
1101+
$newOffsets = [];
1102+
1103+
for ($sliceIndex = 0; $sliceIndex < $this->ndim(); ++$sliceIndex) {
1104+
$slice = $slices[$sliceIndex] ?? null;
1105+
1106+
if ($slice === null) {
1107+
$newOffsets[] = [0, $this->shape()[$sliceIndex]];
1108+
$newTensorShape[] = $this->shape()[$sliceIndex];
1109+
1110+
} elseif (is_int($slice)) {
1111+
$slice = $this->safeIndex($slice, $this->shape()[$sliceIndex], $sliceIndex);
1112+
$newOffsets[] = [$slice, $slice + 1];
1113+
1114+
} elseif (is_array($slice) && count($slice) === 2) {
1115+
if ($slice[0] > $slice[1]) {
1116+
throw new Exception("Invalid slice: " . json_encode($slice));
1117+
}
1118+
$offsets = [
1119+
max($slice[0], 0),
1120+
min($slice[1], $this->shape()[$sliceIndex])
1121+
];
1122+
$newOffsets[] = $offsets;
1123+
$newTensorShape[] = $offsets[1] - $offsets[0];
1124+
1125+
} else {
1126+
throw new Exception("Invalid slice: " . json_encode($slice));
1127+
}
1128+
}
1129+
1130+
$newShape = array_map(fn($offsets) => $offsets[1] - $offsets[0], $newOffsets);
1131+
1132+
$newBufferSize = array_reduce($newShape, fn($a, $b) => $a * $b, 1);
1133+
1134+
$buffer = self::newBuffer($newBufferSize, $this->dtype());
1135+
$stride = $this->stride();
1136+
1137+
for ($i = 0; $i < $newBufferSize; ++$i) {
1138+
$originalIndex = 0;
1139+
for ($j = count($newShape) - 1, $num = $i; $j >= 0; --$j) {
1140+
$size = $newShape[$j];
1141+
$originalIndex += (($num % $size) + $newOffsets[$j][0]) * $stride[$j];
1142+
$num = floor($num / $size);
1143+
}
1144+
$buffer[$i] = $this->buffer[$originalIndex];
1145+
}
1146+
1147+
return new Tensor($buffer, $this->dtype(), $newTensorShape, $this->offset());
1148+
10991149
}
11001150

11011151
/**

0 commit comments

Comments
 (0)