Skip to content

Commit 6932252

Browse files
Merge pull request #25 from CodeWithKyrian/add-llama-support
Add support for Llama models
2 parents 2200fa3 + 0e332a6 commit 6932252

13 files changed

Lines changed: 95 additions & 27 deletions

File tree

composer.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
"codewithkyrian/onnxruntime-downloader-plugin": "^1.1",
2222
"symfony/console": "^6.4|^7.0",
2323
"imagine/imagine": "^1.3",
24-
"rokka/imagine-vips": "^0.31.0",
25-
"spatie/fork": "^1.2"
24+
"rokka/imagine-vips": "^0.31.0"
2625
},
2726
"require-dev": {
2827
"pestphp/pest": "^2.31",

examples/pipelines/text-generation.php

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
ini_set('memory_limit', -1);
1313
//
1414
//$generator = pipeline('text-generation', 'Xenova/gpt2');
15-
$generator = pipeline('text-generation', 'Xenova/Qwen1.5-0.5B-Chat');
16-
//
15+
//$generator = pipeline('text-generation', 'Xenova/Qwen1.5-0.5B-Chat');
16+
$generator = pipeline('text-generation', 'Xenova/TinyLlama-1.1B-Chat-v1.0');
17+
1718
$streamer = StdOutStreamer::make();
1819

1920
$messages = [
2021
['role' => 'system', 'content' => 'You are a helpful assistant.'],
21-
['role' => 'user', 'content' => 'What is diffusion?'],
22+
['role' => 'user', 'content' => 'What is diffusion in chemistry?'],
2223
];
2324

2425
$input = $generator->tokenizer->applyChatTemplate($messages, addGenerationPrompt: true, tokenize: false);

src/Decoders/ByteFallback.php

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ public function __construct(array $config)
1818

1919
protected function decodeChain(array $tokens): array
2020
{
21-
$new_tokens = [];
22-
$previous_byte_tokens = [];
21+
$newTokens = [];
22+
$previousByteTokens = [];
2323

2424
foreach ($tokens as $token) {
2525
$bytes = null;
@@ -30,22 +30,22 @@ protected function decodeChain(array $tokens): array
3030
}
3131
}
3232
if ($bytes !== null) {
33-
$previous_byte_tokens[] = $bytes;
33+
$previousByteTokens[] = $bytes;
3434
} else {
35-
if (count($previous_byte_tokens) > 0) {
36-
$string = $this->bytesToString($previous_byte_tokens);
37-
$new_tokens[] = $string;
38-
$previous_byte_tokens = [];
35+
if (count($previousByteTokens) > 0) {
36+
$string = $this->bytesToString($previousByteTokens);
37+
$newTokens[] = $string;
38+
$previousByteTokens = [];
3939
}
40-
$new_tokens[] = $token;
40+
$newTokens[] = $token;
4141
}
4242
}
43-
if (count($previous_byte_tokens) > 0) {
44-
$string = $this->bytesToString($previous_byte_tokens);
45-
$new_tokens[] = $string;
43+
if (count($previousByteTokens) > 0) {
44+
$string = $this->bytesToString($previousByteTokens);
45+
$newTokens[] = $string;
4646
}
4747

48-
return $new_tokens;
48+
return $newTokens;
4949
}
5050

5151
/**
@@ -56,9 +56,7 @@ protected function decodeChain(array $tokens): array
5656
*/
5757
protected function bytesToString(array $bytes): string
5858
{
59-
$chars = array_map(function ($byte) {
60-
return chr($byte);
61-
}, $bytes);
62-
return implode('', $chars);
59+
$binaryString = pack('C*', ...$bytes);
60+
return mb_convert_encoding($binaryString, 'ISO-8859-1');
6361
}
6462
}

src/Decoders/FuseDecoder.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ public function __construct(array $config)
1818

1919
protected function decodeChain(array $tokens): array
2020
{
21-
return [implode('', $tokens)];
21+
return [implode('', $tokens)];
2222
}
2323
}

src/Decoders/StripDecoder.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ protected function decodeChain(array $tokens): array
2828
return array_map(function ($token) {
2929
$startCut = 0;
3030
for ($i = 0; $i < $this->start; ++$i) {
31-
if ($token[$i] === $this->content) {
31+
if ($token[$i] ?? null === $this->content) {
3232
$startCut = $i + 1;
3333
continue;
3434
} else {
@@ -39,7 +39,7 @@ protected function decodeChain(array $tokens): array
3939
$stopCut = strlen($token);
4040
for ($i = 0; $i < $this->stop; ++$i) {
4141
$index = strlen($token) - $i - 1;
42-
if ($token[$index] === $this->content) {
42+
if ($token[$index] ?? null === $this->content) {
4343
$stopCut = $index;
4444
continue;
4545
} else {

src/Models/Auto/AutoModel.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class AutoModel extends PretrainedMixin
3838
"gptj" => \Codewithkyrian\Transformers\Models\Pretrained\GPTJModel::class,
3939
"gpt_bigcode" => \Codewithkyrian\Transformers\Models\Pretrained\GPTBigCodeModel::class,
4040
"codegen" => \Codewithkyrian\Transformers\Models\Pretrained\CodeGenModel::class,
41+
"llama" => \Codewithkyrian\Transformers\Models\Pretrained\LlamaModel::class,
4142
"qwen2" => \Codewithkyrian\Transformers\Models\Pretrained\Qwen2Model::class,
4243
];
4344

src/Models/Auto/AutoModelForCausalLM.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class AutoModelForCausalLM extends PretrainedMixin
1212
'gptj' => \Codewithkyrian\Transformers\Models\Pretrained\GPTJForCausalLM::class,
1313
'gpt_bigcode' => \Codewithkyrian\Transformers\Models\Pretrained\GPTBigCodeForCausalLM::class,
1414
'codegen' => \Codewithkyrian\Transformers\Models\Pretrained\CodeGenForCausalLM::class,
15+
'llama' => \Codewithkyrian\Transformers\Models\Pretrained\LlamaForCausalLM::class,
1516
'trocr' => \Codewithkyrian\Transformers\Models\Pretrained\TrOCRForCausalLM::class,
1617
'qwen2' => \Codewithkyrian\Transformers\Models\Pretrained\Qwen2ForCausalLM::class
1718
];

src/Models/ModelArchitecture.php

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ protected function decoderRunBeam(PretrainedModel $model, array &$beam): array
119119
'past_key_values' => $beam['prev_model_outputs']['past_key_values'] ?? null,
120120
];
121121

122+
123+
// 2. Run
122124
$output = $model->forward($modelInputs);
123125

124126
// 3. Update
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Models\Pretrained;
7+
8+
class LlamaForCausalLM extends LlamaPretrainedModel
9+
{
10+
11+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Models\Pretrained;
7+
8+
/**
9+
* The bare LLaMA Model outputting raw hidden-states without any specific head on top.
10+
*/
11+
class LlamaModel extends LlamaPretrainedModel
12+
{
13+
14+
}

0 commit comments

Comments
 (0)