@@ -59,6 +59,21 @@ def _prefill(self):
5959 )
6060 return engine , params , prefill_result , true_length
6161
62+ def _prefill_np (self ):
63+ """Performs prefill and returns a kv cache."""
64+ engine , params = self ._setup ()
65+ # A 2 will be pre-pended as 'bos' token from the vocab.
66+ text = "AB"
67+ metadata = engine .get_tokenizer ()
68+ vocab = token_utils .load_vocab (metadata .path , metadata .extra_ids )
69+ tokens , true_length = token_utils .tokenize_and_pad (
70+ text , vocab , is_bos = True , jax_padding = False
71+ )
72+ prefill_result = engine .prefill (
73+ params = params , padded_tokens = tokens , true_length = 3
74+ )
75+ return engine , params , prefill_result , true_length
76+
6277 def _generate (self , slot = 1 ):
6378 """Performs a single generation step."""
6479 engine , params , prefill_result , _ = self ._prefill ()
@@ -83,6 +98,13 @@ def test_prefill(self):
8398 prefill_result [:, :true_length ], np .array ([[4.0 , 130.0 , 132.0 ]])
8499 )
85100
101+ def test_prefill_np (self ):
102+ """Tests prefill with weight = 2."""
103+ _ , _ , prefill_result , true_length = self ._prefill_np ()
104+ np .testing .assert_array_equal (
105+ prefill_result [:, :true_length ], np .array ([[4.0 , 130.0 , 132.0 ]])
106+ )
107+
86108 def test_generate (self , slot = 1 ):
87109 """Tests multiple generation steps."""
88110 engine , params , decode_state , sampled_tokens = self ._generate (slot = slot )
0 commit comments