Skip to content

Commit 01c5a03

Browse files
authored
Update JetStream grpc proto to support I/O with text and token ids (#78)
* Update JetStream grpc proto to support I/O with text and token ids * Update orchestrator and token utils to support text and token I/O * Add and update unit tests * Fix prometheus duplicate metrics issue * add shortuuid dep * Update docstring * Add client tokenization mode * Update client side I/O handling * latest pylint fix
1 parent 2f8924d commit 01c5a03

17 files changed

Lines changed: 632 additions & 245 deletions

benchmarks/benchmark_serving.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,10 +388,10 @@ async def grpc_async_request(
388388
token_list = []
389389
request_start_time = time.perf_counter()
390390
response = stub.Decode(request)
391-
async for sample_list in response:
391+
async for resp in response:
392392
if ttft == 0:
393393
ttft = time.perf_counter() - request_start_time
394-
token_list.extend(sample_list.response[0].token_ids)
394+
token_list.extend(resp.stream_content.samples[0].token_ids)
395395
latency = time.perf_counter() - request_start_time
396396
return token_list, ttft, latency
397397

@@ -405,9 +405,13 @@ async def send_request(
405405
priority: int,
406406
) -> RequestFuncOutput:
407407
"""Send the request to JetStream server."""
408+
# Tokenization on client side following MLPerf standard.
409+
token_ids = tokenizer.encode(input_request.prompt)
408410
request = jetstream_pb2.DecodeRequest(
409411
session_cache=session_cache,
410-
additional_text=input_request.prompt,
412+
token_content=jetstream_pb2.DecodeRequest.TokenContent(
413+
token_ids=token_ids
414+
),
411415
priority=priority,
412416
max_tokens=input_request.output_len,
413417
)
@@ -551,6 +555,7 @@ def main(args: argparse.Namespace):
551555
args.total_mock_requests
552556
) # e.g. [("AB", 2, "AB", 3)]
553557
else:
558+
dataset = []
554559
if args.dataset == "openorca":
555560
dataset = load_openorca_dataset_pkl()
556561
elif args.dataset == "sharegpt":

jetstream/core/orchestrator.py

Lines changed: 146 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,18 @@
8585
import threading
8686
import time
8787
import traceback
88-
from typing import Any, AsyncIterator, Optional, Union
88+
from typing import Any, AsyncIterator, Optional, Tuple, Union, cast
8989

9090
import grpc
9191
import jax
9292
from jetstream.core.proto import jetstream_pb2
9393
from jetstream.core.proto import jetstream_pb2_grpc
9494
from jetstream.core.utils import async_multifuture
95-
from jetstream.engine import engine_api
96-
95+
from jetstream.core.utils.return_sample import ReturnSample
96+
from jetstream.engine import engine_api, tokenizer_api, token_utils
9797
import numpy as np
9898
import prometheus_client
99+
import shortuuid
99100

100101
root = logging.getLogger()
101102
root.setLevel(logging.DEBUG)
@@ -127,27 +128,28 @@ class ActiveRequest:
127128
# We keep prefill and decode information together in the same object so that
128129
# there is less indirection about where this return channel is.
129130
# The return channel returns a list of strings, one per sample for that query.
130-
return_channel: async_multifuture.AsyncMultifuture[list[list[int]]]
131+
return_channel: async_multifuture.AsyncMultifuture[list[ReturnSample]]
131132
# [num_samples,] which corresponds to whether each sample is complete for the
132133
# requests.
133134
complete: Optional[np.ndarray] = None
134135
prefill_result: Any = None
135136
#################### Information relevant for prefill ########################
136137
history_path: Optional[str] = None
137-
prefill_text: Optional[str] = None
138+
prefill_content: Optional[str | list[int]] = None
138139
################## Information relevant for detokenization ###################
139140
# Which generate step this was added at.
140141
generate_timestep_added: Optional[int] = None
142+
is_client_side_tokenization: Optional[bool] = False
141143

142-
def enqueue_tokens(self, generated_tokens: list[list[int]]):
143-
"""Records information about the step.
144+
def enqueue_samples(self, generated_samples: list[ReturnSample]):
145+
"""Adds the generated sample(s) to return channel for current step.
144146
145147
Args:
146-
generated_tokens: One token to put into the return channel
148+
generated_samples: The generated sample(s) for current step.
147149
148150
This should be called only from within the Drivers background thread.
149151
"""
150-
self.return_channel.add_result(generated_tokens)
152+
self.return_channel.add_result(generated_samples)
151153

152154

153155
class JetThread(threading.Thread):
@@ -247,7 +249,8 @@ def __init__(
247249
# At first, a request is placed here in order to get prefilled.
248250
self._prefill_backlog = queue.Queue()
249251
self._prefill_backlog_size_metric = prometheus_client.Gauge(
250-
"jetstream_prefill_backlog_size", "Size of prefill queue"
252+
f"jetstream_prefill_backlog_size_{shortuuid.uuid()}",
253+
"Size of prefill queue",
251254
)
252255

253256
# Stage 2
@@ -438,6 +441,33 @@ def _load_cache_history(self, path: str) -> Union[None, Any]:
438441
else:
439442
return None
440443

444+
def _process_prefill_content(
445+
self,
446+
request: ActiveRequest,
447+
tokenizer: tokenizer_api.Tokenizer,
448+
is_bos: bool,
449+
max_prefill_length: int,
450+
) -> Tuple[jax.Array | np.ndarray, int]:
451+
content = request.prefill_content
452+
if isinstance(content, str):
453+
# If it's text input, tokenize and pad the input.
454+
return tokenizer.encode(
455+
content,
456+
is_bos=is_bos,
457+
max_prefill_length=max_prefill_length,
458+
jax_padding=self._jax_padding,
459+
)
460+
else:
461+
# If it's token input, pad the input.
462+
return token_utils.pad_tokens(
463+
content,
464+
tokenizer.bos_id,
465+
tokenizer.pad_id,
466+
is_bos=is_bos,
467+
max_prefill_length=max_prefill_length,
468+
jax_padding=self._jax_padding,
469+
)
470+
441471
def _prefill_thread(self, idx: int):
442472
"""Thread which runs in the background performing prefills."""
443473
logging.info("---------Spinning up prefill thread %d.---------", idx)
@@ -455,7 +485,6 @@ def _prefill_thread(self, idx: int):
455485

456486
if request is None:
457487
break
458-
# Tokenize, and introduce a leading dimension
459488
is_bos = not bool(request.history_path)
460489
logging.info(
461490
"Prefilling on prefill engine %d : prefill queue size, %d,"
@@ -465,13 +494,11 @@ def _prefill_thread(self, idx: int):
465494
is_bos,
466495
request.history_path,
467496
)
468-
padded_tokens, true_length = tokenizer.encode(
469-
request.prefill_text,
470-
is_bos=is_bos,
471-
max_prefill_length=prefill_engine.max_prefill_length,
472-
jax_padding=self._jax_padding,
497+
# Tokenize and padding the text or token input.
498+
padded_tokens, true_length = self._process_prefill_content(
499+
request, tokenizer, is_bos, prefill_engine.max_prefill_length
473500
)
474-
# Compute new kv cache for the prefill_text.
501+
# Compute new kv cache for the prefill_content.
475502
prefill_result = prefill_engine.prefill(
476503
params=prefill_params,
477504
padded_tokens=padded_tokens,
@@ -497,6 +524,8 @@ def _transfer_thread(self, idx: int):
497524
while self.live:
498525
# The transfer thread can just sleep until it has work to do.
499526
new_request = transfer_backlog.get(block=True)
527+
if new_request is None:
528+
break
500529
target_idx = min(
501530
self._generate_backlogs.items(), key=lambda q: q[1].qsize()
502531
)[0]
@@ -665,15 +694,17 @@ def _detokenize_thread(self, idx: int):
665694

666695
for slot, request in my_live_requests.items():
667696
if request is not None:
668-
results, complete = tokenizer.decode(
697+
results, complete = token_utils.process_result_tokens(
698+
tokenizer=tokenizer,
669699
slot=slot,
670700
slot_max_length=request.max_tokens,
671701
result_tokens=result_tokens,
702+
is_client_side_tokenization=request.is_client_side_tokenization,
672703
complete=request.complete,
673704
)
674705
request.complete = complete
675-
# Return some tokens.
676-
request.enqueue_tokens(results)
706+
# Return some output samples.
707+
request.enqueue_samples(results)
677708
if request.complete.all():
678709
request.return_channel.close()
679710
# Place the slot back on the free queue.
@@ -698,6 +729,21 @@ class LLMOrchestrator(jetstream_pb2_grpc.OrchestratorServicer):
698729
def __init__(self, driver: Driver):
699730
self._driver = driver
700731

732+
def _get_prefill_content(
733+
self, request: jetstream_pb2.DecodeRequest
734+
) -> Tuple[str | list[int], bool]:
735+
which_content = request.WhichOneof("content")
736+
content = getattr(request, which_content)
737+
if which_content == "text_content":
738+
return cast(jetstream_pb2.DecodeRequest.TextContent, content).text, False
739+
else:
740+
return (
741+
list(
742+
cast(jetstream_pb2.DecodeRequest.TokenContent, content).token_ids
743+
),
744+
True,
745+
)
746+
701747
async def Decode( # pylint: disable=invalid-overridden-method
702748
self,
703749
request: jetstream_pb2.DecodeRequest,
@@ -709,14 +755,19 @@ async def Decode( # pylint: disable=invalid-overridden-method
709755
"LLM orchestrator is being used in offline test mode, and will not"
710756
" respond to gRPC queries - only direct function calls."
711757
)
758+
is_client_side_tokenization = False
712759
return_channel = async_multifuture.AsyncMultifuture()
713760
if context:
714761
context.add_done_callback(return_channel.cancel)
762+
prefill_content, is_client_side_tokenization = self._get_prefill_content(
763+
request
764+
)
715765
# Wrap request as an ActiveRequest.
716766
active_request = ActiveRequest(
717767
max_tokens=request.max_tokens,
718768
history_path=request.session_cache,
719-
prefill_text=request.additional_text,
769+
prefill_content=prefill_content,
770+
is_client_side_tokenization=is_client_side_tokenization,
720771
return_channel=return_channel,
721772
)
722773
# The first stage is being prefilled, all other stages are handled
@@ -736,18 +787,78 @@ async def Decode( # pylint: disable=invalid-overridden-method
736787
logging.info(
737788
"Placed request on the prefill queue.",
738789
)
739-
async for response in active_request.return_channel:
740-
# When an active request is created a queue is instantiated. New tokens
741-
# are placed there during the decoding loop, we pop from that queue by
742-
# using the .next method on the active request.
743-
# Yielding allows for the response to be a streaming grpc call - which
744-
# can be called via iterating over a for loop on the other side.
745-
# The DecodeResponse stream should consume all generated tokens in
746-
# return_channel when complete signal is received. It should check if
747-
# return_channel is empty to decide if it should exit the while loop.
748-
repeated_token_ids = []
749-
for token_ids in response:
750-
repeated_token_ids.append(
751-
jetstream_pb2.RepeatedTokenIds(token_ids=token_ids)
790+
# When an active request is created a queue is instantiated. New tokens
791+
# are placed there during the decoding loop, we pop from that queue by
792+
# using the .next method on the active request.
793+
# Yielding allows for the response to be a streaming grpc call - which
794+
# can be called via iterating over a for loop on the client side.
795+
# The DecodeResponse stream should consume all generated tokens in
796+
# return_channel when complete signal is received (AsyncMultifuture
797+
# promises this).
798+
if is_client_side_tokenization:
799+
# If is_client_side_tokenization, the client should request with token
800+
# ids, and the JetStream server will return token ids as response.
801+
# The client should take care of tokenization and detokenization.
802+
async for response in active_request.return_channel:
803+
response = cast(list[ReturnSample], response)
804+
samples = []
805+
for sample in response:
806+
samples.append(
807+
jetstream_pb2.DecodeResponse.StreamContent.Sample(
808+
token_ids=sample.token_ids,
809+
)
810+
)
811+
yield jetstream_pb2.DecodeResponse(
812+
stream_content=jetstream_pb2.DecodeResponse.StreamContent(
813+
samples=samples
814+
)
815+
)
816+
else:
817+
# Buffer response mechanism is used to handle streaming
818+
# detokenization with special character (For some edge cases with
819+
# SentencePiece tokenizer, it requires to decode a complete sequence
820+
# instead of a single token).
821+
buffered_response_list = []
822+
async for response in active_request.return_channel:
823+
response = cast(list[ReturnSample], response)
824+
buffered = False
825+
for item in response:
826+
if item.text and token_utils.is_byte_token(item.text[-1]):
827+
# If any sample ends in bytes, this means we might still need to
828+
# decode more bytes to compose the string.
829+
buffered_response_list.append(response)
830+
buffered = True
831+
break
832+
if buffered:
833+
continue
834+
# Flush the buffered responses to each sample of current response.
835+
current_response_with_flushed_buffer = list(
836+
zip(*buffered_response_list, response)
837+
)
838+
# Empty buffer: [[s0_cur], [s1_cur], ...]
839+
# Has buffer:
840+
# [[s0_b0, s0_b1, ..., s0_cur], [s1_b0, s1_b1, ..., s1_cur], ...]
841+
current_response_with_flushed_buffer = cast(
842+
list[list[ReturnSample]], current_response_with_flushed_buffer
843+
)
844+
# Reset buffer after flushed.
845+
buffered_response_list = []
846+
# Form correct sample(s) and return as StreamContent for this iteration.
847+
samples = []
848+
for sample in current_response_with_flushed_buffer:
849+
text = []
850+
token_ids = []
851+
for resp in sample:
852+
text.extend(resp.text)
853+
token_ids.extend(resp.token_ids)
854+
samples.append(
855+
jetstream_pb2.DecodeResponse.StreamContent.Sample(
856+
text=token_utils.text_tokens_to_str(text),
857+
token_ids=token_ids,
858+
)
859+
)
860+
yield jetstream_pb2.DecodeResponse(
861+
stream_content=jetstream_pb2.DecodeResponse.StreamContent(
862+
samples=samples
863+
)
752864
)
753-
yield jetstream_pb2.DecodeResponse(response=repeated_token_ids)

jetstream/core/proto/jetstream.proto

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@ package jetstream_proto;
1919
// TODO: Merge this with main JetStream core once we settle on an API.
2020

2121
service Orchestrator {
22-
// Generate the next model tokens.
22+
// Query LLM to generate text or tokens.
2323
rpc Decode(DecodeRequest) returns (stream DecodeResponse) {}
2424
}
25+
2526
message DecodeRequest {
2627
// Where to load any pre-existing kv cache from.
2728
string session_cache = 1;
28-
// New text from a user or tool.
29-
string additional_text = 2;
3029
int32 priority = 3;
3130
// The maximum output length of a sequence. It's used in JetStream to control
3231
// the output/decode length of a sequence. It would not be used in the engine.
@@ -35,12 +34,44 @@ message DecodeRequest {
3534
// sequence; max_prefill_predict_length is the maximum length of the
3635
// input/prefill of a sequence.
3736
int32 max_tokens = 4;
37+
38+
message TextContent {
39+
string text = 1;
40+
}
41+
message TokenContent {
42+
repeated int32 token_ids = 1;
43+
}
44+
45+
// The client can pass the inputs either as a string, in which case the server will
46+
// tokenize it, or as tokens, in which case it's the client's responsibility to
47+
// ensure they tokenize its input strings with the correct tokenizer.
48+
oneof content {
49+
TextContent text_content = 5;
50+
TokenContent token_content = 6;
51+
}
52+
reserved 2;
53+
// Next ID: 7
3854
}
55+
3956
message DecodeResponse {
40-
// List of responses, one per sample. The list size depends on text generation strategy the engine used.
41-
repeated RepeatedTokenIds response = 1;
42-
}
43-
message RepeatedTokenIds {
44-
// List of token ids, one list per sample. When speculative decoding is disabled, the list size should be 1; When speculative decoding is enabled, the list size should be >= 1.
45-
repeated int32 token_ids = 1;
57+
// InitialContent supports returning initial one-off response data from the
58+
// stream. It's a placeholder for future features such as history cache.
59+
message InitialContent {}
60+
message StreamContent {
61+
message Sample {
62+
// The text string decoded from token id(s).
63+
string text = 1;
64+
// List of token ids, one list per sample. When speculative decoding is disabled, the list size should be 1; When speculative decoding is enabled, the list size should be >= 1.
65+
repeated int32 token_ids = 2;
66+
}
67+
// Supports multiple samples in the StreamContent. The Sample list size depends on text generation strategy the engine used.
68+
repeated Sample samples = 1;
69+
}
70+
71+
oneof content {
72+
InitialContent initial_content = 2;
73+
StreamContent stream_content = 3;
74+
}
75+
reserved 1;
76+
// Next ID: 4
4677
}

0 commit comments

Comments
 (0)