2727use Codewithkyrian \Transformers \Utils \AutoConfig ;
2828use Codewithkyrian \Transformers \Utils \GenerationConfig ;
2929use Codewithkyrian \Transformers \Utils \Hub ;
30+ use Codewithkyrian \Transformers \Utils \InferenceSession ;
3031use Codewithkyrian \Transformers \Utils \Tensor ;
3132use Error ;
3233use Exception ;
33- use OnnxRuntime \InferenceSession ;
3434use Symfony \Component \Console \Output \OutputInterface ;
3535use function Codewithkyrian \Transformers \Utils \array_some ;
3636
@@ -281,9 +281,7 @@ public function runSession(InferenceSession $session, array $inputs): array
281281
282282 $ outputNames = array_column ($ session ->outputs (), 'name ' );
283283
284- $ outputs = $ session ->run ($ outputNames , $ inputs );
285-
286- return array_combine ($ outputNames , array_map ([Tensor::class, 'fromArray ' ], $ outputs ));
284+ return $ session ->run ($ outputNames , $ inputs );
287285 } catch (MissingModelInputException $ e ) {
288286 throw $ e ;
289287 } catch (Exception $ e ) {
@@ -331,7 +329,8 @@ public function validateInputs(array $inputNames, array $inputs): array
331329 The following inputs will be ignored: " ' . implode (', ' , $ ignored ) . '". ' ;
332330 }
333331
334- return array_map (fn ($ i ) => $ i ->toArray (), $ inputs );
332+ // return array_map(fn($i) => $i->toArray(), $inputs);
333+ return $ inputs ;
335334 }
336335
337336 /**
@@ -468,50 +467,50 @@ public function addPastKeyValues(array &$decoderFeeds, ?array $pastKeyValues): v
468467 $ decoderFeeds = array_merge ($ decoderFeeds , $ pastKeyValues );
469468 } else {
470469 // TODO support batches (i.e., batch_size > 1)
471- $ batch_size = 1 ;
470+ $ batchSize = 1 ;
472471
473472 if ($ this ->config ->isEncoderDecoder && ($ this ->addEncoderPkv ?? true )) {
474- $ encoderShape = [$ batch_size , $ this ->numEncoderHeads , 1 , $ this ->encoderDimKv ];
475- $ decoderShape = [$ batch_size , $ this ->numDecoderHeads , 1 , $ this ->decoderDimKv ];
473+ $ encoderShape = [$ batchSize , $ this ->numEncoderHeads , 0 , $ this ->encoderDimKv ];
474+ $ decoderShape = [$ batchSize , $ this ->numDecoderHeads , 0 , $ this ->decoderDimKv ];
476475
477476
478477 for ($ i = 0 ; $ i < $ this ->numDecoderLayers ; ++$ i ) {
479478 $ decoderFeeds ["past_key_values. $ i.encoder.key " ]
480479 = $ decoderFeeds ["past_key_values. $ i.encoder.value " ]
481- = new Tensor (null , shape: $ encoderShape );
480+ = new Tensor ([] , shape: $ encoderShape );
482481 $ decoderFeeds ["past_key_values. $ i.decoder.key " ]
483482 = $ decoderFeeds ["past_key_values. $ i.decoder.value " ]
484- = new Tensor (null , shape: $ decoderShape );
483+ = new Tensor ([] , shape: $ decoderShape );
485484 }
486485 } else if ($ this ->config ->modelType === 'falcon ' ) {
487486 // NOTE: Custom implementation for Falcon
488- $ shape = [$ batch_size * $ this ->numHeads , 1 , $ this ->dimKv ];
487+ $ shape = [$ batchSize * $ this ->numHeads , 0 , $ this ->dimKv ];
489488
490489 for ($ i = 0 ; $ i < $ this ->numLayers ; ++$ i ) {
491- $ decoderFeeds ["past_key_values. $ i.key " ] = new Tensor (null , shape: $ shape );
492- $ decoderFeeds ["past_key_values. $ i.value " ] = new Tensor (null , shape: $ shape );
490+ $ decoderFeeds ["past_key_values. $ i.key " ] = new Tensor ([] , shape: $ shape );
491+ $ decoderFeeds ["past_key_values. $ i.value " ] = new Tensor ([] , shape: $ shape );
493492 }
494493 } else if ($ this ->config ['multi_query ' ] ?? null ) { // e.g., for `gpt_bigcode`
495- $ shape = [$ batch_size * $ this ->numHeads , 1 , 2 * $ this ->dimKv ];
494+ $ shape = [$ batchSize * $ this ->numHeads , 0 , 2 * $ this ->dimKv ];
496495
497496 for ($ i = 0 ; $ i < $ this ->numLayers ; ++$ i ) {
498- $ decoderFeeds ["past_key_values. $ i.key_value " ] = new Tensor (null , shape: $ shape );
497+ $ decoderFeeds ["past_key_values. $ i.key_value " ] = new Tensor ([] , shape: $ shape );
499498 }
500499 } else if ($ this ->config ['model_type ' ] === 'bloom ' ) {
501500 // NOTE: Custom implementation for Bloom
502- $ keyShape = [$ batch_size * $ this ->numHeads , $ this ->dimKv , 1 ];
503- $ valueShape = [$ batch_size * $ this ->numHeads , 1 , $ this ->dimKv ];
501+ $ keyShape = [$ batchSize * $ this ->numHeads , $ this ->dimKv , 0 ];
502+ $ valueShape = [$ batchSize * $ this ->numHeads , 0 , $ this ->dimKv ];
504503
505504 for ($ i = 0 ; $ i < $ this ->numLayers ; ++$ i ) {
506- $ decoderFeeds ["past_key_values. $ i.key " ] = new Tensor (null , shape: $ keyShape );
507- $ decoderFeeds ["past_key_values. $ i.value " ] = new Tensor (null , shape: $ valueShape );
505+ $ decoderFeeds ["past_key_values. $ i.key " ] = new Tensor ([] , shape: $ keyShape );
506+ $ decoderFeeds ["past_key_values. $ i.value " ] = new Tensor ([] , shape: $ valueShape );
508507 }
509508 } else { // Decoder-only
510- $ shape = [$ batch_size , $ this ->numHeads , 1 , $ this ->dimKv ];
509+ $ shape = [$ batchSize , $ this ->numHeads , 0 , $ this ->dimKv ];
511510
512511 for ($ i = 0 ; $ i < $ this ->numLayers ; ++$ i ) {
513- $ decoderFeeds ["past_key_values. $ i.key " ] = new Tensor (null , shape: $ shape );
514- $ decoderFeeds ["past_key_values. $ i.value " ] = new Tensor (null , shape: $ shape );
512+ $ decoderFeeds ["past_key_values. $ i.key " ] = new Tensor ([] , shape: $ shape );
513+ $ decoderFeeds ["past_key_values. $ i.value " ] = new Tensor ([] , shape: $ shape );
515514 }
516515 }
517516 }
@@ -521,8 +520,10 @@ public function addPastKeyValues(array &$decoderFeeds, ?array $pastKeyValues): v
521520 * @param Tensor $inputs The input token ids.
522521 * @param GenerationConfig|null $generationConfig The generation configuration to use. If null, default configuration will be used.
523522 * @param LogitsProcessorList|null $logitsProcessor An optional logits processor to use. If null, a new LogitsProcessorList instance will be created.
524- * @param array|null $inputsAttentionMask An optional attention mask for the inputs.
523+ * @param Tensor|null $inputsAttentionMask An optional attention mask for the inputs.
524+ * @param Streamer|null $streamer
525525 * @return array An array of generated output sequences, where each sequence is an array of token IDs.
526+ * @throws Exception
526527 */
527528 public function generate (
528529 Tensor $ inputs ,
@@ -609,6 +610,7 @@ public function generate(
609610
610611 $ output = $ this ->runBeam ($ beam );
611612
613+
612614 // add attentions/scores to beam only if user requested
613615 if ($ generationConfig ->output_attentions ) {
614616 $ this ->addAttentionsToBeam ($ beam , $ output );
@@ -626,6 +628,7 @@ public function generate(
626628 $ logits = $ output ['logits ' ]->slice (null , -1 , null );
627629// $logits = $output['logits'];
628630
631+
629632 // Apply logits processor
630633 $ logitsProcessor ($ beam ['output_token_ids ' ], $ logits );
631634
@@ -649,7 +652,6 @@ public function generate(
649652
650653 }
651654
652-
653655 ++$ numOutputTokens ;
654656
655657 // Group and select best beams
@@ -665,15 +667,13 @@ function ($group) use ($generationConfig) {
665667 $ this ->groupBeams ($ newestBeams )
666668 ));
667669
668-
669670 // Flatten beams
670671 $ beams = $ newestBeams ;
671672
672673 // Stream the beams if a streamer is provided
673674 $ streamer ?->put($ beams );
674675 }
675676
676-
677677 // TODO: Ensure that we can return non-batched outputs
678678
679679 $ groupedBeams = $ this ->groupBeams ($ beams );
0 commit comments