Skip to content

Commit c91a888

Browse files
bugfix: Tensor norms not handling right buffer
1 parent d014ef0 commit c91a888

1 file changed

Lines changed: 58 additions & 61 deletions

File tree

src/Utils/Tensor.php

Lines changed: 58 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -604,32 +604,30 @@ public function normalize(int $p = 2, ?int $dim = null): static
604604
* Returns the matrix norm or vector norm of a given tensor.
605605
*
606606
* @param int $ord Order of the norm. Supported values are 1, 2, Infinity.
607-
* @param int|null $dim The axis or axes along which to perform the reduction. If null (default), reduces all dimensions.
607+
* @param int|null $axis The axis or axes along which to perform the reduction. If null (default), reduces all dimensions.
608608
* @param bool $keepdims If true, retains reduced dimensions with length 1.
609609
*
610610
* @return static
611611
*/
612-
public function norm(int $ord = 2, ?int $dim = null, bool $keepdims = false): static
612+
public function norm(int $ord = 2, ?int $axis = null, bool $keepdims = false): static
613613
{
614614
$mo = self::getMo();
615615

616-
if ($dim === null) {
617-
$val = pow(array_reduce($this->buffer, function ($carry, $item) use ($ord) {
618-
return $carry + pow($item, $ord);
619-
}, 0), 1 / $ord);
616+
if ($axis === null) {
617+
$val = pow(array_reduce($this->toBufferArray(), fn($carry, $item) => $carry + pow($item, $ord), 0), 1 / $ord);
620618

621619
return new Tensor([$val], $this->dtype(), []);
622620
}
623621

624622
// Negative indexing
625-
$dim = $this->safeIndex($dim, $this->ndim());
623+
$axis = $this->safeIndex($axis, $this->ndim());
626624

627625
// Calculate the shape of the resulting array after summation
628626
$resultDims = $this->shape();
629-
$resultDims[$dim] = 1; // Remove the specified axis
627+
$resultDims[$axis] = 1; // Remove the specified axis
630628

631629
// Create a new array to store the accumulated values
632-
$result = $this->zeros([count($this->buffer) / $this->shape()[$dim]]);
630+
$result = $this->zeros([count($this->buffer) / $this->shape()[$axis]]);
633631

634632
// Iterate over the data array
635633
foreach ($this->buffer as $i => $value) {
@@ -641,7 +639,7 @@ public function norm(int $ord = 2, ?int $dim = null, bool $keepdims = false): st
641639
for ($j = $this->ndim() - 1; $j >= 0; --$j) {
642640
$size = $this->shape()[$j];
643641

644-
if ($j !== $dim) {
642+
if ($j !== $axis) {
645643
$index = $num % $size;
646644
$resultIndex += $index * $resultMultiplier;
647645
$resultMultiplier *= $resultDims[$j];
@@ -659,12 +657,61 @@ public function norm(int $ord = 2, ?int $dim = null, bool $keepdims = false): st
659657
}
660658

661659
if (!$keepdims) {
662-
array_splice($resultDims, $dim, 1);
660+
array_splice($resultDims, $axis, 1);
663661
}
664662

665663
return new static($result->buffer(), $result->dtype(), $resultDims, $result->offset());
666664
}
667665

666+
/**
667+
* Convert the tensor into a flat array of the buffer contents.
668+
*/
669+
public function toBufferArray()
670+
{
671+
if ($this->buffer instanceof OpenBlasBuffer) {
672+
return $this->buffer->dump();
673+
} elseif ($this->buffer instanceof SplFixedArray) {
674+
return $this->buffer->toArray();
675+
} else {
676+
throw new RuntimeException('Unknown buffer type is inconvertible:' . get_class($this->buffer));
677+
}
678+
}
679+
680+
/**
681+
* Convert the tensor into an array.
682+
*/
683+
public function toArray()
684+
{
685+
if (count($this->shape) == 0) {
686+
return $this->buffer[$this->offset];
687+
}
688+
689+
$idx = $this->offset;
690+
691+
return $this->unflattenArray($this->buffer, $idx, $this->shape);
692+
}
693+
694+
/**
695+
* Unflatten the given flat array into a nested array according to the given shape.
696+
*/
697+
protected function unflattenArray($flatArray, &$currentIndex, array $shape): array
698+
{
699+
$size = array_shift($shape);
700+
$nestedArray = [];
701+
702+
if (count($shape)) {
703+
for ($i = 0; $i < $size; $i++) {
704+
$nestedArray[$i] = $this->unflattenArray($flatArray, $currentIndex, $shape);
705+
}
706+
} else {
707+
for ($i = 0; $i < $size; $i++) {
708+
$nestedArray[$i] = $flatArray[$currentIndex];
709+
$currentIndex++;
710+
}
711+
}
712+
return $nestedArray;
713+
}
714+
668715
/**
669716
* Return a zero matrix with the given dimensions.
670717
* @param array $shape The shape of the zero matrix to return.
@@ -879,7 +926,6 @@ public function stride(): array
879926
return array_reverse($stride, true);
880927
}
881928

882-
883929
/**
884930
* Permutes a tensor according to the provided axes.
885931
* @param array $axes The axes to permute the tensor along.
@@ -892,55 +938,6 @@ public function permute(...$axes): static
892938
return new Tensor($permutedData, $this->dtype(), $shape);
893939
}
894940

895-
/**
896-
* Convert the tensor into a flat array of the buffer contents.
897-
*/
898-
public function toBufferArray()
899-
{
900-
if ($this->buffer instanceof OpenBlasBuffer) {
901-
return $this->buffer->dump();
902-
} elseif ($this->buffer instanceof SplFixedArray) {
903-
return $this->buffer->toArray();
904-
} else {
905-
throw new RuntimeException('Unknown buffer type is inconvertible:' . get_class($this->buffer));
906-
}
907-
}
908-
909-
/**
910-
* Convert the tensor into an array.
911-
*/
912-
public function toArray()
913-
{
914-
if (count($this->shape) == 0) {
915-
return $this->buffer[$this->offset];
916-
}
917-
918-
$idx = $this->offset;
919-
920-
return $this->unflattenArray($this->buffer, $idx, $this->shape);
921-
}
922-
923-
/**
924-
* Unflatten the given flat array into a nested array according to the given shape.
925-
*/
926-
protected function unflattenArray($flatArray, &$currentIndex, array $shape): array
927-
{
928-
$size = array_shift($shape);
929-
$nestedArray = [];
930-
931-
if (count($shape)) {
932-
for ($i = 0; $i < $size; $i++) {
933-
$nestedArray[$i] = $this->unflattenArray($flatArray, $currentIndex, $shape);
934-
}
935-
} else {
936-
for ($i = 0; $i < $size; $i++) {
937-
$nestedArray[$i] = $flatArray[$currentIndex];
938-
$currentIndex++;
939-
}
940-
}
941-
return $nestedArray;
942-
}
943-
944941
/**
945942
* Calculate the softmax of the tensor.
946943
*

0 commit comments

Comments
 (0)