Skip to content

Commit 44686b0

Browse files
authored
Decode Batch Percentage Metrics/Improved Scraping (#82)
* initial-commit * newline * moved gauge * Added labels for metrics * revert removing driver metric field * removed newline * Removed unneccesary lambda * Proper use of labels api * format * missing quotes * convert result to float, add lambda * Moved registration to driver level * made metric driver property * missing field on class * Dont regenerate uuid on each scrape * Init uuid * Moved metrics to separate file * Typos * Initialize metrics * reformat * added 'global' * Moved registration out of function * removed unused import * Update server_lib.py * jetstream_prefill_backlog_size -> prefill_backlog_size * label fields * rename metrics * Moved metrics to singleton class * Revert requirements change * revert requirements changes * Plumb metrics config to JetStreamServer run header * Cleanup of prior commit * Default for hostname metric label * __ -> _ * Linter error * linter error * linter error * default value for idx * Linter error * Linter error * Linter error final fix hopefully * Fixed type annotation * Type fix in server_lib.py * __ -> _ in JetstreamMetricsCollector property names * added docstrings * added module docstring * laxy % formatting * metrics port cannot be 0 * idx can be None * Removed redundant class * removed lingering import * Update test_server.py * hostname -> id * linter * is not -> != * Zijun nits * reformat * missing protocol scheme in url * requests -> aiohttp * linter * Cleaned readme * Changes to tests, readme * fixed assertion * better description * parameterized test setup * linter * Update online-inference-with-maxtext-engine.md * disable protected-access * make prometheus test not async, revert requirements.txt changes * Update requirements.txt * remove aiohttp * moved prometheus test to test_server function * log line * timeout
1 parent 8128c8a commit 44686b0

7 files changed

Lines changed: 164 additions & 29 deletions

File tree

docs/online-inference-with-maxtext-engine.md

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,41 @@ Prompt: Today is a good day
205205
Response: to be a fan
206206
```
207207

208+
### (optional) Observe Jetstream metrics
209+
210+
Metrics are not exported by default, to configure Jetstream to emit metrics start this guide again from step four and replace the `Run the following command to start the JetStream MaxText server` step with the following:
211+
212+
```bash
213+
export PROMETHEUS_PORT=9090
214+
215+
cd ~/maxtext
216+
python MaxText/maxengine_server.py \
217+
MaxText/configs/base.yml \
218+
tokenizer_path=${TOKENIZER_PATH} \
219+
load_parameters_path=${LOAD_PARAMETERS_PATH} \
220+
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
221+
max_target_length=${MAX_TARGET_LENGTH} \
222+
model_name=${MODEL_NAME} \
223+
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
224+
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
225+
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
226+
scan_layers=${SCAN_LAYERS} \
227+
weight_dtype=${WEIGHT_DTYPE} \
228+
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
229+
prometheus_port=${PROMETHEUS_PORT}
230+
```
231+
232+
Now that we configured `prometheus_port=9090` above, we can observe various Jetstream metrics via HTTP requests to `0.0.0.0:9000`. Towards the end, the response should have content similar to the following:
233+
234+
```
235+
# HELP jetstream_prefill_backlog_size Size of prefill queue
236+
# TYPE jetstream_prefill_backlog_size gauge
237+
jetstream_prefill_backlog_size{id="SOME-HOSTNAME-HERE>"} 0.0
238+
# HELP jetstream_slots_available_percentage The percentage of available slots in decode batch
239+
# TYPE jetstream_slots_available_percentage gauge
240+
jetstream_slots_available_percentage{id="<SOME-HOSTNAME-HERE>",idx="0"} 0.96875
241+
```
242+
208243
## Step 6: Run benchmarks with JetStream MaxText server
209244

210245
Note: The JetStream MaxText Server is not running with quantization optimization in Step 3. To get best benchmark results, we need to enable quantization (Please use AQT trained or fine tuned checkpoints to ensure accuracy) for both weights and KV cache, please add the quantization flags and restart the server as following:
@@ -289,4 +324,4 @@ rm -rf maxtext
289324
rm -rf JetStream
290325
# Clean up python virtual environment
291326
rm -rf .env
292-
```
327+
```

jetstream/core/config_lib.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import dataclasses
1818
import functools
1919
from typing import Any, Callable, List, Tuple, Type
20+
from numpy import uint16
2021

2122
from jetstream.engine import engine_api
2223
from jetstream.engine import mock_engine
@@ -46,6 +47,11 @@ class InstantiatedEngines:
4647
interleaved_engines: List[engine_api.Engine]
4748

4849

50+
@dataclasses.dataclass
51+
class MetricsServerConfig:
52+
port: uint16
53+
54+
4955
# ▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼#
5056

5157

jetstream/core/metrics/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Contains common functions for configuring Jetstream server metrics"""
16+
17+
import os
18+
import shortuuid
19+
from prometheus_client import Gauge
20+
21+
22+
class JetstreamMetricsCollector:
23+
"""Wrapper class should be used to assure all metrics have proper tags"""
24+
25+
_id: str = os.getenv("HOSTNAME", shortuuid.uuid())
26+
27+
def __new__(cls):
28+
if not hasattr(cls, "instance"):
29+
cls.instance = super(JetstreamMetricsCollector, cls).__new__(cls)
30+
return cls.instance
31+
32+
# Metric definitions
33+
_prefill_backlog = Gauge(
34+
name="jetstream_prefill_backlog_size",
35+
documentation="Size of prefill queue",
36+
labelnames=["id"],
37+
)
38+
_slots_available_percentage = Gauge(
39+
name="jetstream_slots_available_percentage",
40+
documentation="The percentage of available slots in decode batch",
41+
labelnames=["id", "idx"],
42+
)
43+
44+
def get_prefill_backlog_metric(self):
45+
return self._prefill_backlog.labels(id=self._id)
46+
47+
def get_slots_available_percentage_metric(self, idx: int):
48+
return self._slots_available_percentage.labels(id=self._id, idx=idx)

jetstream/core/orchestrator.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,8 @@
9494
from jetstream.core.utils import async_multifuture
9595
from jetstream.core.utils.return_sample import ReturnSample
9696
from jetstream.engine import engine_api, tokenizer_api, token_utils
97+
from jetstream.core.metrics.prometheus import JetstreamMetricsCollector
9798
import numpy as np
98-
import prometheus_client
99-
import shortuuid
10099

101100
root = logging.getLogger()
102101
root.setLevel(logging.DEBUG)
@@ -212,8 +211,8 @@ class Driver:
212211
# todo: remove jax_padding after all then engine migrate to np padding
213212
_jax_padding = True
214213

215-
# Record metrics for prefill_backlog size
216-
_prefill_backlog_size_metric: prometheus_client.Gauge
214+
# All metrics we want to monitor should be collected with this
215+
_metrics_collector: JetstreamMetricsCollector | None = None
217216

218217
def __init__(
219218
self,
@@ -223,6 +222,7 @@ def __init__(
223222
generate_params: Optional[list[Any]] = None,
224223
interleaved_mode: bool = False,
225224
jax_padding: bool = True,
225+
metrics_collector: JetstreamMetricsCollector | None = None,
226226
):
227227
if prefill_engines is None:
228228
prefill_engines = []
@@ -243,15 +243,16 @@ def __init__(
243243
self._prefill_params = prefill_params
244244
self._generate_params = generate_params
245245
self._interleaved_mode = interleaved_mode
246+
self._metrics_collector = metrics_collector
246247

247248
# Stages 1-4 represent the life cycle of a request.
248249
# Stage 1
249250
# At first, a request is placed here in order to get prefilled.
250251
self._prefill_backlog = queue.Queue()
251-
self._prefill_backlog_size_metric = prometheus_client.Gauge(
252-
f"jetstream_prefill_backlog_size_{shortuuid.uuid()}",
253-
"Size of prefill queue",
254-
)
252+
if self._metrics_collector:
253+
self._metrics_collector.get_prefill_backlog_metric().set_function(
254+
lambda: float(self._prefill_backlog.qsize())
255+
)
255256

256257
# Stage 2
257258
# After prefilling, it is placed here in order to get transferred to
@@ -432,7 +433,6 @@ def place_request_on_prefill_queue(self, request: ActiveRequest):
432433
"""Used to place new requests for prefilling and generation."""
433434
# Don't block so we can fail and shed load when the queue is full.
434435
self._prefill_backlog.put(request, block=False)
435-
self._prefill_backlog_size_metric.set(self._prefill_backlog.qsize())
436436

437437
def _process_prefill_content(
438438
self,
@@ -474,7 +474,6 @@ def _prefill_thread(self, idx: int):
474474
my_transfer_backlog = self._transfer_backlogs[idx]
475475
# The prefill thread can just sleep until it has work to do.
476476
request = self._prefill_backlog.get(block=True)
477-
self._prefill_backlog_size_metric.set(self._prefill_backlog.qsize())
478477

479478
if request is None:
480479
break
@@ -579,6 +578,11 @@ def _generate_thread(self, idx: int):
579578

580579
max_concurrent_decodes = generate_engine.max_concurrent_decodes
581580

581+
if self._metrics_collector:
582+
self._metrics_collector.get_slots_available_percentage_metric(
583+
idx
584+
).set_function(lambda: float(my_slots.qsize() / max_concurrent_decodes))
585+
582586
# Check if there are any free my_slots. We don't want to block here since
583587
# we can still generate if we can't insert. We do this in a while loop to
584588
# insert as many sequences as possible.

jetstream/core/server_lib.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,19 @@
2020
import asyncio
2121
from concurrent import futures
2222
import logging
23-
import os
2423
import threading
2524
from typing import Any, Type
2625

2726
import grpc
2827
import jax
2928
from jetstream.core import config_lib
3029
from jetstream.core import orchestrator
30+
from jetstream.core.metrics.prometheus import JetstreamMetricsCollector
3131
from jetstream.core.proto import jetstream_pb2_grpc
3232

3333
from prometheus_client import start_http_server
3434

3535
_HOST = "[::]"
36-
PROMETHEUS_ENABLED_ON_PORT = (
37-
int(os.getenv("PROMETHEUS_ENABLED_ON_PORT"))
38-
if os.getenv("PROMETHEUS_ENABLED_ON_PORT")
39-
else None
40-
)
4136

4237

4338
class JetStreamServer:
@@ -99,6 +94,7 @@ def run(
9994
credentials: Any = grpc.insecure_server_credentials(),
10095
threads: int | None = None,
10196
jax_padding: bool = True,
97+
metrics_server_config: config_lib.MetricsServerConfig | None = None,
10298
) -> JetStreamServer:
10399
"""Runs a server with a specified config.
104100
@@ -122,13 +118,28 @@ def run(
122118
interleaved_mode = (
123119
len(config.prefill_slices) + len(config.generate_slices) == 0
124120
)
121+
122+
# Setup Prometheus server
123+
metrics_collector: JetstreamMetricsCollector = None
124+
if metrics_server_config and metrics_server_config.port:
125+
logging.info(
126+
"Starting Prometheus server on port %d", metrics_server_config.port
127+
)
128+
start_http_server(metrics_server_config.port)
129+
metrics_collector = JetstreamMetricsCollector()
130+
else:
131+
logging.info(
132+
"Not starting Prometheus server: --prometheus_port flag not set"
133+
)
134+
125135
driver = orchestrator.Driver(
126136
prefill_engines=engines.prefill_engines + engines.interleaved_engines,
127137
generate_engines=engines.generate_engines + engines.interleaved_engines,
128138
prefill_params=prefill_params + shared_params,
129139
generate_params=generate_params + shared_params,
130140
interleaved_mode=interleaved_mode,
131141
jax_padding=jax_padding,
142+
metrics_collector=metrics_collector,
132143
)
133144
# We default threads to the total number of concurrent allowed decodes,
134145
# to make sure we can fully saturate the model. Set default minimum to 64.
@@ -137,17 +148,6 @@ def run(
137148
logging.info("Starting server on port %d with %d threads", port, threads)
138149

139150
jetstream_server.start()
140-
141-
# Setup Prometheus server
142-
if PROMETHEUS_ENABLED_ON_PORT is not None:
143-
logging.info(
144-
"Starting Prometheus server on port %d", PROMETHEUS_ENABLED_ON_PORT
145-
)
146-
start_http_server(PROMETHEUS_ENABLED_ON_PORT)
147-
else:
148-
logging.info(
149-
"Not starting Prometheus server: PROMETHEUS_ENABLED_ON_PORT not set"
150-
)
151151
return jetstream_server
152152

153153

jetstream/tests/core/test_server.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from typing import Any, Type
2222
import unittest
2323

24+
25+
import requests
2426
from parameterized import parameterized
2527
import grpc
2628
from jetstream.core import config_lib
@@ -60,7 +62,10 @@ async def test_server(
6062
"""Sets up a server and requests token responses."""
6163
######################### Server side ######################################
6264
port = portpicker.pick_unused_port()
65+
metrics_port = portpicker.pick_unused_port()
66+
6367
print("port: " + str(port))
68+
print("metrics port: " + str(metrics_port))
6469
credentials = grpc.local_server_credentials()
6570

6671
server = server_lib.run(
@@ -70,12 +75,16 @@ async def test_server(
7075
credentials=credentials,
7176
)
7277
###################### Requester side ######################################
78+
79+
# prometheus not configured, assert no metrics collector on Driver
80+
assert server._driver._metrics_collector is None # pylint: disable=protected-access
81+
7382
async with grpc.aio.secure_channel(
7483
f"localhost:{port}", grpc.local_channel_credentials()
7584
) as channel:
7685
stub = jetstream_pb2_grpc.OrchestratorStub(channel)
7786

78-
# The string representation of np.array([[65, 66]]), [2] will be prependd
87+
# The string representation of np.array([[65, 66]]), [2] will be prepended
7988
# as BOS
8089
text = "AB"
8190
request = jetstream_pb2.DecodeRequest(
@@ -96,5 +105,25 @@ async def test_server(
96105
counter += 1
97106
server.stop()
98107

108+
# Now test server with prometheus config
109+
server = server_lib.run(
110+
port=port,
111+
config=config,
112+
devices=devices,
113+
credentials=credentials,
114+
metrics_server_config=config_lib.MetricsServerConfig(
115+
port=metrics_port
116+
),
117+
)
118+
# assert prometheus server is running and responding
119+
assert server._driver._metrics_collector is not None # pylint: disable=protected-access
120+
assert (
121+
requests.get(
122+
f"http://localhost:{metrics_port}", timeout=5
123+
).status_code
124+
== requests.status_codes.codes["ok"]
125+
)
126+
server.stop()
127+
99128
def test_get_devices(self):
100129
assert len(server_lib.get_devices()) == 1

0 commit comments

Comments
 (0)