Skip to content

Commit d29cd6e

Browse files
bugfix: quantized option not being passed when downloading from Hub.
1 parent 56ea0ee commit d29cd6e

4 files changed

Lines changed: 347 additions & 306 deletions

File tree

src/Commands/DownloadModelCommand.php

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,12 @@
66
namespace Codewithkyrian\Transformers\Commands;
77

88
use Codewithkyrian\Transformers\Models\Auto\AutoModel;
9-
use Codewithkyrian\Transformers\Models\Auto\AutoModelForCausalLM;
10-
use Codewithkyrian\Transformers\Models\Auto\AutoModelForSeq2SeqLM;
11-
use Codewithkyrian\Transformers\Models\Auto\AutoModelForSequenceClassification;
129
use Codewithkyrian\Transformers\Pipelines\Task;
1310
use Codewithkyrian\Transformers\PretrainedTokenizers\AutoTokenizer;
1411
use Codewithkyrian\Transformers\Transformers;
12+
use Exception;
1513
use Symfony\Component\Console\Attribute\AsCommand;
1614
use Symfony\Component\Console\Command\Command;
17-
use Symfony\Component\Console\Helper\ProgressBar;
1815
use Symfony\Component\Console\Input\InputArgument;
1916
use Symfony\Component\Console\Input\InputInterface;
2017
use Symfony\Component\Console\Input\InputOption;
@@ -60,7 +57,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int
6057

6158
$model = $input->getArgument('model');
6259
$cacheDir = $input->getOption('cache-dir');
63-
$quantized = $input->getOption('quantized');
60+
$quantized = filter_var($input->getOption('quantized'), FILTER_VALIDATE_BOOLEAN);
6461
$task = $input->getArgument('task');
6562

6663
Transformers::setup()
@@ -71,9 +68,9 @@ protected function execute(InputInterface $input, OutputInterface $output): int
7168
$task = $task ? Task::tryFrom($task) : null;
7269

7370
if ($task != null) {
74-
pipeline($task, $model, output: $output);
71+
pipeline($task, $model, quantized: $quantized, output: $output);
7572
} else {
76-
AutoTokenizer::fromPretrained($model, quantized: $quantized, output: $output);
73+
AutoTokenizer::fromPretrained($model, output: $output);
7774
AutoModel::fromPretrained($model, $quantized, output: $output);
7875
}
7976

@@ -83,8 +80,8 @@ protected function execute(InputInterface $input, OutputInterface $output): int
8380
$this->askToStar($input, $output);
8481

8582
return Command::SUCCESS;
86-
} catch (\Exception $e) {
87-
$output->writeln(''. $e->getMessage());
83+
} catch (Exception $e) {
84+
$output->writeln('' . $e->getMessage());
8885
return Command::FAILURE;
8986
}
9087
}

src/Models/Auto/PretrainedMixin.php

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
use Codewithkyrian\Transformers\Models\Pretrained\PretrainedModel;
1111
use Codewithkyrian\Transformers\Utils\AutoConfig;
1212
use Symfony\Component\Console\Output\OutputInterface;
13-
use function Codewithkyrian\Transformers\Utils\timeUsage;
1413

1514
/**
1615
* Base class of all AutoModels. Contains the `from_pretrained` function
@@ -42,16 +41,15 @@ abstract class PretrainedMixin
4241
* @return PretrainedModel The instantiated pretrained model.
4342
*/
4443
public static function fromPretrained(
45-
string $modelNameOrPath,
46-
bool $quantized = true,
47-
?array $config = null,
48-
?string $cacheDir = null,
49-
string $revision = 'main',
50-
?string $modelFilename = null,
44+
string $modelNameOrPath,
45+
bool $quantized = true,
46+
?array $config = null,
47+
?string $cacheDir = null,
48+
string $revision = 'main',
49+
?string $modelFilename = null,
5150
?OutputInterface $output = null
5251
): PretrainedModel
5352
{
54-
5553
$config = AutoConfig::fromPretrained($modelNameOrPath, $config, $cacheDir, $revision, $output);
5654

5755
foreach (static::MODEL_CLASS_MAPPINGS as $modelClassMapping) {

0 commit comments

Comments
 (0)