Skip to content

Commit 3d8257a

Browse files
fix: Window and stride calculation error for whisper
1 parent 096f7cc commit 3d8257a

8 files changed

Lines changed: 31 additions & 88 deletions

File tree

examples/pipelines/asr.php

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,18 @@
1818

1919
$audioUrl = __DIR__ . '/../sounds/kyrian-dev.wav';
2020
$audioUrl = __DIR__ . '/../sounds/jfk.wav';
21-
//$audioUrl = __DIR__ . '/../sounds/preamble.wav';
22-
//$audioUrl = __DIR__ . '/../sounds/taunt.wav';
23-
//$audioUrl = __DIR__ . '/../sounds/gettysburg.wav';
24-
//$audioUrl = __DIR__ . '/../sounds/kyrian-speaking.wav';
21+
$audioUrl = __DIR__ . '/../sounds/preamble.wav';
22+
$audioUrl = __DIR__ . '/../sounds/taunt.wav';
23+
$audioUrl = __DIR__ . '/../sounds/gettysburg.wav';
24+
$audioUrl = __DIR__ . '/../sounds/kyrian-speaking.wav';
2525
$audioUrl = __DIR__ . '/../sounds/ted_60.wav';
2626
//$audioUrl = __DIR__ . '/../sounds/french-audio.wav';
2727

28-
$streamer = WhisperTextStreamer::make()
29-
// ->onTimestampStart(fn($time) => print("$time: "))
30-
// ->onTimestampEnd(fn($time) => print("\n"))
31-
;
3228

3329
$output = $transcriber($audioUrl,
3430
maxNewTokens: 256,
3531
chunkLengthSecs: 24,
3632
// returnTimestamps: true,
37-
streamer: $streamer
3833
);
3934

40-
//dd($output, timeUsage(), memoryUsage());
35+
dd($output, timeUsage(), memoryUsage());

examples/pipelines/text-generation.php

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
use function Codewithkyrian\Transformers\Utils\timeUsage;
1212

1313
ini_set('memory_limit', -1);
14-
//
14+
1515
//$generator = pipeline('text-generation', 'Xenova/gpt2');
1616
//$generator = pipeline('text-generation', 'Xenova/Qwen1.5-0.5B-Chat');
1717
$generator = pipeline('text-generation', 'Xenova/TinyLlama-1.1B-Chat-v1.0');
1818

19-
$streamer = TextStreamer::make();
19+
$streamer = TextStreamer::make()->shouldSkipPrompt();
2020

2121
$messages = [
2222
['role' => 'system', 'content' => 'You are a helpful assistant.'],
@@ -36,14 +36,14 @@
3636
);
3737

3838
//$generator = pipeline('text-generation', 'Xenova/codegen-350M-mono');
39-
//$streamer = StdOutStreamer::make();
40-
//
39+
//$streamer = TextStreamer::make();
40+
4141
//$output = $generator(
4242
// 'def fib(n):',
4343
// streamer: $streamer,
4444
// maxNewTokens: 100,
4545
// doSample: true,
46-
// returnFullText: false,
46+
// returnFullText: true,
4747
//);
48-
//
48+
4949
dd($output[0]['generated_text'], timeUsage(), memoryUsage());

src/Generation/Streamers/Streamer.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
abstract class Streamer
1414
{
1515
protected array $promptTokens = [];
16-
protected bool $skipPrompt;
16+
protected bool $skipPrompt = false;
1717
protected bool $nextTokensArePrompt;
1818

1919
protected PretrainedTokenizer $tokenizer;

src/Pipelines/AutomaticSpeechRecognitionPipeline.php

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
use Codewithkyrian\Transformers\Utils\GenerationConfig;
1212
use Codewithkyrian\Transformers\Utils\Image;
1313
use function Codewithkyrian\Transformers\Utils\array_pop_key;
14-
use function Codewithkyrian\Transformers\Utils\array_to_snake_case;
14+
use function Codewithkyrian\Transformers\Utils\array_keys_to_snake_case;
1515

1616
/**
1717
* Pipeline that aims at extracting spoken text contained within some audio.
@@ -111,11 +111,13 @@ private function __invokeWhisper(array|string $inputs, ...$args): array|Tensor|I
111111
$task = array_pop_key($args, 'task');
112112
$streamer = array_pop_key($args, 'streamer');
113113

114-
if (!is_null($streamer) && !is_a($streamer, WhisperTextStreamer::class)) {
115-
throw new \InvalidArgumentException('`streamer` must be an instance of `WhisperTextStreamer`');
116-
}
114+
// if (!is_null($streamer) && !is_a($streamer, WhisperTextStreamer::class)) {
115+
// throw new \InvalidArgumentException('`streamer` must be an instance of `WhisperTextStreamer`');
116+
// }
117+
118+
if (!is_null($streamer)) trigger_error('`streamer` is not supported yet for Whisper', E_USER_WARNING);
117119

118-
$kwargs = array_to_snake_case($args);
120+
$kwargs = array_keys_to_snake_case($args);
119121

120122
$generationConfig = new GenerationConfig($kwargs);
121123

@@ -139,14 +141,12 @@ private function __invokeWhisper(array|string $inputs, ...$args): array|Tensor|I
139141
$timePrecision = $this->processor->featureExtractor->config['chunk_length'] / $this->model->config['max_source_positions'];
140142
$hopLength = $this->processor->featureExtractor->config['hop_length'];
141143
$samplingRate = $this->processor->featureExtractor->config['sampling_rate'];
142-
$timestampBegin = $this->tokenizer->tokenizer->convertTokensToIds(["<|notimestamps|>"])[0] + 1;
143144

144145
$toReturn = [];
145146

146-
$streamer?->setTokenizer($this->tokenizer)
147-
?->shouldSkipPrompt(false)
148-
?->setTimePrecision($timePrecision)
149-
?->setTimestampBegin($timestampBegin);
147+
// $streamer?->setTokenizer($this->tokenizer)
148+
// ?->setTimePrecision($timePrecision)
149+
// ?->setTimestampBegin($timestampBegin);
150150

151151
foreach ($inputs as $input) {
152152
$audio = Audio::read($input);
@@ -168,9 +168,9 @@ private function __invokeWhisper(array|string $inputs, ...$args): array|Tensor|I
168168
$offset = 0;
169169

170170
while ($offset < $audioTensor->size()) {
171-
172171
if ($offset + $window > $audioTensor->size()) {
173172
$window = $audioTensor->size() - $offset;
173+
$jump = $window;
174174
}
175175

176176
$subAudio = $audioTensor->sliceWithBounds([$offset], [$window]);
@@ -206,7 +206,7 @@ private function __invokeWhisper(array|string $inputs, ...$args): array|Tensor|I
206206
foreach ($chunks as &$chunk) {
207207
$generationConfig['num_frames'] = (int)floor($chunk['stride'][0] / $hopLength);
208208

209-
$data = $this->model->generate($chunk['input_features'], generationConfig: $generationConfig, streamer: $streamer);
209+
$data = $this->model->generate($chunk['input_features'], generationConfig: $generationConfig);
210210

211211
// TODO: Right now we only get top beam
212212
if ($returnTimestamps === 'word') {
@@ -219,7 +219,7 @@ private function __invokeWhisper(array|string $inputs, ...$args): array|Tensor|I
219219
// convert stride to seconds
220220
$chunk['stride'] = array_map(fn($x) => $x / $samplingRate, $chunk['stride']);
221221

222-
$streamer?->notifyChunkEnd($chunk['stride'][0]);
222+
// $streamer?->notifyChunkEnd($chunk['stride'][0]);
223223
}
224224

225225
// Merge text chunks

src/Pipelines/Text2TextGenerationPipeline.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
use Codewithkyrian\Transformers\Generation\Streamers\Streamer;
99
use Codewithkyrian\Transformers\Utils\GenerationConfig;
1010
use function Codewithkyrian\Transformers\Utils\array_pop_key;
11-
use function Codewithkyrian\Transformers\Utils\array_to_snake_case;
11+
use function Codewithkyrian\Transformers\Utils\array_keys_to_snake_case;
1212

1313
/**
1414
* A pipeline for generating text using a model that performs text-to-text generation tasks.
@@ -33,7 +33,7 @@ public function __invoke(array|string $inputs, ...$args): array
3333
/** @var Streamer $streamer */
3434
$streamer = array_pop_key($args, 'streamer');
3535

36-
$kwargs = array_to_snake_case($args);
36+
$kwargs = array_keys_to_snake_case($args);
3737

3838
$generateKwargs = new GenerationConfig($kwargs);
3939

src/Pipelines/TextGenerationPipeline.php

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
use Codewithkyrian\Transformers\Utils\GenerationConfig;
1010
use function Codewithkyrian\Transformers\Utils\array_every;
1111
use function Codewithkyrian\Transformers\Utils\array_pop_key;
12-
use function Codewithkyrian\Transformers\Utils\array_to_snake_case;
12+
use function Codewithkyrian\Transformers\Utils\array_keys_to_snake_case;
1313
use function Codewithkyrian\Transformers\Utils\camelCaseToSnakeCase;
1414
use function Codewithkyrian\Transformers\Utils\timeUsage;
1515

@@ -62,7 +62,7 @@ public function __invoke(array|string $inputs, ...$args): array
6262

6363
$returnFullText = array_pop_key($args, 'returnFullText', true);
6464

65-
$kwargs = array_to_snake_case($args);
65+
$kwargs = array_keys_to_snake_case($args);
6666

6767
$generationConfig = new GenerationConfig($kwargs);
6868

@@ -104,9 +104,7 @@ public function __invoke(array|string $inputs, ...$args): array
104104
truncation: true
105105
);
106106

107-
$streamer?->setTokenizer($this->tokenizer)
108-
?->shouldSkipPrompt()
109-
?->setPromptTokens($inputIds[0]->toArray());
107+
$streamer?->setTokenizer($this->tokenizer)?->setPromptTokens($inputIds[0]->toArray());
110108

111109
$outputTokenIds = $this->model->generate($inputIds,
112110
generationConfig: $generationConfig,

src/Tensor/Tensor.php

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,56 +1096,6 @@ public function slice(...$slices): Tensor
10961096
}
10971097

10981098
return new Tensor($buffer, $this->dtype(), $newShape, $this->offset());
1099-
1100-
$newTensorShape = [];
1101-
$newOffsets = [];
1102-
1103-
for ($sliceIndex = 0; $sliceIndex < $this->ndim(); ++$sliceIndex) {
1104-
$slice = $slices[$sliceIndex] ?? null;
1105-
1106-
if ($slice === null) {
1107-
$newOffsets[] = [0, $this->shape()[$sliceIndex]];
1108-
$newTensorShape[] = $this->shape()[$sliceIndex];
1109-
1110-
} elseif (is_int($slice)) {
1111-
$slice = $this->safeIndex($slice, $this->shape()[$sliceIndex], $sliceIndex);
1112-
$newOffsets[] = [$slice, $slice + 1];
1113-
1114-
} elseif (is_array($slice) && count($slice) === 2) {
1115-
if ($slice[0] > $slice[1]) {
1116-
throw new Exception("Invalid slice: " . json_encode($slice));
1117-
}
1118-
$offsets = [
1119-
max($slice[0], 0),
1120-
min($slice[1], $this->shape()[$sliceIndex])
1121-
];
1122-
$newOffsets[] = $offsets;
1123-
$newTensorShape[] = $offsets[1] - $offsets[0];
1124-
1125-
} else {
1126-
throw new Exception("Invalid slice: " . json_encode($slice));
1127-
}
1128-
}
1129-
1130-
$newShape = array_map(fn($offsets) => $offsets[1] - $offsets[0], $newOffsets);
1131-
1132-
$newBufferSize = array_reduce($newShape, fn($a, $b) => $a * $b, 1);
1133-
1134-
$buffer = self::newBuffer($newBufferSize, $this->dtype());
1135-
$stride = $this->stride();
1136-
1137-
for ($i = 0; $i < $newBufferSize; ++$i) {
1138-
$originalIndex = 0;
1139-
for ($j = count($newShape) - 1, $num = $i; $j >= 0; --$j) {
1140-
$size = $newShape[$j];
1141-
$originalIndex += (($num % $size) + $newOffsets[$j][0]) * $stride[$j];
1142-
$num = floor($num / $size);
1143-
}
1144-
$buffer[$i] = $this->buffer[$originalIndex];
1145-
}
1146-
1147-
return new Tensor($buffer, $this->dtype(), $newTensorShape, $this->offset());
1148-
11491099
}
11501100

11511101
/**

src/Utils/Helpers.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ function array_pop_key(array &$array, string|int $key, mixed $default = null)
7171
return $default;
7272
}
7373

74-
function array_to_snake_case(array $array): array
74+
function array_keys_to_snake_case(array $array): array
7575
{
7676
$snakeCasedArray = [];
7777

0 commit comments

Comments
 (0)