Skip to content

Commit a3546e8

Browse files
zhihaoshan-googleZhihao Shan
andauthored
Prerequisite work for supporting disaggregation: (#68)
1. Add transfer thread to transfer KV Cache. 2. For interleaved mode, prioritize prefill and improve the HBM utilization. Co-authored-by: Zhihao Shan <zhihaoshan@google.com>
1 parent 2db6c14 commit a3546e8

2 files changed

Lines changed: 108 additions & 35 deletions

File tree

jetstream/core/orchestrator.py

Lines changed: 101 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -188,17 +188,23 @@ class Driver:
188188
# Stage 1
189189
_prefill_backlog: queue.Queue[ActiveRequest | None]
190190
# Stage 2
191+
_transfer_backlogs: list[queue.Queue[ActiveRequest]] = []
192+
# Stage 3
191193
# We keep this as a dict to avoid a possibly expensive object comparison
192194
# when logging the index of the generate engine we send a prefill result
193195
# to, it allows us to natively have the index from the min operation, rather
194196
# than have to call .index()
195-
_generate_backlogs: dict[int, queue.Queue[ActiveRequest | None]] = {}
196-
# Stage 3
197+
_generate_backlogs: dict[int, queue.Queue[ActiveRequest]] = {}
198+
# Stage 4
197199
# This can be a list because we can pass it as an arg to generate and
198200
# detokenize threads. It is a list of tokens to be detokenized.
199201
_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
200202
_generate_slots: list[queue.Queue[int]] = []
201-
_active_requests: list[queue.Queue[tuple[int, ActiveRequest | None]]] = []
203+
_active_requests: list[queue.Queue[tuple[int, ActiveRequest]]] = []
204+
205+
# For interleaved_mode, only generate if all slots are full
206+
# or corresponding prefill queue is empty.
207+
_interleaved_mode: bool = False
202208

203209
# todo: remove jax_padding after all then engine migrate to np padding
204210
_jax_padding = True
@@ -209,6 +215,7 @@ def __init__(
209215
generate_engines: Optional[list[engine_api.Engine]] = None,
210216
prefill_params: Optional[list[Any]] = None,
211217
generate_params: Optional[list[Any]] = None,
218+
interleaved_mode: bool = False,
212219
jax_padding: bool = True,
213220
):
214221
if prefill_engines is None:
@@ -229,22 +236,39 @@ def __init__(
229236
self._generate_engines = generate_engines
230237
self._prefill_params = prefill_params
231238
self._generate_params = generate_params
239+
self._interleaved_mode = interleaved_mode
240+
232241
# Stages 1-4 represent the life cycle of a request.
233242
# Stage 1
234243
# At first, a request is placed here in order to get prefilled.
235244
self._prefill_backlog = queue.Queue()
236-
# _ready_to_prefill event will block the prefill thread until there is
237-
# available decode slot to insert the prefill result.
238-
self._ready_to_prefill = threading.Event()
239245
# Stage 2
246+
# After prefilling, it is placed here in order to get transferred to
247+
# one of the generate backlogs.
248+
# Interleaved Mode: Max size is 1 to increase the HBM utilization
249+
# during generate.
250+
# Disaggregated Mode: Max size is 4 to allow for 2 prefills to be enqueued
251+
# while 1 transfer is enqueued while 1 is being transferred.
252+
# TODO: Make queue size configurable.
253+
self._transfer_backlogs = [
254+
queue.Queue(1 if self._interleaved_mode else 4)
255+
for i in range(len(self._prefill_engines))
256+
]
257+
# Stage 3
240258
# Each generate engine accesses its own generate backlog.
259+
# Interleaved Mode: Max size is 1 to increase the HBM utilization
260+
# during generate.
261+
# Disaggregated Mode: Set as 1/3 the number of concurrent decodes.
262+
# TODO: Calculate the backlog to saturate the generate engine while
263+
# minimizing the memory usage for disaggregated mode.
264+
# TODO: Make queue size configurable.
241265
self._generate_backlogs = {
242-
# Don't receive more than 1/3 the number of concurrent decodes to avoid
243-
# OOM for single host.
244-
idx: queue.Queue(engine.max_concurrent_decodes // 3)
266+
idx: queue.Queue(
267+
1 if self._interleaved_mode else engine.max_concurrent_decodes // 3
268+
)
245269
for idx, engine in enumerate(self._generate_engines)
246270
}
247-
# Stage 3
271+
# Stage 4
248272
# After generation, ActiveRequests are placed on the detokenization backlog
249273
# for tokens to be sent into each ActiveRequest's return channel.
250274
# We have one of these per generate engine to simplify the logic keeping
@@ -293,6 +317,18 @@ def __init__(
293317
JetThread(
294318
target=functools.partial(self._prefill_thread, idx),
295319
name=f"prefill-{idx}",
320+
daemon=True,
321+
)
322+
for idx in range(len(self._prefill_engines))
323+
]
324+
self._transfer_threads = [
325+
JetThread(
326+
target=functools.partial(
327+
self._transfer_thread,
328+
idx,
329+
),
330+
name=f"transfer-{idx}",
331+
daemon=True,
296332
)
297333
for idx in range(len(self._prefill_engines))
298334
]
@@ -303,6 +339,7 @@ def __init__(
303339
idx,
304340
),
305341
name=f"generate-{idx}",
342+
daemon=True,
306343
)
307344
for idx in range(len(self._generate_engines))
308345
]
@@ -319,6 +356,7 @@ def __init__(
319356
self._all_threads = list(
320357
itertools.chain(
321358
self._prefill_threads,
359+
self._transfer_threads,
322360
self._generate_threads,
323361
self.detokenize_threads,
324362
)
@@ -336,6 +374,7 @@ def stop(self):
336374
all_backlogs = list(
337375
itertools.chain(
338376
[self._prefill_backlog],
377+
self._transfer_backlogs,
339378
self._generate_backlogs.values(),
340379
self._detokenize_backlogs,
341380
)
@@ -400,24 +439,11 @@ def _prefill_thread(self, idx: int):
400439
logging.info("---------Prefill params %d loaded.---------", idx)
401440

402441
while self.live:
403-
# The prefill thread can wait until there is available decode slot to
404-
# insert.
405-
if self._generate_slots[idx].qsize() == 0:
406-
logging.info(
407-
"Prefill waits for available slot; prefill queue size %d",
408-
self._prefill_backlog.qsize(),
409-
)
410-
self._ready_to_prefill.wait()
411-
logging.info(
412-
"Prefill continues; prefill queue size %d",
413-
self._prefill_backlog.qsize(),
414-
)
442+
my_transfer_backlog = self._transfer_backlogs[idx]
415443
# The prefill thread can just sleep until it has work to do.
416444
request = self._prefill_backlog.get(block=True)
417445
if request is None:
418446
break
419-
# TODO: Implement hot/cold cache for history.
420-
history = self._load_cache_history(request.history_path) # pylint: disable = assignment-from-none
421447
# Tokenize, and introduce a leading dimension
422448
is_bos = not bool(request.history_path)
423449
logging.info(
@@ -434,21 +460,60 @@ def _prefill_thread(self, idx: int):
434460
max_prefill_length=prefill_engine.max_prefill_length,
435461
jax_padding=self._jax_padding,
436462
)
437-
# Compute new kv cache for the prefill_text, conditional on
438-
# history.
463+
# Compute new kv cache for the prefill_text.
439464
prefill_result = prefill_engine.prefill(
440465
params=prefill_params,
441-
existing_prefix=history,
442466
padded_tokens=padded_tokens,
443467
true_length=true_length,
444468
)
445469
request.prefill_result = prefill_result
446470
# Once prefill is complete, place it on the generation queue and block if
447471
# full.
448-
self._generate_backlogs[idx].put(request, block=True)
472+
my_transfer_backlog.put(request, block=True)
473+
logging.info(
474+
"Placed request on transfer queue %d, %d queued requests.",
475+
idx,
476+
my_transfer_backlog.qsize(),
477+
)
478+
del prefill_result
479+
del request
480+
481+
def _transfer_thread(self, idx: int):
482+
"""Transfers the kv cache on an active request to the least full
483+
generate backlog."""
484+
transfer_backlog = self._transfer_backlogs[idx]
485+
486+
while self.live:
487+
# The transfer thread can just sleep until it has work to do.
488+
new_request = transfer_backlog.get(block=True)
489+
target_idx = min(
490+
self._generate_backlogs.items(), key=lambda q: q[1].qsize()
491+
)[0]
492+
# Only transfer the KVCache for the disaggregated serving.
493+
# TODO: Remove the conditional after fixing the compatibility.
494+
if not self._interleaved_mode:
495+
logging.info(
496+
"Transferring prefill from prefill engine %d "
497+
"to generate engine %d.",
498+
idx,
499+
target_idx,
500+
)
501+
# Transfer the info to the relevant generate slice.
502+
new_request.prefill_result = jax.device_put(
503+
new_request.prefill_result,
504+
self._generate_engines[
505+
target_idx
506+
].get_prefix_destination_sharding(),
507+
)
508+
# Block here so we don't block on the generate thread that steps.
509+
jax.block_until_ready(new_request.prefill_result)
510+
# Place the request on the correct generate backlog and block if full.
511+
self._generate_backlogs[target_idx].put(new_request, block=True)
449512
logging.info(
450-
"Placed request on the generate queue, generate_backlogs=%d",
451-
self._generate_backlogs[idx].qsize(),
513+
"Successfully transferred prefill "
514+
"from prefill engine %d to generate engine %d.",
515+
idx,
516+
target_idx,
452517
)
453518

454519
def _generate_thread(self, idx: int):
@@ -463,6 +528,7 @@ def _generate_thread(self, idx: int):
463528
generate_timestep = 0
464529
# State to store things like running kv cache in.
465530
decode_state = generate_engine.init_decode_state()
531+
466532
generate_params = self._generate_params[idx]
467533
logging.info("---------Generate params %d loaded.---------", idx)
468534
time_of_last_generate = time.time()
@@ -480,7 +546,6 @@ def _generate_thread(self, idx: int):
480546

481547
max_concurrent_decodes = generate_engine.max_concurrent_decodes
482548

483-
# TODO: Move insert to prefill thread.
484549
# Check if there are any free my_slots. We don't want to block here since
485550
# we can still generate if we can't insert. We do this in a while loop to
486551
# insert as many sequences as possible.
@@ -499,6 +564,11 @@ def _generate_thread(self, idx: int):
499564
# the case when the prefill backlog is cancelled and we end up with no
500565
# more useful prefill work to do.
501566
block = my_slots_size == max_concurrent_decodes
567+
if self._interleaved_mode:
568+
# For interleaved mode, we also blocks when prefill backlog
569+
# is not empty or there are transfer work to do.
570+
block |= not self._prefill_backlog.empty()
571+
block |= not self._transfer_backlogs[idx].empty()
502572
try:
503573
new_request = my_generate_backlog.get(block=block, timeout=1.0)
504574
# Got free slot and new request, use them.
@@ -598,7 +668,6 @@ def _detokenize_thread(self, idx: int):
598668
# Place the slot back on the free queue.
599669
my_live_requests[slot] = None
600670
my_slots.put(slot, block=False) # This should always have space.
601-
self._ready_to_prefill.set()
602671
logging.info(
603672
"Detokenizing generate step %d took %.2fms",
604673
generate_timestep_added,

jetstream/core/server_lib.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,15 @@ def run(
112112
generate_params = [ge.load_params() for ge in engines.generate_engines]
113113
shared_params = [ie.load_params() for ie in engines.interleaved_engines]
114114
logging.info("Loaded all weights.")
115+
interleaved_mode = (
116+
len(config.prefill_slices) + len(config.generate_slices) == 0
117+
)
115118
driver = orchestrator.Driver(
116119
prefill_engines=engines.prefill_engines + engines.interleaved_engines,
117120
generate_engines=engines.generate_engines + engines.interleaved_engines,
118121
prefill_params=prefill_params + shared_params,
119122
generate_params=generate_params + shared_params,
123+
interleaved_mode=interleaved_mode,
120124
jax_padding=jax_padding,
121125
)
122126
# We default threads to the total number of concurrent allowed decodes,
@@ -130,8 +134,8 @@ def run(
130134

131135

132136
def get_devices() -> Any:
133-
"""Gets devices locally."""
134-
# Run interleaved engine on local device.
137+
"""Gets devices."""
138+
# TODO: Add more logs for the devices.
135139
devices = jax.devices()
136-
logging.info("Using local devices for interleaved serving: %d", len(devices))
140+
logging.info("Using devices: %d", len(devices))
137141
return devices

0 commit comments

Comments
 (0)