Skip to content

Commit 6c091fb

Browse files
Create buffer directly when convert model output to tensor
1 parent e7f55a4 commit 6c091fb

4 files changed

Lines changed: 13 additions & 17 deletions

File tree

examples/pipelines/text-generation.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
//
1414
//$generator = pipeline('text-generation', 'Xenova/gpt2');
1515
$generator = pipeline('text-generation', 'Xenova/Qwen1.5-0.5B-Chat');
16-
16+
//
1717
$streamer = StdOutStreamer::make();
1818

1919
$messages = [
@@ -24,7 +24,7 @@
2424
$input = $generator->tokenizer->applyChatTemplate($messages, addGenerationPrompt: true, tokenize: false);
2525

2626
$output = $generator($input,
27-
// streamer: $streamer,
27+
streamer: $streamer,
2828
maxNewTokens: 128,
2929
doSample: true,
3030
returnFullText: false,

src/Models/Pretrained/PretrainedModel.php

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
use Exception;
3434
use Symfony\Component\Console\Output\OutputInterface;
3535
use function Codewithkyrian\Transformers\Utils\array_some;
36-
use function Codewithkyrian\Transformers\Utils\timeUsage;
3736

3837
/**
3938
* A base class for pre-trained models that provides the model configuration and an ONNX session.
@@ -282,10 +281,7 @@ public function runSession(InferenceSession $session, array $inputs): array
282281

283282
$outputNames = array_column($session->outputs(), 'name');
284283

285-
timeUsage();
286-
$out = $session->run($outputNames, $inputs);
287-
dump(timeUsage(true));
288-
return $out;
284+
return $session->run($outputNames, $inputs);
289285
} catch (MissingModelInputException $e) {
290286
throw $e;
291287
} catch (Exception $e) {

src/Utils/InferenceSession.php

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ public function __destruct()
130130

131131
public function run($outputNames, $inputFeed, $logSeverityLevel = null, $logVerbosityLevel = null, $logid = null, $terminate = null): array
132132
{
133+
133134
// pointer references
134135
$refs = [];
135136

@@ -167,7 +168,6 @@ public function run($outputNames, $inputFeed, $logSeverityLevel = null, $logVerb
167168
$output = [];
168169

169170
foreach ($outputTensor as $i => $t) {
170-
// $output[] = $this->createFromOnnxValue($t);
171171
$output[$outputNames[$i]] = $this->createFromOnnxValue($t);
172172
}
173173

@@ -514,13 +514,13 @@ private function fillOutput($ptr, $shape): ?Tensor
514514
return null;
515515
}
516516

517-
$data = [];
517+
$buffer = Tensor::newBuffer($bufferSize);
518518

519519
for ($j = 0; $j < $bufferSize; $j++) {
520-
$data[] = $ptr[$j];
520+
$buffer[$j] = $ptr[$j];
521521
}
522522

523-
return new Tensor($data, shape: $shape);
523+
return new Tensor($buffer, shape: $shape, offset: 0);
524524
}
525525

526526
private function createStringsFromOnnxValue($outPtr, $outputTensorSize)

src/Utils/Tensor.php

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ public function __construct($array = null, int $dtype = null, array $shape = nul
4242

4343
if (is_array($array) || $array instanceof ArrayObject) {
4444
$size = $this->countRecursive($array);
45-
$this->buffer = $this->newBuffer($size, $dtype);
45+
$this->buffer = self::newBuffer($size, $dtype);
4646
$this->flattenArray($array, $this->buffer);
4747
$this->offset = 0;
4848
$shape ??= $this->generateShape($array);
4949
} elseif (is_numeric($array) || is_bool($array)) {
5050
if (is_bool($array) && $dtype != NDArray::bool) {
5151
throw new InvalidArgumentException("Unmatched dtype with bool value");
5252
}
53-
$this->buffer = $this->newBuffer(1, $dtype);
53+
$this->buffer = self::newBuffer(1, $dtype);
5454
$this->buffer[0] = $array;
5555
$this->offset = 0;
5656
$shape = $shape ?? [];
@@ -61,7 +61,7 @@ public function __construct($array = null, int $dtype = null, array $shape = nul
6161
} elseif ($array === null && $shape !== null) {
6262
$this->assertShape($shape);
6363
$size = (int)array_product($shape);
64-
$this->buffer = $this->newBuffer($size, $dtype);
64+
$this->buffer = self::newBuffer($size, $dtype);
6565
$this->offset = 0;
6666
} elseif ($this->isBuffer($array)) {
6767
if (!is_int($offset))
@@ -107,7 +107,7 @@ function countRecursive($array): int
107107
* @param int $dtype The data type of the buffer.
108108
* @return SplFixedArray|OpenBlasBuffer
109109
*/
110-
protected function newBuffer(int $size, int $dtype): SplFixedArray|OpenBlasBuffer
110+
public static function newBuffer(int $size, ?int $dtype = null): SplFixedArray|OpenBlasBuffer
111111
{
112112
if (extension_loaded('rindow_openblas')) {
113113
return new OpenBlasBuffer($size, $dtype);
@@ -389,7 +389,7 @@ public static function cat(array $tensors, int $axis = 0): Tensor
389389
// Create a new array to store the accumulated values
390390
$resultSize = array_product($resultShape);
391391

392-
$result = $tensors[0]->newBuffer($resultSize, $resultType);
392+
$result = self::newBuffer($resultSize, $resultType);
393393

394394
// Create output tensor of same type as first
395395

@@ -921,7 +921,7 @@ public function slice(...$slices): Tensor
921921

922922
$newBufferSize = array_reduce($newShape, fn($a, $b) => $a * $b, 1);
923923

924-
$buffer = $this->newBuffer($newBufferSize, $this->dtype());
924+
$buffer = self::newBuffer($newBufferSize, $this->dtype());
925925
$stride = $this->stride();
926926

927927
for ($i = 0; $i < $newBufferSize; ++$i) {

0 commit comments

Comments
 (0)