8585import threading
8686import time
8787import traceback
88- from typing import Any , AsyncIterator , Optional , Union
88+ from typing import Any , AsyncIterator , Optional , Tuple , Union , cast
8989
9090import grpc
9191import jax
9292from jetstream .core .proto import jetstream_pb2
9393from jetstream .core .proto import jetstream_pb2_grpc
9494from jetstream .core .utils import async_multifuture
95- from jetstream .engine import engine_api
96-
95+ from jetstream .core . utils . return_sample import ReturnSample
96+ from jetstream . engine import engine_api , tokenizer_api , token_utils
9797import numpy as np
9898import prometheus_client
99+ import shortuuid
99100
100101root = logging .getLogger ()
101102root .setLevel (logging .DEBUG )
@@ -127,27 +128,28 @@ class ActiveRequest:
127128 # We keep prefill and decode information together in the same object so that
128129 # there is less indirection about where this return channel is.
129130 # The return channel returns a list of strings, one per sample for that query.
130- return_channel : async_multifuture .AsyncMultifuture [list [list [ int ] ]]
131+ return_channel : async_multifuture .AsyncMultifuture [list [ReturnSample ]]
131132 # [num_samples,] which corresponds to whether each sample is complete for the
132133 # requests.
133134 complete : Optional [np .ndarray ] = None
134135 prefill_result : Any = None
135136 #################### Information relevant for prefill ########################
136137 history_path : Optional [str ] = None
137- prefill_text : Optional [str ] = None
138+ prefill_content : Optional [str | list [ int ] ] = None
138139 ################## Information relevant for detokenization ###################
139140 # Which generate step this was added at.
140141 generate_timestep_added : Optional [int ] = None
142+ is_client_side_tokenization : Optional [bool ] = False
141143
142- def enqueue_tokens (self , generated_tokens : list [list [ int ] ]):
143- """Records information about the step.
144+ def enqueue_samples (self , generated_samples : list [ReturnSample ]):
145+ """Adds the generated sample(s) to return channel for current step.
144146
145147 Args:
146- generated_tokens: One token to put into the return channel
148+ generated_samples: The generated sample(s) for current step.
147149
148150 This should be called only from within the Drivers background thread.
149151 """
150- self .return_channel .add_result (generated_tokens )
152+ self .return_channel .add_result (generated_samples )
151153
152154
153155class JetThread (threading .Thread ):
@@ -247,7 +249,8 @@ def __init__(
247249 # At first, a request is placed here in order to get prefilled.
248250 self ._prefill_backlog = queue .Queue ()
249251 self ._prefill_backlog_size_metric = prometheus_client .Gauge (
250- "jetstream_prefill_backlog_size" , "Size of prefill queue"
252+ f"jetstream_prefill_backlog_size_{ shortuuid .uuid ()} " ,
253+ "Size of prefill queue" ,
251254 )
252255
253256 # Stage 2
@@ -438,6 +441,33 @@ def _load_cache_history(self, path: str) -> Union[None, Any]:
438441 else :
439442 return None
440443
444+ def _process_prefill_content (
445+ self ,
446+ request : ActiveRequest ,
447+ tokenizer : tokenizer_api .Tokenizer ,
448+ is_bos : bool ,
449+ max_prefill_length : int ,
450+ ) -> Tuple [jax .Array | np .ndarray , int ]:
451+ content = request .prefill_content
452+ if isinstance (content , str ):
453+ # If it's text input, tokenize and pad the input.
454+ return tokenizer .encode (
455+ content ,
456+ is_bos = is_bos ,
457+ max_prefill_length = max_prefill_length ,
458+ jax_padding = self ._jax_padding ,
459+ )
460+ else :
461+ # If it's token input, pad the input.
462+ return token_utils .pad_tokens (
463+ content ,
464+ tokenizer .bos_id ,
465+ tokenizer .pad_id ,
466+ is_bos = is_bos ,
467+ max_prefill_length = max_prefill_length ,
468+ jax_padding = self ._jax_padding ,
469+ )
470+
441471 def _prefill_thread (self , idx : int ):
442472 """Thread which runs in the background performing prefills."""
443473 logging .info ("---------Spinning up prefill thread %d.---------" , idx )
@@ -455,7 +485,6 @@ def _prefill_thread(self, idx: int):
455485
456486 if request is None :
457487 break
458- # Tokenize, and introduce a leading dimension
459488 is_bos = not bool (request .history_path )
460489 logging .info (
461490 "Prefilling on prefill engine %d : prefill queue size, %d,"
@@ -465,13 +494,11 @@ def _prefill_thread(self, idx: int):
465494 is_bos ,
466495 request .history_path ,
467496 )
468- padded_tokens , true_length = tokenizer .encode (
469- request .prefill_text ,
470- is_bos = is_bos ,
471- max_prefill_length = prefill_engine .max_prefill_length ,
472- jax_padding = self ._jax_padding ,
497+ # Tokenize and padding the text or token input.
498+ padded_tokens , true_length = self ._process_prefill_content (
499+ request , tokenizer , is_bos , prefill_engine .max_prefill_length
473500 )
474- # Compute new kv cache for the prefill_text .
501+ # Compute new kv cache for the prefill_content .
475502 prefill_result = prefill_engine .prefill (
476503 params = prefill_params ,
477504 padded_tokens = padded_tokens ,
@@ -497,6 +524,8 @@ def _transfer_thread(self, idx: int):
497524 while self .live :
498525 # The transfer thread can just sleep until it has work to do.
499526 new_request = transfer_backlog .get (block = True )
527+ if new_request is None :
528+ break
500529 target_idx = min (
501530 self ._generate_backlogs .items (), key = lambda q : q [1 ].qsize ()
502531 )[0 ]
@@ -665,15 +694,17 @@ def _detokenize_thread(self, idx: int):
665694
666695 for slot , request in my_live_requests .items ():
667696 if request is not None :
668- results , complete = tokenizer .decode (
697+ results , complete = token_utils .process_result_tokens (
698+ tokenizer = tokenizer ,
669699 slot = slot ,
670700 slot_max_length = request .max_tokens ,
671701 result_tokens = result_tokens ,
702+ is_client_side_tokenization = request .is_client_side_tokenization ,
672703 complete = request .complete ,
673704 )
674705 request .complete = complete
675- # Return some tokens .
676- request .enqueue_tokens (results )
706+ # Return some output samples .
707+ request .enqueue_samples (results )
677708 if request .complete .all ():
678709 request .return_channel .close ()
679710 # Place the slot back on the free queue.
@@ -698,6 +729,21 @@ class LLMOrchestrator(jetstream_pb2_grpc.OrchestratorServicer):
698729 def __init__ (self , driver : Driver ):
699730 self ._driver = driver
700731
732+ def _get_prefill_content (
733+ self , request : jetstream_pb2 .DecodeRequest
734+ ) -> Tuple [str | list [int ], bool ]:
735+ which_content = request .WhichOneof ("content" )
736+ content = getattr (request , which_content )
737+ if which_content == "text_content" :
738+ return cast (jetstream_pb2 .DecodeRequest .TextContent , content ).text , False
739+ else :
740+ return (
741+ list (
742+ cast (jetstream_pb2 .DecodeRequest .TokenContent , content ).token_ids
743+ ),
744+ True ,
745+ )
746+
701747 async def Decode ( # pylint: disable=invalid-overridden-method
702748 self ,
703749 request : jetstream_pb2 .DecodeRequest ,
@@ -709,14 +755,19 @@ async def Decode( # pylint: disable=invalid-overridden-method
709755 "LLM orchestrator is being used in offline test mode, and will not"
710756 " respond to gRPC queries - only direct function calls."
711757 )
758+ is_client_side_tokenization = False
712759 return_channel = async_multifuture .AsyncMultifuture ()
713760 if context :
714761 context .add_done_callback (return_channel .cancel )
762+ prefill_content , is_client_side_tokenization = self ._get_prefill_content (
763+ request
764+ )
715765 # Wrap request as an ActiveRequest.
716766 active_request = ActiveRequest (
717767 max_tokens = request .max_tokens ,
718768 history_path = request .session_cache ,
719- prefill_text = request .additional_text ,
769+ prefill_content = prefill_content ,
770+ is_client_side_tokenization = is_client_side_tokenization ,
720771 return_channel = return_channel ,
721772 )
722773 # The first stage is being prefilled, all other stages are handled
@@ -736,18 +787,78 @@ async def Decode( # pylint: disable=invalid-overridden-method
736787 logging .info (
737788 "Placed request on the prefill queue." ,
738789 )
739- async for response in active_request .return_channel :
740- # When an active request is created a queue is instantiated. New tokens
741- # are placed there during the decoding loop, we pop from that queue by
742- # using the .next method on the active request.
743- # Yielding allows for the response to be a streaming grpc call - which
744- # can be called via iterating over a for loop on the other side.
745- # The DecodeResponse stream should consume all generated tokens in
746- # return_channel when complete signal is received. It should check if
747- # return_channel is empty to decide if it should exit the while loop.
748- repeated_token_ids = []
749- for token_ids in response :
750- repeated_token_ids .append (
751- jetstream_pb2 .RepeatedTokenIds (token_ids = token_ids )
790+ # When an active request is created a queue is instantiated. New tokens
791+ # are placed there during the decoding loop, we pop from that queue by
792+ # using the .next method on the active request.
793+ # Yielding allows for the response to be a streaming grpc call - which
794+ # can be called via iterating over a for loop on the client side.
795+ # The DecodeResponse stream should consume all generated tokens in
796+ # return_channel when complete signal is received (AsyncMultifuture
797+ # promises this).
798+ if is_client_side_tokenization :
799+ # If is_client_side_tokenization, the client should request with token
800+ # ids, and the JetStream server will return token ids as response.
801+ # The client should take care of tokenization and detokenization.
802+ async for response in active_request .return_channel :
803+ response = cast (list [ReturnSample ], response )
804+ samples = []
805+ for sample in response :
806+ samples .append (
807+ jetstream_pb2 .DecodeResponse .StreamContent .Sample (
808+ token_ids = sample .token_ids ,
809+ )
810+ )
811+ yield jetstream_pb2 .DecodeResponse (
812+ stream_content = jetstream_pb2 .DecodeResponse .StreamContent (
813+ samples = samples
814+ )
815+ )
816+ else :
817+ # Buffer response mechanism is used to handle streaming
818+ # detokenization with special character (For some edge cases with
819+ # SentencePiece tokenizer, it requires to decode a complete sequence
820+ # instead of a single token).
821+ buffered_response_list = []
822+ async for response in active_request .return_channel :
823+ response = cast (list [ReturnSample ], response )
824+ buffered = False
825+ for item in response :
826+ if item .text and token_utils .is_byte_token (item .text [- 1 ]):
827+ # If any sample ends in bytes, this means we might still need to
828+ # decode more bytes to compose the string.
829+ buffered_response_list .append (response )
830+ buffered = True
831+ break
832+ if buffered :
833+ continue
834+ # Flush the buffered responses to each sample of current response.
835+ current_response_with_flushed_buffer = list (
836+ zip (* buffered_response_list , response )
837+ )
838+ # Empty buffer: [[s0_cur], [s1_cur], ...]
839+ # Has buffer:
840+ # [[s0_b0, s0_b1, ..., s0_cur], [s1_b0, s1_b1, ..., s1_cur], ...]
841+ current_response_with_flushed_buffer = cast (
842+ list [list [ReturnSample ]], current_response_with_flushed_buffer
843+ )
844+ # Reset buffer after flushed.
845+ buffered_response_list = []
846+ # Form correct sample(s) and return as StreamContent for this iteration.
847+ samples = []
848+ for sample in current_response_with_flushed_buffer :
849+ text = []
850+ token_ids = []
851+ for resp in sample :
852+ text .extend (resp .text )
853+ token_ids .extend (resp .token_ids )
854+ samples .append (
855+ jetstream_pb2 .DecodeResponse .StreamContent .Sample (
856+ text = token_utils .text_tokens_to_str (text ),
857+ token_ids = token_ids ,
858+ )
859+ )
860+ yield jetstream_pb2 .DecodeResponse (
861+ stream_content = jetstream_pb2 .DecodeResponse .StreamContent (
862+ samples = samples
863+ )
752864 )
753- yield jetstream_pb2 .DecodeResponse (response = repeated_token_ids )
0 commit comments