Skip to content

Commit 2200fa3

Browse files
Merge pull request #24 from CodeWithKyrian/modify-inference-session-tensor
Custom inference session for improved ONNX model handling
2 parents 1984cae + 5e310f9 commit 2200fa3

15 files changed

Lines changed: 868 additions & 116 deletions

composer.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
"codewithkyrian/onnxruntime-downloader-plugin": "^1.1",
2222
"symfony/console": "^6.4|^7.0",
2323
"imagine/imagine": "^1.3",
24-
"rokka/imagine-vips": "^0.31.0"
24+
"rokka/imagine-vips": "^0.31.0",
25+
"spatie/fork": "^1.2"
2526
},
2627
"require-dev": {
2728
"pestphp/pest": "^2.31",

examples/composer.json

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,34 @@
11
{
2-
"name": "kyrian/examples",
3-
"autoload": {
4-
"psr-4": {
5-
"Kyrian\\Examples\\": "/"
6-
}
7-
},
8-
"authors": [
9-
{
10-
"name": "Kyrian Obikwelu",
11-
"email": "koshnawaza@gmail.com"
12-
}
13-
],
14-
"require": {
15-
"php": "^8.1",
16-
"symfony/console": "^7.0",
17-
"codewithkyrian/transformers": "dev-change-init-process"
18-
},
19-
"minimum-stability": "dev",
20-
"require-dev": {
21-
"symfony/var-dumper": "^7.0"
22-
},
23-
"repositories": [
24-
{
25-
"type" : "path",
26-
"url": "../"
27-
}
28-
]
2+
"name": "kyrian/examples",
3+
"autoload": {
4+
"psr-4": {
5+
"Kyrian\\Examples\\": "/"
6+
}
7+
},
8+
"authors": [
9+
{
10+
"name": "Kyrian Obikwelu",
11+
"email": "koshnawaza@gmail.com"
12+
}
13+
],
14+
"require": {
15+
"php": "^8.1",
16+
"symfony/console": "^7.0",
17+
"codewithkyrian/transformers": "dev-main"
18+
},
19+
"minimum-stability": "dev",
20+
"require-dev": {
21+
"symfony/var-dumper": "^7.0"
22+
},
23+
"repositories": [
24+
{
25+
"type": "path",
26+
"url": "../"
27+
}
28+
],
29+
"config": {
30+
"allow-plugins": {
31+
"codewithkyrian/onnxruntime-downloader-plugin": true
32+
}
33+
}
2934
}

examples/pipelines/text-generation.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
//
1414
//$generator = pipeline('text-generation', 'Xenova/gpt2');
1515
$generator = pipeline('text-generation', 'Xenova/Qwen1.5-0.5B-Chat');
16-
16+
//
1717
$streamer = StdOutStreamer::make();
1818

1919
$messages = [
2020
['role' => 'system', 'content' => 'You are a helpful assistant.'],
21-
['role' => 'user', 'content' => 'What is the product of 5 and 4'],
21+
['role' => 'user', 'content' => 'What is diffusion?'],
2222
];
2323

2424
$input = $generator->tokenizer->applyChatTemplate($messages, addGenerationPrompt: true, tokenize: false);

src/Models/ModelArchitecture.php

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,6 @@ protected function decoderRunBeam(PretrainedModel $model, array &$beam): array
119119
'past_key_values' => $beam['prev_model_outputs']['past_key_values'] ?? null,
120120
];
121121

122-
123-
// 2. Run
124122
$output = $model->forward($modelInputs);
125123

126124
// 3. Update
@@ -222,14 +220,6 @@ protected function decoderForward(PretrainedModel $model, array $modelInputs): a
222220
$model->preparePositionIds($inputNames, $decoderFeeds, $useCacheBranch);
223221
$model->addPastKeyValues($decoderFeeds, $pastKeyValues);
224222

225-
// The initial past key values should have a shape of 0 in one of the dimensions, which
226-
// is the sequence length. However, I haven't found a way to pass a tensor with a shape of 0
227-
// to the model, so I'm using a sequence length of 1 instead for the first step, and then
228-
// offsetting the sequence length by 1 for the subsequent steps. This is a workaround for now.
229-
$prevSequenceLength = $decoderFeeds['past_key_values.0.key']->shape()[2];
230-
$attnMaskLength = $prevSequenceLength == 1 ? 1 : $prevSequenceLength + 1;
231-
$decoderFeeds['attention_mask'] = Tensor::ones([1, $attnMaskLength], dtype: NDArray::int64);
232-
233223
$decoderResults = $model->runSession($model->session, $decoderFeeds);
234224

235225
$logits = $decoderResults['logits'];
@@ -372,7 +362,6 @@ protected function seq2seqForward(PretrainedModel $model, array $modelInputs): a
372362
$model->addPastKeyValues($decoderFeeds, $pastKeyValues);
373363

374364
$decoderResults = $model->runSession($model->decoderMergedSession, $decoderFeeds);
375-
376365
$logits = $decoderResults['logits'];
377366
$pastKeyValues = $model->getPastKeyValues($decoderResults, $pastKeyValues);
378367

src/Models/Pretrained/BartForConditionalGeneration.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
use Codewithkyrian\Transformers\Models\ModelArchitecture;
99
use Codewithkyrian\Transformers\Utils\AutoConfig;
1010
use Codewithkyrian\Transformers\Utils\GenerationConfig;
11-
use OnnxRuntime\InferenceSession;
11+
use Codewithkyrian\Transformers\Utils\InferenceSession;
1212

1313
/**
1414
* The BART Model with a language modeling head. Can be used for summarization.

src/Models/Pretrained/GPT2PretrainedModel.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
use Codewithkyrian\Transformers\Models\ModelArchitecture;
99
use Codewithkyrian\Transformers\Utils\AutoConfig;
1010
use Codewithkyrian\Transformers\Utils\GenerationConfig;
11-
use OnnxRuntime\InferenceSession;
11+
use Codewithkyrian\Transformers\Utils\InferenceSession;
1212

1313
class GPT2PretrainedModel extends PretrainedModel
1414
{

src/Models/Pretrained/M2M100ForConditionalGeneration.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
use Codewithkyrian\Transformers\Models\ModelArchitecture;
99
use Codewithkyrian\Transformers\Utils\AutoConfig;
1010
use Codewithkyrian\Transformers\Utils\GenerationConfig;
11-
use OnnxRuntime\InferenceSession;
11+
use Codewithkyrian\Transformers\Utils\InferenceSession;
1212

1313
class M2M100ForConditionalGeneration extends M2M100PretrainedModel
1414
{

src/Models/Pretrained/PretrainedModel.php

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@
2727
use Codewithkyrian\Transformers\Utils\AutoConfig;
2828
use Codewithkyrian\Transformers\Utils\GenerationConfig;
2929
use Codewithkyrian\Transformers\Utils\Hub;
30+
use Codewithkyrian\Transformers\Utils\InferenceSession;
3031
use Codewithkyrian\Transformers\Utils\Tensor;
3132
use Error;
3233
use Exception;
33-
use OnnxRuntime\InferenceSession;
3434
use Symfony\Component\Console\Output\OutputInterface;
3535
use function Codewithkyrian\Transformers\Utils\array_some;
3636

@@ -281,9 +281,7 @@ public function runSession(InferenceSession $session, array $inputs): array
281281

282282
$outputNames = array_column($session->outputs(), 'name');
283283

284-
$outputs = $session->run($outputNames, $inputs);
285-
286-
return array_combine($outputNames, array_map([Tensor::class, 'fromArray'], $outputs));
284+
return $session->run($outputNames, $inputs);
287285
} catch (MissingModelInputException $e) {
288286
throw $e;
289287
} catch (Exception $e) {
@@ -331,7 +329,8 @@ public function validateInputs(array $inputNames, array $inputs): array
331329
The following inputs will be ignored: "' . implode(', ', $ignored) . '".';
332330
}
333331

334-
return array_map(fn($i) => $i->toArray(), $inputs);
332+
// return array_map(fn($i) => $i->toArray(), $inputs);
333+
return $inputs;
335334
}
336335

337336
/**
@@ -468,50 +467,50 @@ public function addPastKeyValues(array &$decoderFeeds, ?array $pastKeyValues): v
468467
$decoderFeeds = array_merge($decoderFeeds, $pastKeyValues);
469468
} else {
470469
// TODO support batches (i.e., batch_size > 1)
471-
$batch_size = 1;
470+
$batchSize = 1;
472471

473472
if ($this->config->isEncoderDecoder && ($this->addEncoderPkv ?? true)) {
474-
$encoderShape = [$batch_size, $this->numEncoderHeads, 1, $this->encoderDimKv];
475-
$decoderShape = [$batch_size, $this->numDecoderHeads, 1, $this->decoderDimKv];
473+
$encoderShape = [$batchSize, $this->numEncoderHeads, 0, $this->encoderDimKv];
474+
$decoderShape = [$batchSize, $this->numDecoderHeads, 0, $this->decoderDimKv];
476475

477476

478477
for ($i = 0; $i < $this->numDecoderLayers; ++$i) {
479478
$decoderFeeds["past_key_values.$i.encoder.key"]
480479
= $decoderFeeds["past_key_values.$i.encoder.value"]
481-
= new Tensor(null, shape: $encoderShape);
480+
= new Tensor([], shape: $encoderShape);
482481
$decoderFeeds["past_key_values.$i.decoder.key"]
483482
= $decoderFeeds["past_key_values.$i.decoder.value"]
484-
= new Tensor(null, shape: $decoderShape);
483+
= new Tensor([], shape: $decoderShape);
485484
}
486485
} else if ($this->config->modelType === 'falcon') {
487486
// NOTE: Custom implementation for Falcon
488-
$shape = [$batch_size * $this->numHeads, 1, $this->dimKv];
487+
$shape = [$batchSize * $this->numHeads, 0, $this->dimKv];
489488

490489
for ($i = 0; $i < $this->numLayers; ++$i) {
491-
$decoderFeeds["past_key_values.$i.key"] = new Tensor(null, shape: $shape);
492-
$decoderFeeds["past_key_values.$i.value"] = new Tensor(null, shape: $shape);
490+
$decoderFeeds["past_key_values.$i.key"] = new Tensor([], shape: $shape);
491+
$decoderFeeds["past_key_values.$i.value"] = new Tensor([], shape: $shape);
493492
}
494493
} else if ($this->config['multi_query'] ?? null) { // e.g., for `gpt_bigcode`
495-
$shape = [$batch_size * $this->numHeads, 1, 2 * $this->dimKv];
494+
$shape = [$batchSize * $this->numHeads, 0, 2 * $this->dimKv];
496495

497496
for ($i = 0; $i < $this->numLayers; ++$i) {
498-
$decoderFeeds["past_key_values.$i.key_value"] = new Tensor(null, shape: $shape);
497+
$decoderFeeds["past_key_values.$i.key_value"] = new Tensor([], shape: $shape);
499498
}
500499
} else if ($this->config['model_type'] === 'bloom') {
501500
// NOTE: Custom implementation for Bloom
502-
$keyShape = [$batch_size * $this->numHeads, $this->dimKv, 1];
503-
$valueShape = [$batch_size * $this->numHeads, 1, $this->dimKv];
501+
$keyShape = [$batchSize * $this->numHeads, $this->dimKv, 0];
502+
$valueShape = [$batchSize * $this->numHeads, 0, $this->dimKv];
504503

505504
for ($i = 0; $i < $this->numLayers; ++$i) {
506-
$decoderFeeds["past_key_values.$i.key"] = new Tensor(null, shape: $keyShape);
507-
$decoderFeeds["past_key_values.$i.value"] = new Tensor(null, shape: $valueShape);
505+
$decoderFeeds["past_key_values.$i.key"] = new Tensor([], shape: $keyShape);
506+
$decoderFeeds["past_key_values.$i.value"] = new Tensor([], shape: $valueShape);
508507
}
509508
} else { // Decoder-only
510-
$shape = [$batch_size, $this->numHeads, 1, $this->dimKv];
509+
$shape = [$batchSize, $this->numHeads, 0, $this->dimKv];
511510

512511
for ($i = 0; $i < $this->numLayers; ++$i) {
513-
$decoderFeeds["past_key_values.$i.key"] = new Tensor(null, shape: $shape);
514-
$decoderFeeds["past_key_values.$i.value"] = new Tensor(null, shape: $shape);
512+
$decoderFeeds["past_key_values.$i.key"] = new Tensor([], shape: $shape);
513+
$decoderFeeds["past_key_values.$i.value"] = new Tensor([], shape: $shape);
515514
}
516515
}
517516
}
@@ -521,8 +520,10 @@ public function addPastKeyValues(array &$decoderFeeds, ?array $pastKeyValues): v
521520
* @param Tensor $inputs The input token ids.
522521
* @param GenerationConfig|null $generationConfig The generation configuration to use. If null, default configuration will be used.
523522
* @param LogitsProcessorList|null $logitsProcessor An optional logits processor to use. If null, a new LogitsProcessorList instance will be created.
524-
* @param array|null $inputsAttentionMask An optional attention mask for the inputs.
523+
* @param Tensor|null $inputsAttentionMask An optional attention mask for the inputs.
524+
* @param Streamer|null $streamer
525525
* @return array An array of generated output sequences, where each sequence is an array of token IDs.
526+
* @throws Exception
526527
*/
527528
public function generate(
528529
Tensor $inputs,
@@ -609,6 +610,7 @@ public function generate(
609610

610611
$output = $this->runBeam($beam);
611612

613+
612614
// add attentions/scores to beam only if user requested
613615
if ($generationConfig->output_attentions) {
614616
$this->addAttentionsToBeam($beam, $output);
@@ -626,6 +628,7 @@ public function generate(
626628
$logits = $output['logits']->slice(null, -1, null);
627629
// $logits = $output['logits'];
628630

631+
629632
// Apply logits processor
630633
$logitsProcessor($beam['output_token_ids'], $logits);
631634

@@ -649,7 +652,6 @@ public function generate(
649652

650653
}
651654

652-
653655
++$numOutputTokens;
654656

655657
// Group and select best beams
@@ -665,15 +667,13 @@ function ($group) use ($generationConfig) {
665667
$this->groupBeams($newestBeams)
666668
));
667669

668-
669670
// Flatten beams
670671
$beams = $newestBeams;
671672

672673
// Stream the beams if a streamer is provided
673674
$streamer?->put($beams);
674675
}
675676

676-
677677
// TODO: Ensure that we can return non-batched outputs
678678

679679
$groupedBeams = $this->groupBeams($beams);

src/Models/Pretrained/Qwen2PreTrainedModel.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
use Codewithkyrian\Transformers\Models\ModelArchitecture;
99
use Codewithkyrian\Transformers\Utils\AutoConfig;
1010
use Codewithkyrian\Transformers\Utils\GenerationConfig;
11-
use OnnxRuntime\InferenceSession;
11+
use Codewithkyrian\Transformers\Utils\InferenceSession;
1212

1313
/**
1414
* The bare Qwen2 Model outputting raw hidden-states without any specific head on top.
@@ -32,7 +32,7 @@ public function __construct(
3232
$this->config['pad_token_id'] = $this->config['eos_token_id'];
3333
$this->config->padTokenId = $this->config['eos_token_id'];
3434

35-
$this->numHeads = $this->config['num_key_value_heads'] ?? $this->config['num_attention_heads'];
35+
$this->numHeads = $this->config['num_key_value_heads'] ?? $this->config['num_attention_heads'];
3636
$this->numLayers = $this->config['num_hidden_layers'];
3737
$this->dimKv = $this->config['hidden_size'] / $this->config['num_attention_heads'];
3838
}

src/Models/Pretrained/T5ForConditionalGeneration.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
use Codewithkyrian\Transformers\Models\ModelArchitecture;
99
use Codewithkyrian\Transformers\Utils\AutoConfig;
1010
use Codewithkyrian\Transformers\Utils\GenerationConfig;
11-
use OnnxRuntime\InferenceSession;
11+
use Codewithkyrian\Transformers\Utils\InferenceSession;
1212

1313
/**
1414
* T5Model is a class representing a T5 model for conditional generation.

0 commit comments

Comments
 (0)