2727use Codewithkyrian \Transformers \Utils \AutoConfig ;
2828use Codewithkyrian \Transformers \Utils \GenerationConfig ;
2929use Codewithkyrian \Transformers \Utils \Hub ;
30+ use Codewithkyrian \Transformers \Utils \InferenceSession ;
3031use Codewithkyrian \Transformers \Utils \Tensor ;
3132use Error ;
3233use Exception ;
33- use OnnxRuntime \InferenceSession ;
3434use Symfony \Component \Console \Output \OutputInterface ;
3535use function Codewithkyrian \Transformers \Utils \array_some ;
36+ use function Codewithkyrian \Transformers \Utils \timeUsage ;
3637
3738/**
3839 * A base class for pre-trained models that provides the model configuration and an ONNX session.
@@ -281,9 +282,10 @@ public function runSession(InferenceSession $session, array $inputs): array
281282
282283 $ outputNames = array_column ($ session ->outputs (), 'name ' );
283284
284- $ outputs = $ session ->run ($ outputNames , $ inputs );
285-
286- return array_combine ($ outputNames , array_map ([Tensor::class, 'fromArray ' ], $ outputs ));
285+ timeUsage ();
286+ $ out = $ session ->run ($ outputNames , $ inputs );
287+ dump (timeUsage (true ));
288+ return $ out ;
287289 } catch (MissingModelInputException $ e ) {
288290 throw $ e ;
289291 } catch (Exception $ e ) {
@@ -331,7 +333,8 @@ public function validateInputs(array $inputNames, array $inputs): array
331333 The following inputs will be ignored: " ' . implode (', ' , $ ignored ) . '". ' ;
332334 }
333335
334- return array_map (fn ($ i ) => $ i ->toArray (), $ inputs );
336+ // return array_map(fn($i) => $i->toArray(), $inputs);
337+ return $ inputs ;
335338 }
336339
337340 /**
@@ -521,8 +524,10 @@ public function addPastKeyValues(array &$decoderFeeds, ?array $pastKeyValues): v
521524 * @param Tensor $inputs The input token ids.
522525 * @param GenerationConfig|null $generationConfig The generation configuration to use. If null, default configuration will be used.
523526 * @param LogitsProcessorList|null $logitsProcessor An optional logits processor to use. If null, a new LogitsProcessorList instance will be created.
524- * @param array|null $inputsAttentionMask An optional attention mask for the inputs.
527+ * @param Tensor|null $inputsAttentionMask An optional attention mask for the inputs.
528+ * @param Streamer|null $streamer
525529 * @return array An array of generated output sequences, where each sequence is an array of token IDs.
530+ * @throws Exception
526531 */
527532 public function generate (
528533 Tensor $ inputs ,
@@ -609,6 +614,7 @@ public function generate(
609614
610615 $ output = $ this ->runBeam ($ beam );
611616
617+
612618 // add attentions/scores to beam only if user requested
613619 if ($ generationConfig ->output_attentions ) {
614620 $ this ->addAttentionsToBeam ($ beam , $ output );
@@ -626,6 +632,7 @@ public function generate(
626632 $ logits = $ output ['logits ' ]->slice (null , -1 , null );
627633// $logits = $output['logits'];
628634
635+
629636 // Apply logits processor
630637 $ logitsProcessor ($ beam ['output_token_ids ' ], $ logits );
631638
@@ -649,7 +656,6 @@ public function generate(
649656
650657 }
651658
652-
653659 ++$ numOutputTokens ;
654660
655661 // Group and select best beams
@@ -665,15 +671,13 @@ function ($group) use ($generationConfig) {
665671 $ this ->groupBeams ($ newestBeams )
666672 ));
667673
668-
669674 // Flatten beams
670675 $ beams = $ newestBeams ;
671676
672677 // Stream the beams if a streamer is provided
673678 $ streamer ?->put($ beams );
674679 }
675680
676-
677681 // TODO: Ensure that we can return non-batched outputs
678682
679683 $ groupedBeams = $ this ->groupBeams ($ beams );
0 commit comments