Skip to content

Commit 0d7fcf7

Browse files
authored
Support gracefully stopping orchestrator and server (#6)
* Support gracefully stopping orchestrator and server * Add JetStream server
1 parent da4328e commit 0d7fcf7

6 files changed

Lines changed: 123 additions & 27 deletions

File tree

.github/workflows/UnitTests.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,10 @@ jobs:
6969
python -m jetstream.engine.utils_test
7070
- name: Test mock JetStream engine implementation
7171
run: |
72-
python -m jetstream.engine.mock_engine_test
72+
python -m jetstream.engine.mock_engine_test
73+
- name: Test JetStream core orchestrator
74+
run: |
75+
python -m jetstream.core.orchestrator_test
76+
- name: Test JetStream core server library
77+
run: |
78+
python -m jetstream.core.server_test

benchmarks/benchmark_serving.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ async def send_request(
242242
max_tokens: int,
243243
threads: int,
244244
) -> RequestFuncOutput:
245-
"""Send the request to wiz server."""
245+
"""Send the request to JetStream server."""
246246
loop = asyncio.get_running_loop()
247247
loop.set_default_executor(ThreadPoolExecutor(max_workers=threads))
248248
request = jetstream_pb2.DecodeRequest(
@@ -406,7 +406,7 @@ def main(args: argparse.Namespace):
406406
# Save to file
407407
base_model_id = model_id.split("/")[-1]
408408
file_name = (
409-
f"JetEngine-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
409+
f"JetStream-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
410410
)
411411
with open(file_name, "w") as outfile:
412412
json.dump(result_json, outfile)
@@ -433,7 +433,7 @@ def main(args: argparse.Namespace):
433433
help=(
434434
"Name of the model. (it's just used to label the benchmark, the model"
435435
" config is defined in config_lib, and passed as the server config"
436-
" flag when we run the wiz-pathways server)"
436+
" flag when we run the JetStream server)"
437437
),
438438
)
439439
parser.add_argument(

jetstream/core/orchestrator.py

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676

7777
import dataclasses
7878
import functools
79+
import itertools
7980
import logging
8081
import os
8182
import queue
@@ -100,15 +101,19 @@
100101

101102
handler = logging.StreamHandler(sys.stdout)
102103
handler.setLevel(logging.DEBUG)
103-
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
104+
formatter = logging.Formatter(
105+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
106+
)
104107
handler.setFormatter(formatter)
105108
root.addHandler(handler)
106109

110+
107111
def delete_pytree(p):
108112
def delete_leaf(leaf):
109113
if isinstance(leaf, jax.Array):
110114
leaf.delete()
111115
del leaf
116+
112117
jax.tree_map(delete_leaf, p)
113118

114119

@@ -185,19 +190,19 @@ class Driver:
185190
_prefill_params: Optional[dict[int, Any]] = {}
186191
_generate_params: Optional[dict[int, Any]] = {}
187192
# Stage 1
188-
_prefill_backlog: queue.Queue[ActiveRequest]
193+
_prefill_backlog: queue.Queue[ActiveRequest | None]
189194
# Stage 2
190195
# We keep this as a dict to avoid a possibly expensive object comparison
191196
# when logging the index of the generate engine we send a prefill result
192197
# to, it allows us to natively have the index from the min operation, rather
193198
# than have to call .index()
194-
_generate_backlogs: dict[int, queue.Queue[ActiveRequest]] = {}
199+
_generate_backlogs: dict[int, queue.Queue[ActiveRequest | None]] = {}
195200
# Stage 3
196201
# This can be a list because we can pass it as an arg to generate and
197202
# detokenize threads. It is a list of tokens to be detokenized.
198203
_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
199204
_generate_slots: list[queue.Queue[int]] = []
200-
_active_requests: list[queue.Queue[tuple[int, ActiveRequest]]] = []
205+
_active_requests: list[queue.Queue[tuple[int, ActiveRequest | None]]] = []
201206

202207
def __init__(
203208
self,
@@ -296,11 +301,58 @@ def __init__(
296301
)
297302
for idx, engine in enumerate(self._generate_engines)
298303
]
304+
self._all_threads = list(
305+
itertools.chain(
306+
self._prefill_threads,
307+
self._generate_threads,
308+
self.detokenize_threads,
309+
)
310+
)
299311
self.live = True
300312
# Kick off all threads
301-
_ = [f.start() for f in self._prefill_threads]
302-
_ = [f.start() for f in self._generate_threads]
303-
_ = [f.start() for f in self.detokenize_threads]
313+
for t in self._all_threads:
314+
t.start()
315+
316+
def stop(self):
317+
"""Stops the driver and all background threads."""
318+
# Signal to all threads that they should stop.
319+
self.live = False
320+
321+
all_backlogs = list(
322+
itertools.chain(
323+
[self._prefill_backlog],
324+
self._generate_backlogs.values(),
325+
self._detokenize_backlogs,
326+
)
327+
)
328+
329+
while any(t.is_alive() for t in self._all_threads):
330+
# Empty all backlogs and mark any remaining requests as cancelled.
331+
for q in all_backlogs:
332+
while True:
333+
try:
334+
r = q.get_nowait()
335+
if r is None:
336+
continue
337+
elif isinstance(r, ActiveRequest):
338+
r.return_channel = None
339+
else: # detokenize backlog
340+
_, r = r
341+
if isinstance(r, ActiveRequest):
342+
r.return_channel = None
343+
except queue.Empty:
344+
break
345+
346+
# Put sentinels to unblock threads.
347+
for q in all_backlogs:
348+
try:
349+
q.put_nowait(None)
350+
except queue.Full:
351+
pass
352+
353+
# Wait for all threads to stop.
354+
for t in self._all_threads:
355+
t.join()
304356

305357
def get_total_concurrent_requests(self) -> int:
306358
"""Returns the total number of concurrent requests the driver can service."""
@@ -338,10 +390,12 @@ def _prefill_thread(
338390
while self.live:
339391
# We don't want to keep lots of kv caches live in memory on the prefill
340392
# slice that aren't about to be sent over to a generation slice.
341-
if (self._generate_backlogs[idx].qsize() < generate_backpressure):
393+
if self._generate_backlogs[idx].qsize() < generate_backpressure:
342394
# Check if there is anything on the prefill backlog, pop if so.
343395
try:
344396
request = self._prefill_backlog.get(block=True)
397+
if request is None:
398+
break
345399
# TODO: Implement hot/cold cache for history.
346400
history = self._load_cache_history(request.history_path) # pylint: disable = assignment-from-none
347401
# Tokenize, and introduce a leading dimension
@@ -372,7 +426,8 @@ def _prefill_thread(
372426
# Once prefill is complete, place it on the generation queue.
373427
self._generate_backlogs[idx].put(request)
374428
logging.info(
375-
f'Placed request on the generate queue, {self._generate_backlogs[idx].qsize()=}'
429+
'Placed request on the generate queue,'
430+
f' {self._generate_backlogs[idx].qsize()=}'
376431
)
377432
except queue.Empty:
378433
# Otherwise, don't do anything!
@@ -410,13 +465,16 @@ def _generate_thread(
410465
while self.live:
411466
if (time.time() - time_of_last_print) > 1:
412467
logging.info(
413-
f'Generate thread making a decision with: prefill_backlog={self._prefill_backlog.qsize()} generate_free_slots={my_slots.qsize()}'
468+
'Generate thread making a decision with:'
469+
f' prefill_backlog={self._prefill_backlog.qsize()} generate_free_slots={my_slots.qsize()}'
414470
)
415471
time_of_last_print = time.time()
416472
# Check if there are any free my_slots.
417473
if not my_slots.empty() and not self._generate_backlogs[idx].empty():
418474
# Only get requests from the backlog corresponding to this engine.
419475
new_request = self._generate_backlogs[idx].get()
476+
if new_request is None:
477+
break
420478
slot = my_slots.get()
421479
logging.info(
422480
'Generate slice %d slot %d step %d',
@@ -475,6 +533,8 @@ def _detokenize_thread(
475533
while self.live:
476534
try:
477535
data = my_detokenize_backlog.get(block=True)
536+
if data is None:
537+
break
478538
start_detokenise_time = time.time()
479539
if isinstance(data[1], engine_api.ResultTokens):
480540
# We want to detokenise them.

jetstream/core/orchestrator_test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@
4141
tokenizer returns).
4242
"""
4343

44-
from jetstream.engine import mock_engine
44+
from absl.testing import absltest
4545
from jetstream.core import orchestrator
4646
from jetstream.core.proto import jetstream_pb2
47-
from absl.testing import absltest
47+
from jetstream.engine import mock_engine
4848

4949

5050
class OrchestratorTest(absltest.TestCase):
@@ -87,12 +87,16 @@ def test_orchestrator(self):
8787
counter = 0
8888
for token in iterator:
8989
# Tokens come through as bytes.
90-
print('actual output: ' + bytes(token.response[0], encoding='utf-8').decode())
90+
print(
91+
'actual output: '
92+
+ bytes(token.response[0], encoding='utf-8').decode()
93+
)
9194
assert (
9295
bytes(token.response[0], encoding='utf-8').decode()
9396
== expected_tokens[counter]
9497
)
9598
counter += 1
99+
driver.stop()
96100

97101

98102
if __name__ == '__main__':

jetstream/core/server_lib.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,45 @@
3131
_HOST = '[::]'
3232

3333

34+
class JetStreamServer:
35+
"""JetStream grpc server."""
36+
37+
def __init__(self, driver: orchestrator.Driver, server: grpc.Server):
38+
self._driver = driver
39+
self._server = server
40+
41+
def start(self, port, credentials) -> None:
42+
self._server.add_secure_port(f'{_HOST}:{port}', credentials)
43+
self._server.start()
44+
45+
def stop(self) -> None:
46+
# Gracefully clean up threads in the orchestrator.
47+
self._driver.stop()
48+
self._server.stop(0)
49+
50+
def wait_for_termination(self) -> None:
51+
self._server.wait_for_termination()
52+
53+
3454
def run(
3555
port: int,
3656
config: Type[config_lib.ServerConfig],
3757
devices: Any,
3858
credentials: Any = grpc.insecure_server_credentials(),
3959
threads: int | None = None,
40-
) -> grpc.Server:
60+
) -> JetStreamServer:
4161
"""Runs a server with a specified config.
4262
4363
Args:
4464
port: Port on which the server will be made available.
4565
config: A ServerConfig to config engine, model, device slices, etc.
46-
device: Device objects, will be used to get engine with proper slicing.
66+
devices: Device objects, will be used to get engine with proper slicing.
4767
credentials: Should use grpc credentials by default.
4868
threads: Number of RPC handlers worker threads. This should be at least
4969
equal to the decoding batch size to fully saturate the decoding queue.
70+
71+
Returns:
72+
JetStreamServer that wraps the grpc server and orchestrator driver.
5073
"""
5174
logging.info('Kicking off gRPC server.')
5275
engines = config_lib.get_engines(config, devices=devices)
@@ -69,9 +92,9 @@ def run(
6992
)
7093
logging.info('Starting server on port %d with %d threads', port, threads)
7194

72-
server.add_secure_port(f'{_HOST}:{port}', credentials)
73-
server.start()
74-
return server
95+
jetstream_server = JetStreamServer(driver, server)
96+
jetstream_server.start(port, credentials)
97+
return jetstream_server
7598

7699

77100
def get_devices() -> Any:

jetstream/core/server_test.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@
2020

2121
from typing import Any, Type
2222

23+
from absl.testing import absltest, parameterized
2324
import grpc
24-
import portpicker
25-
2625
from jetstream.core import config_lib
2726
from jetstream.core import server_lib
2827
from jetstream.core.proto import jetstream_pb2
2928
from jetstream.core.proto import jetstream_pb2_grpc
30-
from absl.testing import absltest, parameterized
29+
import portpicker
3130

3231

3332
class ServerTest(parameterized.TestCase):
@@ -58,7 +57,7 @@ def test_server(
5857
print('port: ' + str(port))
5958
credentials = grpc.local_server_credentials()
6059

61-
_ = server_lib.run(
60+
server = server_lib.run(
6261
port=port,
6362
config=config,
6463
devices=devices,
@@ -83,12 +82,16 @@ def test_server(
8382
counter = 0
8483
for token in iterator:
8584
# Tokens come through as bytes
86-
print('actual output: ' + bytes(token.response[0], encoding='utf-8').decode())
85+
print(
86+
'actual output: '
87+
+ bytes(token.response[0], encoding='utf-8').decode()
88+
)
8789
assert (
8890
bytes(token.response[0], encoding='utf-8').decode()
8991
== expected_tokens[counter]
9092
)
9193
counter += 1
94+
server.stop()
9295

9396

9497
if __name__ == '__main__':

0 commit comments

Comments
 (0)