|
9 | 9 | use Codewithkyrian\Transformers\Models\Auto\AutoModelForCausalLM; |
10 | 10 | use Codewithkyrian\Transformers\Models\Auto\AutoModelForSeq2SeqLM; |
11 | 11 | use Codewithkyrian\Transformers\Models\Auto\AutoModelForSequenceClassification; |
| 12 | +use Codewithkyrian\Transformers\Pipelines\Task; |
12 | 13 | use Codewithkyrian\Transformers\Transformers; |
13 | 14 | use Symfony\Component\Console\Attribute\AsCommand; |
14 | 15 | use Symfony\Component\Console\Command\Command; |
15 | 16 | use Symfony\Component\Console\Input\InputArgument; |
16 | 17 | use Symfony\Component\Console\Input\InputInterface; |
17 | 18 | use Symfony\Component\Console\Input\InputOption; |
18 | 19 | use Symfony\Component\Console\Output\OutputInterface; |
| 20 | +use function Codewithkyrian\Transformers\Pipelines\pipeline; |
19 | 21 |
|
20 | 22 | #[AsCommand( |
21 | 23 | name: 'download-model', |
@@ -63,13 +65,14 @@ protected function execute(InputInterface $input, OutputInterface $output): int |
63 | 65 |
|
64 | 66 | // Download the model |
65 | 67 | try { |
66 | | - // TODO: Verify the tasks and corresponding AutoModel classes |
67 | | - $model = match ($task) { |
68 | | - 'text-generation' => AutoModelForCausalLM::fromPretrained($model, $quantized, cacheDir: $cacheDir), |
69 | | - 'text-classification', 'sentiment-analysis' => AutoModelForSequenceClassification::fromPretrained($model, $quantized, cacheDir: $cacheDir), |
70 | | - 'translation' => AutoModelForSeq2SeqLM::fromPretrained($model, $quantized, cacheDir: $cacheDir), |
71 | | - default => AutoModel::fromPretrained($model, $quantized, cacheDir: $cacheDir), |
72 | | - }; |
| 68 | + $task = Task::tryFrom($task); |
| 69 | + |
| 70 | + if ($task != null) { |
| 71 | + pipeline($task, $model); |
| 72 | + } else { |
| 73 | + AutoModel::fromPretrained($model, $quantized, cacheDir: $cacheDir); |
| 74 | + } |
| 75 | + |
73 | 76 |
|
74 | 77 | $output->writeln('✔ Model downloaded successfully.'); |
75 | 78 |
|
|
0 commit comments