Skip to content

Commit 45f8735

Browse files
authored
Request input/output size metrics (#123)
* first commit * remove unused code * fmt * changed buckets * now using DEFAULT_PREFILL_BUCKETS * missing parenthese
1 parent 64ff9ea commit 45f8735

2 files changed

Lines changed: 47 additions & 10 deletions

File tree

jetstream/core/metrics/prometheus.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
import os
1818
import shortuuid
19-
from prometheus_client import Counter, Gauge
19+
from prometheus_client import Counter, Gauge, Histogram
20+
21+
from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS
2022

2123

2224
class JetstreamMetricsCollector:
@@ -55,6 +57,39 @@ def __new__(cls):
5557
documentation="Total time taken to start the Jetstream server",
5658
labelnames=["id"],
5759
)
60+
_request_input_length = Histogram(
61+
name="jetstream_request_input_length",
62+
documentation="Number of input tokens per request",
63+
labelnames=["id"],
64+
buckets=DEFAULT_PREFILL_BUCKETS,
65+
)
66+
_request_output_length = Histogram(
67+
name="jetstream_request_output_length",
68+
documentation="Number of output tokens per request",
69+
labelnames=["id"],
70+
buckets=[
71+
1,
72+
2,
73+
5,
74+
10,
75+
20,
76+
50,
77+
100,
78+
200,
79+
500,
80+
1000,
81+
2000,
82+
5000,
83+
10000,
84+
20000,
85+
50000,
86+
100000,
87+
200000,
88+
500000,
89+
1000000,
90+
2000000,
91+
],
92+
)
5893
_request_success_count = Counter(
5994
name="jetstream_request_success_count",
6095
documentation="Number of requests successfully completed",
@@ -76,5 +111,11 @@ def get_slots_used_percentage_metric(self, idx: int):
76111
def get_server_startup_latency_metric(self):
77112
return self._server_startup_latency.labels(id=self._id)
78113

114+
def get_request_input_length(self):
115+
return self._request_input_length.labels(id=self._id)
116+
117+
def get_request_output_length(self):
118+
return self._request_output_length.labels(id=self._id)
119+
79120
def get_request_success_count_metric(self):
80121
return self._request_success_count.labels(id=self._id)

jetstream/core/orchestrator.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,6 @@
109109
root.addHandler(handler)
110110

111111

112-
def delete_pytree(p):
113-
def delete_leaf(leaf):
114-
if isinstance(leaf, jax.Array):
115-
leaf.delete()
116-
del leaf
117-
118-
jax.tree_map(delete_leaf, p)
119-
120-
121112
@dataclasses.dataclass
122113
class ActiveRequest:
123114
"""Current state of the driver."""
@@ -532,6 +523,8 @@ def _prefill_thread(self, idx: int):
532523
idx,
533524
my_transfer_backlog.qsize(),
534525
)
526+
if self._metrics_collector:
527+
self._metrics_collector.get_request_input_length().observe(true_length)
535528

536529
del prefill_result
537530
del request
@@ -781,6 +774,9 @@ def _detokenize_thread(self, idx: int):
781774
request.enqueue_samples(results)
782775
if request.complete.all():
783776
if self._metrics_collector:
777+
self._metrics_collector.get_request_output_length().observe(
778+
result_tokens.get_result_at_slot(slot).lengths
779+
)
784780
self._metrics_collector.get_request_success_count_metric().inc()
785781
request.return_channel.close()
786782
# Place the slot back on the free queue.

0 commit comments

Comments
 (0)