8080import os
8181import queue
8282import signal
83+ import sys
8384import threading
8485import time
8586import traceback
9495import numpy as np
9596
9697
98+ root = logging .getLogger ()
99+ root .setLevel (logging .DEBUG )
100+
101+ handler = logging .StreamHandler (sys .stdout )
102+ handler .setLevel (logging .DEBUG )
103+ formatter = logging .Formatter ('%(asctime)s - %(name)s - %(levelname)s - %(message)s' )
104+ handler .setFormatter (formatter )
105+ root .addHandler (handler )
106+
107+ def delete_pytree (p ):
108+ def delete_leaf (leaf ):
109+ if isinstance (leaf , jax .Array ):
110+ leaf .delete ()
111+ del leaf
112+ jax .tree_map (delete_leaf , p )
113+
114+
97115@dataclasses .dataclass
98116class ActiveRequest :
99117 """Current state of the driver."""
@@ -169,14 +187,12 @@ class Driver:
169187 # Stage 1
170188 _prefill_backlog : queue .Queue [ActiveRequest ]
171189 # Stage 2
172- _transfer_backlog : queue .Queue [ActiveRequest ]
173- # Stage 3
174190 # We keep this as a dict to avoid a possibly expensive object comparison
175191 # when logging the index of the generate engine we send a prefill result
176192 # to, it allows us to natively have the index from the min operation, rather
177193 # than have to call .index()
178194 _generate_backlogs : dict [int , queue .Queue [ActiveRequest ]] = {}
179- # Stage 4
195+ # Stage 3
180196 # This can be a list because we can pass it as an arg to generate and
181197 # detokenize threads. It is a list of tokens to be detokenized.
182198 _detokenize_backlogs : list [queue .Queue [engine_api .ResultTokens ]] = []
@@ -204,15 +220,11 @@ def __init__(
204220 # At first, a request is placed here in order to get prefilled.
205221 self ._prefill_backlog = queue .Queue ()
206222 # Stage 2
207- # After prefilling, it is placed here in order to get transferred to
208- # one of the generate backlogs.
209- self ._transfer_backlog = queue .Queue ()
210- # Stage 3
211223 # Each generate engine accesses its own generate backlog.
212224 self ._generate_backlogs = {
213225 idx : queue .Queue () for idx , _ in enumerate (generate_engines )
214226 }
215- # Stage 4
227+ # Stage 3
216228 # After generation, ActiveRequests are placed on the detokenization backlog
217229 # for tokens to be sent into each ActiveRequest's return channel.
218230 # We have one of these per generate engine to simplify the logic keeping
@@ -257,7 +269,6 @@ def __init__(
257269 )
258270 for idx , engine in enumerate (self ._prefill_engines )
259271 ]
260- self ._transfer_thread = JetThread (target = self ._transfer_thread )
261272 self ._generate_threads = [
262273 JetThread (
263274 target = functools .partial (
@@ -288,7 +299,6 @@ def __init__(
288299 self .live = True
289300 # Kick off all threads
290301 _ = [f .start () for f in self ._prefill_threads ]
291- self ._transfer_thread .start ()
292302 _ = [f .start () for f in self ._generate_threads ]
293303 _ = [f .start () for f in self .detokenize_threads ]
294304
@@ -316,7 +326,7 @@ def _prefill_thread(
316326 self ,
317327 idx : int ,
318328 prefill_engine : engine_api .Engine ,
319- transfer_backpressure : int = 8 ,
329+ generate_backpressure : int = 3 ,
320330 ):
321331 """Thread which runs in the background performing prefills."""
322332 logging .info ('---------Spinning up prefill thread %d.---------' , idx )
@@ -328,7 +338,7 @@ def _prefill_thread(
328338 while self .live :
329339 # We don't want to keep lots of kv caches live in memory on the prefill
330340 # slice that aren't about to be sent over to a generation slice.
331- if self ._transfer_backlog .qsize () < transfer_backpressure :
341+ if ( self ._generate_backlogs [ idx ] .qsize () < generate_backpressure ) :
332342 # Check if there is anything on the prefill backlog, pop if so.
333343 try :
334344 request = self ._prefill_backlog .get (block = True )
@@ -337,10 +347,9 @@ def _prefill_thread(
337347 # Tokenize, and introduce a leading dimension
338348 is_bos = not bool (request .history_path )
339349 logging .info (
340- 'Prefilling on prefill engine %d : "%s", prefill queue size, %d,'
350+ 'Prefilling on prefill engine %d : prefill queue size, %d,'
341351 ' is_bos: %s, history: %s' ,
342352 idx ,
343- request .prefill_text ,
344353 self ._prefill_backlog .qsize (),
345354 is_bos ,
346355 request .history_path ,
@@ -359,53 +368,16 @@ def _prefill_thread(
359368 padded_tokens = padded_tokens ,
360369 true_length = true_length ,
361370 )
362- jax .block_until_ready (prefill_result )
363371 request .prefill_result = prefill_result
364372 # Once prefill is complete, place it on the generation queue.
365- self ._transfer_backlog .put (request )
373+ self ._generate_backlogs [ idx ] .put (request )
366374 logging .info (
367- 'Placed request "%s" on the transfer queue.' , request . prefill_text
375+ f 'Placed request on the generate queue, { self . _generate_backlogs [ idx ]. qsize () = } '
368376 )
369377 except queue .Empty :
370378 # Otherwise, don't do anything!
371379 pass
372380
373- def _transfer_thread (self , generation_backpressure : int = 8 ):
374- """Transfers the kv cache on an active request to the least full generate backlog."""
375- while self .live :
376- # We don't want to keep lots of kv caches live in memory on a generate
377- # slice that haven't been inserted
378- if (
379- sum ([backlog .qsize () for backlog in self ._generate_backlogs .values ()])
380- < generation_backpressure
381- ):
382- try :
383- new_request = self ._transfer_backlog .get (block = True )
384-
385- # Get the index of the generate queue with the minimmum qsize.
386- idx = min (
387- self ._generate_backlogs .items (), key = lambda q : q [1 ].qsize ()
388- )[0 ]
389- logging .info (
390- 'Transferring "%s" to generate engine %d.' ,
391- new_request .prefill_text ,
392- idx ,
393- )
394- # Transfer the info to the relevant generate backlog.
395- new_request .prefill_result = jax .device_put (
396- new_request .prefill_result ,
397- self ._generate_engines [idx ].get_prefix_destination_sharding (),
398- )
399- # Place the request on the correct generate backlog.
400- self ._generate_backlogs [idx ].put (new_request )
401- logging .info (
402- 'Request "%s" tsuccessfully transferrred to generate engine %d.' ,
403- new_request .prefill_text ,
404- idx ,
405- )
406- except queue .Empty :
407- pass
408-
409381 def _generate_thread (
410382 self ,
411383 idx : int ,
@@ -434,22 +406,28 @@ def _generate_thread(
434406 generate_params = self ._generate_params [idx ]
435407 logging .info ('---------Generate params %d loaded.---------' , idx )
436408 time_of_last_generate = time .time ()
409+ time_of_last_print = time .time ()
437410 while self .live :
411+ if (time .time () - time_of_last_print ) > 1 :
412+ logging .info (
413+ f'Generate thread making a decision with: prefill_backlog={ self ._prefill_backlog .qsize ()} generate_free_slots={ my_slots .qsize ()} '
414+ )
415+ time_of_last_print = time .time ()
438416 # Check if there are any free my_slots.
439417 if not my_slots .empty () and not self ._generate_backlogs [idx ].empty ():
440418 # Only get requests from the backlog corresponding to this engine.
441419 new_request = self ._generate_backlogs [idx ].get ()
442420 slot = my_slots .get ()
443421 logging .info (
444- 'Generate slice %d slot %d step %d, generating for : "%s" ' ,
422+ 'Generate slice %d slot %d step %d' ,
445423 idx ,
446424 slot ,
447425 generate_timestep ,
448- new_request .prefill_text ,
449426 )
450427 decode_state = generate_engine .insert (
451428 new_request .prefill_result , decode_state , slot = slot
452429 )
430+ delete_pytree (new_request .prefill_result )
453431 new_request .generate_timestep_added = generate_timestep
454432 new_request .complete = np .zeros (
455433 (generate_engine .samples_per_slot ,), dtype = np .bool_
@@ -576,8 +554,7 @@ def Decode(
576554 ),
577555 )
578556 logging .info (
579- 'Placed request with text "%s" on the prefill queue.' ,
580- active_request .prefill_text ,
557+ 'Placed request on the prefill queue.' ,
581558 )
582559
583560 while True :
0 commit comments