66namespace Codewithkyrian \Transformers \Commands ;
77
88use Codewithkyrian \Transformers \Models \Auto \AutoModel ;
9- use Codewithkyrian \Transformers \Models \Auto \AutoModelForCausalLM ;
10- use Codewithkyrian \Transformers \Models \Auto \AutoModelForSeq2SeqLM ;
11- use Codewithkyrian \Transformers \Models \Auto \AutoModelForSequenceClassification ;
129use Codewithkyrian \Transformers \Pipelines \Task ;
1310use Codewithkyrian \Transformers \PretrainedTokenizers \AutoTokenizer ;
1411use Codewithkyrian \Transformers \Transformers ;
12+ use Exception ;
1513use Symfony \Component \Console \Attribute \AsCommand ;
1614use Symfony \Component \Console \Command \Command ;
17- use Symfony \Component \Console \Helper \ProgressBar ;
1815use Symfony \Component \Console \Input \InputArgument ;
1916use Symfony \Component \Console \Input \InputInterface ;
2017use Symfony \Component \Console \Input \InputOption ;
@@ -60,7 +57,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int
6057
6158 $ model = $ input ->getArgument ('model ' );
6259 $ cacheDir = $ input ->getOption ('cache-dir ' );
63- $ quantized = $ input ->getOption ('quantized ' );
60+ $ quantized = filter_var ( $ input ->getOption ('quantized ' ), FILTER_VALIDATE_BOOLEAN );
6461 $ task = $ input ->getArgument ('task ' );
6562
6663 Transformers::setup ()
@@ -71,9 +68,9 @@ protected function execute(InputInterface $input, OutputInterface $output): int
7168 $ task = $ task ? Task::tryFrom ($ task ) : null ;
7269
7370 if ($ task != null ) {
74- pipeline ($ task , $ model , output: $ output );
71+ pipeline ($ task , $ model , quantized: $ quantized , output: $ output );
7572 } else {
76- AutoTokenizer::fromPretrained ($ model , quantized: $ quantized , output: $ output );
73+ AutoTokenizer::fromPretrained ($ model , output: $ output );
7774 AutoModel::fromPretrained ($ model , $ quantized , output: $ output );
7875 }
7976
@@ -83,8 +80,8 @@ protected function execute(InputInterface $input, OutputInterface $output): int
8380 $ this ->askToStar ($ input , $ output );
8481
8582 return Command::SUCCESS ;
86- } catch (\ Exception $ e ) {
87- $ output ->writeln ('✘ ' . $ e ->getMessage ());
83+ } catch (Exception $ e ) {
84+ $ output ->writeln ('✘ ' . $ e ->getMessage ());
8885 return Command::FAILURE ;
8986 }
9087 }
0 commit comments