Skip to content

Commit 52d896c

Browse files
Add Audio Classification Pipeline and support for processing stereo audio
1 parent 2b7b495 commit 52d896c

10 files changed

Lines changed: 210 additions & 38 deletions

File tree

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
use function Codewithkyrian\Transformers\Pipelines\pipeline;
6+
use function Codewithkyrian\Transformers\Utils\memoryUsage;
7+
use function Codewithkyrian\Transformers\Utils\timeUsage;
8+
9+
require_once './bootstrap.php';
10+
11+
$classifier = pipeline('audio-classification', 'Xenova/ast-finetuned-audioset-10-10-0.4593');
12+
13+
//$audioUrl = __DIR__ . '/../sounds/dog_barking.wav';
14+
$audioUrl = __DIR__ . '/../sounds/cat_meow.wav';
15+
16+
$output = $classifier($audioUrl, topK: 4);
17+
18+
dd($output, timeUsage(), memoryUsage());

src/Models/Auto/AutoModel.php

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ class AutoModel extends PretrainedMixin
2020
"deit" => \Codewithkyrian\Transformers\Models\Pretrained\DeiTModel::class,
2121
"siglip" => \Codewithkyrian\Transformers\Models\Pretrained\SigLipModel::class,
2222

23+
"audio-spectrogram-transformer" => \Codewithkyrian\Transformers\Models\Pretrained\ASTModel::class,
24+
2325
'detr' => \Codewithkyrian\Transformers\Models\Pretrained\DETRModel::class,
2426
'yolos' => \Codewithkyrian\Transformers\Models\Pretrained\YOLOSModel::class,
2527
'owlvit' => \Codewithkyrian\Transformers\Models\Pretrained\OwlVitModel::class,
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Models\Auto;
7+
8+
class AutoModelForAudioClassification extends PretrainedMixin
9+
{
10+
const MODEL_CLASS_MAPPING = [
11+
'audio-spectrogram-transformer' => \Codewithkyrian\Transformers\Models\Pretrained\ASTForAudioClassification::class,
12+
];
13+
14+
const MODEL_CLASS_MAPPINGS = [
15+
self::MODEL_CLASS_MAPPING,
16+
];
17+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Models\Pretrained;
7+
8+
class ASTForAudioClassification extends ASTPretrainedModel
9+
{
10+
11+
}

src/Models/Pretrained/ASTModel.php

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Models\Pretrained;
7+
8+
class ASTModel extends ASTPretrainedModel
9+
{
10+
11+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Models\Pretrained;
7+
8+
/**
9+
* Audio Spectrogram Transformer (AST) models
10+
*/
11+
class ASTPretrainedModel extends PretrainedModel
12+
{
13+
14+
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Pipelines;
7+
8+
use Codewithkyrian\Transformers\Utils\Audio;
9+
10+
/**
11+
* Audio classification pipeline using any `AutoModelForAudioClassification`.
12+
* This pipeline predicts the class of a raw waveform or an audio file.
13+
*
14+
* *Example:** Perform audio classification with `Xenova/wav2vec2-large-xlsr-53-gender-recognition-librispeech`.
15+
* ```php
16+
* $classifier = pipeline('audio-classification', 'Xenova/wav2vec2-large-xlsr-53-gender-recognition-librispeech');
17+
* $url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
18+
* $output = $classifier($url);
19+
* // [
20+
* // [ label: 'male', score: 0.9981542229652405 ],
21+
* // [ label: 'female', score: 0.001845747814513743 ]
22+
* // ]
23+
* ```
24+
*
25+
* *Example:** Perform audio classification with `Xenova/ast-finetuned-audioset-10-10-0.4593` and return top 4 results.
26+
* ```php
27+
* $classifier = await pipeline('audio-classification', 'Xenova/ast-finetuned-audioset-10-10-0.4593');
28+
* $url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cat_meow.wav';
29+
* $output = $classifier($url, topK: 4);
30+
* // [
31+
* // [ label: 'Meow', score: 0.5617874264717102 ],
32+
* // [ label: 'Cat', score: 0.22365376353263855 ],
33+
* // [ label: 'Domestic animals, pets', score: 0.1141069084405899 ],
34+
* // [ label: 'Animal', score: 0.08985692262649536 ],
35+
* // ]
36+
* ```
37+
*/
38+
class AudioClassificationPipeline extends Pipeline
39+
{
40+
public function __invoke(array|string $inputs, ...$args): array
41+
{
42+
$topK = $args["topK"] ?? 1;
43+
44+
$isBatched = is_array($inputs);
45+
46+
if (!$isBatched) {
47+
$inputs = [$inputs];
48+
}
49+
50+
$sampleRate = $this->processor->featureExtractor->config['sampling_rate'];
51+
$id2label = $this->model->config['id2label'];
52+
$toReturn = [];
53+
54+
foreach ($inputs as $input) {
55+
$audio = Audio::read($input);
56+
$audioTensor = $audio->toTensor(samplerate: $sampleRate);
57+
58+
$inputs = ($this->processor)($audioTensor);
59+
$outputs = ($this->model)($inputs);
60+
61+
$logits = $outputs['logits'][0];
62+
63+
[$scores, $indices] = $logits->softmax()->topk($topK, true);
64+
65+
$values = [];
66+
67+
foreach ($indices as $i => $index) {
68+
$values[] = ['label' => $id2label[$index], 'score' => $scores[$i]];
69+
}
70+
71+
if ($topK === 1) {
72+
$toReturn = array_merge($toReturn, $values);
73+
} else {
74+
$toReturn[] = $values;
75+
}
76+
}
77+
78+
return $isBatched || $topK === 1 ? $toReturn : $toReturn[0];
79+
}
80+
}

src/Pipelines/Task.php

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
namespace Codewithkyrian\Transformers\Pipelines;
66

77
use Codewithkyrian\Transformers\Models\Auto\AutoModel;
8+
use Codewithkyrian\Transformers\Models\Auto\AutoModelForAudioClassification;
89
use Codewithkyrian\Transformers\Models\Auto\AutoModelForCausalLM;
910
use Codewithkyrian\Transformers\Models\Auto\AutoModelForImageClassification;
1011
use Codewithkyrian\Transformers\Models\Auto\AutoModelForImageFeatureExtraction;
@@ -49,6 +50,8 @@ enum Task: string
4950
case ObjectDetection = 'object-detection';
5051
case ZeroShotObjectDetection = 'zero-shot-object-detection';
5152

53+
case AudioClassification = 'audio-classification';
54+
5255

5356
public function pipeline(PretrainedModel $model, ?PretrainedTokenizer $tokenizer, ?Processor $processor): Pipeline
5457
{
@@ -89,6 +92,8 @@ public function pipeline(PretrainedModel $model, ?PretrainedTokenizer $tokenizer
8992
self::ObjectDetection => new ObjectDetectionPipeline($this, $model, $tokenizer, $processor),
9093

9194
self::ZeroShotObjectDetection => new ZeroShotObjectDetectionPipeline($this, $model, $tokenizer, $processor),
95+
96+
self::AudioClassification => new AudioClassificationPipeline($this, $model, processor: $processor),
9297
};
9398
}
9499

@@ -129,16 +134,18 @@ public function defaultModelName(): string
129134
self::ObjectDetection => 'Xenova/detr-resnet-50', // Original: 'facebook/detr-resnet-50',
130135

131136
self::ZeroShotObjectDetection => 'Xenova/owlvit-base-patch32', // Original: 'google/owlvit-base-patch32',
137+
138+
self::AudioClassification => 'Xenova/wav2vec2-base-superb-ks', // Original: 'superb/wav2vec2-base-superb-ks',
132139
};
133140
}
134141

135142
public function autoModel(
136-
string $modelNameOrPath,
137-
bool $quantized = true,
138-
?array $config = null,
139-
?string $cacheDir = null,
140-
string $revision = 'main',
141-
?string $modelFilename = null,
143+
string $modelNameOrPath,
144+
bool $quantized = true,
145+
?array $config = null,
146+
?string $cacheDir = null,
147+
string $revision = 'main',
148+
?string $modelFilename = null,
142149
?callable $onProgress = null
143150
): PretrainedModel
144151
{
@@ -176,13 +183,15 @@ public function autoModel(
176183
self::ObjectDetection => AutoModelForObjectDetection::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
177184

178185
self::ZeroShotObjectDetection => AutoModelForZeroShotObjectDetection::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
186+
187+
self::AudioClassification => AutoModelForAudioClassification::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
179188
};
180189
}
181190

182191
public function autoTokenizer(
183-
string $modelNameOrPath,
184-
?string $cacheDir = null,
185-
string $revision = 'main',
192+
string $modelNameOrPath,
193+
?string $cacheDir = null,
194+
string $revision = 'main',
186195
?callable $onProgress = null
187196
): ?PretrainedTokenizer
188197
{
@@ -191,7 +200,8 @@ public function autoTokenizer(
191200
self::ImageClassification,
192201
self::ImageToImage,
193202
self::ImageFeatureExtraction,
194-
self::ObjectDetection => null,
203+
self::ObjectDetection,
204+
self::AudioClassification => null,
195205

196206

197207
self::SentimentAnalysis,
@@ -214,10 +224,10 @@ public function autoTokenizer(
214224
}
215225

216226
public function autoProcessor(
217-
string $modelNameOrPath,
218-
?array $config = null,
219-
?string $cacheDir = null,
220-
string $revision = 'main',
227+
string $modelNameOrPath,
228+
?array $config = null,
229+
?string $cacheDir = null,
230+
string $revision = 'main',
221231
?callable $onProgress = null
222232
): ?Processor
223233
{
@@ -229,7 +239,8 @@ public function autoProcessor(
229239
self::ZeroShotImageClassification,
230240
self::ImageToImage,
231241
self::ObjectDetection,
232-
self::ZeroShotObjectDetection => AutoProcessor::fromPretrained($modelNameOrPath, $config, $cacheDir, $revision, $onProgress),
242+
self::ZeroShotObjectDetection,
243+
self::AudioClassification => AutoProcessor::fromPretrained($modelNameOrPath, $config, $cacheDir, $revision, $onProgress),
233244

234245

235246
self::SentimentAnalysis,

src/Tensor/Tensor.php

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,15 @@ class Tensor implements NDArray, Countable, Serializable, IteratorAggregate
4444
protected Buffer $buffer;
4545

4646
protected static $pack = [
47-
NDArray::bool => 'C',
48-
NDArray::int8 => 'c',
49-
NDArray::int16 => 's',
50-
NDArray::int32 => 'l',
51-
NDArray::int64 => 'q',
52-
NDArray::uint8 => 'C',
53-
NDArray::uint16 => 'S',
54-
NDArray::uint32 => 'L',
55-
NDArray::uint64 => 'Q',
47+
NDArray::bool => 'C',
48+
NDArray::int8 => 'c',
49+
NDArray::int16 => 's',
50+
NDArray::int32 => 'l',
51+
NDArray::int64 => 'q',
52+
NDArray::uint8 => 'C',
53+
NDArray::uint16 => 'S',
54+
NDArray::uint32 => 'L',
55+
NDArray::uint64 => 'Q',
5656
//NDArray::float8 => 'N/A',
5757
//NDArray::float16 => 'N/A',
5858
NDArray::float32 => 'g',
@@ -414,7 +414,7 @@ public function toString(): string
414414
*/
415415
public function toBufferArray(): array
416416
{
417-
$fmt = self::$pack[$this->dtype].'*';
417+
$fmt = self::$pack[$this->dtype] . '*';
418418

419419
return array_values(unpack($fmt, $this->buffer->dump()));
420420
}
@@ -848,10 +848,16 @@ public function mean(?int $axis = null, bool $keepShape = false): static|float|i
848848
{
849849
$mo = self::mo();
850850

851+
if ($axis !== null) {
852+
$axis = $this->safeIndex($axis, $this->ndim());
853+
}
854+
851855
$mean = $mo->mean($this, $axis);
852856

853857
if ($mean instanceof NDArray) {
854-
$shape = $mean->shape();
858+
$shape = $this->shape();
859+
860+
$shape[$axis] = 1;
855861

856862
if (!$keepShape) {
857863
array_splice($shape, $axis, 1);

0 commit comments

Comments
 (0)