Skip to content

Commit 68436d2

Browse files
Move Spectrogram calculation to C and FFI
1 parent 52d896c commit 68436d2

25 files changed

Lines changed: 1687 additions & 203 deletions

examples/pipelines/asr.php

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
use Codewithkyrian\Transformers\Generation\Streamers\StdOutStreamer;
6+
use function Codewithkyrian\Transformers\Pipelines\pipeline;
7+
use function Codewithkyrian\Transformers\Utils\memoryUsage;
8+
use function Codewithkyrian\Transformers\Utils\timeUsage;
9+
10+
require_once './bootstrap.php';
11+
12+
ini_set('memory_limit', '-1');
13+
14+
$transcriber = pipeline('automatic-speech-recognition', 'Xenova/whisper-tiny.en');
15+
16+
$audioUrl = __DIR__ . '/../sounds/kyrian-dev.wav';
17+
$audioUrl = __DIR__ . '/../sounds/jfk.wav';
18+
$audioUrl = __DIR__ . '/../sounds/preamble.wav';
19+
$audioUrl = __DIR__ . '/../sounds/taunt.wav';
20+
$audioUrl = __DIR__ . '/../sounds/gettysburg.wav';
21+
22+
$streamer = StdOutStreamer::make();
23+
$output = $transcriber($audioUrl, maxNewTokens: 256, streamer: $streamer);
24+
25+
dd( timeUsage(), memoryUsage());

examples/pipelines/image-to-image.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313
$upscaler = pipeline('image-to-image', 'Xenova/swin2SR-classical-sr-x2-64');
1414

15-
$url = __DIR__ . '/../images/butterfly.jpg';
15+
$url = __DIR__ . '/../images/versus.jpeg';
1616

17-
$savePath = __DIR__ . '/../images/butterfly-super-resolution.jpg';
17+
$savePath = __DIR__ . '/../images/versus-x4.jpeg';
1818

1919
$output = $upscaler($url, saveTo: $savePath);
2020

examples/sounds/gettysburg.wav

757 KB
Binary file not shown.

examples/sounds/jfk.wav

1.85 MB
Binary file not shown.

examples/sounds/preamble.wav

823 KB
Binary file not shown.

examples/sounds/taunt.wav

89.1 KB
Binary file not shown.

src/DataStructures/NP2FFT.php

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public function __construct(int $fftLength)
3030
// Helper variables
3131
$a = 2 * ($fftLength - 1);
3232
$b = 2 * (2 * $fftLength - 1);
33-
$nextP2 = pow(2, ceil(log($b, 2)));
33+
$nextP2 = (int) (2 ** ceil(log($b, 2)));
3434
$this->bufferSize = $nextP2;
3535
$this->a = $a;
3636

@@ -68,6 +68,7 @@ public function __construct(int $fftLength)
6868
$ichirp[$i2] = $chirp[$i2];
6969
$ichirp[$i2 + 1] = -$chirp[$i2 + 1];
7070
}
71+
7172
$this->slicedChirpBuffer = SplFixedArray::fromArray(array_slice($chirp->toArray(), $a, $b - $a));
7273

7374
// create object to perform Fast Fourier Transforms
@@ -115,10 +116,10 @@ private function transform(SplFixedArray $output, SplFixedArray $input, bool $re
115116
$this->f->inverseTransform($ob3, $ib2);
116117

117118
for ($j = 0; $j < count($ob3); $j += 2) {
118-
$aReal = $ob3[$j + $a];
119-
$a_imag = $ob3[$j + $a + 1];
120-
$b_real = $sb[$j];
121-
$b_imag = $sb[$j + 1];
119+
$aReal = $ob3[$j + $a] ?? 0;
120+
$a_imag = $ob3[$j + $a + 1] ?? 0;
121+
$b_real = $sb[$j] ?? 0;
122+
$b_imag = $sb[$j + 1] ?? 0;
122123

123124
$output[$j] = $aReal * $b_real - $a_imag * $b_imag;
124125
$output[$j + 1] = $aReal * $b_imag + $a_imag * $b_real;
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\FeatureExtractors;
7+
8+
use Codewithkyrian\Transformers\Tensor\Tensor;
9+
use Codewithkyrian\Transformers\Utils\Audio;
10+
use function Codewithkyrian\Transformers\Utils\timeUsage;
11+
12+
class WhisperFeatureExtractor extends FeatureExtractor
13+
{
14+
protected Tensor $window;
15+
16+
public function __construct(array $config)
17+
{
18+
parent::__construct($config);
19+
20+
$this->config['mel_filters'] ??= Audio::melFilterBank(
21+
(int)(1 + $config['n_fft'] / 2),
22+
nMelFilters: $config['feature_size'],
23+
minFrequency: 0,
24+
maxFrequency: 8000,
25+
samplingRate: $config['sampling_rate'],
26+
norm: 'slaney',
27+
melScale: 'slaney',
28+
);
29+
30+
$this->window = Audio::windowFunction($config['n_fft'], 'hann', false);
31+
}
32+
33+
/**
34+
* Extracts features from a given audio using the provided configuration.
35+
* @param Tensor $waveform The audio tensor to extract features from.
36+
* @return Tensor[] The extracted features.
37+
*/
38+
public function __invoke(Tensor $waveform): array
39+
{
40+
if ($waveform->size() > $this->config['n_samples']) {
41+
trigger_error('Attempting to extract features for audio longer than 30 seconds.' .
42+
'If using a pipeline to extract transcript from a long audio clip,' .
43+
'remember to specify `chunkLengthSecs` and/or `strideLengthSecs` in the pipeline options.', E_USER_WARNING);
44+
45+
$waveform = $waveform->slice(0, $this->config['n_samples']);
46+
} else {
47+
$padding = $this->config['n_samples'] - $waveform->size();
48+
// create a new Tensor with the same data type as the input waveform
49+
$padding = Tensor::zeros([$padding], dtype: $waveform->dtype());
50+
$waveform = Tensor::concat([$waveform, $padding]);
51+
}
52+
53+
timeUsage();
54+
$features = Audio::spectrogram(
55+
$waveform,
56+
$this->window,
57+
frameLength: $this->config['n_fft'],
58+
hopLength: $this->config['hop_length'],
59+
power: 2.0,
60+
melFilters: $this->config['mel_filters'],
61+
logMel: 'log10',
62+
63+
maxNumFrames: $this->config['nb_max_frames'],
64+
);
65+
66+
$maxValue = $features->max();
67+
68+
$features->u(fn($x) => (max($x, $maxValue - 8.0) + 4.0) / 4.0);
69+
70+
return [
71+
'input_features' => $features->unsqueeze(0)
72+
];
73+
}
74+
}

src/Generation/LogitsProcessors/ForceTokensLogitsProcessor.php

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ class ForceTokensLogitsProcessor extends LogitsProcessor
1515

1616
public function __construct(array $forcedDecoderIds)
1717
{
18-
$this->forceTokenMap = array_fill_keys(array_keys($forcedDecoderIds), 0);
18+
foreach ($forcedDecoderIds[0] as $inputLength => $forcedId) {
19+
$this->forceTokenMap[$inputLength] = $forcedId;
20+
}
1921
}
2022

2123
/**
@@ -27,7 +29,7 @@ public function __construct(array $forcedDecoderIds)
2729
*/
2830
public function __invoke(array $inputIds, Tensor $logits): Tensor
2931
{
30-
$map = $this->forceTokenMap[count($inputIds) ?? 0]; // Access length from inputIds
32+
$map = $this->forceTokenMap[count($inputIds)] ?? null; // Access length from inputIds
3133

3234
if ($map) {
3335
Tensor::mo()->la()->fill(-INF, $logits);
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Generation\LogitsProcessors;
7+
8+
use Codewithkyrian\Transformers\Tensor\Tensor;
9+
use Codewithkyrian\Transformers\Utils\GenerationConfig;
10+
11+
class WhisperTimeStampLogitsProcessor extends LogitsProcessor
12+
{
13+
/**
14+
* @var int|mixed The ID of the end-of-sequence token.
15+
*/
16+
protected int $eosTokenId;
17+
18+
/**
19+
* @var int The ID of the token used to indicate that a token should not have a timestamp.
20+
*/
21+
protected int $noTimestampsTokenId;
22+
23+
/**
24+
* @var int The ID at which timestamps begin.
25+
*/
26+
protected int $timestampBegin;
27+
28+
/**
29+
* @var int The index at which the first token can have a timestamp.
30+
*/
31+
protected int $beginIndex;
32+
33+
/**
34+
* @var ?int The maximum index at which an initial timestamp can appear.
35+
*/
36+
protected ?int $maxInitialTimestampIndex;
37+
38+
/**
39+
* Constructs a new WhisperTimeStampLogitsProcessor.
40+
*/
41+
public function __construct(GenerationConfig $generateConfig)
42+
{
43+
$this->eosTokenId = $generateConfig->eos_token_id;
44+
$this->noTimestampsTokenId = $generateConfig['no_timestamps_token_id'];
45+
$this->timestampBegin = $this->noTimestampsTokenId + 1;
46+
47+
$this->beginIndex = count($generateConfig['forced_decoder_ids'] ?? []) + 2;
48+
if (end($generateConfig['forced_decoder_ids'])[1] === $this->noTimestampsTokenId) {
49+
$this->beginIndex -= 1;
50+
}
51+
$this->maxInitialTimestampIndex = $generateConfig['max_initial_timestamp_index'] ?? null;
52+
}
53+
54+
/**
55+
* Modify the logits to handle timestamp tokens.
56+
* @param array $inputIds The input sequence of tokens.
57+
* @param Tensor $logits The logits output by the model.
58+
* @return Tensor The modified logits.
59+
*/
60+
public function __invoke(array $inputIds, Tensor $logits): Tensor
61+
{
62+
// suppress which is handled by without_timestamps
63+
$logits->buffer()[$this->noTimestampsTokenId] = -INF;
64+
65+
if (count($inputIds) === $this->beginIndex - 1) {
66+
Tensor::mo()->la()->fill(-INF, $logits);
67+
$logits->buffer()[$this->timestampBegin] = 0;
68+
return $logits;
69+
}
70+
71+
// timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
72+
$seq = array_slice($inputIds, $this->beginIndex);
73+
$lastWasTimestamp = count($seq) >= 1 && $seq[count($seq) - 1] >= $this->timestampBegin;
74+
$penultimateWasTimestamp = count($seq) < 2 || $seq[count($seq) - 2] >= $this->timestampBegin;
75+
76+
if ($lastWasTimestamp) {
77+
if ($penultimateWasTimestamp) { // has to be non-timestamp
78+
for ($i = $this->timestampBegin; $i < $logits->size(); $i++) {
79+
$logitsData[$i] = -INF;
80+
}
81+
} else { // cannot be normal text tokens
82+
for ($i = 0; $i < $this->eosTokenId; $i++) {
83+
$logitsData[$i] = -INF;
84+
}
85+
}
86+
}
87+
88+
// apply the `max_initial_timestamp` option
89+
if (count($inputIds) === $this->beginIndex && $this->maxInitialTimestampIndex !== null) {
90+
$lastAllowed = $this->timestampBegin + $this->maxInitialTimestampIndex;
91+
for ($i = $lastAllowed + 1; $i < $logits->size(); $i++) {
92+
$logitsData[$i] = -INF;
93+
}
94+
}
95+
96+
// if sum of probability over timestamps is above any other token, sample timestamp
97+
// $logProbs = log_softmax($logitsData);
98+
$logProbs = $logits->softmax()->log();
99+
$timestampLogProb = log(array_sum(array_map('exp', array_slice($logProbs, $this->timestampBegin))));
100+
$maxTextTokenLogProb = max(array_slice($logProbs, 0, $this->timestampBegin));
101+
102+
if ($timestampLogProb > $maxTextTokenLogProb) {
103+
for ($i = 0; $i < $this->timestampBegin; $i++) {
104+
$logitsData[$i] = -INF;
105+
}
106+
}
107+
108+
return $logits;
109+
}
110+
}

0 commit comments

Comments
 (0)