@@ -188,17 +188,23 @@ class Driver:
188188 # Stage 1
189189 _prefill_backlog : queue .Queue [ActiveRequest | None ]
190190 # Stage 2
191+ _transfer_backlogs : list [queue .Queue [ActiveRequest ]] = []
192+ # Stage 3
191193 # We keep this as a dict to avoid a possibly expensive object comparison
192194 # when logging the index of the generate engine we send a prefill result
193195 # to, it allows us to natively have the index from the min operation, rather
194196 # than have to call .index()
195- _generate_backlogs : dict [int , queue .Queue [ActiveRequest | None ]] = {}
196- # Stage 3
197+ _generate_backlogs : dict [int , queue .Queue [ActiveRequest ]] = {}
198+ # Stage 4
197199 # This can be a list because we can pass it as an arg to generate and
198200 # detokenize threads. It is a list of tokens to be detokenized.
199201 _detokenize_backlogs : list [queue .Queue [engine_api .ResultTokens ]] = []
200202 _generate_slots : list [queue .Queue [int ]] = []
201- _active_requests : list [queue .Queue [tuple [int , ActiveRequest | None ]]] = []
203+ _active_requests : list [queue .Queue [tuple [int , ActiveRequest ]]] = []
204+
205+ # For interleaved_mode, only generate if all slots are full
206+ # or corresponding prefill queue is empty.
207+ _interleaved_mode : bool = False
202208
203209 # todo: remove jax_padding after all then engine migrate to np padding
204210 _jax_padding = True
@@ -209,6 +215,7 @@ def __init__(
209215 generate_engines : Optional [list [engine_api .Engine ]] = None ,
210216 prefill_params : Optional [list [Any ]] = None ,
211217 generate_params : Optional [list [Any ]] = None ,
218+ interleaved_mode : bool = False ,
212219 jax_padding : bool = True ,
213220 ):
214221 if prefill_engines is None :
@@ -229,22 +236,39 @@ def __init__(
229236 self ._generate_engines = generate_engines
230237 self ._prefill_params = prefill_params
231238 self ._generate_params = generate_params
239+ self ._interleaved_mode = interleaved_mode
240+
232241 # Stages 1-4 represent the life cycle of a request.
233242 # Stage 1
234243 # At first, a request is placed here in order to get prefilled.
235244 self ._prefill_backlog = queue .Queue ()
236- # _ready_to_prefill event will block the prefill thread until there is
237- # available decode slot to insert the prefill result.
238- self ._ready_to_prefill = threading .Event ()
239245 # Stage 2
246+ # After prefilling, it is placed here in order to get transferred to
247+ # one of the generate backlogs.
248+ # Interleaved Mode: Max size is 1 to increase the HBM utilization
249+ # during generate.
250+ # Disaggregated Mode: Max size is 4 to allow for 2 prefills to be enqueued
251+ # while 1 transfer is enqueued while 1 is being transferred.
252+ # TODO: Make queue size configurable.
253+ self ._transfer_backlogs = [
254+ queue .Queue (1 if self ._interleaved_mode else 4 )
255+ for i in range (len (self ._prefill_engines ))
256+ ]
257+ # Stage 3
240258 # Each generate engine accesses its own generate backlog.
259+ # Interleaved Mode: Max size is 1 to increase the HBM utilization
260+ # during generate.
261+ # Disaggregated Mode: Set as 1/3 the number of concurrent decodes.
262+ # TODO: Calculate the backlog to saturate the generate engine while
263+ # minimizing the memory usage for disaggregated mode.
264+ # TODO: Make queue size configurable.
241265 self ._generate_backlogs = {
242- # Don't receive more than 1/3 the number of concurrent decodes to avoid
243- # OOM for single host.
244- idx : queue . Queue ( engine . max_concurrent_decodes // 3 )
266+ idx : queue . Queue (
267+ 1 if self . _interleaved_mode else engine . max_concurrent_decodes // 3
268+ )
245269 for idx , engine in enumerate (self ._generate_engines )
246270 }
247- # Stage 3
271+ # Stage 4
248272 # After generation, ActiveRequests are placed on the detokenization backlog
249273 # for tokens to be sent into each ActiveRequest's return channel.
250274 # We have one of these per generate engine to simplify the logic keeping
@@ -293,6 +317,18 @@ def __init__(
293317 JetThread (
294318 target = functools .partial (self ._prefill_thread , idx ),
295319 name = f"prefill-{ idx } " ,
320+ daemon = True ,
321+ )
322+ for idx in range (len (self ._prefill_engines ))
323+ ]
324+ self ._transfer_threads = [
325+ JetThread (
326+ target = functools .partial (
327+ self ._transfer_thread ,
328+ idx ,
329+ ),
330+ name = f"transfer-{ idx } " ,
331+ daemon = True ,
296332 )
297333 for idx in range (len (self ._prefill_engines ))
298334 ]
@@ -303,6 +339,7 @@ def __init__(
303339 idx ,
304340 ),
305341 name = f"generate-{ idx } " ,
342+ daemon = True ,
306343 )
307344 for idx in range (len (self ._generate_engines ))
308345 ]
@@ -319,6 +356,7 @@ def __init__(
319356 self ._all_threads = list (
320357 itertools .chain (
321358 self ._prefill_threads ,
359+ self ._transfer_threads ,
322360 self ._generate_threads ,
323361 self .detokenize_threads ,
324362 )
@@ -336,6 +374,7 @@ def stop(self):
336374 all_backlogs = list (
337375 itertools .chain (
338376 [self ._prefill_backlog ],
377+ self ._transfer_backlogs ,
339378 self ._generate_backlogs .values (),
340379 self ._detokenize_backlogs ,
341380 )
@@ -400,24 +439,11 @@ def _prefill_thread(self, idx: int):
400439 logging .info ("---------Prefill params %d loaded.---------" , idx )
401440
402441 while self .live :
403- # The prefill thread can wait until there is available decode slot to
404- # insert.
405- if self ._generate_slots [idx ].qsize () == 0 :
406- logging .info (
407- "Prefill waits for available slot; prefill queue size %d" ,
408- self ._prefill_backlog .qsize (),
409- )
410- self ._ready_to_prefill .wait ()
411- logging .info (
412- "Prefill continues; prefill queue size %d" ,
413- self ._prefill_backlog .qsize (),
414- )
442+ my_transfer_backlog = self ._transfer_backlogs [idx ]
415443 # The prefill thread can just sleep until it has work to do.
416444 request = self ._prefill_backlog .get (block = True )
417445 if request is None :
418446 break
419- # TODO: Implement hot/cold cache for history.
420- history = self ._load_cache_history (request .history_path ) # pylint: disable = assignment-from-none
421447 # Tokenize, and introduce a leading dimension
422448 is_bos = not bool (request .history_path )
423449 logging .info (
@@ -434,21 +460,60 @@ def _prefill_thread(self, idx: int):
434460 max_prefill_length = prefill_engine .max_prefill_length ,
435461 jax_padding = self ._jax_padding ,
436462 )
437- # Compute new kv cache for the prefill_text, conditional on
438- # history.
463+ # Compute new kv cache for the prefill_text.
439464 prefill_result = prefill_engine .prefill (
440465 params = prefill_params ,
441- existing_prefix = history ,
442466 padded_tokens = padded_tokens ,
443467 true_length = true_length ,
444468 )
445469 request .prefill_result = prefill_result
446470 # Once prefill is complete, place it on the generation queue and block if
447471 # full.
448- self ._generate_backlogs [idx ].put (request , block = True )
472+ my_transfer_backlog .put (request , block = True )
473+ logging .info (
474+ "Placed request on transfer queue %d, %d queued requests." ,
475+ idx ,
476+ my_transfer_backlog .qsize (),
477+ )
478+ del prefill_result
479+ del request
480+
481+ def _transfer_thread (self , idx : int ):
482+ """Transfers the kv cache on an active request to the least full
483+ generate backlog."""
484+ transfer_backlog = self ._transfer_backlogs [idx ]
485+
486+ while self .live :
487+ # The transfer thread can just sleep until it has work to do.
488+ new_request = transfer_backlog .get (block = True )
489+ target_idx = min (
490+ self ._generate_backlogs .items (), key = lambda q : q [1 ].qsize ()
491+ )[0 ]
492+ # Only transfer the KVCache for the disaggregated serving.
493+ # TODO: Remove the conditional after fixing the compatibility.
494+ if not self ._interleaved_mode :
495+ logging .info (
496+ "Transferring prefill from prefill engine %d "
497+ "to generate engine %d." ,
498+ idx ,
499+ target_idx ,
500+ )
501+ # Transfer the info to the relevant generate slice.
502+ new_request .prefill_result = jax .device_put (
503+ new_request .prefill_result ,
504+ self ._generate_engines [
505+ target_idx
506+ ].get_prefix_destination_sharding (),
507+ )
508+ # Block here so we don't block on the generate thread that steps.
509+ jax .block_until_ready (new_request .prefill_result )
510+ # Place the request on the correct generate backlog and block if full.
511+ self ._generate_backlogs [target_idx ].put (new_request , block = True )
449512 logging .info (
450- "Placed request on the generate queue, generate_backlogs=%d" ,
451- self ._generate_backlogs [idx ].qsize (),
513+ "Successfully transferred prefill "
514+ "from prefill engine %d to generate engine %d." ,
515+ idx ,
516+ target_idx ,
452517 )
453518
454519 def _generate_thread (self , idx : int ):
@@ -463,6 +528,7 @@ def _generate_thread(self, idx: int):
463528 generate_timestep = 0
464529 # State to store things like running kv cache in.
465530 decode_state = generate_engine .init_decode_state ()
531+
466532 generate_params = self ._generate_params [idx ]
467533 logging .info ("---------Generate params %d loaded.---------" , idx )
468534 time_of_last_generate = time .time ()
@@ -480,7 +546,6 @@ def _generate_thread(self, idx: int):
480546
481547 max_concurrent_decodes = generate_engine .max_concurrent_decodes
482548
483- # TODO: Move insert to prefill thread.
484549 # Check if there are any free my_slots. We don't want to block here since
485550 # we can still generate if we can't insert. We do this in a while loop to
486551 # insert as many sequences as possible.
@@ -499,6 +564,11 @@ def _generate_thread(self, idx: int):
499564 # the case when the prefill backlog is cancelled and we end up with no
500565 # more useful prefill work to do.
501566 block = my_slots_size == max_concurrent_decodes
567+ if self ._interleaved_mode :
568+ # For interleaved mode, we also blocks when prefill backlog
569+ # is not empty or there are transfer work to do.
570+ block |= not self ._prefill_backlog .empty ()
571+ block |= not self ._transfer_backlogs [idx ].empty ()
502572 try :
503573 new_request = my_generate_backlog .get (block = block , timeout = 1.0 )
504574 # Got free slot and new request, use them.
@@ -598,7 +668,6 @@ def _detokenize_thread(self, idx: int):
598668 # Place the slot back on the free queue.
599669 my_live_requests [slot ] = None
600670 my_slots .put (slot , block = False ) # This should always have space.
601- self ._ready_to_prefill .set ()
602671 logging .info (
603672 "Detokenizing generate step %d took %.2fms" ,
604673 generate_timestep_added ,
0 commit comments