99use Codewithkyrian \Transformers \Models \Pretrained \PretrainedModel ;
1010use Codewithkyrian \Transformers \Utils \GenerationConfig ;
1111use Codewithkyrian \Transformers \Utils \Tensor ;
12+ use Interop \Polite \Math \Matrix \NDArray ;
1213
1314enum ModelArchitecture: string
1415{
@@ -34,7 +35,7 @@ public function runBeam(PretrainedModel $model, array &$beam): array
3435 {
3536 return match ($ this ) {
3637 self ::DecoderOnly => $ this ->decoderRunBeam ($ model , $ beam ),
37- self ::Seq2SeqLM, self ::Vision2Seq => $ this ->seq2seqRunBeam ($ model , $ beam ),
38+ self ::Seq2SeqLM, self ::Vision2Seq => $ this ->seq2seqRunBeam ($ model , $ beam ),
3839 default => throw new \Error ('This model type does not support beam search ' ),
3940 };
4041 }
@@ -114,10 +115,11 @@ protected function decoderRunBeam(PretrainedModel $model, array &$beam): array
114115 // 1. Prepare
115116 $ modelInputs = [
116117 'input_ids ' => $ beam ['model_input_ids ' ],
117- 'attention_mask ' => new Tensor ($ attnMaskData , shape: [1 , $ attnMaskLength ]),
118+ 'attention_mask ' => new Tensor ($ attnMaskData , NDArray::int64, [1 , $ attnMaskLength ]),
118119 'past_key_values ' => $ beam ['prev_model_outputs ' ]['past_key_values ' ] ?? null ,
119120 ];
120121
122+
121123 // 2. Run
122124 $ output = $ model ->forward ($ modelInputs );
123125
@@ -155,7 +157,7 @@ protected function decoderStartBeams(
155157 $ attnMask = null ;
156158 if ($ inputsAttentionMask !== null ) {
157159 $ attnMask = $ inputsAttentionMask [$ beamId ];
158- $ attnMask ->reshape ([1 , ...$ attnMask ->shape ()]);
160+ $ attnMask = $ attnMask ->reshape ([1 , ...$ attnMask ->shape ()]);
159161 } else {
160162 $ attnMask = $ model ->prepareAttentionMask ($ tokens );
161163 }
@@ -189,8 +191,7 @@ protected function decoderStartBeams(
189191 protected function decoderUpdatebeam (array &$ beam , int $ newTokenId ): void
190192 {
191193 $ beam ['output_token_ids ' ][] = $ newTokenId ;
192-
193- $ beam ['model_input_ids ' ] = new Tensor ([$ newTokenId ], shape: [1 , 1 ]);
194+ $ beam ['model_input_ids ' ] = new Tensor ([$ newTokenId ], NDArray::int64, [1 , 1 ]);
194195 }
195196
196197 /**
@@ -221,6 +222,14 @@ protected function decoderForward(PretrainedModel $model, array $modelInputs): a
221222 $ model ->preparePositionIds ($ inputNames , $ decoderFeeds , $ useCacheBranch );
222223 $ model ->addPastKeyValues ($ decoderFeeds , $ pastKeyValues );
223224
225+ // The initial past key values should have a shape of 0 in one of the dimensions, which
226+ // is the sequence length. However, I haven't found a way to pass a tensor with a shape of 0
227+ // to the model, so I'm using a sequence length of 1 instead for the first step, and then
228+ // offsetting the sequence length by 1 for the subsequent steps. This is a workaround for now.
229+ $ prevSequenceLength = $ decoderFeeds ['past_key_values.0.key ' ]->shape ()[2 ];
230+ $ attnMaskLength = $ prevSequenceLength == 1 ? 1 : $ prevSequenceLength + 1 ;
231+ $ decoderFeeds ['attention_mask ' ] = Tensor::ones ([1 , $ attnMaskLength ], dtype: NDArray::int64);
232+
224233 $ decoderResults = $ model ->runSession ($ model ->session , $ decoderFeeds );
225234
226235 $ logits = $ decoderResults ['logits ' ];
0 commit comments