Skip to content

Commit 0cbe8ba

Browse files
authored
add jax_padding support driver and server lib (#54)
1 parent ee90d08 commit 0cbe8ba

3 files changed

Lines changed: 31 additions & 0 deletions

File tree

jetstream/core/orchestrator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,16 @@ class Driver:
201201
_generate_slots: list[queue.Queue[int]] = []
202202
_active_requests: list[queue.Queue[tuple[int, ActiveRequest | None]]] = []
203203

204+
# todo: remove jax_padding after all then engine migrate to np padding
205+
_jax_padding = True
206+
204207
def __init__(
205208
self,
206209
prefill_engines: Optional[list[engine_api.Engine]] = None,
207210
generate_engines: Optional[list[engine_api.Engine]] = None,
208211
prefill_params: Optional[list[Any]] = None,
209212
generate_params: Optional[list[Any]] = None,
213+
jax_padding: bool = True,
210214
):
211215
if prefill_engines is None:
212216
prefill_engines = []
@@ -283,6 +287,8 @@ def __init__(
283287
for idx, engine in enumerate(self._generate_engines)
284288
]
285289

290+
self._jax_padding = jax_padding
291+
286292
# Create all threads
287293
self._prefill_threads = [
288294
JetThread(
@@ -428,6 +434,7 @@ def _prefill_thread(self, idx: int):
428434
vocab,
429435
is_bos=is_bos,
430436
max_prefill_length=prefill_engine.max_prefill_length,
437+
jax_padding=self._jax_padding,
431438
)
432439
# Compute new kv cache for the prefill_text, conditional on
433440
# history.

jetstream/core/server_lib.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def run(
9191
devices: Any,
9292
credentials: Any = grpc.insecure_server_credentials(),
9393
threads: int | None = None,
94+
jax_padding: bool = True,
9495
) -> JetStreamServer:
9596
"""Runs a server with a specified config.
9697
@@ -116,6 +117,7 @@ def run(
116117
generate_engines=engines.generate_engines + engines.interleaved_engines,
117118
prefill_params=prefill_params + shared_params,
118119
generate_params=generate_params + shared_params,
120+
jax_padding=jax_padding,
119121
)
120122
# We default threads to the total number of concurrent allowed decodes,
121123
# to make sure we can fully saturate the model. Set default minimum to 64.

jetstream/tests/engine/test_mock_engine.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,21 @@ def _prefill(self):
5959
)
6060
return engine, params, prefill_result, true_length
6161

62+
def _prefill_np(self):
63+
"""Performs prefill and returns a kv cache."""
64+
engine, params = self._setup()
65+
# A 2 will be pre-pended as 'bos' token from the vocab.
66+
text = "AB"
67+
metadata = engine.get_tokenizer()
68+
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
69+
tokens, true_length = token_utils.tokenize_and_pad(
70+
text, vocab, is_bos=True, jax_padding=False
71+
)
72+
prefill_result = engine.prefill(
73+
params=params, padded_tokens=tokens, true_length=3
74+
)
75+
return engine, params, prefill_result, true_length
76+
6277
def _generate(self, slot=1):
6378
"""Performs a single generation step."""
6479
engine, params, prefill_result, _ = self._prefill()
@@ -83,6 +98,13 @@ def test_prefill(self):
8398
prefill_result[:, :true_length], np.array([[4.0, 130.0, 132.0]])
8499
)
85100

101+
def test_prefill_np(self):
102+
"""Tests prefill with weight = 2."""
103+
_, _, prefill_result, true_length = self._prefill_np()
104+
np.testing.assert_array_equal(
105+
prefill_result[:, :true_length], np.array([[4.0, 130.0, 132.0]])
106+
)
107+
86108
def test_generate(self, slot=1):
87109
"""Tests multiple generation steps."""
88110
engine, params, decode_state, sampled_tokens = self._generate(slot=slot)

0 commit comments

Comments
 (0)