@@ -261,10 +261,10 @@ public static function onesLike(Tensor $other): static
261261 * Return a one matrix with the given dimensions.
262262 *
263263 * @param array $shape The shape of the one matrix to return.
264- * @param string|null $dtype The data type of the one matrix to return. Eg: float32, int32, etc. If null, defaults to float32.
264+ * @param ?int $dtype The data type of the one matrix to return. Eg: float32, int32, etc. If null, defaults to float32.
265265 * @return static
266266 */
267- public static function ones (array $ shape , ?string $ dtype = null ): static
267+ public static function ones (array $ shape , ?int $ dtype = null ): static
268268 {
269269 $ mo = self ::getMo ();
270270
@@ -303,7 +303,7 @@ public static function fromArray(array|NDArray $array, ?string $dtype = null, $s
303303 /**
304304 * Reshape the tensor into the given shape.
305305 */
306- public function reshape (array $ shape ): NDArray
306+ public function reshape (array $ shape ): static
307307 {
308308 $ this ->assertShape ($ shape );
309309
@@ -443,37 +443,37 @@ public function count(): int
443443 }
444444
445445 /**
446- * Returns a tensor with all specified dimensions of input of size 1 removed.
446+ * Returns a tensor with all specified axis of input of size 1 removed.
447447 *
448- * @param ?int $dim If given, the input will be squeezed only in the specified dimensions .
448+ * @param ?int $axis If given, the input will be squeezed only in the specified axis .
449449 *
450450 * @return static The squeezed tensor.
451451 */
452- public function unsqueeze (?int $ dim = null ): static
452+ public function unsqueeze (?int $ axis = null ): static
453453 {
454454 return new Tensor (
455455 $ this ->buffer (),
456456 $ this ->dtype ,
457- $ this ->calcUnsqueezeDims ($ this ->shape (), $ dim ),
457+ $ this ->calcUnsqueezeShape ($ this ->shape (), $ axis ),
458458 $ this ->offset
459459 );
460460 }
461461
462462 /**
463- * Helper function to calculate new dimensions when performing an unsqueeze operation.
464- * @param array $dims The dimensions of the tensor.
465- * @param int $dim The dimension to unsqueeze.
466- * @return array The new dimensions .
463+ * Helper function to calculate new shape when performing an unsqueeze operation.
464+ * @param array $shape The shape of the tensor.
465+ * @param int $axis The axis to unsqueeze.
466+ * @return array The new shape .
467467 */
468- protected function calcUnsqueezeDims (array $ dims , int $ dim ): array
468+ protected function calcUnsqueezeShape (array $ shape , int $ axis ): array
469469 {
470470 // Dimension out of range (e.g., "expected to be in range of [-4, 3], but got 4")
471471 // + 1 since we allow inserting at the end (i.e. dim = -1)
472- $ dim = self ::safeIndex ($ dim , count ($ dims ) + 1 );
473- $ newDims = $ dims ;
474- // Insert 1 into specified dimension
475- array_splice ( $ newDims , $ dim , 0 , [ 1 ]);
476- return $ newDims ;
472+ $ axis = self ::safeIndex ($ axis , count ($ shape ) + 1 );
473+
474+ $ shape [ $ axis - 1 ] = 1 ;
475+
476+ return $ shape ;
477477 }
478478
479479 /**
@@ -605,11 +605,11 @@ public function normalize(int $p = 2, ?int $dim = null): static
605605 *
606606 * @param int $ord Order of the norm. Supported values are 1, 2, Infinity.
607607 * @param int|null $axis The axis or axes along which to perform the reduction. If null (default), reduces all dimensions.
608- * @param bool $keepdims If true, retains reduced dimensions with length 1.
608+ * @param bool $keepShape If true, retains reduced shape with length 1.
609609 *
610610 * @return static
611611 */
612- public function norm (int $ ord = 2 , ?int $ axis = null , bool $ keepdims = false ): static
612+ public function norm (int $ ord = 2 , ?int $ axis = null , bool $ keepShape = false ): static
613613 {
614614 $ mo = self ::getMo ();
615615
@@ -623,8 +623,8 @@ public function norm(int $ord = 2, ?int $axis = null, bool $keepdims = false): s
623623 $ axis = $ this ->safeIndex ($ axis , $ this ->ndim ());
624624
625625 // Calculate the shape of the resulting array after summation
626- $ resultDims = $ this ->shape ();
627- $ resultDims [$ axis ] = 1 ; // Remove the specified axis
626+ $ resultShape = $ this ->shape ();
627+ $ resultShape [$ axis ] = 1 ; // Remove the specified axis
628628
629629 // Create a new array to store the accumulated values
630630 $ result = $ this ->zeros ([count ($ this ->buffer ) / $ this ->shape ()[$ axis ]]);
@@ -642,7 +642,7 @@ public function norm(int $ord = 2, ?int $axis = null, bool $keepdims = false): s
642642 if ($ j !== $ axis ) {
643643 $ index = $ num % $ size ;
644644 $ resultIndex += $ index * $ resultMultiplier ;
645- $ resultMultiplier *= $ resultDims [$ j ];
645+ $ resultMultiplier *= $ resultShape [$ j ];
646646 }
647647
648648 $ num = floor ($ num / $ size );
@@ -656,11 +656,11 @@ public function norm(int $ord = 2, ?int $axis = null, bool $keepdims = false): s
656656 $ result = $ mo ->op ($ result , '** ' , 1 / $ ord );
657657 }
658658
659- if (!$ keepdims ) {
660- array_splice ($ resultDims , $ axis , 1 );
659+ if (!$ keepShape ) {
660+ array_splice ($ resultShape , $ axis , 1 );
661661 }
662662
663- return new static ($ result ->buffer (), $ result ->dtype (), $ resultDims , $ result ->offset ());
663+ return new static ($ result ->buffer (), $ result ->dtype (), $ resultShape , $ result ->offset ());
664664 }
665665
666666 /**
@@ -794,7 +794,7 @@ public function to(int $dtype): static
794794 /**
795795 * Returns the mean value of each row of the tensor in the given dimension dim.
796796 */
797- public function mean (?int $ axis = null , bool $ keepdims = false ): static |float |int
797+ public function mean (?int $ axis = null , bool $ keepShape = false ): static |float |int
798798 {
799799 $ mo = self ::getMo ();
800800
@@ -803,7 +803,7 @@ public function mean(?int $axis = null, bool $keepdims = false): static|float|in
803803 if ($ mean instanceof NDArray) {
804804 $ shape = $ mean ->shape ();
805805
806- if (!$ keepdims ) {
806+ if (!$ keepShape ) {
807807 array_splice ($ shape , $ axis , 1 );
808808 }
809809
@@ -858,15 +858,15 @@ public function meanPooling(Tensor $other): Tensor
858858
859859 public function slice (...$ slices ): Tensor
860860 {
861- $ newTensorDims = [];
861+ $ newTensorShape = [];
862862 $ newOffsets = [];
863863
864864 for ($ sliceIndex = 0 ; $ sliceIndex < $ this ->ndim (); ++$ sliceIndex ) {
865865 $ slice = $ slices [$ sliceIndex ] ?? null ;
866866
867867 if ($ slice === null ) {
868868 $ newOffsets [] = [0 , $ this ->shape ()[$ sliceIndex ]];
869- $ newTensorDims [] = $ this ->shape ()[$ sliceIndex ];
869+ $ newTensorShape [] = $ this ->shape ()[$ sliceIndex ];
870870
871871 } elseif (is_int ($ slice )) {
872872 $ slice = $ this ->safeIndex ($ slice , $ this ->shape ()[$ sliceIndex ], $ sliceIndex );
@@ -881,31 +881,31 @@ public function slice(...$slices): Tensor
881881 min ($ slice [1 ], $ this ->shape ()[$ sliceIndex ])
882882 ];
883883 $ newOffsets [] = $ offsets ;
884- $ newTensorDims [] = $ offsets [1 ] - $ offsets [0 ];
884+ $ newTensorShape [] = $ offsets [1 ] - $ offsets [0 ];
885885
886886 } else {
887887 throw new Exception ("Invalid slice: " . json_encode ($ slice ));
888888 }
889889 }
890890
891- $ newDims = array_map (fn ($ offsets ) => $ offsets [1 ] - $ offsets [0 ], $ newOffsets );
891+ $ newShape = array_map (fn ($ offsets ) => $ offsets [1 ] - $ offsets [0 ], $ newOffsets );
892892
893- $ newBufferSize = array_reduce ($ newDims , fn ($ a , $ b ) => $ a * $ b , 1 );
893+ $ newBufferSize = array_reduce ($ newShape , fn ($ a , $ b ) => $ a * $ b , 1 );
894894
895895 $ buffer = $ this ->newBuffer ($ newBufferSize , $ this ->dtype ());
896896 $ stride = $ this ->stride ();
897897
898898 for ($ i = 0 ; $ i < $ newBufferSize ; ++$ i ) {
899899 $ originalIndex = 0 ;
900- for ($ j = count ($ newDims ) - 1 , $ num = $ i ; $ j >= 0 ; --$ j ) {
901- $ size = $ newDims [$ j ];
900+ for ($ j = count ($ newShape ) - 1 , $ num = $ i ; $ j >= 0 ; --$ j ) {
901+ $ size = $ newShape [$ j ];
902902 $ originalIndex += (($ num % $ size ) + $ newOffsets [$ j ][0 ]) * $ stride [$ j ];
903903 $ num = floor ($ num / $ size );
904904 }
905905 $ buffer [$ i ] = $ this ->buffer [$ originalIndex ];
906906 }
907907
908- return new Tensor ($ buffer , $ this ->dtype (), $ newDims , $ this ->offset ());
908+ return new Tensor ($ buffer , $ this ->dtype (), $ newShape , $ this ->offset ());
909909 }
910910
911911 /**
0 commit comments