Skip to content

Commit 813bc4d

Browse files
feat: refactor model generation for simplicity and efficiency
- New Stopping Criteria: Added MaxLength, MaxTime, and Interruptable stopping criteria for more flexible generation control. - Streamers Refactor: Simplified streamer implementation to improve clarity. - Performance Benchmarking: Introduced a token-per-second (TPS) metric to benchmark model performance across updates. - Bug Fixes: - Fixed an error in the Tensor::slice() method - Corrected the RepetitionPenaltyLogitsProcessor to properly utilize tokens.
1 parent 89ab0f1 commit 813bc4d

24 files changed

Lines changed: 791 additions & 712 deletions

examples/pipelines/summarization.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,4 @@
5858

5959
$summary = $summarizer($article, streamer: $streamer, maxNewTokens: 512);
6060

61-
dd("Done", timeUsage(), memoryUsage());
61+
dd("Done", timeUsage(), memoryUsage());

examples/pipelines/text-generation.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,4 @@
4747
// returnFullText: true,
4848
//);
4949

50-
dd($output[0]['generated_text'], timeUsage(), memoryUsage());
50+
dd($output[0]['generated_text'], $streamer->getTPS()." tps", timeUsage(), memoryUsage());

examples/pipelines/text2text-generation.php

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,4 @@
2121

2222
$output = $generator($query, streamer: $streamer, maxNewTokens: 256, doSample: true, repetitionPenalty: 1.1, temperature: 0.7);
2323

24-
//dd($output);
25-
dd('Done', timeUsage(), memoryUsage());
24+
dd('Done', $streamer->getTPS()." tps", timeUsage(), memoryUsage());

src/Generation/LogitsProcessors/LogitsProcessorList.php

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public function push(LogitsProcessor $item): void
2828
*
2929
* @param LogitsProcessor[] $items The logits processor functions to add.
3030
*/
31-
public function extend(traversable $items): void
31+
public function extend(Traversable $items): void
3232
{
3333
foreach ($items as $item) {
3434
$this->processors[] = $item;
@@ -41,13 +41,15 @@ public function extend(traversable $items): void
4141
* @param array $inputIds The input IDs for the language model.
4242
* @param Tensor $batchedLogits A 2D array of logits, where each row corresponds to a single input sequence.
4343
*/
44-
public function __invoke(array $inputIds, Tensor &$batchedLogits): void
44+
public function __invoke(array $inputIds, Tensor &$batchedLogits): Tensor
4545
{
46-
for ($i = 0; $i < count($batchedLogits); $i++) {
47-
foreach ($this->processors as $processor) {
48-
$processor($inputIds, $batchedLogits[$i]); // Apply processors in-place
49-
}
46+
$toReturn = $batchedLogits;
47+
48+
foreach ($this->processors as $processor) {
49+
$toReturn = $processor($inputIds, $toReturn); // Some apply processors in-place
5050
}
51+
52+
return $toReturn;
5153
}
5254

5355
/**
@@ -59,4 +61,4 @@ public function getIterator(): Traversable
5961
{
6062
yield from $this->processors;
6163
}
62-
}
64+
}

src/Generation/LogitsProcessors/RepetitionPenaltyLogitsProcessor.php

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
*/
1313
class RepetitionPenaltyLogitsProcessor extends LogitsProcessor
1414
{
15-
public function __construct(protected float $penalty)
16-
{
17-
}
15+
public function __construct(protected float $penalty) {}
1816

1917
/**
2018
* Apply the repetition penalty to the logits.
@@ -24,13 +22,16 @@ public function __invoke(array $inputIds, Tensor $logits): Tensor
2422
// Modify the logits corresponding to each element in `input_ids`.
2523
// As a consequence, the logits corresponding to tokens that appear
2624
// many times in the output will be penalised more.
27-
foreach ($inputIds as $inputId) {
28-
if ($logits->buffer()[$inputId] < 0) {
29-
$logits->buffer()[$inputId] *= $this->penalty;
30-
} else {
31-
$logits->buffer()[$inputId] /= $this->penalty;
25+
for ($i = 0; $i < count($inputIds); $i++) {
26+
foreach ($inputIds[$i] as $inputId) {
27+
if ($logits[$i]->buffer()[$inputId] < 0) {
28+
$logits[$i]->buffer()[$inputId] *= $this->penalty;
29+
} else {
30+
$logits[$i]->buffer()[$inputId] /= $this->penalty;
31+
}
3232
}
3333
}
34+
3435
return $logits;
3536
}
36-
}
37+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Codewithkyrian\Transformers\Generation\StoppingCriteria;
6+
7+
/**
8+
* This class stops generation whenever the "end-of-sequence" token is generated.
9+
*/
10+
class EosTokenCriteria extends StoppingCriteria
11+
{
12+
private array $eosTokenIds;
13+
14+
/**
15+
* @param int|int[] $eosTokenId The id of the *end-of-sequence* token.
16+
*/
17+
public function __construct(int|array $eosTokenId)
18+
{
19+
$this->eosTokenIds = is_array($eosTokenId) ? $eosTokenId : [$eosTokenId];
20+
}
21+
22+
public function __invoke(array $inputIds, array $scores): array
23+
{
24+
return array_map(function ($ids) {
25+
$lastToken = end($ids);
26+
return in_array($lastToken, $this->eosTokenIds, true);
27+
}, $inputIds);
28+
}
29+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Codewithkyrian\Transformers\Generation\StoppingCriteria;
6+
7+
/**
8+
* This class stops generation whenever the user interrupts the process.
9+
*/
10+
class InterruptableStoppingCriteria extends StoppingCriteria
11+
{
12+
private bool $interrupted = false;
13+
14+
public function interrupt(): void
15+
{
16+
$this->interrupted = true;
17+
}
18+
19+
public function reset(): void
20+
{
21+
$this->interrupted = false;
22+
}
23+
24+
public function __invoke(array $inputIds, array $scores): array
25+
{
26+
return array_fill(0, count($inputIds), $this->interrupted);
27+
}
28+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Codewithkyrian\Transformers\Generation\StoppingCriteria;
6+
7+
use Codewithkyrian\Transformers\Transformers;
8+
9+
/**
10+
* This class stops generation whenever the full generated number of tokens exceeds `maxLength`.
11+
* For decoder-only transformers, this includes the initial prompted tokens.
12+
*/
13+
class MaxLengthCriteria extends StoppingCriteria
14+
{
15+
16+
17+
/**
18+
* @param int $maxLength The maximum length that the output sequence can have in number of tokens.
19+
* @param int|null $maxPositionEmbeddings The maximum model length,
20+
*/
21+
public function __construct(protected int $maxLength, protected ?int $maxPositionEmbeddings = null) {}
22+
23+
/**
24+
* Evaluates whether generation should stop based on token count.
25+
*
26+
* @param array $inputIds Array of input IDs (2D array where each sub-array is a sequence of token IDs).
27+
* @param array $scores Optional scores for the generated tokens.
28+
*
29+
* @return array|bool[]
30+
*/
31+
public function __invoke(array $inputIds, array $scores): array
32+
{
33+
// return array_map(fn ($ids) => count($ids) >= $this->maxLength, $inputIds);
34+
$results = [];
35+
foreach ($inputIds as $ids) {
36+
$currentLength = count($ids);
37+
$isDone = $currentLength >= $this->maxLength;
38+
39+
if ($this->maxPositionEmbeddings !== null && !$isDone && $currentLength >= $this->maxPositionEmbeddings) {
40+
echo
41+
"This is a friendly reminder - the current text generation call will exceed the model's predefined " .
42+
"maximum length ({$this->maxPositionEmbeddings}). Depending on the model, you may observe " .
43+
"exceptions, performance degradation, or nothing at all."
44+
;
45+
}
46+
47+
$results[] = $isDone;
48+
}
49+
50+
return $results;
51+
}
52+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Codewithkyrian\Transformers\Generation\StoppingCriteria;
6+
7+
class MaxTimeCriteria extends StoppingCriteria
8+
{
9+
private float $maxTime;
10+
private float $initialTimestamp;
11+
12+
/**
13+
* @param float $maxTime The maximum allowed time in seconds for the generation.
14+
* @param float|null $initialTimestamp The start of the generation allowed time. Defaults to the current time.
15+
*/
16+
public function __construct(float $maxTime, ?float $initialTimestamp = null)
17+
{
18+
$this->maxTime = $maxTime;
19+
$this->initialTimestamp = $initialTimestamp ?? microtime(true);
20+
}
21+
22+
/**
23+
* Evaluates whether generation should stop based on elapsed time.
24+
*
25+
* @param array $inputIds Array of input IDs (2D array where each sub-array is a sequence of token IDs).
26+
* @param array $scores Scores for the generated tokens.
27+
*
28+
* @return array Boolean array indicating whether generation should stop for each sequence.
29+
*/
30+
public function __invoke(array $inputIds, array $scores): array
31+
{
32+
$elapsedTime = microtime(true) - $this->initialTimestamp;
33+
$isDone = $elapsedTime > $this->maxTime;
34+
35+
// Return the same stopping criteria for all sequences
36+
return array_fill(0, count($inputIds), $isDone);
37+
}
38+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Codewithkyrian\Transformers\Generation\StoppingCriteria;
6+
7+
use Traversable;
8+
9+
/**
10+
* Abstract base class for all stopping criteria that can be applied during generation.
11+
*/
12+
abstract class StoppingCriteria
13+
{
14+
/**
15+
* @param int[][] $inputIds Indices of input sequence tokens in the vocabulary of shape `(batch_size, sequence_length)`.
16+
* @param float[][] $scores Prediction scores of a language modeling head of shape `(batch_size, vocab_size)`.
17+
*
18+
* @return bool[] A list of booleans indicating whether each sequence should be stopped.
19+
*/
20+
abstract public function __invoke(array $inputIds, array $scores): array;
21+
}

0 commit comments

Comments
 (0)