@@ -223,6 +223,7 @@ def __init__(
223223 interleaved_mode : bool = False ,
224224 jax_padding : bool = True ,
225225 metrics_collector : JetstreamMetricsCollector | None = None ,
226+ is_ray_backend : bool = False ,
226227 ):
227228 if prefill_engines is None :
228229 prefill_engines = []
@@ -374,6 +375,7 @@ def __init__(
374375 )
375376 )
376377 self .live = True
378+ self ._is_ray_backend = is_ray_backend
377379 # Start all threads
378380 for t in self ._all_threads :
379381 t .start ()
@@ -508,6 +510,29 @@ def _prefill_thread(self, idx: int):
508510 del prefill_result
509511 del request
510512
513+ def _jax_transfer_prefill_result (
514+ self , new_request : ActiveRequest , target_idx : int
515+ ):
516+ new_request .prefill_result = jax .device_put (
517+ new_request .prefill_result ,
518+ self ._generate_engines [target_idx ].get_prefix_destination_sharding (),
519+ )
520+ # Block here so we don't block on the generate thread that steps.
521+ jax .block_until_ready (new_request .prefill_result )
522+
523+ def _ray_transfer_prefill_result (
524+ self , new_request : ActiveRequest , target_idx : int
525+ ):
526+ self ._generate_engines [target_idx ].transfer (new_request .prefill_result )
527+
528+ def _transfer_prefill_result (
529+ self , new_request : ActiveRequest , target_idx : int
530+ ):
531+ if self ._is_ray_backend :
532+ self ._ray_transfer_prefill_result (new_request , target_idx )
533+ else :
534+ self ._jax_transfer_prefill_result (new_request , target_idx )
535+
511536 def _transfer_thread (self , idx : int ):
512537 """Transfers the kv cache on an active request to the least full
513538 generate backlog."""
@@ -531,14 +556,7 @@ def _transfer_thread(self, idx: int):
531556 target_idx ,
532557 )
533558 # Transfer the info to the relevant generate slice.
534- new_request .prefill_result = jax .device_put (
535- new_request .prefill_result ,
536- self ._generate_engines [
537- target_idx
538- ].get_prefix_destination_sharding (),
539- )
540- # Block here so we don't block on the generate thread that steps.
541- jax .block_until_ready (new_request .prefill_result )
559+ self ._transfer_prefill_result (new_request , target_idx )
542560 # Place the request on the correct generate backlog and block if full.
543561 self ._generate_backlogs [target_idx ].put (new_request , block = True )
544562 logging .info (
0 commit comments