Skip to content

Commit 894bc90

Browse files
feat: add support for new models and update text generation pipeline
- Introduced new model classes: Gemma, Gemma2, Gemma3, Qwen3, and Phi with their respective causal language models. - Enhanced model handling in auto models for better lookup
1 parent 9882719 commit 894bc90

24 files changed

Lines changed: 199 additions & 42 deletions

examples/pipelines/text-generation.php

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
require_once './bootstrap.php';
66

7-
use Codewithkyrian\Transformers\Generation\Streamers\StdOutStreamer;
87
use Codewithkyrian\Transformers\Generation\Streamers\TextStreamer;
98
use function Codewithkyrian\Transformers\Pipelines\pipeline;
109
use function Codewithkyrian\Transformers\Utils\memoryUsage;
@@ -14,8 +13,9 @@
1413

1514
//$generator = pipeline('text-generation', 'Xenova/gpt2');
1615
//$generator = pipeline('text-generation', 'Xenova/Qwen1.5-0.5B-Chat');
17-
$generator = pipeline('text-generation', 'Xenova/TinyLlama-1.1B-Chat-v1.0');
16+
// $generator = pipeline('text-generation', 'Xenova/TinyLlama-1.1B-Chat-v1.0');
1817
// $generator = pipeline('text-generation', 'onnx-community/Llama-3.2-1B-Instruct', modelFilename: 'model_q4');
18+
$generator = pipeline('text-generation', 'onnx-community/Qwen3-0.6B-ONNX');
1919

2020
$streamer = TextStreamer::make()->shouldSkipPrompt();
2121

@@ -26,14 +26,15 @@
2626

2727
$input = $generator->tokenizer->applyChatTemplate($messages, addGenerationPrompt: true, tokenize: false);
2828

29-
$output = $generator($input,
29+
$output = $generator(
30+
$input,
3031
streamer: $streamer,
3132
maxNewTokens: 256,
3233
doSample: true,
3334
returnFullText: false,
34-
// temperature: 0.7,
35-
// repetitionPenalty: 1.3,
36-
// earlyStopping: true
35+
// temperature: 0.7,
36+
// repetitionPenalty: 1.3,
37+
// earlyStopping: true
3738
);
3839

3940
//$generator = pipeline('text-generation', 'Xenova/codegen-350M-mono');
@@ -47,4 +48,4 @@
4748
// returnFullText: true,
4849
//);
4950

50-
dd($output[0]['generated_text'], $streamer->getTPS()." tps", timeUsage(), memoryUsage());
51+
dd($output[0]['generated_text'], $streamer->getTPS() . " tps", timeUsage(), memoryUsage());

src/Models/Auto/AutoModel.php

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ class AutoModel extends AutoModelBase
4343
"codegen" => \Codewithkyrian\Transformers\Models\Pretrained\CodeGenModel::class,
4444
"llama" => \Codewithkyrian\Transformers\Models\Pretrained\LlamaModel::class,
4545
"qwen2" => \Codewithkyrian\Transformers\Models\Pretrained\Qwen2Model::class,
46+
"gemma" => \Codewithkyrian\Transformers\Models\Pretrained\GemmaModel::class,
47+
"gemma2" => \Codewithkyrian\Transformers\Models\Pretrained\Gemma2Model::class,
48+
"gemma3" => \Codewithkyrian\Transformers\Models\Pretrained\Gemma3Model::class,
49+
"qwen3" => \Codewithkyrian\Transformers\Models\Pretrained\Qwen3Model::class,
50+
"phi" => \Codewithkyrian\Transformers\Models\Pretrained\PhiModel::class,
51+
"phi3" => \Codewithkyrian\Transformers\Models\Pretrained\Phi3Model::class,
4652
];
4753

4854
const MODELS = [
@@ -62,6 +68,5 @@ class AutoModel extends AutoModelBase
6268
...AutoModelForZeroShotObjectDetection::MODELS,
6369
];
6470

65-
6671
const BASE_IF_FAIL = true;
6772
}

src/Models/Auto/AutoModelBase.php

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
declare(strict_types=1);
44

5-
65
namespace Codewithkyrian\Transformers\Models\Auto;
76

87
use Codewithkyrian\Transformers\Configs\AutoConfig;
@@ -51,39 +50,31 @@ public static function fromPretrained(
5150
): PretrainedModel {
5251
$config = AutoConfig::fromPretrained($modelNameOrPath, $config, $cacheDir, $revision, $onProgress);
5352

54-
foreach (static::MODELS as $modelType => $modelClass) {
55-
if ($modelType != $config->modelType) continue;
53+
$modelClass = static::MODELS[$config->modelType] ?? null;
5654

57-
$modelArchitecture = self::getModelArchitecture($modelClass);
55+
if ($modelClass === null) {
56+
if (static::BASE_IF_FAIL) {
57+
$logger = Transformers::getLogger();
58+
$logger->warning("Unknown model class for model type {$config->modelType}. Using base class PreTrainedModel.");
5859

59-
return $modelClass::fromPretrained(
60-
modelNameOrPath: $modelNameOrPath,
61-
quantized: $quantized,
62-
config: $config,
63-
cacheDir: $cacheDir,
64-
revision: $revision,
65-
modelFilename: $modelFilename,
66-
modelArchitecture: $modelArchitecture,
67-
onProgress: $onProgress
68-
);
60+
$modelClass = PretrainedModel::class;
61+
} else {
62+
throw UnsupportedModelTypeException::make($config->modelType);
63+
}
6964
}
7065

71-
if (static::BASE_IF_FAIL) {
72-
$logger = Transformers::getLogger();
73-
$logger->warning("Unknown model class for model type {$config->modelType}. Using base class PreTrainedModel.");
66+
$modelArchitecture = self::getModelArchitecture($modelClass);
7467

75-
return PretrainedModel::fromPretrained(
76-
modelNameOrPath: $modelNameOrPath,
77-
quantized: $quantized,
78-
config: $config,
79-
cacheDir: $cacheDir,
80-
revision: $revision,
81-
modelFilename: $modelFilename,
82-
onProgress: $onProgress
83-
);
84-
} else {
85-
throw UnsupportedModelTypeException::make($config->modelType);
86-
}
68+
return $modelClass::fromPretrained(
69+
modelNameOrPath: $modelNameOrPath,
70+
quantized: $quantized,
71+
config: $config,
72+
cacheDir: $cacheDir,
73+
revision: $revision,
74+
modelFilename: $modelFilename,
75+
modelArchitecture: $modelArchitecture,
76+
onProgress: $onProgress
77+
);
8778
}
8879

8980
protected static function getModelArchitecture($modelClass): ModelArchitecture

src/Models/Auto/AutoModelForCausalLM.php

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ class AutoModelForCausalLM extends AutoModelBase
1313
'codegen' => \Codewithkyrian\Transformers\Models\Pretrained\CodeGenForCausalLM::class,
1414
'llama' => \Codewithkyrian\Transformers\Models\Pretrained\LlamaForCausalLM::class,
1515
'trocr' => \Codewithkyrian\Transformers\Models\Pretrained\TrOCRForCausalLM::class,
16-
'qwen2' => \Codewithkyrian\Transformers\Models\Pretrained\Qwen2ForCausalLM::class
16+
'qwen2' => \Codewithkyrian\Transformers\Models\Pretrained\Qwen2ForCausalLM::class,
17+
'gemma' => \Codewithkyrian\Transformers\Models\Pretrained\GemmaForCausalLM::class,
18+
'gemma2' => \Codewithkyrian\Transformers\Models\Pretrained\Gemma2ForCausalLM::class,
19+
'gemma3' => \Codewithkyrian\Transformers\Models\Pretrained\Gemma3ForCausalLM::class,
20+
'qwen3' => \Codewithkyrian\Transformers\Models\Pretrained\Qwen3ForCausalLM::class,
21+
'phi' => \Codewithkyrian\Transformers\Models\Pretrained\PhiForCausalLM::class,
22+
'phi3' => \Codewithkyrian\Transformers\Models\Pretrained\Phi3ForCausalLM::class,
1723
];
1824
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Codewithkyrian\Transformers\Models\Pretrained;
6+
7+
class Gemma2ForCausalLM extends Gemma2PretrainedModel {}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Codewithkyrian\Transformers\Models\Pretrained;
6+
7+
/**
8+
* The bare Gemma2 Model outputting raw hidden-states without any specific head on top.
9+
*/
10+
class Gemma2Model extends Gemma2PretrainedModel {}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Codewithkyrian\Transformers\Models\Pretrained;
6+
7+
/**
8+
* The bare Gemma2 Model outputting raw hidden-states without any specific head on top.
9+
*/
10+
class Gemma2PretrainedModel extends PretrainedModel {}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Codewithkyrian\Transformers\Models\Pretrained;
6+
7+
class Gemma3ForCausalLM extends Gemma3PretrainedModel {}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Codewithkyrian\Transformers\Models\Pretrained;
6+
7+
/**
8+
* The bare Gemma3 Model outputting raw hidden-states without any specific head on top.
9+
*/
10+
class Gemma3Model extends Gemma3PretrainedModel {}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Codewithkyrian\Transformers\Models\Pretrained;
6+
7+
/**
8+
* The bare Gemma3 Model outputting raw hidden-states without any specific head on top.
9+
*/
10+
class Gemma3PretrainedModel extends PretrainedModel {}

0 commit comments

Comments
 (0)