Skip to content

Commit 9882719

Browse files
refactor: rename model class mappings for consistency and clarity
- Updated model class mapping constants in AutoModel and its subclasses to use a unified naming convention. - Removed redundant MODEL_CLASS_MAPPINGS constants in favor of direct usage of MODELS. - Improved code readability and maintainability by consolidating model definitions.
1 parent a1da846 commit 9882719

19 files changed

Lines changed: 101 additions & 128 deletions

src/Models/Auto/AutoModel.php

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
class AutoModel extends AutoModelBase
88
{
9-
const ENCODER_ONLY_MODEL_MAPPING = [
9+
const ENCODER_ONLY_MODELS = [
1010
"albert" => \Codewithkyrian\Transformers\Models\Pretrained\AlbertModel::class,
1111
"bert" => \Codewithkyrian\Transformers\Models\Pretrained\BertModel::class,
1212
"distilbert" => \Codewithkyrian\Transformers\Models\Pretrained\DistilBertModel::class,
@@ -18,7 +18,7 @@ class AutoModel extends AutoModelBase
1818
"clip" => \Codewithkyrian\Transformers\Models\Pretrained\CLIPModel::class,
1919
"vit" => \Codewithkyrian\Transformers\Models\Pretrained\ViTModel::class,
2020
"deit" => \Codewithkyrian\Transformers\Models\Pretrained\DeiTModel::class,
21-
"siglip" => \Codewithkyrian\Transformers\Models\Pretrained\SigLipModel::class,
21+
"siglip" => \Codewithkyrian\Transformers\Models\Pretrained\SiglipModel::class,
2222

2323
"audio-spectrogram-transformer" => \Codewithkyrian\Transformers\Models\Pretrained\ASTModel::class,
2424
"wav2vec2" => \Codewithkyrian\Transformers\Models\Pretrained\Wav2Vec2Model::class,
@@ -30,13 +30,13 @@ class AutoModel extends AutoModelBase
3030
'swin2sr' => \Codewithkyrian\Transformers\Models\Pretrained\Swin2SRModel::class,
3131
];
3232

33-
const ENCODER_DECODER_MODEL_MAPPING = [
33+
const ENCODER_DECODER_MODELS = [
3434
"t5" => \Codewithkyrian\Transformers\Models\Pretrained\T5Model::class,
3535
"bart" => \Codewithkyrian\Transformers\Models\Pretrained\BartModel::class,
3636
"m2m_100" => \Codewithkyrian\Transformers\Models\Pretrained\M2M100Model::class,
3737
];
3838

39-
const DECODER_ONLY_MODEL_MAPPING = [
39+
const DECODER_ONLY_MODELS = [
4040
"gpt2" => \Codewithkyrian\Transformers\Models\Pretrained\GPT2Model::class,
4141
"gptj" => \Codewithkyrian\Transformers\Models\Pretrained\GPTJModel::class,
4242
"gpt_bigcode" => \Codewithkyrian\Transformers\Models\Pretrained\GPTBigCodeModel::class,
@@ -45,21 +45,21 @@ class AutoModel extends AutoModelBase
4545
"qwen2" => \Codewithkyrian\Transformers\Models\Pretrained\Qwen2Model::class,
4646
];
4747

48-
const MODEL_CLASS_MAPPINGS = [
49-
self::ENCODER_ONLY_MODEL_MAPPING,
50-
self::ENCODER_DECODER_MODEL_MAPPING,
51-
self::DECODER_ONLY_MODEL_MAPPING,
48+
const MODELS = [
49+
...self::ENCODER_ONLY_MODELS,
50+
...self::ENCODER_DECODER_MODELS,
51+
...self::DECODER_ONLY_MODELS,
5252

53-
AutoModelForSequenceClassification::MODEL_CLASS_MAPPING,
54-
AutoModelForTokenClassification::MODEL_CLASS_MAPPING,
55-
AutoModelForSeq2SeqLM::MODEL_CLASS_MAPPING,
56-
AutoModelForCausalLM::MODEL_CLASS_MAPPING,
57-
AutoModelForMaskedLM::MODEL_CLASS_MAPPING,
58-
AutoModelForQuestionAnswering::MODEL_CLASS_MAPPING,
59-
AutoModelForImageClassification::MODEL_CLASS_MAPPING,
60-
AutoModelForVision2Seq::MODEL_CLASS_MAPPING,
61-
AutoModelForObjectDetection::MODEL_CLASS_MAPPING,
62-
AutoModelForZeroShotObjectDetection::MODEL_CLASS_MAPPING,
53+
...AutoModelForSequenceClassification::MODELS,
54+
...AutoModelForTokenClassification::MODELS,
55+
...AutoModelForSeq2SeqLM::MODELS,
56+
...AutoModelForCausalLM::MODELS,
57+
...AutoModelForMaskedLM::MODELS,
58+
...AutoModelForQuestionAnswering::MODELS,
59+
...AutoModelForImageClassification::MODELS,
60+
...AutoModelForVision2Seq::MODELS,
61+
...AutoModelForObjectDetection::MODELS,
62+
...AutoModelForZeroShotObjectDetection::MODELS,
6363
];
6464

6565

src/Models/Auto/AutoModelBase.php

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ abstract class AutoModelBase
1919
{
2020
/**
2121
* Mapping from model type to model class.
22-
* @var array<string, array<string, string>> The model class mappings.
22+
* @var array<string, class-string<PretrainedModel>> The model class mappings.
2323
*/
24-
const MODEL_CLASS_MAPPINGS = [];
24+
const MODELS = [];
2525

2626
/**
2727
* Whether to attempt to instantiate the base class (`PretrainedModel`) if
@@ -47,15 +47,12 @@ public static function fromPretrained(
4747
?string $cacheDir = null,
4848
string $revision = 'main',
4949
?string $modelFilename = null,
50-
?callable $onProgress = null
50+
?callable $onProgress = null
5151
): PretrainedModel {
5252
$config = AutoConfig::fromPretrained($modelNameOrPath, $config, $cacheDir, $revision, $onProgress);
5353

54-
foreach (static::MODEL_CLASS_MAPPINGS as $modelClassMapping) {
55-
$modelClass = $modelClassMapping[$config->modelType] ?? null;
56-
57-
58-
if ($modelClass === null) continue;
54+
foreach (static::MODELS as $modelType => $modelClass) {
55+
if ($modelType != $config->modelType) continue;
5956

6057
$modelArchitecture = self::getModelArchitecture($modelClass);
6158

@@ -92,20 +89,20 @@ public static function fromPretrained(
9289
protected static function getModelArchitecture($modelClass): ModelArchitecture
9390
{
9491
return match (true) {
95-
in_array($modelClass, AutoModel::ENCODER_ONLY_MODEL_MAPPING) => ModelArchitecture::EncoderOnly,
96-
in_array($modelClass, AutoModel::ENCODER_DECODER_MODEL_MAPPING) => ModelArchitecture::EncoderDecoder,
97-
in_array($modelClass, AutoModel::DECODER_ONLY_MODEL_MAPPING) => ModelArchitecture::DecoderOnly,
98-
in_array($modelClass, AutoModelForSequenceClassification::MODEL_CLASS_MAPPING) => ModelArchitecture::EncoderOnly,
99-
in_array($modelClass, AutoModelForSeq2SeqLM::MODEL_CLASS_MAPPING) => ModelArchitecture::Seq2SeqLM,
100-
in_array($modelClass, AutoModelForCausalLM::MODEL_CLASS_MAPPING) => ModelArchitecture::DecoderOnly,
101-
in_array($modelClass, AutoModelForTokenClassification::MODEL_CLASS_MAPPING) => ModelArchitecture::EncoderOnly,
102-
in_array($modelClass, AutoModelForQuestionAnswering::MODEL_CLASS_MAPPING) => ModelArchitecture::EncoderOnly,
103-
in_array($modelClass, AutoModelForMaskedLM::MODEL_CLASS_MAPPING) => ModelArchitecture::EncoderOnly,
104-
in_array($modelClass, AutoModelForVision2Seq::MODEL_CLASS_MAPPING) => ModelArchitecture::Vision2Seq,
105-
in_array($modelClass, AutoModelForImageClassification::MODEL_CLASS_MAPPING) => ModelArchitecture::EncoderOnly,
106-
in_array($modelClass, AutoModelForAudioClassification::MODEL_CLASS_MAPPING) => ModelArchitecture::EncoderOnly,
107-
in_array($modelClass, AutoModelForSpeechSeq2Seq::MODEL_CLASS_MAPPING) => ModelArchitecture::Seq2SeqLM,
108-
in_array($modelClass, AutoModelForCTC::MODEL_CLASS_MAPPING) => ModelArchitecture::EncoderOnly,
92+
in_array($modelClass, AutoModel::ENCODER_ONLY_MODELS) => ModelArchitecture::EncoderOnly,
93+
in_array($modelClass, AutoModel::ENCODER_DECODER_MODELS) => ModelArchitecture::EncoderDecoder,
94+
in_array($modelClass, AutoModel::DECODER_ONLY_MODELS) => ModelArchitecture::DecoderOnly,
95+
in_array($modelClass, AutoModelForSequenceClassification::MODELS) => ModelArchitecture::EncoderOnly,
96+
in_array($modelClass, AutoModelForSeq2SeqLM::MODELS) => ModelArchitecture::Seq2SeqLM,
97+
in_array($modelClass, AutoModelForCausalLM::MODELS) => ModelArchitecture::DecoderOnly,
98+
in_array($modelClass, AutoModelForTokenClassification::MODELS) => ModelArchitecture::EncoderOnly,
99+
in_array($modelClass, AutoModelForQuestionAnswering::MODELS) => ModelArchitecture::EncoderOnly,
100+
in_array($modelClass, AutoModelForMaskedLM::MODELS) => ModelArchitecture::EncoderOnly,
101+
in_array($modelClass, AutoModelForVision2Seq::MODELS) => ModelArchitecture::Vision2Seq,
102+
in_array($modelClass, AutoModelForImageClassification::MODELS) => ModelArchitecture::EncoderOnly,
103+
in_array($modelClass, AutoModelForAudioClassification::MODELS) => ModelArchitecture::EncoderOnly,
104+
in_array($modelClass, AutoModelForSpeechSeq2Seq::MODELS) => ModelArchitecture::Seq2SeqLM,
105+
in_array($modelClass, AutoModelForCTC::MODELS) => ModelArchitecture::EncoderOnly,
109106

110107
default => ModelArchitecture::EncoderOnly,
111108
};

src/Models/Auto/AutoModelForAudioClassification.php

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,8 @@
77

88
class AutoModelForAudioClassification extends AutoModelBase
99
{
10-
const MODEL_CLASS_MAPPING = [
10+
const MODELS = [
1111
'audio-spectrogram-transformer' => \Codewithkyrian\Transformers\Models\Pretrained\ASTForAudioClassification::class,
1212
'wav2vec2' => \Codewithkyrian\Transformers\Models\Pretrained\Wav2Vec2ForSequenceClassification::class,
1313
];
14-
15-
const MODEL_CLASS_MAPPINGS = [
16-
self::MODEL_CLASS_MAPPING,
17-
];
1814
}

src/Models/Auto/AutoModelForCTC.php

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,7 @@
77

88
class AutoModelForCTC extends AutoModelBase
99
{
10-
const MODEL_CLASS_MAPPING = [
10+
const MODELS = [
1111
'wav2vec2' => \Codewithkyrian\Transformers\Models\Pretrained\Wav2Vec2ForCTC::class,
1212
];
13-
14-
const MODEL_CLASS_MAPPINGS = [
15-
self::MODEL_CLASS_MAPPING,
16-
];
1713
}

src/Models/Auto/AutoModelForCausalLM.php

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
declare(strict_types=1);
44

5-
65
namespace Codewithkyrian\Transformers\Models\Auto;
76

87
class AutoModelForCausalLM extends AutoModelBase
98
{
10-
const MODEL_CLASS_MAPPING = [
9+
const MODELS = [
1110
'gpt2' => \Codewithkyrian\Transformers\Models\Pretrained\GPT2LMHeadModel::class,
1211
'gptj' => \Codewithkyrian\Transformers\Models\Pretrained\GPTJForCausalLM::class,
1312
'gpt_bigcode' => \Codewithkyrian\Transformers\Models\Pretrained\GPTBigCodeForCausalLM::class,
@@ -16,8 +15,4 @@ class AutoModelForCausalLM extends AutoModelBase
1615
'trocr' => \Codewithkyrian\Transformers\Models\Pretrained\TrOCRForCausalLM::class,
1716
'qwen2' => \Codewithkyrian\Transformers\Models\Pretrained\Qwen2ForCausalLM::class
1817
];
19-
20-
const MODEL_CLASS_MAPPINGS = [
21-
self::MODEL_CLASS_MAPPING,
22-
];
2318
}

src/Models/Auto/AutoModelForImageClassification.php

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,12 @@
22

33
declare(strict_types=1);
44

5-
65
namespace Codewithkyrian\Transformers\Models\Auto;
76

87
class AutoModelForImageClassification extends AutoModelBase
98
{
10-
const MODEL_CLASS_MAPPING = [
9+
const MODELS = [
1110
'vit' => \Codewithkyrian\Transformers\Models\Pretrained\ViTForImageClassification::class,
1211
'deit' => \Codewithkyrian\Transformers\Models\Pretrained\DeiTForImageClassification::class,
1312
];
14-
15-
const MODEL_CLASS_MAPPINGS = [
16-
self::MODEL_CLASS_MAPPING,
17-
];
1813
}

src/Models/Auto/AutoModelForImageFeatureExtraction.php

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,8 @@
77

88
class AutoModelForImageFeatureExtraction extends AutoModelBase
99
{
10-
const MODEL_CLASS_MAPPING = [
10+
const MODELS = [
1111
'clip' => \Codewithkyrian\Transformers\Models\Pretrained\CLIPVisionModelWithProjection::class,
1212
'siglip' => \Codewithkyrian\Transformers\Models\Pretrained\SiglipVisionModel::class,
1313
];
14-
15-
const MODEL_CLASS_MAPPINGS = [
16-
self::MODEL_CLASS_MAPPING,
17-
AutoModel::ENCODER_ONLY_MODEL_MAPPING,
18-
AutoModel::DECODER_ONLY_MODEL_MAPPING,
19-
];
2014
}

src/Models/Auto/AutoModelForImageToImage.php

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,7 @@
77

88
class AutoModelForImageToImage extends AutoModelBase
99
{
10-
const MODEL_CLASS_MAPPING = [
10+
const MODELS = [
1111
'swin2sr' => \Codewithkyrian\Transformers\Models\Pretrained\Swin2SRForImageSuperResolution::class,
1212
];
13-
14-
const MODEL_CLASS_MAPPINGS = [
15-
self::MODEL_CLASS_MAPPING,
16-
];
1713
}

src/Models/Auto/AutoModelForMaskedLM.php

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
declare(strict_types=1);
44

5-
65
namespace Codewithkyrian\Transformers\Models\Auto;
76

87
class AutoModelForMaskedLM extends AutoModelBase
98
{
10-
const MODEL_CLASS_MAPPING = [
9+
const MODELS = [
1110
"albert" => \Codewithkyrian\Transformers\Models\Pretrained\AlbertForMaskedLM::class,
1211
"bert" => \Codewithkyrian\Transformers\Models\Pretrained\BertForMaskedLM::class,
1312
"deberta" => \Codewithkyrian\Transformers\Models\Pretrained\DebertaForMaskedLM::class,
@@ -17,8 +16,4 @@ class AutoModelForMaskedLM extends AutoModelBase
1716
"roberta" => \Codewithkyrian\Transformers\Models\Pretrained\RobertaForMaskedLM::class,
1817
"roformer" => \Codewithkyrian\Transformers\Models\Pretrained\RoFormerForMaskedLM::class,
1918
];
20-
21-
const MODEL_CLASS_MAPPINGS = [
22-
self::MODEL_CLASS_MAPPING,
23-
];
2419
}

src/Models/Auto/AutoModelForObjectDetection.php

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,12 @@
22

33
declare(strict_types=1);
44

5-
65
namespace Codewithkyrian\Transformers\Models\Auto;
76

87
class AutoModelForObjectDetection extends AutoModelBase
98
{
10-
const MODEL_CLASS_MAPPING = [
9+
const MODELS = [
1110
'detr' => \Codewithkyrian\Transformers\Models\Pretrained\DetrForObjectDetection::class,
1211
'yolos' => \Codewithkyrian\Transformers\Models\Pretrained\YolosForObjectDetection::class,
1312
];
14-
15-
const MODEL_CLASS_MAPPINGS = [
16-
self::MODEL_CLASS_MAPPING,
17-
];
1813
}

0 commit comments

Comments
 (0)