Skip to content

Commit c33ab83

Browse files
Add llama model support
1 parent 5e310f9 commit c33ab83

14 files changed

Lines changed: 102 additions & 27 deletions

composer.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
"codewithkyrian/onnxruntime-downloader-plugin": "^1.1",
2222
"symfony/console": "^6.4|^7.0",
2323
"imagine/imagine": "^1.3",
24-
"rokka/imagine-vips": "^0.31.0",
25-
"spatie/fork": "^1.2"
24+
"rokka/imagine-vips": "^0.31.0"
2625
},
2726
"require-dev": {
2827
"pestphp/pest": "^2.31",

examples/pipelines/text-generation.php

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
ini_set('memory_limit', -1);
1313
//
1414
//$generator = pipeline('text-generation', 'Xenova/gpt2');
15-
$generator = pipeline('text-generation', 'Xenova/Qwen1.5-0.5B-Chat');
16-
//
15+
//$generator = pipeline('text-generation', 'Xenova/Qwen1.5-0.5B-Chat');
16+
$generator = pipeline('text-generation', 'Xenova/TinyLlama-1.1B-Chat-v1.0');
17+
1718
$streamer = StdOutStreamer::make();
1819

1920
$messages = [
2021
['role' => 'system', 'content' => 'You are a helpful assistant.'],
21-
['role' => 'user', 'content' => 'What is diffusion?'],
22+
['role' => 'user', 'content' => 'What is diffusion in chemistry?'],
2223
];
2324

2425
$input = $generator->tokenizer->applyChatTemplate($messages, addGenerationPrompt: true, tokenize: false);

src/Decoders/ByteFallback.php

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ public function __construct(array $config)
1818

1919
protected function decodeChain(array $tokens): array
2020
{
21-
$new_tokens = [];
22-
$previous_byte_tokens = [];
21+
$newTokens = [];
22+
$previousByteTokens = [];
2323

2424
foreach ($tokens as $token) {
2525
$bytes = null;
@@ -30,22 +30,22 @@ protected function decodeChain(array $tokens): array
3030
}
3131
}
3232
if ($bytes !== null) {
33-
$previous_byte_tokens[] = $bytes;
33+
$previousByteTokens[] = $bytes;
3434
} else {
35-
if (count($previous_byte_tokens) > 0) {
36-
$string = $this->bytesToString($previous_byte_tokens);
37-
$new_tokens[] = $string;
38-
$previous_byte_tokens = [];
35+
if (count($previousByteTokens) > 0) {
36+
$string = $this->bytesToString($previousByteTokens);
37+
$newTokens[] = $string;
38+
$previousByteTokens = [];
3939
}
40-
$new_tokens[] = $token;
40+
$newTokens[] = $token;
4141
}
4242
}
43-
if (count($previous_byte_tokens) > 0) {
44-
$string = $this->bytesToString($previous_byte_tokens);
45-
$new_tokens[] = $string;
43+
if (count($previousByteTokens) > 0) {
44+
$string = $this->bytesToString($previousByteTokens);
45+
$newTokens[] = $string;
4646
}
4747

48-
return $new_tokens;
48+
return $newTokens;
4949
}
5050

5151
/**
@@ -56,9 +56,7 @@ protected function decodeChain(array $tokens): array
5656
*/
5757
protected function bytesToString(array $bytes): string
5858
{
59-
$chars = array_map(function ($byte) {
60-
return chr($byte);
61-
}, $bytes);
62-
return implode('', $chars);
59+
$binaryString = pack('C*', ...$bytes);
60+
return mb_convert_encoding($binaryString, 'ISO-8859-1');
6361
}
6462
}

src/Decoders/DecoderSequence.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ protected function decodeChain(array $tokens): array
2929
{
3030
return array_reduce(
3131
$this->decoders,
32-
fn(array $tokens, Decoder $decoder) => $decoder->decode($tokens),
32+
fn(array $tokens, Decoder $decoder) => $decoder->decodeChain($tokens),
3333
$tokens
3434
);
3535
}

src/Decoders/FuseDecoder.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ public function __construct(array $config)
1818

1919
protected function decodeChain(array $tokens): array
2020
{
21-
return [implode('', $tokens)];
21+
return [implode('', $tokens)];
2222
}
2323
}

src/Decoders/ReplaceDecoder.php

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,17 @@ protected function decodeChain(array $tokens): array
2020
{
2121
$pattern = $this->config['pattern'] ?? null;
2222

23+
2324
return $pattern == null ?
2425
$tokens :
2526
array_map(function ($token) use ($pattern) {
26-
return str_replace($pattern, $this->config['content'], $token);
27+
if (isset($pattern['Regex'])) {
28+
return preg_replace("/{$pattern['Regex']}/u", $this->config['content'], (string)$token);
29+
} elseif (isset($pattern['String'])) {
30+
return str_replace($pattern['String'], $this->config['content'], (string)$token);
31+
} else {
32+
return $token;
33+
}
2734
}, $tokens);
2835
}
2936
}

src/Models/Auto/AutoModel.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class AutoModel extends PretrainedMixin
3838
"gptj" => \Codewithkyrian\Transformers\Models\Pretrained\GPTJModel::class,
3939
"gpt_bigcode" => \Codewithkyrian\Transformers\Models\Pretrained\GPTBigCodeModel::class,
4040
"codegen" => \Codewithkyrian\Transformers\Models\Pretrained\CodeGenModel::class,
41+
"llama" => \Codewithkyrian\Transformers\Models\Pretrained\LlamaModel::class,
4142
"qwen2" => \Codewithkyrian\Transformers\Models\Pretrained\Qwen2Model::class,
4243
];
4344

src/Models/Auto/AutoModelForCausalLM.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class AutoModelForCausalLM extends PretrainedMixin
1212
'gptj' => \Codewithkyrian\Transformers\Models\Pretrained\GPTJForCausalLM::class,
1313
'gpt_bigcode' => \Codewithkyrian\Transformers\Models\Pretrained\GPTBigCodeForCausalLM::class,
1414
'codegen' => \Codewithkyrian\Transformers\Models\Pretrained\CodeGenForCausalLM::class,
15+
'llama' => \Codewithkyrian\Transformers\Models\Pretrained\LlamaForCausalLM::class,
1516
'trocr' => \Codewithkyrian\Transformers\Models\Pretrained\TrOCRForCausalLM::class,
1617
'qwen2' => \Codewithkyrian\Transformers\Models\Pretrained\Qwen2ForCausalLM::class
1718
];

src/Models/ModelArchitecture.php

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ protected function decoderRunBeam(PretrainedModel $model, array &$beam): array
119119
'past_key_values' => $beam['prev_model_outputs']['past_key_values'] ?? null,
120120
];
121121

122+
123+
// 2. Run
122124
$output = $model->forward($modelInputs);
123125

124126
// 3. Update
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Models\Pretrained;
7+
8+
class LlamaForCausalLM extends LlamaPretrainedModel
9+
{
10+
11+
}

0 commit comments

Comments
 (0)