Skip to content

Commit 5e310f9

Browse files
Accept zero in tensor dimensions
1 parent 0f1c44e commit 5e310f9

4 files changed

Lines changed: 24 additions & 28 deletions

File tree

src/Models/ModelArchitecture.php

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -220,14 +220,6 @@ protected function decoderForward(PretrainedModel $model, array $modelInputs): a
220220
$model->preparePositionIds($inputNames, $decoderFeeds, $useCacheBranch);
221221
$model->addPastKeyValues($decoderFeeds, $pastKeyValues);
222222

223-
// The initial past key values should have a shape of 0 in one of the dimensions, which
224-
// is the sequence length. However, I haven't found a way to pass a tensor with a shape of 0
225-
// to the model, so I'm using a sequence length of 1 instead for the first step, and then
226-
// offsetting the sequence length by 1 for the subsequent steps. This is a workaround for now.
227-
$prevSequenceLength = $decoderFeeds['past_key_values.0.key']->shape()[2];
228-
$attnMaskLength = $prevSequenceLength == 1 ? 1 : $prevSequenceLength + 1;
229-
$decoderFeeds['attention_mask'] = Tensor::ones([1, $attnMaskLength], dtype: NDArray::int64);
230-
231223
$decoderResults = $model->runSession($model->session, $decoderFeeds);
232224

233225
$logits = $decoderResults['logits'];

src/Models/Pretrained/PretrainedModel.php

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -467,50 +467,50 @@ public function addPastKeyValues(array &$decoderFeeds, ?array $pastKeyValues): v
467467
$decoderFeeds = array_merge($decoderFeeds, $pastKeyValues);
468468
} else {
469469
// TODO support batches (i.e., batch_size > 1)
470-
$batch_size = 1;
470+
$batchSize = 1;
471471

472472
if ($this->config->isEncoderDecoder && ($this->addEncoderPkv ?? true)) {
473-
$encoderShape = [$batch_size, $this->numEncoderHeads, 1, $this->encoderDimKv];
474-
$decoderShape = [$batch_size, $this->numDecoderHeads, 1, $this->decoderDimKv];
473+
$encoderShape = [$batchSize, $this->numEncoderHeads, 0, $this->encoderDimKv];
474+
$decoderShape = [$batchSize, $this->numDecoderHeads, 0, $this->decoderDimKv];
475475

476476

477477
for ($i = 0; $i < $this->numDecoderLayers; ++$i) {
478478
$decoderFeeds["past_key_values.$i.encoder.key"]
479479
= $decoderFeeds["past_key_values.$i.encoder.value"]
480-
= new Tensor(null, shape: $encoderShape);
480+
= new Tensor([], shape: $encoderShape);
481481
$decoderFeeds["past_key_values.$i.decoder.key"]
482482
= $decoderFeeds["past_key_values.$i.decoder.value"]
483-
= new Tensor(null, shape: $decoderShape);
483+
= new Tensor([], shape: $decoderShape);
484484
}
485485
} else if ($this->config->modelType === 'falcon') {
486486
// NOTE: Custom implementation for Falcon
487-
$shape = [$batch_size * $this->numHeads, 1, $this->dimKv];
487+
$shape = [$batchSize * $this->numHeads, 0, $this->dimKv];
488488

489489
for ($i = 0; $i < $this->numLayers; ++$i) {
490-
$decoderFeeds["past_key_values.$i.key"] = new Tensor(null, shape: $shape);
491-
$decoderFeeds["past_key_values.$i.value"] = new Tensor(null, shape: $shape);
490+
$decoderFeeds["past_key_values.$i.key"] = new Tensor([], shape: $shape);
491+
$decoderFeeds["past_key_values.$i.value"] = new Tensor([], shape: $shape);
492492
}
493493
} else if ($this->config['multi_query'] ?? null) { // e.g., for `gpt_bigcode`
494-
$shape = [$batch_size * $this->numHeads, 1, 2 * $this->dimKv];
494+
$shape = [$batchSize * $this->numHeads, 0, 2 * $this->dimKv];
495495

496496
for ($i = 0; $i < $this->numLayers; ++$i) {
497-
$decoderFeeds["past_key_values.$i.key_value"] = new Tensor(null, shape: $shape);
497+
$decoderFeeds["past_key_values.$i.key_value"] = new Tensor([], shape: $shape);
498498
}
499499
} else if ($this->config['model_type'] === 'bloom') {
500500
// NOTE: Custom implementation for Bloom
501-
$keyShape = [$batch_size * $this->numHeads, $this->dimKv, 1];
502-
$valueShape = [$batch_size * $this->numHeads, 1, $this->dimKv];
501+
$keyShape = [$batchSize * $this->numHeads, $this->dimKv, 0];
502+
$valueShape = [$batchSize * $this->numHeads, 0, $this->dimKv];
503503

504504
for ($i = 0; $i < $this->numLayers; ++$i) {
505-
$decoderFeeds["past_key_values.$i.key"] = new Tensor(null, shape: $keyShape);
506-
$decoderFeeds["past_key_values.$i.value"] = new Tensor(null, shape: $valueShape);
505+
$decoderFeeds["past_key_values.$i.key"] = new Tensor([], shape: $keyShape);
506+
$decoderFeeds["past_key_values.$i.value"] = new Tensor([], shape: $valueShape);
507507
}
508508
} else { // Decoder-only
509-
$shape = [$batch_size, $this->numHeads, 1, $this->dimKv];
509+
$shape = [$batchSize, $this->numHeads, 0, $this->dimKv];
510510

511511
for ($i = 0; $i < $this->numLayers; ++$i) {
512-
$decoderFeeds["past_key_values.$i.key"] = new Tensor(null, shape: $shape);
513-
$decoderFeeds["past_key_values.$i.value"] = new Tensor(null, shape: $shape);
512+
$decoderFeeds["past_key_values.$i.key"] = new Tensor([], shape: $shape);
513+
$decoderFeeds["past_key_values.$i.value"] = new Tensor([], shape: $shape);
514514
}
515515
}
516516
}

src/Utils/InferenceSession.php

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,11 @@ private function createInputTensor($inputFeed, &$refs)
372372
if (isset($inputTypes[$inp['type']])) {
373373
$typeEnum = $inputTypes[$inp['type']];
374374
$castType = $this->castTypes()[$typeEnum];
375-
$inputTensorValues = $this->ffi->new("{$castType}[$size]");
375+
if ($size == 0) {
376+
$inputTensorValues = $this->ffi->new("void *");
377+
} else {
378+
$inputTensorValues = $this->ffi->new("{$castType}[$size]");
379+
}
376380
} else {
377381
$this->unsupportedType('input', $inp['type']);
378382
}
@@ -494,7 +498,7 @@ private function createFromOnnxValue($outPtr)
494498
$values = $this->createFromOnnxValue($mapValues);
495499
return array_combine($keys, $values);
496500
} else {
497-
$this->unsupported_type('element', $elemType);
501+
$this->unsupportedType('element', $elemType);
498502
}
499503
} else {
500504
$this->unsupportedType('ONNX', $outType->cdata);

src/Utils/Tensor.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ protected function assertShape(array $shape): void
201201
throw new InvalidArgumentException(
202202
"Invalid shape numbers. It gives " . gettype($num));
203203
}
204-
if ($num <= 0) {
204+
if ($num < 0) {
205205
throw new InvalidArgumentException(
206206
"Invalid shape numbers. It gives " . $num);
207207
}

0 commit comments

Comments
 (0)