Skip to content

Commit a815a89

Browse files
Add chat mode detection
1 parent 0c5bf22 commit a815a89

2 files changed

Lines changed: 14 additions & 1 deletion

File tree

src/Pipelines/TextGenerationPipeline.php

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ public function __invoke(array|string $texts, ...$args): array
6666

6767
$generationConfig = new GenerationConfig($snakeCasedArgs);
6868

69+
$isChatMode = $this->isChatMode($texts);
70+
71+
if ($isChatMode) {
72+
$texts = $this->tokenizer->applyChatTemplate($texts, addGenerationPrompt: true, tokenize: false);
73+
}
74+
6975
$isBatched = is_array($texts);
7076

7177
if (!$isBatched) {
@@ -87,6 +93,7 @@ public function __invoke(array|string $texts, ...$args): array
8793

8894
$decoded = $this->tokenizer->batchDecode($outputTokenIds, skipSpecialTokens: true);
8995

96+
9097
$toReturn = array_fill(0, count($texts), []);
9198

9299
for ($i = 0; $i < count($decoded); ++$i) {
@@ -104,4 +111,11 @@ protected function camelCaseToSnakeCase(string $input): string
104111
{
105112
return strtolower(preg_replace('/(?<!^)[A-Z]/', '_$0', $input));
106113
}
114+
115+
// Detect chat mode
116+
protected function isChatMode(string|array $texts): bool
117+
{
118+
return is_array($texts) && isset($texts[0]) && is_array($texts[0]) && !array_is_list($texts[0]);
119+
120+
}
107121
}

src/Utils/Helpers.php

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ function array_every(array $array, callable $callback): bool
5858
return true;
5959
}
6060

61-
6261
function joinPaths(string ...$args): string
6362
{
6463
$paths = [];

0 commit comments

Comments
 (0)