Skip to content

Commit d014ef0

Browse files
Add dot and cross product to Tensor
1 parent 06940e4 commit d014ef0

3 files changed

Lines changed: 33 additions & 23 deletions

File tree

src/Pipelines/ImageToImagePipeline.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public function __invoke(array|string $inputs, ...$args): array|Image
6060
foreach ($outputs['reconstruction'] as $i => $batch) {
6161
$output = $batch->squeeze()
6262
->clamp(0, 1)
63-
->multiplyScalar(255)
63+
->multiply(255)
6464
->round()
6565
->to(NDArray::uint8);
6666

src/Utils/Tensor.php

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -477,34 +477,21 @@ protected function calcUnsqueezeDims(array $dims, int $dim): array
477477
}
478478

479479
/**
480-
* Add two NDArrays element-wise, A + B
480+
* Add a tensor or scalar to this tensor. If it's a tensor, it must be the same shape, and it performs
481+
* an element-wise addition. If it's a scalar, it adds the scalar to every element in the tensor.
481482
*
482-
* @param Tensor $other The NDArray to add to this NDArray.
483+
* @param Tensor|float|int $other The NDArray to add to this NDArray.
483484
* @return static
484485
*/
485-
public function add(Tensor $other): static
486+
public function add(Tensor|float|int $other): static
486487
{
487488
$mo = self::getMo();
488489

489-
$ndArray = $mo->add($this, $other);
490+
$ndArray = is_scalar($other) ? $mo->op($this, '+', $other) : $mo->add($this, $other);
490491

491492
return new static($ndArray->buffer(), $ndArray->dtype(), $ndArray->shape(), $ndArray->offset());
492493
}
493494

494-
/**
495-
* Return a new Tensor with every element added by a constant.
496-
*
497-
* @param float|int $scalar The constant to add.
498-
* @return static
499-
*/
500-
public function addScalar(float|int $scalar): static
501-
{
502-
$mo = self::getMo();
503-
504-
$ndArray = $mo->op($this, '+', $scalar);
505-
506-
return new static($ndArray->buffer(), $ndArray->dtype(), $ndArray->shape(), $ndArray->offset());
507-
}
508495

509496
/**
510497
* Return a new Tensor with the sigmoid function applied to each element.
@@ -526,7 +513,7 @@ public function sigmoid(): self
526513
*
527514
* @return self
528515
*/
529-
public function multiplyScalar(float|int $scalar): self
516+
public function multiply(float|int $scalar): self
530517
{
531518
$mo = self::getMo();
532519

@@ -535,6 +522,29 @@ public function multiplyScalar(float|int $scalar): self
535522
return new static($ndArray->buffer(), $ndArray->dtype(), $ndArray->shape(), $ndArray->offset());
536523
}
537524

525+
/**
526+
* Calculate the dot product of this tensor and another tensor.
527+
*/
528+
public function dot(Tensor $other): float
529+
{
530+
$mo = self::getMo();
531+
532+
return $mo->dot($this, $other);
533+
}
534+
535+
/**
536+
* Calculate the cross product of this tensor and another tensor. The shapes of the tensors must be compatible for
537+
* cross product
538+
*/
539+
public function cross(Tensor $other): Tensor
540+
{
541+
$mo = self::getMo();
542+
543+
$crossProduct = $mo->cross($this, $other);
544+
545+
return new static($crossProduct->buffer(), $crossProduct->dtype(), $crossProduct->shape(), $crossProduct->offset());
546+
}
547+
538548
/**
539549
* Return a transposed version of this Tensor.
540550
* @return $this

tests/Utils/TensorTest.php

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666

6767
it('can add a scalar to each element of a tensor', function () {
6868
$tensor = Tensor::fromArray([[1, 2], [3, 4]]);
69-
$result = $tensor->addScalar(5);
69+
$result = $tensor->add(5);
7070

7171
expect($result)->toBeInstanceOf(Tensor::class)
7272
->and($result->toArray())->toBe([[6, 7], [8, 9]]);
@@ -82,15 +82,15 @@
8282

8383
it('can multiply each element of a tensor by a scalar', function () {
8484
$tensor = Tensor::fromArray([[1, 2], [3, 4]]);
85-
$result = $tensor->multiplyScalar(2);
85+
$result = $tensor->multiply(2);
8686

8787
expect($result)->toBeInstanceOf(Tensor::class)
8888
->and($result->toArray())->toBe([[2.0, 4.0], [6.0, 8.0]]);
8989
});
9090

9191
it('can compute the mean value of each row of the tensor', function () {
9292
$tensor = Tensor::fromArray([[1, 2], [3, 4]]);
93-
$result = $tensor->mean(dim: 1);
93+
$result = $tensor->mean(axis: 1);
9494

9595
expect($result)->toBeInstanceOf(Tensor::class)
9696
->and($result->toArray())->toBe([1.5, 3.5]);

0 commit comments

Comments
 (0)