@@ -19,9 +19,9 @@ abstract class AutoModelBase
1919{
2020 /**
2121 * Mapping from model type to model class.
22- * @var array<string, array< string, string >> The model class mappings.
22+ * @var array<string, class- string<PretrainedModel >> The model class mappings.
2323 */
24- const MODEL_CLASS_MAPPINGS = [];
24+ const MODELS = [];
2525
2626 /**
2727 * Whether to attempt to instantiate the base class (`PretrainedModel`) if
@@ -47,15 +47,12 @@ public static function fromPretrained(
4747 ?string $ cacheDir = null ,
4848 string $ revision = 'main ' ,
4949 ?string $ modelFilename = null ,
50- ?callable $ onProgress = null
50+ ?callable $ onProgress = null
5151 ): PretrainedModel {
5252 $ config = AutoConfig::fromPretrained ($ modelNameOrPath , $ config , $ cacheDir , $ revision , $ onProgress );
5353
54- foreach (static ::MODEL_CLASS_MAPPINGS as $ modelClassMapping ) {
55- $ modelClass = $ modelClassMapping [$ config ->modelType ] ?? null ;
56-
57-
58- if ($ modelClass === null ) continue ;
54+ foreach (static ::MODELS as $ modelType => $ modelClass ) {
55+ if ($ modelType != $ config ->modelType ) continue ;
5956
6057 $ modelArchitecture = self ::getModelArchitecture ($ modelClass );
6158
@@ -92,20 +89,20 @@ public static function fromPretrained(
9289 protected static function getModelArchitecture ($ modelClass ): ModelArchitecture
9390 {
9491 return match (true ) {
95- in_array ($ modelClass , AutoModel::ENCODER_ONLY_MODEL_MAPPING ) => ModelArchitecture::EncoderOnly,
96- in_array ($ modelClass , AutoModel::ENCODER_DECODER_MODEL_MAPPING ) => ModelArchitecture::EncoderDecoder,
97- in_array ($ modelClass , AutoModel::DECODER_ONLY_MODEL_MAPPING ) => ModelArchitecture::DecoderOnly,
98- in_array ($ modelClass , AutoModelForSequenceClassification::MODEL_CLASS_MAPPING ) => ModelArchitecture::EncoderOnly,
99- in_array ($ modelClass , AutoModelForSeq2SeqLM::MODEL_CLASS_MAPPING ) => ModelArchitecture::Seq2SeqLM,
100- in_array ($ modelClass , AutoModelForCausalLM::MODEL_CLASS_MAPPING ) => ModelArchitecture::DecoderOnly,
101- in_array ($ modelClass , AutoModelForTokenClassification::MODEL_CLASS_MAPPING ) => ModelArchitecture::EncoderOnly,
102- in_array ($ modelClass , AutoModelForQuestionAnswering::MODEL_CLASS_MAPPING ) => ModelArchitecture::EncoderOnly,
103- in_array ($ modelClass , AutoModelForMaskedLM::MODEL_CLASS_MAPPING ) => ModelArchitecture::EncoderOnly,
104- in_array ($ modelClass , AutoModelForVision2Seq::MODEL_CLASS_MAPPING ) => ModelArchitecture::Vision2Seq,
105- in_array ($ modelClass , AutoModelForImageClassification::MODEL_CLASS_MAPPING ) => ModelArchitecture::EncoderOnly,
106- in_array ($ modelClass , AutoModelForAudioClassification::MODEL_CLASS_MAPPING ) => ModelArchitecture::EncoderOnly,
107- in_array ($ modelClass , AutoModelForSpeechSeq2Seq::MODEL_CLASS_MAPPING ) => ModelArchitecture::Seq2SeqLM,
108- in_array ($ modelClass , AutoModelForCTC::MODEL_CLASS_MAPPING ) => ModelArchitecture::EncoderOnly,
92+ in_array ($ modelClass , AutoModel::ENCODER_ONLY_MODELS ) => ModelArchitecture::EncoderOnly,
93+ in_array ($ modelClass , AutoModel::ENCODER_DECODER_MODELS ) => ModelArchitecture::EncoderDecoder,
94+ in_array ($ modelClass , AutoModel::DECODER_ONLY_MODELS ) => ModelArchitecture::DecoderOnly,
95+ in_array ($ modelClass , AutoModelForSequenceClassification::MODELS ) => ModelArchitecture::EncoderOnly,
96+ in_array ($ modelClass , AutoModelForSeq2SeqLM::MODELS ) => ModelArchitecture::Seq2SeqLM,
97+ in_array ($ modelClass , AutoModelForCausalLM::MODELS ) => ModelArchitecture::DecoderOnly,
98+ in_array ($ modelClass , AutoModelForTokenClassification::MODELS ) => ModelArchitecture::EncoderOnly,
99+ in_array ($ modelClass , AutoModelForQuestionAnswering::MODELS ) => ModelArchitecture::EncoderOnly,
100+ in_array ($ modelClass , AutoModelForMaskedLM::MODELS ) => ModelArchitecture::EncoderOnly,
101+ in_array ($ modelClass , AutoModelForVision2Seq::MODELS ) => ModelArchitecture::Vision2Seq,
102+ in_array ($ modelClass , AutoModelForImageClassification::MODELS ) => ModelArchitecture::EncoderOnly,
103+ in_array ($ modelClass , AutoModelForAudioClassification::MODELS ) => ModelArchitecture::EncoderOnly,
104+ in_array ($ modelClass , AutoModelForSpeechSeq2Seq::MODELS ) => ModelArchitecture::Seq2SeqLM,
105+ in_array ($ modelClass , AutoModelForCTC::MODELS ) => ModelArchitecture::EncoderOnly,
109106
110107 default => ModelArchitecture::EncoderOnly,
111108 };
0 commit comments