Skip to content

Commit 3b2bcd6

Browse files
Fix bug with attention mask shape and refactored Tensor for more consistent names
- Fixed the bug with the initial beam requiring the past key values to have a tensor with 0 in the 3rd dimension. - Renamed methods and parameters in Tensor for a more consistent naming - dims -> shape
1 parent 7c6cc25 commit 3b2bcd6

5 files changed

Lines changed: 66 additions & 59 deletions

File tree

examples/pipelines/text-generation.php

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111

1212
ini_set('memory_limit', -1);
1313
//
14+
//$generator = pipeline('text-generation', 'Xenova/gpt2');
1415
$generator = pipeline('text-generation', 'Xenova/Qwen1.5-0.5B-Chat');
1516

1617
$streamer = StdOutStreamer::make($generator->tokenizer);
1718

1819
$messages = [
19-
['role' => 'user', 'content' => 'Hello!'],
20-
['role' => 'assistant', 'content' => 'Hi! How are you?'],
21-
['role' => 'user', 'content' => 'I am doing great. What about you?'],
20+
['role' => 'system', 'content' => 'You are a helpful assistant.'],
21+
['role' => 'user', 'content' => 'Who are you'],
2222
];
2323

2424
$input = $generator->tokenizer->applyChatTemplate($messages, addGenerationPrompt: true, tokenize: false);
@@ -27,9 +27,9 @@
2727
streamer: $streamer,
2828
maxNewTokens: 128,
2929
doSample: true,
30-
temperature: 0.7,
31-
repetitionPenalty: 1.3,
32-
earlyStopping: true
30+
// temperature: 0.7,
31+
// repetitionPenalty: 1.3,
32+
// earlyStopping: true
3333
);
3434

3535
//$generator = pipeline('text-generation', 'Xenova/codegen-350M-mono');

src/Models/ModelArchitecture.php

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
use Codewithkyrian\Transformers\Models\Pretrained\PretrainedModel;
1010
use Codewithkyrian\Transformers\Utils\GenerationConfig;
1111
use Codewithkyrian\Transformers\Utils\Tensor;
12+
use Interop\Polite\Math\Matrix\NDArray;
1213

1314
enum ModelArchitecture: string
1415
{
@@ -34,7 +35,7 @@ public function runBeam(PretrainedModel $model, array &$beam): array
3435
{
3536
return match ($this) {
3637
self::DecoderOnly => $this->decoderRunBeam($model, $beam),
37-
self::Seq2SeqLM, self::Vision2Seq => $this->seq2seqRunBeam($model, $beam),
38+
self::Seq2SeqLM, self::Vision2Seq => $this->seq2seqRunBeam($model, $beam),
3839
default => throw new \Error('This model type does not support beam search'),
3940
};
4041
}
@@ -114,10 +115,11 @@ protected function decoderRunBeam(PretrainedModel $model, array &$beam): array
114115
// 1. Prepare
115116
$modelInputs = [
116117
'input_ids' => $beam['model_input_ids'],
117-
'attention_mask' => new Tensor($attnMaskData, shape: [1, $attnMaskLength]),
118+
'attention_mask' => new Tensor($attnMaskData, NDArray::int64, [1, $attnMaskLength]),
118119
'past_key_values' => $beam['prev_model_outputs']['past_key_values'] ?? null,
119120
];
120121

122+
121123
// 2. Run
122124
$output = $model->forward($modelInputs);
123125

@@ -155,7 +157,7 @@ protected function decoderStartBeams(
155157
$attnMask = null;
156158
if ($inputsAttentionMask !== null) {
157159
$attnMask = $inputsAttentionMask[$beamId];
158-
$attnMask->reshape([1, ...$attnMask->shape()]);
160+
$attnMask = $attnMask->reshape([1, ...$attnMask->shape()]);
159161
} else {
160162
$attnMask = $model->prepareAttentionMask($tokens);
161163
}
@@ -189,8 +191,7 @@ protected function decoderStartBeams(
189191
protected function decoderUpdatebeam(array &$beam, int $newTokenId): void
190192
{
191193
$beam['output_token_ids'][] = $newTokenId;
192-
193-
$beam['model_input_ids'] = new Tensor([$newTokenId], shape: [1, 1]);
194+
$beam['model_input_ids'] = new Tensor([$newTokenId], NDArray::int64, [1, 1]);
194195
}
195196

196197
/**
@@ -221,6 +222,14 @@ protected function decoderForward(PretrainedModel $model, array $modelInputs): a
221222
$model->preparePositionIds($inputNames, $decoderFeeds, $useCacheBranch);
222223
$model->addPastKeyValues($decoderFeeds, $pastKeyValues);
223224

225+
// The initial past key values should have a shape of 0 in one of the dimensions, which
226+
// is the sequence length. However, I haven't found a way to pass a tensor with a shape of 0
227+
// to the model, so I'm using a sequence length of 1 instead for the first step, and then
228+
// offsetting the sequence length by 1 for the subsequent steps. This is a workaround for now.
229+
$prevSequenceLength = $decoderFeeds['past_key_values.0.key']->shape()[2];
230+
$attnMaskLength = $prevSequenceLength == 1 ? 1 : $prevSequenceLength + 1;
231+
$decoderFeeds['attention_mask'] = Tensor::ones([1, $attnMaskLength], dtype: NDArray::int64);
232+
224233
$decoderResults = $model->runSession($model->session, $decoderFeeds);
225234

226235
$logits = $decoderResults['logits'];

src/Models/Pretrained/PretrainedModel.php

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ public static function constructSession(
263263
}
264264

265265
/**
266-
* @param InferenceSession $session
266+
* @param array $inputNames
267267
* @param Tensor[] $inputs
268268
* @return Tensor[]
269269
* @throws MissingModelInputException
@@ -318,8 +318,6 @@ public function runSession(InferenceSession $session, array $inputs): array
318318

319319
$outputNames = array_column($session->outputs(), 'name');
320320

321-
file_put_contents('inputs.json', json_encode($inputs));
322-
323321
$outputs = $session->run($outputNames, $inputs);
324322

325323
return array_combine($outputNames, array_map([Tensor::class, 'fromArray'], $outputs));
@@ -495,7 +493,6 @@ public function preparePositionIds(array $inputNames, array &$feeds, bool $useCa
495493
$feeds['position_ids'] = new Tensor($data, shape: $feeds['attention_mask']->shape());
496494

497495
if ($useCacheBranch) {
498-
// TODO: Fix this
499496
$feeds['position_ids'] = $feeds['position_ids']->slice(null, -1)->unsqueeze(-1);
500497
}
501498
}
@@ -677,8 +674,10 @@ public function addPastKeyValues(array &$decoderFeeds, ?array $pastKeyValues): v
677674
* @param Tensor $inputs The input token ids.
678675
* @param GenerationConfig|null $generationConfig The generation configuration to use. If null, default configuration will be used.
679676
* @param LogitsProcessorList|null $logitsProcessor An optional logits processor to use. If null, a new LogitsProcessorList instance will be created.
680-
* @param array|null $inputsAttentionMask An optional attention mask for the inputs.
677+
* @param Tensor|null $inputsAttentionMask An optional attention mask for the inputs.
678+
* @param Streamer|null $streamer
681679
* @return array An array of generated output sequences, where each sequence is an array of token IDs.
680+
* @throws Exception
682681
*/
683682
public function generate(
684683
Tensor $inputs,
@@ -793,7 +792,6 @@ public function generate(
793792

794793
// update new beam
795794
$this->updateBeam($newBeam, $newTokenId);
796-
797795
$newBeam['score'] += $logProb;
798796

799797
if ($eosTokenIds && in_array($newTokenId, $eosTokenIds, true)) {
@@ -812,16 +810,11 @@ public function generate(
812810
$newestBeams = array_merge(...array_map(
813811
function ($group) use ($generationConfig) {
814812
usort($group, fn($a, $b) => $b['score'] <=> $a['score']);
815-
return array_slice(
816-
$group,
817-
0,
818-
$generationConfig->num_beams
819-
);
813+
return array_slice($group, 0, $generationConfig->num_beams);
820814
},
821815
$this->groupBeams($newestBeams)
822816
));
823817

824-
825818
// Flatten beams
826819
$beams = $newestBeams;
827820

src/Pipelines/TextGenerationPipeline.php

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ public function __invoke(array|string $inputs, ...$args): array
9090
truncation: true
9191
);
9292

93-
$outputTokenIds = $this->model->generate($inputIds, generationConfig: $generationConfig, streamer: $streamer);
93+
$outputTokenIds = $this->model->generate(
94+
$inputIds,
95+
generationConfig: $generationConfig,
96+
inputsAttentionMask: $attentionMask,
97+
streamer: $streamer
98+
);
9499

95100
$decoded = $this->tokenizer->batchDecode($outputTokenIds, skipSpecialTokens: true);
96101

src/Utils/Tensor.php

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,10 @@ public static function onesLike(Tensor $other): static
261261
* Return a one matrix with the given dimensions.
262262
*
263263
* @param array $shape The shape of the one matrix to return.
264-
* @param string|null $dtype The data type of the one matrix to return. Eg: float32, int32, etc. If null, defaults to float32.
264+
* @param ?int $dtype The data type of the one matrix to return. Eg: float32, int32, etc. If null, defaults to float32.
265265
* @return static
266266
*/
267-
public static function ones(array $shape, ?string $dtype = null): static
267+
public static function ones(array $shape, ?int $dtype = null): static
268268
{
269269
$mo = self::getMo();
270270

@@ -303,7 +303,7 @@ public static function fromArray(array|NDArray $array, ?string $dtype = null, $s
303303
/**
304304
* Reshape the tensor into the given shape.
305305
*/
306-
public function reshape(array $shape): NDArray
306+
public function reshape(array $shape): static
307307
{
308308
$this->assertShape($shape);
309309

@@ -443,37 +443,37 @@ public function count(): int
443443
}
444444

445445
/**
446-
* Returns a tensor with all specified dimensions of input of size 1 removed.
446+
* Returns a tensor with all specified axis of input of size 1 removed.
447447
*
448-
* @param ?int $dim If given, the input will be squeezed only in the specified dimensions.
448+
* @param ?int $axis If given, the input will be squeezed only in the specified axis.
449449
*
450450
* @return static The squeezed tensor.
451451
*/
452-
public function unsqueeze(?int $dim = null): static
452+
public function unsqueeze(?int $axis = null): static
453453
{
454454
return new Tensor(
455455
$this->buffer(),
456456
$this->dtype,
457-
$this->calcUnsqueezeDims($this->shape(), $dim),
457+
$this->calcUnsqueezeShape($this->shape(), $axis),
458458
$this->offset
459459
);
460460
}
461461

462462
/**
463-
* Helper function to calculate new dimensions when performing an unsqueeze operation.
464-
* @param array $dims The dimensions of the tensor.
465-
* @param int $dim The dimension to unsqueeze.
466-
* @return array The new dimensions.
463+
* Helper function to calculate new shape when performing an unsqueeze operation.
464+
* @param array $shape The shape of the tensor.
465+
* @param int $axis The axis to unsqueeze.
466+
* @return array The new shape.
467467
*/
468-
protected function calcUnsqueezeDims(array $dims, int $dim): array
468+
protected function calcUnsqueezeShape(array $shape, int $axis): array
469469
{
470470
// Dimension out of range (e.g., "expected to be in range of [-4, 3], but got 4")
471471
// + 1 since we allow inserting at the end (i.e. dim = -1)
472-
$dim = self::safeIndex($dim, count($dims) + 1);
473-
$newDims = $dims;
474-
// Insert 1 into specified dimension
475-
array_splice($newDims, $dim, 0, [1]);
476-
return $newDims;
472+
$axis = self::safeIndex($axis, count($shape) + 1);
473+
474+
$shape[$axis - 1] = 1;
475+
476+
return $shape;
477477
}
478478

479479
/**
@@ -605,11 +605,11 @@ public function normalize(int $p = 2, ?int $dim = null): static
605605
*
606606
* @param int $ord Order of the norm. Supported values are 1, 2, Infinity.
607607
* @param int|null $axis The axis or axes along which to perform the reduction. If null (default), reduces all dimensions.
608-
* @param bool $keepdims If true, retains reduced dimensions with length 1.
608+
* @param bool $keepShape If true, retains reduced shape with length 1.
609609
*
610610
* @return static
611611
*/
612-
public function norm(int $ord = 2, ?int $axis = null, bool $keepdims = false): static
612+
public function norm(int $ord = 2, ?int $axis = null, bool $keepShape = false): static
613613
{
614614
$mo = self::getMo();
615615

@@ -623,8 +623,8 @@ public function norm(int $ord = 2, ?int $axis = null, bool $keepdims = false): s
623623
$axis = $this->safeIndex($axis, $this->ndim());
624624

625625
// Calculate the shape of the resulting array after summation
626-
$resultDims = $this->shape();
627-
$resultDims[$axis] = 1; // Remove the specified axis
626+
$resultShape = $this->shape();
627+
$resultShape[$axis] = 1; // Remove the specified axis
628628

629629
// Create a new array to store the accumulated values
630630
$result = $this->zeros([count($this->buffer) / $this->shape()[$axis]]);
@@ -642,7 +642,7 @@ public function norm(int $ord = 2, ?int $axis = null, bool $keepdims = false): s
642642
if ($j !== $axis) {
643643
$index = $num % $size;
644644
$resultIndex += $index * $resultMultiplier;
645-
$resultMultiplier *= $resultDims[$j];
645+
$resultMultiplier *= $resultShape[$j];
646646
}
647647

648648
$num = floor($num / $size);
@@ -656,11 +656,11 @@ public function norm(int $ord = 2, ?int $axis = null, bool $keepdims = false): s
656656
$result = $mo->op($result, '**', 1 / $ord);
657657
}
658658

659-
if (!$keepdims) {
660-
array_splice($resultDims, $axis, 1);
659+
if (!$keepShape) {
660+
array_splice($resultShape, $axis, 1);
661661
}
662662

663-
return new static($result->buffer(), $result->dtype(), $resultDims, $result->offset());
663+
return new static($result->buffer(), $result->dtype(), $resultShape, $result->offset());
664664
}
665665

666666
/**
@@ -794,7 +794,7 @@ public function to(int $dtype): static
794794
/**
795795
* Returns the mean value of each row of the tensor in the given dimension dim.
796796
*/
797-
public function mean(?int $axis = null, bool $keepdims = false): static|float|int
797+
public function mean(?int $axis = null, bool $keepShape = false): static|float|int
798798
{
799799
$mo = self::getMo();
800800

@@ -803,7 +803,7 @@ public function mean(?int $axis = null, bool $keepdims = false): static|float|in
803803
if ($mean instanceof NDArray) {
804804
$shape = $mean->shape();
805805

806-
if (!$keepdims) {
806+
if (!$keepShape) {
807807
array_splice($shape, $axis, 1);
808808
}
809809

@@ -858,15 +858,15 @@ public function meanPooling(Tensor $other): Tensor
858858

859859
public function slice(...$slices): Tensor
860860
{
861-
$newTensorDims = [];
861+
$newTensorShape = [];
862862
$newOffsets = [];
863863

864864
for ($sliceIndex = 0; $sliceIndex < $this->ndim(); ++$sliceIndex) {
865865
$slice = $slices[$sliceIndex] ?? null;
866866

867867
if ($slice === null) {
868868
$newOffsets[] = [0, $this->shape()[$sliceIndex]];
869-
$newTensorDims[] = $this->shape()[$sliceIndex];
869+
$newTensorShape[] = $this->shape()[$sliceIndex];
870870

871871
} elseif (is_int($slice)) {
872872
$slice = $this->safeIndex($slice, $this->shape()[$sliceIndex], $sliceIndex);
@@ -881,31 +881,31 @@ public function slice(...$slices): Tensor
881881
min($slice[1], $this->shape()[$sliceIndex])
882882
];
883883
$newOffsets[] = $offsets;
884-
$newTensorDims[] = $offsets[1] - $offsets[0];
884+
$newTensorShape[] = $offsets[1] - $offsets[0];
885885

886886
} else {
887887
throw new Exception("Invalid slice: " . json_encode($slice));
888888
}
889889
}
890890

891-
$newDims = array_map(fn($offsets) => $offsets[1] - $offsets[0], $newOffsets);
891+
$newShape = array_map(fn($offsets) => $offsets[1] - $offsets[0], $newOffsets);
892892

893-
$newBufferSize = array_reduce($newDims, fn($a, $b) => $a * $b, 1);
893+
$newBufferSize = array_reduce($newShape, fn($a, $b) => $a * $b, 1);
894894

895895
$buffer = $this->newBuffer($newBufferSize, $this->dtype());
896896
$stride = $this->stride();
897897

898898
for ($i = 0; $i < $newBufferSize; ++$i) {
899899
$originalIndex = 0;
900-
for ($j = count($newDims) - 1, $num = $i; $j >= 0; --$j) {
901-
$size = $newDims[$j];
900+
for ($j = count($newShape) - 1, $num = $i; $j >= 0; --$j) {
901+
$size = $newShape[$j];
902902
$originalIndex += (($num % $size) + $newOffsets[$j][0]) * $stride[$j];
903903
$num = floor($num / $size);
904904
}
905905
$buffer[$i] = $this->buffer[$originalIndex];
906906
}
907907

908-
return new Tensor($buffer, $this->dtype(), $newDims, $this->offset());
908+
return new Tensor($buffer, $this->dtype(), $newShape, $this->offset());
909909
}
910910

911911
/**

0 commit comments

Comments
 (0)