Skip to content

Commit 7daeb10

Browse files
authored
add tokenize_and_pad function back (#70)
1 parent b282506 commit 7daeb10

1 file changed

Lines changed: 41 additions & 0 deletions

File tree

jetstream/engine/token_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,47 @@ def take_nearest_length(lengths: list[int], length: int) -> int:
4141
return lengths[pos]
4242

4343

44+
def tokenize_and_pad(
45+
s: str,
46+
vocab: Vocabulary,
47+
is_bos: bool = True,
48+
prefill_lengths: Optional[List[int]] = None,
49+
max_prefill_length: Optional[int] = None,
50+
jax_padding: bool = True,
51+
) -> Tuple[Union[jax.Array, np.ndarray], int]:
52+
"""Tokenize and pads a string.
53+
54+
Args:
55+
s: String to tokenize.
56+
vocab: Vocabulary to tokenize with.
57+
is_bos: Whether or not this is the beginning of a sequence. Default to yes
58+
as prefill is typically used when beginning sequences.
59+
prefill_lengths: Buckets to pad the sequence to for static compilation.
60+
max_prefill_length: Maximum bucket to use.
61+
jax_padding: convert to JAX padded tokens if True.
62+
63+
Returns:
64+
tokens: Tokenized into integers.
65+
true_length: Actual length of the non-padded sequence.
66+
"""
67+
68+
tokens = np.array(vocab.encode_tf(s)) # [Length]
69+
bos_id = vocab.bos_id
70+
pad_id = vocab.pad_id
71+
assert pad_id == 0, "Further logic required if pad_id not 0."
72+
73+
padded_tokens, true_length = pad_tokens(
74+
tokens=tokens,
75+
bos_id=bos_id,
76+
pad_id=pad_id,
77+
is_bos=is_bos,
78+
prefill_lengths=prefill_lengths,
79+
max_prefill_length=max_prefill_length,
80+
jax_padding=jax_padding,
81+
)
82+
return padded_tokens, true_length
83+
84+
4485
def pad_tokens(
4586
tokens: np.ndarray,
4687
bos_id: int,

0 commit comments

Comments
 (0)