Skip to content

Commit d681995

Browse files
authored
Various request time metrics (#121)
* first commit * nit * fmt * description tweak * added more metrics * nit * nit * default metadata values * move `new_request.metadata.transfer_start_time = time.perf_counter()` * avoid NoneType * NoneType * set transfer_end_time and fmt * camel case -> snake case * description update * change descriptions * fmt * logs * better logs * changed timings * observing queue duration metric * buckets in sorted order * buckets not in sorted order * corrected times * number of output tokens * move prefill_start_time, enable debug, maybe correct len for num tokens in detokenize * fmt * correct lengths of output tokens based on debug * debug transfer queue time * remove log * removed logs, almost final * nits * readd log * change logs * reomve log * condence * improve test coverage * revert _abort_or_raise deletion * start_time mandatory * undo * nit * updated buckets * added 'jetstream_time_per_request' * nit * add 'jetstream_wait_time_per_request' * nit * missing .metadata * lint * change order of params * changed metric description * Add metadata field to proto * update proto * tweak generated file * tweak generated file * update proto * pylint * generate protos * change start time assignment * .value * CopyFrom * change definition of queue duration metric * Increase test coverage * fixed assertions * fmt * incorrect prefill time * Add license statements * Protobuf Python Version * fmt * pylint
1 parent 3946afa commit d681995

7 files changed

Lines changed: 300 additions & 52 deletions

File tree

jetstream/core/metrics/prometheus.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import os
1818
import shortuuid
1919
from prometheus_client import Counter, Gauge, Histogram
20-
2120
from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS
2221

2322

@@ -37,21 +36,46 @@ def __new__(cls):
3736
documentation="Size of prefill queue",
3837
labelnames=["id"],
3938
)
39+
4040
_transfer_backlog = Gauge(
4141
name="jetstream_transfer_backlog_size",
4242
documentation="Size of transfer queue",
4343
labelnames=["id", "idx"],
4444
)
45+
4546
_generate_backlog = Gauge(
4647
name="jetstream_generate_backlog_size",
4748
documentation="Size of generate queue",
4849
labelnames=["id", "idx"],
4950
)
51+
52+
_queue_duration = Histogram(
53+
name="jetstream_queue_duration",
54+
documentation="The total time each request spends enqueued in seconds",
55+
labelnames=["id"],
56+
buckets=[
57+
0.01,
58+
0.02,
59+
0.05,
60+
0.1,
61+
0.2,
62+
0.5,
63+
1.0,
64+
2.0,
65+
5.0,
66+
10.0,
67+
20.0,
68+
50.0,
69+
100.0,
70+
],
71+
)
72+
5073
_slots_used_percentage = Gauge(
5174
name="jetstream_slots_used_percentage",
5275
documentation="The percentage of decode slots currently being used",
5376
labelnames=["id", "idx"],
5477
)
78+
5579
_server_startup_latency = Gauge(
5680
name="jetstream_server_startup_latency",
5781
documentation="Total time taken to start the Jetstream server",
@@ -96,6 +120,100 @@ def __new__(cls):
96120
labelnames=["id"],
97121
)
98122

123+
_time_to_first_token = Histogram(
124+
name="jetstream_time_to_first_token",
125+
documentation="Time to first token per request in seconds",
126+
labelnames=["id"],
127+
buckets=[
128+
0.001,
129+
0.005,
130+
0.01,
131+
0.02,
132+
0.04,
133+
0.06,
134+
0.08,
135+
0.1,
136+
0.25,
137+
0.5,
138+
0.75,
139+
1.0,
140+
2.5,
141+
5.0,
142+
7.5,
143+
10.0,
144+
],
145+
)
146+
147+
_time_per_output_token = Histogram(
148+
name="jetstream_time_per_output_token",
149+
documentation="Average time per output token per request in seconds",
150+
labelnames=["id"],
151+
buckets=[
152+
0.01,
153+
0.025,
154+
0.05,
155+
0.075,
156+
0.1,
157+
0.15,
158+
0.2,
159+
0.3,
160+
0.4,
161+
0.5,
162+
0.75,
163+
1.0,
164+
2.5,
165+
],
166+
)
167+
168+
_time_per_prefill_token = Histogram(
169+
name="jetstream_time_per_prefill_token",
170+
documentation="Prefill time per token per request in seconds",
171+
labelnames=["id"],
172+
buckets=[
173+
0.00001,
174+
0.00002,
175+
0.00005,
176+
0.0001,
177+
0.0002,
178+
0.0005,
179+
0.001,
180+
0.002,
181+
0.005,
182+
0.01,
183+
0.02,
184+
0.05,
185+
0.1,
186+
],
187+
)
188+
189+
_time_per_request = Histogram(
190+
name="jetstream_time_per_request",
191+
documentation="End to end request latency in seconds",
192+
labelnames=["id"],
193+
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0],
194+
)
195+
196+
_wait_time_per_request = Histogram(
197+
name="jetstream_wait_time_per_request",
198+
documentation="Time each request is not being prefilled or decoded",
199+
labelnames=["id"],
200+
buckets=[
201+
0.01,
202+
0.02,
203+
0.05,
204+
0.1,
205+
0.2,
206+
0.5,
207+
1.0,
208+
2.0,
209+
5.0,
210+
10.0,
211+
20.0,
212+
50.0,
213+
100.0,
214+
],
215+
)
216+
99217
def get_prefill_backlog_metric(self):
100218
return self._prefill_backlog.labels(id=self._id)
101219

@@ -105,12 +223,30 @@ def get_transfer_backlog_metric(self, idx: int):
105223
def get_generate_backlog_metric(self, idx: int):
106224
return self._generate_backlog.labels(id=self._id, idx=idx)
107225

226+
def get_queue_duration(self):
227+
return self._queue_duration.labels(id=self._id)
228+
108229
def get_slots_used_percentage_metric(self, idx: int):
109230
return self._slots_used_percentage.labels(id=self._id, idx=idx)
110231

111232
def get_server_startup_latency_metric(self):
112233
return self._server_startup_latency.labels(id=self._id)
113234

235+
def get_time_to_first_token(self):
236+
return self._time_to_first_token.labels(id=self._id)
237+
238+
def get_time_per_output_token(self):
239+
return self._time_per_output_token.labels(id=self._id)
240+
241+
def get_time_per_prefill_token(self):
242+
return self._time_per_prefill_token.labels(id=self._id)
243+
244+
def get_time_per_request(self):
245+
return self._time_per_request.labels(id=self._id)
246+
247+
def get_wait_time_per_request(self):
248+
return self._wait_time_per_request.labels(id=self._id)
249+
114250
def get_request_input_length(self):
115251
return self._request_input_length.labels(id=self._id)
116252

jetstream/core/orchestrator.py

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,24 @@
109109
root.addHandler(handler)
110110

111111

112+
@dataclasses.dataclass
113+
class ActiveRequestMetadata:
114+
"""Inference request metadata."""
115+
116+
start_time: Optional[float] = None
117+
118+
prefill_enqueue_time: Optional[float] = None
119+
prefill_dequeue_time: Optional[float] = None
120+
121+
transfer_enqueue_time: Optional[float] = None
122+
transfer_dequeue_time: Optional[float] = None
123+
124+
generate_enqueue_time: Optional[float] = None
125+
generate_dequeue_time: Optional[float] = None
126+
127+
complete_time: Optional[float] = None
128+
129+
112130
@dataclasses.dataclass
113131
class ActiveRequest:
114132
"""Current state of the driver."""
@@ -130,6 +148,8 @@ class ActiveRequest:
130148
# Which generate step this was added at.
131149
generate_timestep_added: Optional[int] = None
132150
is_client_side_tokenization: Optional[bool] = False
151+
################## Information relevant for metrics ###################
152+
metadata: ActiveRequestMetadata = ActiveRequestMetadata()
133153

134154
def enqueue_samples(self, generated_samples: list[ReturnSample]):
135155
"""Adds the generated sample(s) to return channel for current step.
@@ -477,10 +497,10 @@ def _prefill_thread(self, idx: int):
477497
my_transfer_backlog = self._transfer_backlogs[idx]
478498
# The prefill thread can just sleep until it has work to do.
479499
request = self._prefill_backlog.get(block=True)
480-
request_start_time = time.perf_counter()
481500

482501
if request is None:
483502
break
503+
request.metadata.prefill_dequeue_time = time.perf_counter()
484504
is_bos = True
485505
logging.info(
486506
"Prefilling on prefill engine %d : prefill queue size, %d,"
@@ -511,8 +531,10 @@ def _prefill_thread(self, idx: int):
511531
# put first token to detokenize queue
512532
request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_)
513533
my_detokenize_backlog = self._detokenize_backlogs[idx]
534+
request.metadata.transfer_enqueue_time = time.perf_counter()
514535
my_detokenize_backlog.put(
515-
(first_token, request, request_start_time), block=True
536+
(first_token, request, request.metadata.prefill_dequeue_time),
537+
block=True,
516538
)
517539

518540
# Once prefill is complete, place it on the generation queue and block if
@@ -526,6 +548,15 @@ def _prefill_thread(self, idx: int):
526548
if self._metrics_collector:
527549
self._metrics_collector.get_request_input_length().observe(true_length)
528550

551+
if self._metrics_collector:
552+
self._metrics_collector.get_time_per_prefill_token().observe(
553+
(
554+
request.metadata.transfer_enqueue_time
555+
- request.metadata.prefill_dequeue_time
556+
)
557+
/ true_length
558+
)
559+
529560
del prefill_result
530561
del request
531562

@@ -562,6 +593,7 @@ def _transfer_thread(self, idx: int):
562593
new_request = transfer_backlog.get(block=True)
563594
if new_request is None:
564595
break
596+
new_request.metadata.transfer_dequeue_time = time.perf_counter()
565597
target_idx = min(
566598
self._generate_backlogs.items(), key=lambda q: q[1].qsize()
567599
)[0]
@@ -577,6 +609,7 @@ def _transfer_thread(self, idx: int):
577609
# Transfer the info to the relevant generate slice.
578610
self._transfer_prefill_result(new_request, target_idx)
579611
# Place the request on the correct generate backlog and block if full.
612+
new_request.metadata.generate_enqueue_time = time.perf_counter()
580613
self._generate_backlogs[target_idx].put(new_request, block=True)
581614
logging.info(
582615
"Successfully transferred prefill "
@@ -649,6 +682,24 @@ def _generate_thread(self, idx: int):
649682
block |= not self._transfer_backlogs[idx].empty()
650683
try:
651684
new_request = my_generate_backlog.get(block=block, timeout=1.0)
685+
if new_request is None:
686+
break
687+
new_request.metadata.generate_dequeue_time = time.perf_counter()
688+
if (
689+
self._metrics_collector
690+
and new_request.metadata.start_time is not None
691+
):
692+
self._metrics_collector.get_queue_duration().observe(
693+
# Time in prefill queue
694+
new_request.metadata.prefill_dequeue_time
695+
- new_request.metadata.prefill_enqueue_time
696+
# Time in transfer queue
697+
+ new_request.metadata.transfer_dequeue_time
698+
- new_request.metadata.transfer_enqueue_time
699+
# Time in generate queue
700+
+ new_request.metadata.generate_dequeue_time
701+
- new_request.metadata.generate_enqueue_time
702+
)
652703
# Got free slot and new request, use them.
653704
except queue.Empty:
654705
# No new requests, we can't insert, so put back slot.
@@ -731,7 +782,7 @@ def _detokenize_thread(self, idx: int):
731782
start_detokenize_time = time.time()
732783
# prefill first token
733784
if isinstance(data[0], engine_api.ResultTokens):
734-
request_first_token, request, request_start_time = data
785+
request_first_token, request, _ = data
735786
request_first_token = request_first_token.convert_to_numpy()
736787

737788
results, complete = token_utils.process_result_tokens(
@@ -747,9 +798,14 @@ def _detokenize_thread(self, idx: int):
747798
request.enqueue_samples(results)
748799

749800
first_token_return_time = time.perf_counter()
801+
if self._metrics_collector:
802+
self._metrics_collector.get_time_to_first_token().observe(
803+
first_token_return_time - request.metadata.prefill_dequeue_time
804+
)
750805
logging.info(
751806
"TTFT duration: %fms",
752-
(first_token_return_time - request_start_time) * 1000,
807+
(first_token_return_time - request.metadata.prefill_dequeue_time)
808+
* 1000,
753809
)
754810
# generate step tokens
755811
elif isinstance(data[1], engine_api.ResultTokens):
@@ -773,12 +829,41 @@ def _detokenize_thread(self, idx: int):
773829
# Return some output samples.
774830
request.enqueue_samples(results)
775831
if request.complete.all():
832+
request.metadata.complete_time = time.perf_counter()
833+
request.return_channel.close()
776834
if self._metrics_collector:
777835
self._metrics_collector.get_request_output_length().observe(
778836
result_tokens.get_result_at_slot(slot).lengths
779837
)
780838
self._metrics_collector.get_request_success_count_metric().inc()
781-
request.return_channel.close()
839+
self._metrics_collector.get_time_per_output_token().observe(
840+
(
841+
request.metadata.complete_time
842+
- request.metadata.transfer_enqueue_time
843+
)
844+
/ result_tokens.get_result_at_slot(slot).lengths
845+
)
846+
self._metrics_collector.get_time_per_request().observe(
847+
request.metadata.complete_time
848+
- request.metadata.transfer_enqueue_time
849+
)
850+
851+
if request.metadata.start_time:
852+
total_time = (
853+
request.metadata.complete_time
854+
- request.metadata.start_time
855+
)
856+
prefill_time = (
857+
request.metadata.transfer_enqueue_time
858+
- request.metadata.prefill_dequeue_time
859+
)
860+
generate_time = (
861+
request.metadata.complete_time
862+
- request.metadata.generate_dequeue_time
863+
)
864+
self._metrics_collector.get_wait_time_per_request().observe(
865+
total_time - prefill_time - generate_time
866+
)
782867
# Place the slot back on the free queue.
783868
my_live_requests[slot] = None
784869
my_slots.put(slot, block=False) # This should always have space.
@@ -895,6 +980,10 @@ async def Decode( # pylint: disable=invalid-overridden-method
895980
prefill_content=prefill_content,
896981
is_client_side_tokenization=is_client_side_tokenization,
897982
return_channel=return_channel,
983+
metadata=ActiveRequestMetadata(
984+
start_time=request.metadata.start_time,
985+
prefill_enqueue_time=time.perf_counter(),
986+
),
898987
)
899988
# The first stage is being prefilled, all other stages are handled
900989
# inside the driver (transfer, generate*N, detokenize).

jetstream/core/proto/jetstream.proto

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,17 @@ message DecodeRequest {
5050
TextContent text_content = 5;
5151
TokenContent token_content = 6;
5252
}
53+
54+
message Metadata {
55+
float start_time = 1;
56+
}
57+
58+
oneof metadata_optional {
59+
Metadata metadata = 7;
60+
}
61+
5362
reserved 1, 2, 3;
54-
// Next ID: 7
63+
// Next ID: 8
5564
}
5665

5766
message DecodeResponse {

0 commit comments

Comments
 (0)