@@ -459,9 +459,11 @@ public static function zerosLike(Tensor $other): static
459459 */
460460 public static function stack (array $ tensors , int $ axis = 0 ): Tensor
461461 {
462- // TODO: Perform validation of shapes
463- // NOTE: stack expects each tensor to be equal size
464- return self ::cat (array_map (fn ($ t ) => $ t ->unsqueeze ($ axis ), $ tensors ), $ axis );
462+ $ mo = self ::mo ();
463+
464+ $ stacked = $ mo ->la ()->stack ($ tensors , $ axis );
465+
466+ return new Tensor ($ stacked ->buffer (), $ stacked ->dtype (), $ stacked ->shape (), $ stacked ->offset ());
465467 }
466468
467469 /**
@@ -473,58 +475,13 @@ public static function stack(array $tensors, int $axis = 0): Tensor
473475 * @return Tensor The concatenated tensor.
474476 * @throws Exception
475477 */
476- public static function cat (array $ tensors , int $ axis = 0 ): Tensor
478+ public static function concat (array $ tensors , int $ axis = 0 ): Tensor
477479 {
478- $ axis = self ::safeIndex ($ axis , $ tensors [0 ]->ndim ());
479-
480- // TODO: Perform validation of shapes
481-
482- $ resultShape = $ tensors [0 ]->shape ();
483- $ resultOffset = $ tensors [0 ]->offset ();
484- $ resultType = $ tensors [0 ]->dtype ();
485- $ resultShape [$ axis ] = array_reduce ($ tensors , fn ($ carry , $ tensor ) => $ carry + $ tensor ->shape ()[$ axis ], 0 );
486-
487- // Create a new array to store the accumulated values
488- $ resultSize = array_product ($ resultShape );
489-
490- $ result = self ::newBuffer ($ resultSize , $ resultType );
491-
492- // Create output tensor of same type as first
493-
494- if ($ axis === 0 ) {
495- // Handle special case for performance reasons
496-
497- $ offset = 0 ;
498- foreach ($ tensors as $ t ) {
499- for ($ i = 0 ; $ i < $ t ->buffer ->count (); $ i ++) {
500- $ result [$ offset ++] = $ t ->buffer ()[$ i ];
501- }
502- }
503- } else {
504- $ currentShape = 0 ;
505-
506- foreach ($ tensors as $ tensor ) {
507- for ($ i = 0 ; $ i < $ tensor ->buffer ->count (); $ i ++) {
508- $ resultIndex = 0 ;
509-
510- for ($ j = $ tensor ->ndim () - 1 , $ num = $ i , $ resultMultiplier = 1 ; $ j >= 0 ; --$ j ) {
511- $ size = $ tensor ->shape ()[$ j ];
512- $ index = $ num % $ size ;
513- if ($ j === $ axis ) {
514- $ index += $ currentShape ;
515- }
516- $ resultIndex += $ index * $ resultMultiplier ;
517- $ resultMultiplier *= $ resultShape [$ j ];
518- $ num = (int )floor ($ num / $ size );
519- }
520- $ result [$ resultIndex ] = $ tensor ->buffer ()[$ i ];
521- }
480+ $ mo = self ::mo ();
522481
523- $ currentShape += $ tensor ->shape ()[$ axis ];
524- }
525- }
482+ $ ndArray = $ mo ->la ()->concat ($ tensors , $ axis );
526483
527- return new Tensor ( $ result , $ resultType , $ resultShape , $ resultOffset );
484+ return new static ( $ ndArray -> buffer () , $ ndArray -> dtype () , $ ndArray -> shape () , $ ndArray -> offset () );
528485 }
529486
530487 /**
@@ -577,31 +534,16 @@ public function squeeze(?int $axis = null): static
577534 */
578535 public function unsqueeze (?int $ axis = null ): static
579536 {
580- return new Tensor (
581- $ this ->buffer (),
582- $ this ->dtype ,
583- $ this ->calcUnsqueezeShape ($ this ->shape (), $ axis ),
584- $ this ->offset
585- );
586- }
537+ $ shape = $ this ->shape ();
587538
588- /**
589- * Helper function to calculate new shape when performing an unsqueeze operation.
590- * @param array $shape The shape of the tensor.
591- * @param int $axis The axis to unsqueeze.
592- * @return array The new shape.
593- */
594- protected function calcUnsqueezeShape (array $ shape , int $ axis ): array
595- {
596- // Dimension out of range (e.g., "expected to be in range of [-4, 3], but got 4")
597- // + 1 since we allow inserting at the end (i.e. $axis = -1)
598539 $ axis = self ::safeIndex ($ axis , count ($ shape ) + 1 );
599540
600541 array_splice ($ shape , $ axis , 0 , 1 );
601542
602- return $ shape ;
543+ return new Tensor ( $ this -> buffer (), $ this -> dtype , $ shape, $ this -> offset ) ;
603544 }
604545
546+
605547 /**
606548 * Add a tensor or scalar to this tensor. If it's a tensor, it must be the same shape, and it performs
607549 * an element-wise addition. If it's a scalar, it adds the scalar to every element in the tensor.
0 commit comments