@@ -44,9 +44,10 @@ abstract public function sample(Tensor $logits, int $index);
4444 * @param int $index
4545 * @return array
4646 */
47- public function getLogits (Tensor $ logits , int $ index ): array
47+ public function getLogits (Tensor $ logits , int $ index ): Tensor
4848 {
49- $ vocabSize = $ logits ->shape ()[count ($ logits ->shape ()) - 1 ];
49+ $ vocabSize = $ logits ->shape ()[$ logits ->ndim () - 1 ];
50+
5051// $logs = $logits->buffer()->toArray();
5152//
5253// if ($index === -1) {
@@ -56,21 +57,23 @@ public function getLogits(Tensor $logits, int $index): array
5657// $logs = array_slice($logs, $startIndex, $startIndex + $vocabSize);
5758// }
5859
59- $ start = $ index === - 1 ? $ logits ->buffer ()-> count () - $ vocabSize : $ index * $ vocabSize ;
60- $ end = $ start + $ vocabSize ;
60+ $ start = array_fill ( 0 , $ logits ->ndim () - 2 , 0 ) ;
61+ $ size = array_fill ( 0 , $ logits -> ndim () - 2 , 1 ) ;
6162
62- $ logs = [];
63+ $ start [] = $ index ;
64+ $ size [] = 1 ;
6365
64- for ($ i = $ start ; $ i < $ end ; $ i ++) {
65- $ logs [] = $ logits ->buffer ()[$ i ];
66- }
66+ $ start [] = -$ vocabSize ;
67+ $ size [] = $ vocabSize ;
68+
69+ $ logs = $ logits ->newSlice ($ start , $ size );
6770
68- // add temperature
6971 if ($ this ->generationConfig ->temperature > 0 ) {
70- $ logs = array_map ( fn ( $ x ) => $ x / $ this ->generationConfig ->temperature , $ logs );
72+ $ logs = $ logs -> divide ( $ this ->generationConfig ->temperature );
7173 }
7274
73- return $ logs ;
75+ // Remove all dimensions of 1, leaving a flat 1D array of vocab_size
76+ return $ logs ->squeeze ();
7477 }
7578
7679 /**
@@ -85,6 +88,7 @@ public function randomSelect(array $probabilities): int
8588
8689 // Generate a random number between 0 and the sum of probabilities
8790 $ r = mt_rand () / mt_getrandmax () * $ sumProbabilities ;
91+
8892 foreach ($ probabilities as $ i => $ probability ) {
8993 $ r -= $ probability ;
9094
0 commit comments