Skip to content

Commit e7f55a4

Browse files
Use a custom Inference session class to make Tensor a first class output
1 parent 08ab987 commit e7f55a4

15 files changed

Lines changed: 841 additions & 81 deletions

composer.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
"codewithkyrian/onnxruntime-downloader-plugin": "^1.1",
2222
"symfony/console": "^6.4|^7.0",
2323
"imagine/imagine": "^1.3",
24-
"rokka/imagine-vips": "^0.31.0"
24+
"rokka/imagine-vips": "^0.31.0",
25+
"spatie/fork": "^1.2"
2526
},
2627
"require-dev": {
2728
"pestphp/pest": "^2.31",

examples/composer.json

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,34 @@
11
{
2-
"name": "kyrian/examples",
3-
"autoload": {
4-
"psr-4": {
5-
"Kyrian\\Examples\\": "/"
6-
}
7-
},
8-
"authors": [
9-
{
10-
"name": "Kyrian Obikwelu",
11-
"email": "koshnawaza@gmail.com"
12-
}
13-
],
14-
"require": {
15-
"php": "^8.1",
16-
"symfony/console": "^7.0",
17-
"codewithkyrian/transformers": "dev-change-init-process"
18-
},
19-
"minimum-stability": "dev",
20-
"require-dev": {
21-
"symfony/var-dumper": "^7.0"
22-
},
23-
"repositories": [
24-
{
25-
"type" : "path",
26-
"url": "../"
27-
}
28-
]
2+
"name": "kyrian/examples",
3+
"autoload": {
4+
"psr-4": {
5+
"Kyrian\\Examples\\": "/"
6+
}
7+
},
8+
"authors": [
9+
{
10+
"name": "Kyrian Obikwelu",
11+
"email": "koshnawaza@gmail.com"
12+
}
13+
],
14+
"require": {
15+
"php": "^8.1",
16+
"symfony/console": "^7.0",
17+
"codewithkyrian/transformers": "dev-main"
18+
},
19+
"minimum-stability": "dev",
20+
"require-dev": {
21+
"symfony/var-dumper": "^7.0"
22+
},
23+
"repositories": [
24+
{
25+
"type": "path",
26+
"url": "../"
27+
}
28+
],
29+
"config": {
30+
"allow-plugins": {
31+
"codewithkyrian/onnxruntime-downloader-plugin": true
32+
}
33+
}
2934
}

examples/pipelines/text-generation.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818

1919
$messages = [
2020
['role' => 'system', 'content' => 'You are a helpful assistant.'],
21-
['role' => 'user', 'content' => 'What is the product of 5 and 4'],
21+
['role' => 'user', 'content' => 'What is diffusion?'],
2222
];
2323

2424
$input = $generator->tokenizer->applyChatTemplate($messages, addGenerationPrompt: true, tokenize: false);
2525

2626
$output = $generator($input,
27-
streamer: $streamer,
27+
// streamer: $streamer,
2828
maxNewTokens: 128,
2929
doSample: true,
3030
returnFullText: false,

src/Models/ModelArchitecture.php

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,6 @@ protected function decoderRunBeam(PretrainedModel $model, array &$beam): array
119119
'past_key_values' => $beam['prev_model_outputs']['past_key_values'] ?? null,
120120
];
121121

122-
123-
// 2. Run
124122
$output = $model->forward($modelInputs);
125123

126124
// 3. Update
@@ -372,7 +370,6 @@ protected function seq2seqForward(PretrainedModel $model, array $modelInputs): a
372370
$model->addPastKeyValues($decoderFeeds, $pastKeyValues);
373371

374372
$decoderResults = $model->runSession($model->decoderMergedSession, $decoderFeeds);
375-
376373
$logits = $decoderResults['logits'];
377374
$pastKeyValues = $model->getPastKeyValues($decoderResults, $pastKeyValues);
378375

src/Models/Pretrained/BartForConditionalGeneration.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
use Codewithkyrian\Transformers\Models\ModelArchitecture;
99
use Codewithkyrian\Transformers\Utils\AutoConfig;
1010
use Codewithkyrian\Transformers\Utils\GenerationConfig;
11-
use OnnxRuntime\InferenceSession;
11+
use Codewithkyrian\Transformers\Utils\InferenceSession;
1212

1313
/**
1414
* The BART Model with a language modeling head. Can be used for summarization.

src/Models/Pretrained/GPT2PretrainedModel.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
use Codewithkyrian\Transformers\Models\ModelArchitecture;
99
use Codewithkyrian\Transformers\Utils\AutoConfig;
1010
use Codewithkyrian\Transformers\Utils\GenerationConfig;
11-
use OnnxRuntime\InferenceSession;
11+
use Codewithkyrian\Transformers\Utils\InferenceSession;
1212

1313
class GPT2PretrainedModel extends PretrainedModel
1414
{

src/Models/Pretrained/M2M100ForConditionalGeneration.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
use Codewithkyrian\Transformers\Models\ModelArchitecture;
99
use Codewithkyrian\Transformers\Utils\AutoConfig;
1010
use Codewithkyrian\Transformers\Utils\GenerationConfig;
11-
use OnnxRuntime\InferenceSession;
11+
use Codewithkyrian\Transformers\Utils\InferenceSession;
1212

1313
class M2M100ForConditionalGeneration extends M2M100PretrainedModel
1414
{

src/Models/Pretrained/PretrainedModel.php

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@
2727
use Codewithkyrian\Transformers\Utils\AutoConfig;
2828
use Codewithkyrian\Transformers\Utils\GenerationConfig;
2929
use Codewithkyrian\Transformers\Utils\Hub;
30+
use Codewithkyrian\Transformers\Utils\InferenceSession;
3031
use Codewithkyrian\Transformers\Utils\Tensor;
3132
use Error;
3233
use Exception;
33-
use OnnxRuntime\InferenceSession;
3434
use Symfony\Component\Console\Output\OutputInterface;
3535
use function Codewithkyrian\Transformers\Utils\array_some;
36+
use function Codewithkyrian\Transformers\Utils\timeUsage;
3637

3738
/**
3839
* A base class for pre-trained models that provides the model configuration and an ONNX session.
@@ -281,9 +282,10 @@ public function runSession(InferenceSession $session, array $inputs): array
281282

282283
$outputNames = array_column($session->outputs(), 'name');
283284

284-
$outputs = $session->run($outputNames, $inputs);
285-
286-
return array_combine($outputNames, array_map([Tensor::class, 'fromArray'], $outputs));
285+
timeUsage();
286+
$out = $session->run($outputNames, $inputs);
287+
dump(timeUsage(true));
288+
return $out;
287289
} catch (MissingModelInputException $e) {
288290
throw $e;
289291
} catch (Exception $e) {
@@ -331,7 +333,8 @@ public function validateInputs(array $inputNames, array $inputs): array
331333
The following inputs will be ignored: "' . implode(', ', $ignored) . '".';
332334
}
333335

334-
return array_map(fn($i) => $i->toArray(), $inputs);
336+
// return array_map(fn($i) => $i->toArray(), $inputs);
337+
return $inputs;
335338
}
336339

337340
/**
@@ -521,8 +524,10 @@ public function addPastKeyValues(array &$decoderFeeds, ?array $pastKeyValues): v
521524
* @param Tensor $inputs The input token ids.
522525
* @param GenerationConfig|null $generationConfig The generation configuration to use. If null, default configuration will be used.
523526
* @param LogitsProcessorList|null $logitsProcessor An optional logits processor to use. If null, a new LogitsProcessorList instance will be created.
524-
* @param array|null $inputsAttentionMask An optional attention mask for the inputs.
527+
* @param Tensor|null $inputsAttentionMask An optional attention mask for the inputs.
528+
* @param Streamer|null $streamer
525529
* @return array An array of generated output sequences, where each sequence is an array of token IDs.
530+
* @throws Exception
526531
*/
527532
public function generate(
528533
Tensor $inputs,
@@ -609,6 +614,7 @@ public function generate(
609614

610615
$output = $this->runBeam($beam);
611616

617+
612618
// add attentions/scores to beam only if user requested
613619
if ($generationConfig->output_attentions) {
614620
$this->addAttentionsToBeam($beam, $output);
@@ -626,6 +632,7 @@ public function generate(
626632
$logits = $output['logits']->slice(null, -1, null);
627633
// $logits = $output['logits'];
628634

635+
629636
// Apply logits processor
630637
$logitsProcessor($beam['output_token_ids'], $logits);
631638

@@ -649,7 +656,6 @@ public function generate(
649656

650657
}
651658

652-
653659
++$numOutputTokens;
654660

655661
// Group and select best beams
@@ -665,15 +671,13 @@ function ($group) use ($generationConfig) {
665671
$this->groupBeams($newestBeams)
666672
));
667673

668-
669674
// Flatten beams
670675
$beams = $newestBeams;
671676

672677
// Stream the beams if a streamer is provided
673678
$streamer?->put($beams);
674679
}
675680

676-
677681
// TODO: Ensure that we can return non-batched outputs
678682

679683
$groupedBeams = $this->groupBeams($beams);

src/Models/Pretrained/Qwen2PreTrainedModel.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
use Codewithkyrian\Transformers\Models\ModelArchitecture;
99
use Codewithkyrian\Transformers\Utils\AutoConfig;
1010
use Codewithkyrian\Transformers\Utils\GenerationConfig;
11-
use OnnxRuntime\InferenceSession;
11+
use Codewithkyrian\Transformers\Utils\InferenceSession;
1212

1313
/**
1414
* The bare Qwen2 Model outputting raw hidden-states without any specific head on top.
@@ -32,7 +32,7 @@ public function __construct(
3232
$this->config['pad_token_id'] = $this->config['eos_token_id'];
3333
$this->config->padTokenId = $this->config['eos_token_id'];
3434

35-
$this->numHeads = $this->config['num_key_value_heads'] ?? $this->config['num_attention_heads'];
35+
$this->numHeads = $this->config['num_key_value_heads'] ?? $this->config['num_attention_heads'];
3636
$this->numLayers = $this->config['num_hidden_layers'];
3737
$this->dimKv = $this->config['hidden_size'] / $this->config['num_attention_heads'];
3838
}

src/Models/Pretrained/T5ForConditionalGeneration.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
use Codewithkyrian\Transformers\Models\ModelArchitecture;
99
use Codewithkyrian\Transformers\Utils\AutoConfig;
1010
use Codewithkyrian\Transformers\Utils\GenerationConfig;
11-
use OnnxRuntime\InferenceSession;
11+
use Codewithkyrian\Transformers\Utils\InferenceSession;
1212

1313
/**
1414
* T5Model is a class representing a T5 model for conditional generation.

0 commit comments

Comments
 (0)