Skip to content

Commit da4328e

Browse files
authored
First support necessary for MaxText (#5)
- remove transfer thread - have generate thread issue deletes for prefill content - logging to stdout - benchmark script to use TF tokenizer
1 parent 25527d6 commit da4328e

2 files changed

Lines changed: 48 additions & 62 deletions

File tree

benchmarks/benchmark_serving.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,14 @@
4141
(run with mock)
4242
python -m benchmarks.benchmark_serving \
4343
--request-rate 1
44+
45+
e2e example: python3 benchmark_serving.py --tokenizer /home/rwitten/maxtext/assets/tokenizer --num-prompts 100 --dataset ~/ShareGPT_V3_unfiltered_cleaned_split.json
4446
"""
4547

48+
49+
import tensorflow as tf
50+
import tensorflow_text as tftxt
51+
4652
import argparse
4753
import asyncio
4854
from concurrent.futures import ThreadPoolExecutor
@@ -89,8 +95,11 @@ def get_tokenizer(tokenizer_name: str) -> Any:
8995
if tokenizer_name == "test":
9096
return "test"
9197
else:
92-
raise NotImplementedError
93-
98+
with tf.io.gfile.GFile(tokenizer_name, 'rb') as model_fp:
99+
sp_model = model_fp.read()
100+
sp_tokenizer = tftxt.SentencepieceTokenizer(
101+
model=sp_model, add_bos=True, add_eos=False, reverse=False)
102+
return sp_tokenizer
94103

95104
def sample_requests(
96105
dataset_path: str,
@@ -114,11 +123,11 @@ def sample_requests(
114123

115124
# Tokenize the prompts and completions.
116125
prompts = [prompt for prompt, _ in dataset]
117-
prompt_token_ids = tokenizer.encode(
126+
prompt_token_ids = tokenizer.tokenize(
118127
prompts
119128
) # adjust this code based on tokenizer method
120129
completions = [completion for _, completion in dataset]
121-
completion_token_ids = tokenizer.encode(
130+
completion_token_ids = tokenizer.tokenize(
122131
completions
123132
) # adjust this code based on tokenizer method
124133
tokenized_dataset = []
@@ -176,7 +185,7 @@ def calculate_metrics(
176185
for i in range(len(outputs)):
177186
if outputs[i].success:
178187
output_len = len(
179-
tokenizer.encode(outputs[i].generated_text)
188+
tokenizer.tokenize(outputs[i].generated_text)
180189
if tokenizer != "test"
181190
else "ĊŌƟ"
182191
)

jetstream/core/orchestrator.py

Lines changed: 34 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
import os
8181
import queue
8282
import signal
83+
import sys
8384
import threading
8485
import time
8586
import traceback
@@ -94,6 +95,23 @@
9495
import numpy as np
9596

9697

98+
root = logging.getLogger()
99+
root.setLevel(logging.DEBUG)
100+
101+
handler = logging.StreamHandler(sys.stdout)
102+
handler.setLevel(logging.DEBUG)
103+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
104+
handler.setFormatter(formatter)
105+
root.addHandler(handler)
106+
107+
def delete_pytree(p):
108+
def delete_leaf(leaf):
109+
if isinstance(leaf, jax.Array):
110+
leaf.delete()
111+
del leaf
112+
jax.tree_map(delete_leaf, p)
113+
114+
97115
@dataclasses.dataclass
98116
class ActiveRequest:
99117
"""Current state of the driver."""
@@ -169,14 +187,12 @@ class Driver:
169187
# Stage 1
170188
_prefill_backlog: queue.Queue[ActiveRequest]
171189
# Stage 2
172-
_transfer_backlog: queue.Queue[ActiveRequest]
173-
# Stage 3
174190
# We keep this as a dict to avoid a possibly expensive object comparison
175191
# when logging the index of the generate engine we send a prefill result
176192
# to, it allows us to natively have the index from the min operation, rather
177193
# than have to call .index()
178194
_generate_backlogs: dict[int, queue.Queue[ActiveRequest]] = {}
179-
# Stage 4
195+
# Stage 3
180196
# This can be a list because we can pass it as an arg to generate and
181197
# detokenize threads. It is a list of tokens to be detokenized.
182198
_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
@@ -204,15 +220,11 @@ def __init__(
204220
# At first, a request is placed here in order to get prefilled.
205221
self._prefill_backlog = queue.Queue()
206222
# Stage 2
207-
# After prefilling, it is placed here in order to get transferred to
208-
# one of the generate backlogs.
209-
self._transfer_backlog = queue.Queue()
210-
# Stage 3
211223
# Each generate engine accesses its own generate backlog.
212224
self._generate_backlogs = {
213225
idx: queue.Queue() for idx, _ in enumerate(generate_engines)
214226
}
215-
# Stage 4
227+
# Stage 3
216228
# After generation, ActiveRequests are placed on the detokenization backlog
217229
# for tokens to be sent into each ActiveRequest's return channel.
218230
# We have one of these per generate engine to simplify the logic keeping
@@ -257,7 +269,6 @@ def __init__(
257269
)
258270
for idx, engine in enumerate(self._prefill_engines)
259271
]
260-
self._transfer_thread = JetThread(target=self._transfer_thread)
261272
self._generate_threads = [
262273
JetThread(
263274
target=functools.partial(
@@ -288,7 +299,6 @@ def __init__(
288299
self.live = True
289300
# Kick off all threads
290301
_ = [f.start() for f in self._prefill_threads]
291-
self._transfer_thread.start()
292302
_ = [f.start() for f in self._generate_threads]
293303
_ = [f.start() for f in self.detokenize_threads]
294304

@@ -316,7 +326,7 @@ def _prefill_thread(
316326
self,
317327
idx: int,
318328
prefill_engine: engine_api.Engine,
319-
transfer_backpressure: int = 8,
329+
generate_backpressure: int = 3,
320330
):
321331
"""Thread which runs in the background performing prefills."""
322332
logging.info('---------Spinning up prefill thread %d.---------', idx)
@@ -328,7 +338,7 @@ def _prefill_thread(
328338
while self.live:
329339
# We don't want to keep lots of kv caches live in memory on the prefill
330340
# slice that aren't about to be sent over to a generation slice.
331-
if self._transfer_backlog.qsize() < transfer_backpressure:
341+
if (self._generate_backlogs[idx].qsize() < generate_backpressure):
332342
# Check if there is anything on the prefill backlog, pop if so.
333343
try:
334344
request = self._prefill_backlog.get(block=True)
@@ -337,10 +347,9 @@ def _prefill_thread(
337347
# Tokenize, and introduce a leading dimension
338348
is_bos = not bool(request.history_path)
339349
logging.info(
340-
'Prefilling on prefill engine %d : "%s", prefill queue size, %d,'
350+
'Prefilling on prefill engine %d : prefill queue size, %d,'
341351
' is_bos: %s, history: %s',
342352
idx,
343-
request.prefill_text,
344353
self._prefill_backlog.qsize(),
345354
is_bos,
346355
request.history_path,
@@ -359,53 +368,16 @@ def _prefill_thread(
359368
padded_tokens=padded_tokens,
360369
true_length=true_length,
361370
)
362-
jax.block_until_ready(prefill_result)
363371
request.prefill_result = prefill_result
364372
# Once prefill is complete, place it on the generation queue.
365-
self._transfer_backlog.put(request)
373+
self._generate_backlogs[idx].put(request)
366374
logging.info(
367-
'Placed request "%s" on the transfer queue.', request.prefill_text
375+
f'Placed request on the generate queue, {self._generate_backlogs[idx].qsize()=}'
368376
)
369377
except queue.Empty:
370378
# Otherwise, don't do anything!
371379
pass
372380

373-
def _transfer_thread(self, generation_backpressure: int = 8):
374-
"""Transfers the kv cache on an active request to the least full generate backlog."""
375-
while self.live:
376-
# We don't want to keep lots of kv caches live in memory on a generate
377-
# slice that haven't been inserted
378-
if (
379-
sum([backlog.qsize() for backlog in self._generate_backlogs.values()])
380-
< generation_backpressure
381-
):
382-
try:
383-
new_request = self._transfer_backlog.get(block=True)
384-
385-
# Get the index of the generate queue with the minimmum qsize.
386-
idx = min(
387-
self._generate_backlogs.items(), key=lambda q: q[1].qsize()
388-
)[0]
389-
logging.info(
390-
'Transferring "%s" to generate engine %d.',
391-
new_request.prefill_text,
392-
idx,
393-
)
394-
# Transfer the info to the relevant generate backlog.
395-
new_request.prefill_result = jax.device_put(
396-
new_request.prefill_result,
397-
self._generate_engines[idx].get_prefix_destination_sharding(),
398-
)
399-
# Place the request on the correct generate backlog.
400-
self._generate_backlogs[idx].put(new_request)
401-
logging.info(
402-
'Request "%s" tsuccessfully transferrred to generate engine %d.',
403-
new_request.prefill_text,
404-
idx,
405-
)
406-
except queue.Empty:
407-
pass
408-
409381
def _generate_thread(
410382
self,
411383
idx: int,
@@ -434,22 +406,28 @@ def _generate_thread(
434406
generate_params = self._generate_params[idx]
435407
logging.info('---------Generate params %d loaded.---------', idx)
436408
time_of_last_generate = time.time()
409+
time_of_last_print = time.time()
437410
while self.live:
411+
if (time.time() - time_of_last_print) > 1:
412+
logging.info(
413+
f'Generate thread making a decision with: prefill_backlog={self._prefill_backlog.qsize()} generate_free_slots={my_slots.qsize()}'
414+
)
415+
time_of_last_print = time.time()
438416
# Check if there are any free my_slots.
439417
if not my_slots.empty() and not self._generate_backlogs[idx].empty():
440418
# Only get requests from the backlog corresponding to this engine.
441419
new_request = self._generate_backlogs[idx].get()
442420
slot = my_slots.get()
443421
logging.info(
444-
'Generate slice %d slot %d step %d, generating for : "%s"',
422+
'Generate slice %d slot %d step %d',
445423
idx,
446424
slot,
447425
generate_timestep,
448-
new_request.prefill_text,
449426
)
450427
decode_state = generate_engine.insert(
451428
new_request.prefill_result, decode_state, slot=slot
452429
)
430+
delete_pytree(new_request.prefill_result)
453431
new_request.generate_timestep_added = generate_timestep
454432
new_request.complete = np.zeros(
455433
(generate_engine.samples_per_slot,), dtype=np.bool_
@@ -576,8 +554,7 @@ def Decode(
576554
),
577555
)
578556
logging.info(
579-
'Placed request with text "%s" on the prefill queue.',
580-
active_request.prefill_text,
557+
'Placed request on the prefill queue.',
581558
)
582559

583560
while True:

0 commit comments

Comments
 (0)