@@ -604,32 +604,30 @@ public function normalize(int $p = 2, ?int $dim = null): static
604604 * Returns the matrix norm or vector norm of a given tensor.
605605 *
606606 * @param int $ord Order of the norm. Supported values are 1, 2, Infinity.
607- * @param int|null $dim The axis or axes along which to perform the reduction. If null (default), reduces all dimensions.
607+ * @param int|null $axis The axis or axes along which to perform the reduction. If null (default), reduces all dimensions.
608608 * @param bool $keepdims If true, retains reduced dimensions with length 1.
609609 *
610610 * @return static
611611 */
612- public function norm (int $ ord = 2 , ?int $ dim = null , bool $ keepdims = false ): static
612+ public function norm (int $ ord = 2 , ?int $ axis = null , bool $ keepdims = false ): static
613613 {
614614 $ mo = self ::getMo ();
615615
616- if ($ dim === null ) {
617- $ val = pow (array_reduce ($ this ->buffer , function ($ carry , $ item ) use ($ ord ) {
618- return $ carry + pow ($ item , $ ord );
619- }, 0 ), 1 / $ ord );
616+ if ($ axis === null ) {
617+ $ val = pow (array_reduce ($ this ->toBufferArray (), fn ($ carry , $ item ) => $ carry + pow ($ item , $ ord ), 0 ), 1 / $ ord );
620618
621619 return new Tensor ([$ val ], $ this ->dtype (), []);
622620 }
623621
624622 // Negative indexing
625- $ dim = $ this ->safeIndex ($ dim , $ this ->ndim ());
623+ $ axis = $ this ->safeIndex ($ axis , $ this ->ndim ());
626624
627625 // Calculate the shape of the resulting array after summation
628626 $ resultDims = $ this ->shape ();
629- $ resultDims [$ dim ] = 1 ; // Remove the specified axis
627+ $ resultDims [$ axis ] = 1 ; // Remove the specified axis
630628
631629 // Create a new array to store the accumulated values
632- $ result = $ this ->zeros ([count ($ this ->buffer ) / $ this ->shape ()[$ dim ]]);
630+ $ result = $ this ->zeros ([count ($ this ->buffer ) / $ this ->shape ()[$ axis ]]);
633631
634632 // Iterate over the data array
635633 foreach ($ this ->buffer as $ i => $ value ) {
@@ -641,7 +639,7 @@ public function norm(int $ord = 2, ?int $dim = null, bool $keepdims = false): st
641639 for ($ j = $ this ->ndim () - 1 ; $ j >= 0 ; --$ j ) {
642640 $ size = $ this ->shape ()[$ j ];
643641
644- if ($ j !== $ dim ) {
642+ if ($ j !== $ axis ) {
645643 $ index = $ num % $ size ;
646644 $ resultIndex += $ index * $ resultMultiplier ;
647645 $ resultMultiplier *= $ resultDims [$ j ];
@@ -659,12 +657,61 @@ public function norm(int $ord = 2, ?int $dim = null, bool $keepdims = false): st
659657 }
660658
661659 if (!$ keepdims ) {
662- array_splice ($ resultDims , $ dim , 1 );
660+ array_splice ($ resultDims , $ axis , 1 );
663661 }
664662
665663 return new static ($ result ->buffer (), $ result ->dtype (), $ resultDims , $ result ->offset ());
666664 }
667665
666+ /**
667+ * Convert the tensor into a flat array of the buffer contents.
668+ */
669+ public function toBufferArray ()
670+ {
671+ if ($ this ->buffer instanceof OpenBlasBuffer) {
672+ return $ this ->buffer ->dump ();
673+ } elseif ($ this ->buffer instanceof SplFixedArray) {
674+ return $ this ->buffer ->toArray ();
675+ } else {
676+ throw new RuntimeException ('Unknown buffer type is inconvertible: ' . get_class ($ this ->buffer ));
677+ }
678+ }
679+
680+ /**
681+ * Convert the tensor into an array.
682+ */
683+ public function toArray ()
684+ {
685+ if (count ($ this ->shape ) == 0 ) {
686+ return $ this ->buffer [$ this ->offset ];
687+ }
688+
689+ $ idx = $ this ->offset ;
690+
691+ return $ this ->unflattenArray ($ this ->buffer , $ idx , $ this ->shape );
692+ }
693+
694+ /**
695+ * Unflatten the given flat array into a nested array according to the given shape.
696+ */
697+ protected function unflattenArray ($ flatArray , &$ currentIndex , array $ shape ): array
698+ {
699+ $ size = array_shift ($ shape );
700+ $ nestedArray = [];
701+
702+ if (count ($ shape )) {
703+ for ($ i = 0 ; $ i < $ size ; $ i ++) {
704+ $ nestedArray [$ i ] = $ this ->unflattenArray ($ flatArray , $ currentIndex , $ shape );
705+ }
706+ } else {
707+ for ($ i = 0 ; $ i < $ size ; $ i ++) {
708+ $ nestedArray [$ i ] = $ flatArray [$ currentIndex ];
709+ $ currentIndex ++;
710+ }
711+ }
712+ return $ nestedArray ;
713+ }
714+
668715 /**
669716 * Return a zero matrix with the given dimensions.
670717 * @param array $shape The shape of the zero matrix to return.
@@ -879,7 +926,6 @@ public function stride(): array
879926 return array_reverse ($ stride , true );
880927 }
881928
882-
883929 /**
884930 * Permutes a tensor according to the provided axes.
885931 * @param array $axes The axes to permute the tensor along.
@@ -892,55 +938,6 @@ public function permute(...$axes): static
892938 return new Tensor ($ permutedData , $ this ->dtype (), $ shape );
893939 }
894940
895- /**
896- * Convert the tensor into a flat array of the buffer contents.
897- */
898- public function toBufferArray ()
899- {
900- if ($ this ->buffer instanceof OpenBlasBuffer) {
901- return $ this ->buffer ->dump ();
902- } elseif ($ this ->buffer instanceof SplFixedArray) {
903- return $ this ->buffer ->toArray ();
904- } else {
905- throw new RuntimeException ('Unknown buffer type is inconvertible: ' . get_class ($ this ->buffer ));
906- }
907- }
908-
909- /**
910- * Convert the tensor into an array.
911- */
912- public function toArray ()
913- {
914- if (count ($ this ->shape ) == 0 ) {
915- return $ this ->buffer [$ this ->offset ];
916- }
917-
918- $ idx = $ this ->offset ;
919-
920- return $ this ->unflattenArray ($ this ->buffer , $ idx , $ this ->shape );
921- }
922-
923- /**
924- * Unflatten the given flat array into a nested array according to the given shape.
925- */
926- protected function unflattenArray ($ flatArray , &$ currentIndex , array $ shape ): array
927- {
928- $ size = array_shift ($ shape );
929- $ nestedArray = [];
930-
931- if (count ($ shape )) {
932- for ($ i = 0 ; $ i < $ size ; $ i ++) {
933- $ nestedArray [$ i ] = $ this ->unflattenArray ($ flatArray , $ currentIndex , $ shape );
934- }
935- } else {
936- for ($ i = 0 ; $ i < $ size ; $ i ++) {
937- $ nestedArray [$ i ] = $ flatArray [$ currentIndex ];
938- $ currentIndex ++;
939- }
940- }
941- return $ nestedArray ;
942- }
943-
944941 /**
945942 * Calculate the softmax of the tensor.
946943 *
0 commit comments