Skip to content

Commit 8128c8a

Browse files
authored
Allow tokenizer to customize stop_tokens (#84)
1 parent 87aa565 commit 8128c8a

4 files changed

Lines changed: 12 additions & 1 deletion

File tree

benchmarks/__init__.py

Whitespace-only changes.

jetstream/engine/mock_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class TestVocab(Vocabulary):
5151
eos_id = 1
5252
bos_id = 2
5353
unk_id = 3
54+
stop_tokens = {pad_id, eos_id}
5455
_base_vocab_size = 2**16
5556
tokenizer: TestTokenizer = TestTokenizer()
5657

jetstream/engine/token_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def process_result_tokens(
187187
slot_valid = slot_data.valid
188188
slot_lengths = slot_data.lengths
189189
samples, speculations = slot_tokens.shape
190-
stop_tokens = [tokenizer.eos_id, tokenizer.pad_id]
190+
stop_tokens = tokenizer.stop_tokens
191191
# Stop anything which has reached it's max length.
192192
complete = complete | (slot_lengths > slot_max_length)
193193
if debug:
@@ -395,6 +395,11 @@ def decode(self, token_ids: list[int]) -> str:
395395
"""
396396
return self.tokenizer.decode(token_ids)
397397

398+
@property
399+
def stop_tokens(self) -> set[int]:
400+
"""ID of the stop token."""
401+
return self.tokenizer.stop_tokens
402+
398403
@property
399404
def pad_id(self) -> int:
400405
"""ID of the pad token."""

jetstream/engine/tokenizer_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,8 @@ def eos_id(self) -> int:
6565
@abc.abstractmethod
6666
def bos_id(self) -> int:
6767
"""ID of BOS token."""
68+
69+
@property
70+
def stop_tokens(self) -> set[int]:
71+
"""ID of the stop token."""
72+
return {self.eos_id, self.pad_id}

0 commit comments

Comments
 (0)