Skip to content

Commit 64b1e50

Browse files
authored
Fix starvation with async server and interleaving optimization (#13)
* Fix starvation with async server * keep grpc option * rm default for return_channel * revert rm default for return_channel * Using blocking fixed size queue to block and yield threads efficiently * fix AsyncMultifuture * complete fix - optimized interleaving prefill, insert, and generate * fix unit test and pytype * add TODO
1 parent 2a91d38 commit 64b1e50

6 files changed

Lines changed: 392 additions & 233 deletions

File tree

benchmarks/benchmark_serving.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -250,18 +250,17 @@ def calculate_metrics(
250250
return metrics
251251

252252

253-
def grpc_sync_request(api_url: str, request: Any) -> tuple[list[str], float, float]:
253+
async def grpc_async_request(api_url: str, request: Any) -> tuple[list[str], float, float]:
254254
"""Send grpc synchronous request since the current grpc server is sync."""
255255
options = [("grpc.keepalive_timeout_ms", 10000)]
256-
with grpc.insecure_channel(api_url, options=options) as channel:
257-
grpc.channel_ready_future(channel).result()
256+
async with grpc.aio.insecure_channel(api_url, options=options) as channel:
258257
stub = jetstream_pb2_grpc.OrchestratorStub(channel)
259258
print("Making request")
260259
ttft = 0
261260
token_list = []
262261
request_start_time = time.perf_counter()
263262
response = stub.Decode(request)
264-
for token in response:
263+
async for token in response:
265264
if ttft == 0:
266265
ttft = time.perf_counter() - request_start_time
267266
token_list.append(token.response[0])
@@ -278,8 +277,6 @@ async def send_request(
278277
threads: int,
279278
) -> RequestFuncOutput:
280279
"""Send the request to JetStream server."""
281-
loop = asyncio.get_running_loop()
282-
loop.set_default_executor(ThreadPoolExecutor(max_workers=threads))
283280
request = jetstream_pb2.DecodeRequest(
284281
session_cache=session_cache,
285282
additional_text=input_request.prompt,
@@ -289,9 +286,7 @@ async def send_request(
289286
output = RequestFuncOutput()
290287
output.input_request = input_request
291288
output.prompt_len = input_request.prompt_len
292-
generated_token_list, ttft, latency = await loop.run_in_executor(
293-
None, grpc_sync_request, api_url, request
294-
)
289+
generated_token_list, ttft, latency = await grpc_async_request(api_url, request)
295290
output.ttft = ttft
296291
output.latency = latency
297292
output.generated_token_list = generated_token_list

0 commit comments

Comments
 (0)