Skip to content

Commit e19a790

Browse files
authored
Add ray disaggregated serving support (#87)
* add ray dissagregated serving support * function fix * fix lint error * refactor parameter * add ActiveRequest annotation in function
1 parent eaf0d6e commit e19a790

3 files changed

Lines changed: 28 additions & 8 deletions

File tree

jetstream/core/config_lib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class ServerConfig:
3838
prefill_engine_create_fns: Tuple[CreateEngineFn, ...] = ()
3939
generate_engine_create_fns: Tuple[CreateEngineFn, ...] = ()
4040
interleaved_engine_create_fns: Tuple[CreateEngineFn, ...] = ()
41+
is_ray_backend: bool = False
4142

4243

4344
@dataclasses.dataclass

jetstream/core/orchestrator.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def __init__(
223223
interleaved_mode: bool = False,
224224
jax_padding: bool = True,
225225
metrics_collector: JetstreamMetricsCollector | None = None,
226+
is_ray_backend: bool = False,
226227
):
227228
if prefill_engines is None:
228229
prefill_engines = []
@@ -374,6 +375,7 @@ def __init__(
374375
)
375376
)
376377
self.live = True
378+
self._is_ray_backend = is_ray_backend
377379
# Start all threads
378380
for t in self._all_threads:
379381
t.start()
@@ -508,6 +510,29 @@ def _prefill_thread(self, idx: int):
508510
del prefill_result
509511
del request
510512

513+
def _jax_transfer_prefill_result(
514+
self, new_request: ActiveRequest, target_idx: int
515+
):
516+
new_request.prefill_result = jax.device_put(
517+
new_request.prefill_result,
518+
self._generate_engines[target_idx].get_prefix_destination_sharding(),
519+
)
520+
# Block here so we don't block on the generate thread that steps.
521+
jax.block_until_ready(new_request.prefill_result)
522+
523+
def _ray_transfer_prefill_result(
524+
self, new_request: ActiveRequest, target_idx: int
525+
):
526+
self._generate_engines[target_idx].transfer(new_request.prefill_result)
527+
528+
def _transfer_prefill_result(
529+
self, new_request: ActiveRequest, target_idx: int
530+
):
531+
if self._is_ray_backend:
532+
self._ray_transfer_prefill_result(new_request, target_idx)
533+
else:
534+
self._jax_transfer_prefill_result(new_request, target_idx)
535+
511536
def _transfer_thread(self, idx: int):
512537
"""Transfers the kv cache on an active request to the least full
513538
generate backlog."""
@@ -531,14 +556,7 @@ def _transfer_thread(self, idx: int):
531556
target_idx,
532557
)
533558
# Transfer the info to the relevant generate slice.
534-
new_request.prefill_result = jax.device_put(
535-
new_request.prefill_result,
536-
self._generate_engines[
537-
target_idx
538-
].get_prefix_destination_sharding(),
539-
)
540-
# Block here so we don't block on the generate thread that steps.
541-
jax.block_until_ready(new_request.prefill_result)
559+
self._transfer_prefill_result(new_request, target_idx)
542560
# Place the request on the correct generate backlog and block if full.
543561
self._generate_backlogs[target_idx].put(new_request, block=True)
544562
logging.info(

jetstream/core/server_lib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def run(
146146
interleaved_mode=interleaved_mode,
147147
jax_padding=jax_padding,
148148
metrics_collector=metrics_collector,
149+
is_ray_backend=config.is_ray_backend,
149150
)
150151
# We default threads to the total number of concurrent allowed decodes,
151152
# to make sure we can fully saturate the model. Set default minimum to 64.

0 commit comments

Comments
 (0)