@@ -41,6 +41,47 @@ def take_nearest_length(lengths: list[int], length: int) -> int:
4141 return lengths [pos ]
4242
4343
44+ def tokenize_and_pad (
45+ s : str ,
46+ vocab : Vocabulary ,
47+ is_bos : bool = True ,
48+ prefill_lengths : Optional [List [int ]] = None ,
49+ max_prefill_length : Optional [int ] = None ,
50+ jax_padding : bool = True ,
51+ ) -> Tuple [Union [jax .Array , np .ndarray ], int ]:
52+ """Tokenize and pads a string.
53+
54+ Args:
55+ s: String to tokenize.
56+ vocab: Vocabulary to tokenize with.
57+ is_bos: Whether or not this is the beginning of a sequence. Default to yes
58+ as prefill is typically used when beginning sequences.
59+ prefill_lengths: Buckets to pad the sequence to for static compilation.
60+ max_prefill_length: Maximum bucket to use.
61+ jax_padding: convert to JAX padded tokens if True.
62+
63+ Returns:
64+ tokens: Tokenized into integers.
65+ true_length: Actual length of the non-padded sequence.
66+ """
67+
68+ tokens = np .array (vocab .encode_tf (s )) # [Length]
69+ bos_id = vocab .bos_id
70+ pad_id = vocab .pad_id
71+ assert pad_id == 0 , "Further logic required if pad_id not 0."
72+
73+ padded_tokens , true_length = pad_tokens (
74+ tokens = tokens ,
75+ bos_id = bos_id ,
76+ pad_id = pad_id ,
77+ is_bos = is_bos ,
78+ prefill_lengths = prefill_lengths ,
79+ max_prefill_length = max_prefill_length ,
80+ jax_padding = jax_padding ,
81+ )
82+ return padded_tokens , true_length
83+
84+
4485def pad_tokens (
4586 tokens : np .ndarray ,
4687 bos_id : int ,
0 commit comments