Skip to content

Commit bb16e7c

Browse files
Change streamer init argument name
1 parent dc50b7c commit bb16e7c

4 files changed

Lines changed: 7 additions & 7 deletions

File tree

src/Generation/Streamers/Streamer.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
*/
1313
abstract class Streamer
1414
{
15-
abstract public function init(PretrainedTokenizer $tokenizer, array $inputTokens, bool $includeInput): void;
15+
abstract public function init(PretrainedTokenizer $tokenizer, array $inputTokens, bool $excludeInput = false): void;
1616

1717
abstract public function put(mixed $value): void;
1818

src/Generation/Streamers/TextStreamer.php

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class TextStreamer extends Streamer
9090
{
9191
protected PretrainedTokenizer $tokenizer;
9292
protected array $inputTokens = [];
93-
protected bool $includeInput = false;
93+
protected bool $excludeInput = true;
9494
protected string $printedText = '';
9595
protected mixed $onStreamCallback = null;
9696
protected mixed $onStreamEndCallback = null;
@@ -108,13 +108,13 @@ public static function make(): self
108108
return new static();
109109
}
110110

111-
public function init(PretrainedTokenizer $tokenizer, array $inputTokens, bool $includeInput): void
111+
public function init(PretrainedTokenizer $tokenizer, array $inputTokens, bool $excludeInput = false): void
112112
{
113113
$this->tokenizer = $tokenizer;
114114
$this->inputTokens = $inputTokens;
115-
$this->includeInput = $includeInput;
115+
$this->excludeInput = $excludeInput;
116116

117-
if (!$this->includeInput) {
117+
if ($this->excludeInput) {
118118
$this->printedText = $this->tokenizer->decode($this->inputTokens, skipSpecialTokens: true);
119119
$this->printedLength = mb_strlen($this->printedText);
120120

src/Pipelines/Text2TextGenerationPipeline.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ public function __invoke(array|string $inputs, ...$args): array
7979

8080

8181
// Streamer can only handle one input at a time for now, so we only pass the first input
82-
$streamer->init($this->tokenizer, $inputIds[0]->toArray(), true);
82+
$streamer->init($this->tokenizer, $inputIds[0]->toArray());
8383

8484
// Generate output token ids
8585
$outputTokenIds = $this->model->generate($inputIds, generationConfig: $generateKwargs, streamer: $streamer);

src/Pipelines/TextGenerationPipeline.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ public function __invoke(array|string $inputs, ...$args): array
114114
);
115115

116116
// Streamer can only handle one input at a time for now, so we only pass the first input
117-
$streamer->init($this->tokenizer, $inputIds[0]->toArray(), false);
117+
$streamer->init($this->tokenizer, $inputIds[0]->toArray(), true);
118118

119119
$outputTokenIds = $this->model->generate($inputIds,
120120
generationConfig: $generationConfig,

0 commit comments

Comments
 (0)