Skip to content

Commit ec78937

Browse files
authored
fix pad_tokens description (#80)
1 parent 6205d81 commit ec78937

1 file changed

Lines changed: 6 additions & 5 deletions

File tree

jetstream/engine/token_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,14 @@ def pad_tokens(
9191
max_prefill_length: Optional[int] = None,
9292
jax_padding: bool = True,
9393
) -> Tuple[Union[jax.Array, np.ndarray], int]:
94-
"""Tokenize and pads a string.
94+
"""Pads tokens to the nearest prefill length that is equal to or greater
95+
than the token length.
9596
9697
Args:
97-
s: String to tokenize.
98-
vocab: Vocabulary to tokenize with.
99-
is_bos: Whether or not this is the beginning of a sequence. Default to yes
100-
as prefill is typically used when beginning sequences.
98+
tokens: Tokens.
99+
bos_id: Bos ID.
100+
pad_id: Pad ID.
101+
is_bos: Add a beginning of sequence token if this is ture.
101102
prefill_lengths: Buckets to pad the sequence to for static compilation.
102103
max_prefill_length: Maximum bucket to use.
103104
jax_padding: convert to JAX padded tokens if True.

0 commit comments

Comments
 (0)