|
2 | 2 |
|
3 | 3 | declare(strict_types=1); |
4 | 4 |
|
5 | | - |
6 | 5 | namespace Codewithkyrian\Transformers\Models\Auto; |
7 | 6 |
|
8 | 7 | use Codewithkyrian\Transformers\Configs\AutoConfig; |
@@ -51,39 +50,31 @@ public static function fromPretrained( |
51 | 50 | ): PretrainedModel { |
52 | 51 | $config = AutoConfig::fromPretrained($modelNameOrPath, $config, $cacheDir, $revision, $onProgress); |
53 | 52 |
|
54 | | - foreach (static::MODELS as $modelType => $modelClass) { |
55 | | - if ($modelType != $config->modelType) continue; |
| 53 | + $modelClass = static::MODELS[$config->modelType] ?? null; |
56 | 54 |
|
57 | | - $modelArchitecture = self::getModelArchitecture($modelClass); |
| 55 | + if ($modelClass === null) { |
| 56 | + if (static::BASE_IF_FAIL) { |
| 57 | + $logger = Transformers::getLogger(); |
| 58 | + $logger->warning("Unknown model class for model type {$config->modelType}. Using base class PreTrainedModel."); |
58 | 59 |
|
59 | | - return $modelClass::fromPretrained( |
60 | | - modelNameOrPath: $modelNameOrPath, |
61 | | - quantized: $quantized, |
62 | | - config: $config, |
63 | | - cacheDir: $cacheDir, |
64 | | - revision: $revision, |
65 | | - modelFilename: $modelFilename, |
66 | | - modelArchitecture: $modelArchitecture, |
67 | | - onProgress: $onProgress |
68 | | - ); |
| 60 | + $modelClass = PretrainedModel::class; |
| 61 | + } else { |
| 62 | + throw UnsupportedModelTypeException::make($config->modelType); |
| 63 | + } |
69 | 64 | } |
70 | 65 |
|
71 | | - if (static::BASE_IF_FAIL) { |
72 | | - $logger = Transformers::getLogger(); |
73 | | - $logger->warning("Unknown model class for model type {$config->modelType}. Using base class PreTrainedModel."); |
| 66 | + $modelArchitecture = self::getModelArchitecture($modelClass); |
74 | 67 |
|
75 | | - return PretrainedModel::fromPretrained( |
76 | | - modelNameOrPath: $modelNameOrPath, |
77 | | - quantized: $quantized, |
78 | | - config: $config, |
79 | | - cacheDir: $cacheDir, |
80 | | - revision: $revision, |
81 | | - modelFilename: $modelFilename, |
82 | | - onProgress: $onProgress |
83 | | - ); |
84 | | - } else { |
85 | | - throw UnsupportedModelTypeException::make($config->modelType); |
86 | | - } |
| 68 | + return $modelClass::fromPretrained( |
| 69 | + modelNameOrPath: $modelNameOrPath, |
| 70 | + quantized: $quantized, |
| 71 | + config: $config, |
| 72 | + cacheDir: $cacheDir, |
| 73 | + revision: $revision, |
| 74 | + modelFilename: $modelFilename, |
| 75 | + modelArchitecture: $modelArchitecture, |
| 76 | + onProgress: $onProgress |
| 77 | + ); |
87 | 78 | } |
88 | 79 |
|
89 | 80 | protected static function getModelArchitecture($modelClass): ModelArchitecture |
|
0 commit comments