Skip to content

Commit a326386

Browse files
Optimize Tensor softmax for 1 dimensional tensors
1 parent f8bf3b5 commit a326386

4 files changed

Lines changed: 31 additions & 34 deletions

File tree

examples/pipelines/text-generation.php

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,35 +13,35 @@
1313
//
1414
//$generator = pipeline('text-generation', 'Xenova/gpt2');
1515
//$generator = pipeline('text-generation', 'Xenova/Qwen1.5-0.5B-Chat');
16-
//$generator = pipeline('text-generation', 'Xenova/TinyLlama-1.1B-Chat-v1.0');
17-
//
18-
//$streamer = StdOutStreamer::make();
19-
//
20-
//$messages = [
21-
// ['role' => 'system', 'content' => 'You are a helpful assistant.'],
22-
// ['role' => 'user', 'content' => 'What is diffusion in chemistry?'],
23-
//];
24-
//
25-
//$input = $generator->tokenizer->applyChatTemplate($messages, addGenerationPrompt: true, tokenize: false);
26-
//
27-
//$output = $generator($input,
28-
// streamer: $streamer,
29-
// maxNewTokens: 128,
30-
// doSample: true,
31-
// returnFullText: false,
32-
//// temperature: 0.7,
33-
//// repetitionPenalty: 1.3,
34-
//// earlyStopping: true
35-
//);
16+
$generator = pipeline('text-generation', 'Xenova/TinyLlama-1.1B-Chat-v1.0');
3617

37-
$generator = pipeline('text-generation', 'Xenova/codegen-350M-mono');
38-
$streamer = StdOutStreamer::make($generator->tokenizer);
18+
$streamer = StdOutStreamer::make();
3919

40-
$output = $generator(
41-
'def fib(n):',
20+
$messages = [
21+
['role' => 'system', 'content' => 'You are a helpful assistant.'],
22+
['role' => 'user', 'content' => 'What is diffusion?'],
23+
];
24+
25+
$input = $generator->tokenizer->applyChatTemplate($messages, addGenerationPrompt: true, tokenize: false);
26+
27+
$output = $generator($input,
4228
streamer: $streamer,
43-
maxNewTokens: 100,
44-
doSample: true
29+
maxNewTokens: 256,
30+
doSample: true,
31+
returnFullText: false,
32+
// temperature: 0.7,
33+
// repetitionPenalty: 1.3,
34+
// earlyStopping: true
4535
);
4636

37+
//$generator = pipeline('text-generation', 'Xenova/codegen-350M-mono');
38+
//$streamer = StdOutStreamer::make($generator->tokenizer);
39+
//
40+
//$output = $generator(
41+
// 'def fib(n):',
42+
// streamer: $streamer,
43+
// maxNewTokens: 100,
44+
// doSample: true
45+
//);
46+
4747
dd($output[0]['generated_text'], timeUsage(), memoryUsage());

src/Generation/Samplers/Sampler.php

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ abstract public function sample(Tensor $logits, int $index);
4747
public function getLogits(Tensor $logits, int $index): array
4848
{
4949
$vocabSize = $logits->shape()[count($logits->shape()) - 1];
50-
$logs = $logits->buffer()->toArray();
50+
$logs = $logits->toBufferArray();
5151

5252
if ($index === -1) {
5353
$logs = array_slice($logs, -$vocabSize);
@@ -76,7 +76,6 @@ public function randomSelect(array $probabilities): int
7676

7777
// Generate a random number between 0 and the sum of probabilities
7878
$r = mt_rand() / mt_getrandmax() * $sumProbabilities;
79-
8079
foreach ($probabilities as $i => $probability) {
8180
$r -= $probability;
8281

src/Generation/Streamers/TextStreamer.php

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@
8989
class TextStreamer extends Streamer
9090
{
9191
protected PretrainedTokenizer $tokenizer;
92-
protected array $inputTokens = [];
9392
protected bool $excludeInput = true;
9493
protected string $printedText = '';
9594
protected mixed $onStreamCallback = null;
@@ -111,14 +110,14 @@ public static function make(): self
111110
public function init(PretrainedTokenizer $tokenizer, array $inputTokens, bool $excludeInput = false): void
112111
{
113112
$this->tokenizer = $tokenizer;
114-
$this->inputTokens = $inputTokens;
115113
$this->excludeInput = $excludeInput;
116114

117115
if ($this->excludeInput) {
118-
$this->printedText = $this->tokenizer->decode($this->inputTokens, skipSpecialTokens: true);
116+
$this->printedText = $this->tokenizer->decode($inputTokens, skipSpecialTokens: true);
119117
$this->printedLength = mb_strlen($this->printedText);
120118

121-
$this->lastDecodedCheckpointForToken = count($this->inputTokens);
119+
120+
$this->lastDecodedCheckpointForToken = count($inputTokens) - 1;
122121
$this->lastDecodedCheckpointForText = mb_strlen($this->printedText);
123122
}
124123
}
@@ -154,7 +153,6 @@ public function put(mixed $value): void
154153
// Check for punctuation marks indicating the end of a word or sentence
155154
$punctuationMarks = ['.', ',', '!', '?', ';', ':'];
156155

157-
158156
$this->printedText = mb_substr($this->printedText, 0, $this->lastDecodedCheckpointForText)
159157
. ($this->lastDecodedCheckpointForToken == 0 ? '' : ' ')
160158
. $decodedText;

src/Utils/Tensor.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,7 @@ public function permute(...$axes): static
974974
public function softmax(): array|static
975975
{
976976
return match ($this->ndim()) {
977-
1 => Math::softmax($this->toArray()),
977+
1 => $this->unsqueeze(0)->softmax2D(),
978978
2 => $this->softmax2D(),
979979
default => throw new InvalidArgumentException("Softmax is only supported for 1D and 2D tensors.")
980980
};

0 commit comments

Comments
 (0)