55
66namespace Codewithkyrian \Transformers \Pipelines ;
77
8+ use Codewithkyrian \Transformers \Generation \Streamers \Streamer ;
89use Codewithkyrian \Transformers \Utils \GenerationConfig ;
10+ use function Codewithkyrian \Transformers \Utils \array_every ;
911use function Codewithkyrian \Transformers \Utils \camelCaseToSnakeCase ;
1012
1113/**
@@ -53,12 +55,18 @@ class TextGenerationPipeline extends Pipeline
5355 public function __invoke (array |string $ inputs , ...$ args ): array
5456 {
5557 $ streamer = null ;
56-
5758 if (array_key_exists ('streamer ' , $ args )) {
59+ /** @var Streamer $streamer */
5860 $ streamer = $ args ['streamer ' ];
5961 unset($ args ['streamer ' ]);
6062 }
6163
64+ $ returnFullText = true ; // By default, return full text
65+ if (array_key_exists ('returnFullText ' , $ args )) {
66+ $ returnFullText = $ args ['returnFullText ' ];
67+ unset($ args ['returnFullText ' ]);
68+ }
69+
6270 // Convert the rest of the arguments key names from camelCase to snake_case
6371 $ snakeCasedArgs = [];
6472 foreach ($ args as $ key => $ value ) {
@@ -67,45 +75,79 @@ public function __invoke(array|string $inputs, ...$args): array
6775
6876 $ generationConfig = new GenerationConfig ($ snakeCasedArgs );
6977
70- $ isChatMode = $ this ->isChatMode ($ inputs );
71-
72- if ($ isChatMode ) {
73- $ inputs = $ this ->tokenizer ->applyChatTemplate ($ inputs , addGenerationPrompt: true , tokenize: false );
74- }
75-
76- $ isBatched = is_array ($ inputs );
77-
78- if (!$ isBatched ) {
79- $ inputs = [$ inputs ];
78+ $ isBatched = false ;
79+ $ isChatInput = false ;
80+
81+ // Normalize inputs
82+ $ texts = [];
83+
84+ if (is_string ($ inputs )) {
85+ $ texts = $ inputs = [$ inputs ];
86+ } elseif (is_array ($ inputs ) && array_every ($ inputs , fn ($ x ) => is_string ($ x ))) {
87+ $ isBatched = true ;
88+ $ texts = $ inputs ;
89+ } else {
90+ if ($ this ->isChat ($ inputs )) {
91+ $ inputs = [$ inputs ];
92+ } elseif (is_array ($ inputs ) && array_every ($ inputs , [$ this , 'isChat ' ])) {
93+ $ isBatched = true ;
94+ } else {
95+ throw new \Exception ('Input must be a string, an array of strings, a Chat, or an array of Chats ' );
96+ }
97+ $ isChatInput = true ;
98+
99+ // If the input is a chat, apply the chat template
100+ $ texts = array_map (fn ($ x ) => $ this ->tokenizer ->applyChatTemplate ($ x , addGenerationPrompt: true , tokenize: false ), $ inputs );
80101 }
81102
82103 // By default, do not add special tokens
83- $ addSpecialTokens = $ this ->model ->config ['add_special_tokens ' ] ?? false ;
104+ $ addSpecialTokens = $ generationConfig ['add_special_tokens ' ] ?? false ;
105+
106+ $ returnFullText = $ isChatInput ? false : $ returnFullText ;
84107
85108 $ this ->tokenizer ->paddingSide = 'left ' ;
86109 ['input_ids ' => $ inputIds , 'attention_mask ' => $ attentionMask ] = $ this ->tokenizer ->tokenize (
87- $ inputs ,
110+ $ texts ,
88111 padding: true ,
89112 addSpecialTokens: $ addSpecialTokens ,
90113 truncation: true
91114 );
92115
93- $ outputTokenIds = $ this ->model ->generate (
94- $ inputIds ,
116+ // Streamer can only handle one input at a time for now, so we only pass the first input
117+ $ streamer ->init ($ this ->tokenizer , $ inputIds [0 ]->toArray (), true );
118+
119+ $ outputTokenIds = $ this ->model ->generate ($ inputIds ,
95120 generationConfig: $ generationConfig ,
96121 inputsAttentionMask: $ attentionMask ,
97122 streamer: $ streamer
98123 );
99124
100125 $ decoded = $ this ->tokenizer ->batchDecode ($ outputTokenIds , skipSpecialTokens: true );
101126
127+ $ promptLengths = null ;
128+ if (!$ returnFullText && $ inputIds ->shape ()[count ($ inputIds ->shape ()) - 1 ] > 0 ) {
129+ $ promptLengths = array_map (fn ($ x ) => mb_strlen ($ x ), $ this ->tokenizer ->batchDecode ($ inputIds ->toArray (), skipSpecialTokens: true ));
130+ }
102131
103132 $ toReturn = array_fill (0 , count ($ inputs ), []);
104133
105134 for ($ i = 0 ; $ i < count ($ decoded ); ++$ i ) {
106135 $ textIndex = floor ($ i / count ($ outputTokenIds ) * count ($ inputs ));
136+
137+ if ($ promptLengths !== null ) {
138+ // Trim the decoded text to only include the generated part
139+ $ decoded [$ i ] = substr ($ decoded [$ i ], $ promptLengths [$ textIndex ]);
140+
141+ // Remove the leading space
142+ $ decoded [$ i ] = ltrim ($ decoded [$ i ]);
143+ }
144+
107145 $ toReturn [$ textIndex ][] = [
108- 'generated_text ' => $ decoded [$ i ]
146+ 'generated_text ' => $ isChatInput
147+ ? array_merge ($ inputs [$ textIndex ], [
148+ ['role ' => 'assistant ' , 'content ' => $ decoded [$ i ]]
149+ ])
150+ : $ decoded [$ i ],
109151 ];
110152 }
111153
@@ -114,9 +156,9 @@ public function __invoke(array|string $inputs, ...$args): array
114156 }
115157
116158 // Detect chat mode
117- protected function isChatMode ( string | array $ texts ): bool
159+ function isChat ( $ x ): bool
118160 {
119- return is_array ($ texts ) && isset ($ texts [0 ]) && is_array ($ texts [0 ]) && !array_is_list ($ texts [0 ]);
120-
161+ return is_array ($ x ) && array_every ($ x , fn ($ item ) => isset ($ item ['role ' ]) && isset ($ item ['content ' ]));
121162 }
163+
122164}
0 commit comments