Skip to content

Commit a6d7fd5

Browse files
Add a simple(but limited) jinja template parser
1 parent 715e2f7 commit a6d7fd5

4 files changed

Lines changed: 494 additions & 13 deletions

File tree

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
use Codewithkyrian\Transformers\PretrainedTokenizers\AutoTokenizer;
6+
7+
require_once './bootstrap.php';
8+
9+
//$tokenizer = AutoTokenizer::fromPretrained('mistralai/Mistral-7B-Instruct-v0.1');
10+
$tokenizer = AutoTokenizer::fromPretrained('facebook/blenderbot-400M-distill');
11+
$messages = [
12+
['role' => 'user', 'content' => 'Hello!'],
13+
['role' => 'assistant', 'content' => 'Hi! How are you?'],
14+
['role' => 'user', 'content' => 'I am doing great.'],
15+
['role' => 'assistant', 'content' => 'That is great to hear.'],
16+
];
17+
18+
$text = $tokenizer->applyChatTemplate($messages, addGenerationPrompt: true, tokenize: false);
19+
20+
dd($text);
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Exceptions;
7+
8+
class TemplateParseException extends \Exception implements TransformersException
9+
{
10+
public static function undefinedVariable($variableName): TemplateParseException
11+
{
12+
return new self("Undefined variable: $variableName");
13+
}
14+
}

src/PretrainedTokenizers/PretrainedTokenizer.php

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,23 @@
55

66
namespace Codewithkyrian\Transformers\PretrainedTokenizers;
77

8+
use ByJG\JinjaPhp\Template;
9+
use ByJG\JinjaPhp\Undefined\DebugUndefined;
810
use Codewithkyrian\Transformers\Decoders\Decoder;
911
use Codewithkyrian\Transformers\Normalizers\Normalizer;
1012
use Codewithkyrian\Transformers\PostProcessors\PostProcessedOutput;
1113
use Codewithkyrian\Transformers\PostProcessors\PostProcessor;
1214
use Codewithkyrian\Transformers\PreTokenizers\PreTokenizer;
1315
use Codewithkyrian\Transformers\Tokenizers\AddedToken;
1416
use Codewithkyrian\Transformers\Tokenizers\Tokenizer;
17+
use Codewithkyrian\Transformers\Utils\JinjaTemplate;
1518
use Codewithkyrian\Transformers\Utils\Tensor;
1619

1720
class PretrainedTokenizer
1821
{
1922
protected bool $returnTokenTypeIds = false;
2023

21-
protected bool $warnedAboutChatTemplate;
24+
protected bool $warnedAboutChatTemplate = false;
2225

2326
protected string $defaultChatTemplate = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}";
2427

@@ -68,7 +71,7 @@ class PretrainedTokenizer
6871
protected bool $legacy;
6972

7073
protected mixed $chatTemplate;
71-
protected \SplObjectStorage $compiledTemplateCache;
74+
protected array $compiledTemplateCache = [];
7275

7376
/**
7477
* @param array $tokenizerJSON The JSON of the tokenizer.
@@ -145,8 +148,6 @@ public function __construct(protected array $tokenizerJSON, protected array $tok
145148
$this->legacy = false;
146149

147150
$this->chatTemplate = $tokenizerConfig['chat_template'] ?? null;
148-
$this->compiledTemplateCache = new \SplObjectStorage();
149-
150151
}
151152

152153
/**
@@ -335,7 +336,7 @@ function ($key) {
335336
if (
336337
array_reduce($encodedTokens, function ($carry, $x) use ($encodedTokens) {
337338
foreach ($x as $key => $value) {
338-
if (count($value) !== count($encodedTokens[0][$key] ?? [])) {
339+
if (count($value ?? []) !== count($encodedTokens[0][$key] ?? [])) {
339340
return true;
340341
}
341342
}
@@ -610,10 +611,10 @@ function truncateHelper(array &$item, int $length): void
610611

611612
protected function getDefaultChatTemplate(): string
612613
{
613-
if (!$this->warnedAboutChatTemplate) {
614-
trigger_error("The default chat template is deprecated and will be removed in a future version. Please use the `chat_template` option instead.", E_USER_WARNING);
615-
$this->warnedAboutChatTemplate = true;
616-
}
614+
// if (!$this->warnedAboutChatTemplate) {
615+
// trigger_error("The default chat template is deprecated and will be removed in a future version. Please use the `chat_template` option instead.", E_USER_WARNING);
616+
// $this->warnedAboutChatTemplate = true;
617+
// }
617618

618619
return $this->defaultChatTemplate;
619620
}
@@ -670,7 +671,8 @@ public function applyChatTemplate(
670671

671672
if ($compiledTemplate === null) {
672673
// TODO: Use Jinja to compile the template
673-
$compiledTemplate = null;
674+
$compiledTemplate = new JinjaTemplate($chatTemplate);
675+
// $compiledTemplate->withUndefined(new DebugUndefined());
674676
$this->compiledTemplateCache[$chatTemplate] = $compiledTemplate;
675677
}
676678

@@ -682,7 +684,6 @@ public function applyChatTemplate(
682684
}
683685
}
684686

685-
686687
$rendered = $compiledTemplate->render(array_merge([
687688
'messages' => $conversation,
688689
'add_generation_prompt' => $addGenerationPrompt,
@@ -695,10 +696,10 @@ public function applyChatTemplate(
695696
addSpecialTokens: false,
696697
truncation: $truncation,
697698
maxLength: $maxLength
698-
)['input_ids'];
699+
)['input_ids']->toArray();
699700
}
700701

701-
return $rendered;
702+
return stripcslashes($rendered);
702703
}
703704

704705
}

0 commit comments

Comments
 (0)