Skip to content

Commit 32d8cfd

Browse files
Refactor stack method in tensor for performance
1 parent dda9a70 commit 32d8cfd

6 files changed

Lines changed: 23 additions & 96 deletions

File tree

src/FeatureExtractors/ImageFeatureExtractor.php

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,7 @@ public function padImage(
201201
int $constantValues = 0
202202
): array
203203
{
204-
$imageHeight = $imgShape[0];
205-
$imageWidth = $imgShape[1];
206-
$imageChannels = $imgShape[2];
204+
[$imageHeight, $imageWidth, $imageChannels] = $imgShape;
207205

208206
if (is_array($padSize)) {
209207
$paddedImageWidth = $padSize['width'];
@@ -459,7 +457,6 @@ public function preprocess(
459457

460458
$imgShape = [$image->height(), $image->width(), $image->channels];
461459

462-
463460
if ($this->doRescale) {
464461
$this->rescale($pixelData);
465462
}
@@ -529,21 +526,12 @@ public function __invoke(Image|array $images, ...$args): array
529526
$imageData[] = $this->preprocess($image);
530527
}
531528

532-
// Stack pixel values
533-
$pixelValues = [];
534-
foreach ($imageData as $data) {
535-
$pixelValues[] = $data['pixel_values'];
536-
}
529+
$pixelValues = array_column($imageData, 'pixel_values');
530+
$originalSizes = array_column($imageData, 'original_size');
531+
$reshapedInputSizes = array_column($imageData, 'reshaped_input_size');
537532

538533
$stackedPixelValues = Tensor::stack($pixelValues, 0);
539534

540-
// Prepare metadata
541-
$originalSizes = [];
542-
$reshapedInputSizes = [];
543-
foreach ($imageData as $data) {
544-
$originalSizes[] = $data['original_size'];
545-
$reshapedInputSizes[] = $data['reshaped_input_size'];
546-
}
547535
return [
548536
'pixel_values' => $stackedPixelValues,
549537
'original_sizes' => $originalSizes,

src/Generation/Samplers/Sampler.php

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,11 @@ public function getLogits(Tensor $logits, int $index): Tensor
5757
// $logs = array_slice($logs, $startIndex, $startIndex + $vocabSize);
5858
// }
5959

60-
$start = array_fill(0, $logits->ndim() - 2, 0);
61-
$size = array_fill(0, $logits->ndim() - 2, 1);
60+
$start = array_fill(0, $logits->ndim(), 0);
61+
$size = array_fill(0, $logits->ndim(), 1);
6262

63-
$start[] = $index;
64-
$size[] = 1;
65-
66-
$start[] = -$vocabSize;
67-
$size[] = $vocabSize;
63+
array_splice($start, -2, replacement: [$index, 0]);
64+
array_splice($size, -2, replacement: [1, $vocabSize]);
6865

6966
$logs = $logits->newSlice($start, $size);
7067

src/Models/Pretrained/PretrainedModel.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ public function generate(
625625
// In most cases, this will be [batch_size, 1, vocab_size]
626626
// So, we select the last token's logits:
627627
// (equivalent to `logits = outputs.logits[:, -1, :]`)
628-
$logits = $output['logits']-lice(null, -1, null);
628+
$logits = $output['logits']->slice(null, -1, null);
629629

630630
// Apply logits processor
631631
$logitsProcessor($beam['output_token_ids'], $logits);

src/Pipelines/ImageClassificationPipeline.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ public function __invoke(array|string $inputs, ...$args): array
6262

6363
['pixel_values' => $pixelValues] = ($this->processor)($preparedImages);
6464

65-
6665
/** @var SequenceClassifierOutput $output */
6766
$output = $this->model->__invoke(['pixel_values' => $pixelValues]);
6867

@@ -86,6 +85,7 @@ public function __invoke(array|string $inputs, ...$args): array
8685
$toReturn[] = $values;
8786
}
8887
}
88+
8989
if ($isBatched || $topK === 1) {
9090
return $toReturn;
9191
} else {

src/Pipelines/ZeroShotImageClassificationPipeline.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public function __invoke(array|string $inputs, ...$args): array
5151
$output = $this->model->__invoke(array_merge($textInputs, ['pixel_values' => $pixelValues]));
5252

5353
$activationFn = $this->model->config['model_type'] === 'siglip' ?
54-
fn(Tensor $batch) => $batch->sigmoid()->toArray() :
54+
fn(Tensor $batch) => $batch->sigmoid():
5555
fn(Tensor $batch) => $batch->softmax();
5656

5757
// Compare each image with each candidate label

src/Utils/Tensor.php

Lines changed: 12 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -459,9 +459,11 @@ public static function zerosLike(Tensor $other): static
459459
*/
460460
public static function stack(array $tensors, int $axis = 0): Tensor
461461
{
462-
// TODO: Perform validation of shapes
463-
// NOTE: stack expects each tensor to be equal size
464-
return self::cat(array_map(fn($t) => $t->unsqueeze($axis), $tensors), $axis);
462+
$mo = self::mo();
463+
464+
$stacked = $mo->la()->stack($tensors, $axis);
465+
466+
return new Tensor($stacked->buffer(), $stacked->dtype(), $stacked->shape(), $stacked->offset());
465467
}
466468

467469
/**
@@ -473,58 +475,13 @@ public static function stack(array $tensors, int $axis = 0): Tensor
473475
* @return Tensor The concatenated tensor.
474476
* @throws Exception
475477
*/
476-
public static function cat(array $tensors, int $axis = 0): Tensor
478+
public static function concat(array $tensors, int $axis = 0): Tensor
477479
{
478-
$axis = self::safeIndex($axis, $tensors[0]->ndim());
479-
480-
// TODO: Perform validation of shapes
481-
482-
$resultShape = $tensors[0]->shape();
483-
$resultOffset = $tensors[0]->offset();
484-
$resultType = $tensors[0]->dtype();
485-
$resultShape[$axis] = array_reduce($tensors, fn($carry, $tensor) => $carry + $tensor->shape()[$axis], 0);
486-
487-
// Create a new array to store the accumulated values
488-
$resultSize = array_product($resultShape);
489-
490-
$result = self::newBuffer($resultSize, $resultType);
491-
492-
// Create output tensor of same type as first
493-
494-
if ($axis === 0) {
495-
// Handle special case for performance reasons
496-
497-
$offset = 0;
498-
foreach ($tensors as $t) {
499-
for ($i = 0; $i < $t->buffer->count(); $i++) {
500-
$result[$offset++] = $t->buffer()[$i];
501-
}
502-
}
503-
} else {
504-
$currentShape = 0;
505-
506-
foreach ($tensors as $tensor) {
507-
for ($i = 0; $i < $tensor->buffer->count(); $i++) {
508-
$resultIndex = 0;
509-
510-
for ($j = $tensor->ndim() - 1, $num = $i, $resultMultiplier = 1; $j >= 0; --$j) {
511-
$size = $tensor->shape()[$j];
512-
$index = $num % $size;
513-
if ($j === $axis) {
514-
$index += $currentShape;
515-
}
516-
$resultIndex += $index * $resultMultiplier;
517-
$resultMultiplier *= $resultShape[$j];
518-
$num = (int)floor($num / $size);
519-
}
520-
$result[$resultIndex] = $tensor->buffer()[$i];
521-
}
480+
$mo = self::mo();
522481

523-
$currentShape += $tensor->shape()[$axis];
524-
}
525-
}
482+
$ndArray = $mo->la()->concat($tensors, $axis);
526483

527-
return new Tensor($result, $resultType, $resultShape, $resultOffset);
484+
return new static($ndArray->buffer(), $ndArray->dtype(), $ndArray->shape(), $ndArray->offset());
528485
}
529486

530487
/**
@@ -577,31 +534,16 @@ public function squeeze(?int $axis = null): static
577534
*/
578535
public function unsqueeze(?int $axis = null): static
579536
{
580-
return new Tensor(
581-
$this->buffer(),
582-
$this->dtype,
583-
$this->calcUnsqueezeShape($this->shape(), $axis),
584-
$this->offset
585-
);
586-
}
537+
$shape = $this->shape();
587538

588-
/**
589-
* Helper function to calculate new shape when performing an unsqueeze operation.
590-
* @param array $shape The shape of the tensor.
591-
* @param int $axis The axis to unsqueeze.
592-
* @return array The new shape.
593-
*/
594-
protected function calcUnsqueezeShape(array $shape, int $axis): array
595-
{
596-
// Dimension out of range (e.g., "expected to be in range of [-4, 3], but got 4")
597-
// + 1 since we allow inserting at the end (i.e. $axis = -1)
598539
$axis = self::safeIndex($axis, count($shape) + 1);
599540

600541
array_splice($shape, $axis, 0, 1);
601542

602-
return $shape;
543+
return new Tensor($this->buffer(), $this->dtype, $shape, $this->offset);
603544
}
604545

546+
605547
/**
606548
* Add a tensor or scalar to this tensor. If it's a tensor, it must be the same shape, and it performs
607549
* an element-wise addition. If it's a scalar, it adds the scalar to every element in the tensor.

0 commit comments

Comments
 (0)