Skip to content

Commit eaf0d6e

Browse files
authored
Add profiling support and update docs (#85)
* Add profiling support and update docs * docs format * pylint * fix test * configurable port
1 parent 5d1e317 commit eaf0d6e

6 files changed

Lines changed: 158 additions & 57 deletions

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ Currently, there are two reference engine implementations available -- one for J
2424
- README: https://github.com/google/jetstream-pytorch/blob/main/README.md
2525

2626
## Documentation
27-
- [Online Inference with MaxText on v5e Cloud TPU VM](https://cloud.google.com/tpu/docs/tutorials/LLM/jetstream) [[README](#jetstream-maxtext-inference-on-v5e-cloud-tpu-vm-user-guide)]
27+
- [Online Inference with MaxText on v5e Cloud TPU VM](https://cloud.google.com/tpu/docs/tutorials/LLM/jetstream) [[README](https://github.com/google/JetStream/blob/main/docs/online-inference-with-maxtext-engine.md)]
2828
- [Online Inference with Pytorch on v5e Cloud TPU VM](https://cloud.google.com/tpu/docs/tutorials/LLM/jetstream-pytorch) [[README](https://github.com/google/jetstream-pytorch/tree/main?tab=readme-ov-file#jetstream-pytorch)]
2929
- [Serve Gemma using TPUs on GKE with JetStream](https://cloud.google.com/kubernetes-engine/docs/tutorials/serve-gemma-tpu-jetstream)
30+
- [Observability in JetStream Server](https://github.com/google/JetStream/blob/main/docs/observability-prometheus-metrics-in-jetstream-server.md)
31+
- [Profiling in JetStream Server](https://github.com/google/JetStream/blob/main/docs/profiling-with-jax-profiler-and-tensorboard.md)
3032
- [JetStream Standalone Local Setup](#jetstream-standalone-local-setup)
3133

3234

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Observability in JetStream Server
2+
3+
In JetStream Server, we use [Prometheus](https://prometheus.io/docs/introduction/overview/) to collect key metrics within JetStream orchestrator and engines. We implemented a [Prometheus client server](https://prometheus.github.io/client_python/exporting/http/) in JetStream `server_lib.py` and use `MetricsServerConfig` (by passing `prometheus_port` in server entrypoint) to gaurd the metrics observability feature.
4+
5+
## Enable Prometheus server to observe Jetstream metrics
6+
7+
Metrics are not exported by default, here is an example to run JetStream MaxText server with metrics observability:
8+
9+
```bash
10+
# Refer to JetStream MaxText User Guide for the following server config.
11+
export TOKENIZER_PATH=assets/tokenizer.gemma
12+
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
13+
export MAX_PREFILL_PREDICT_LENGTH=1024
14+
export MAX_TARGET_LENGTH=2048
15+
export MODEL_NAME=gemma-7b
16+
export ICI_FSDP_PARALLELISM=1
17+
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
18+
export ICI_TENSOR_PARALLELISM=1
19+
export SCAN_LAYERS=false
20+
export WEIGHT_DTYPE=bfloat16
21+
export PER_DEVICE_BATCH_SIZE=11
22+
# Set PROMETHEUS_PORT to enable Prometheus metrics.
23+
export PROMETHEUS_PORT=9090
24+
25+
cd ~/maxtext
26+
python MaxText/maxengine_server.py \
27+
MaxText/configs/base.yml \
28+
tokenizer_path=${TOKENIZER_PATH} \
29+
load_parameters_path=${LOAD_PARAMETERS_PATH} \
30+
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
31+
max_target_length=${MAX_TARGET_LENGTH} \
32+
model_name=${MODEL_NAME} \
33+
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
34+
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
35+
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
36+
scan_layers=${SCAN_LAYERS} \
37+
weight_dtype=${WEIGHT_DTYPE} \
38+
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
39+
prometheus_port=${PROMETHEUS_PORT}
40+
```
41+
42+
Now that we configured `prometheus_port=9090` above, we can observe various Jetstream metrics via HTTP requests to `0.0.0.0:9000`. Towards the end, the response should have content similar to the following:
43+
44+
```
45+
# HELP jetstream_prefill_backlog_size Size of prefill queue
46+
# TYPE jetstream_prefill_backlog_size gauge
47+
jetstream_prefill_backlog_size{id="SOME-HOSTNAME-HERE>"} 0.0
48+
# HELP jetstream_slots_available_percentage The percentage of available slots in decode batch
49+
# TYPE jetstream_slots_available_percentage gauge
50+
jetstream_slots_available_percentage{id="<SOME-HOSTNAME-HERE>",idx="0"} 0.96875
51+
```

docs/online-inference-with-maxtext-engine.md

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -205,41 +205,6 @@ Prompt: Today is a good day
205205
Response: to be a fan
206206
```
207207

208-
### (optional) Observe Jetstream metrics
209-
210-
Metrics are not exported by default, to configure Jetstream to emit metrics start this guide again from step four and replace the `Run the following command to start the JetStream MaxText server` step with the following:
211-
212-
```bash
213-
export PROMETHEUS_PORT=9090
214-
215-
cd ~/maxtext
216-
python MaxText/maxengine_server.py \
217-
MaxText/configs/base.yml \
218-
tokenizer_path=${TOKENIZER_PATH} \
219-
load_parameters_path=${LOAD_PARAMETERS_PATH} \
220-
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
221-
max_target_length=${MAX_TARGET_LENGTH} \
222-
model_name=${MODEL_NAME} \
223-
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
224-
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
225-
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
226-
scan_layers=${SCAN_LAYERS} \
227-
weight_dtype=${WEIGHT_DTYPE} \
228-
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
229-
prometheus_port=${PROMETHEUS_PORT}
230-
```
231-
232-
Now that we configured `prometheus_port=9090` above, we can observe various Jetstream metrics via HTTP requests to `0.0.0.0:9000`. Towards the end, the response should have content similar to the following:
233-
234-
```
235-
# HELP jetstream_prefill_backlog_size Size of prefill queue
236-
# TYPE jetstream_prefill_backlog_size gauge
237-
jetstream_prefill_backlog_size{id="SOME-HOSTNAME-HERE>"} 0.0
238-
# HELP jetstream_slots_available_percentage The percentage of available slots in decode batch
239-
# TYPE jetstream_slots_available_percentage gauge
240-
jetstream_slots_available_percentage{id="<SOME-HOSTNAME-HERE>",idx="0"} 0.96875
241-
```
242-
243208
## Step 6: Run benchmarks with JetStream MaxText server
244209

245210
Note: The JetStream MaxText Server is not running with quantization optimization in Step 3. To get best benchmark results, we need to enable quantization (Please use AQT trained or fine tuned checkpoints to ensure accuracy) for both weights and KV cache, please add the quantization flags and restart the server as following:
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Profiling in JetStream Server
2+
3+
In JetStream server, we have implemented JAX profiler server to support profiling JAX program with tensorboard.
4+
5+
## Profiling with JAX profiler server and tenorboard server
6+
7+
Following the [JAX official manual profiling approach](https://jax.readthedocs.io/en/latest/profiling.html#manual-capture-via-tensorboard), here is an example of JetStream MaxText server profiling with tensorboard:
8+
9+
1. Start a TensorBoard server:
10+
```bash
11+
tensorboard --logdir /tmp/tensorboard/
12+
```
13+
You should be able to load TensorBoard at http://localhost:6006/. You can specify a different port with the `--port` flag.
14+
15+
2. Start JetStream MaxText server:
16+
```bash
17+
# Refer to JetStream MaxText User Guide for the following server config.
18+
export TOKENIZER_PATH=assets/tokenizer.gemma
19+
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
20+
export MAX_PREFILL_PREDICT_LENGTH=1024
21+
export MAX_TARGET_LENGTH=2048
22+
export MODEL_NAME=gemma-7b
23+
export ICI_FSDP_PARALLELISM=1
24+
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
25+
export ICI_TENSOR_PARALLELISM=1
26+
export SCAN_LAYERS=false
27+
export WEIGHT_DTYPE=bfloat16
28+
export PER_DEVICE_BATCH_SIZE=11
29+
# Set ENABLE_JAX_PROFILER to enable JAX profiler server at port 9999.
30+
export ENABLE_JAX_PROFILER=true
31+
# Set JAX_PROFILER_PORT to customize JAX profiler server port.
32+
export JAX_PROFILER_PORT=9999
33+
34+
cd ~/maxtext
35+
python MaxText/maxengine_server.py \
36+
MaxText/configs/base.yml \
37+
tokenizer_path=${TOKENIZER_PATH} \
38+
load_parameters_path=${LOAD_PARAMETERS_PATH} \
39+
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
40+
max_target_length=${MAX_TARGET_LENGTH} \
41+
model_name=${MODEL_NAME} \
42+
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
43+
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
44+
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
45+
scan_layers=${SCAN_LAYERS} \
46+
weight_dtype=${WEIGHT_DTYPE} \
47+
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
48+
enable_jax_profiler=${ENABLE_JAX_PROFILER} \
49+
jax_profiler_port=${JAX_PROFILER_PORT}
50+
```
51+
52+
3. Open http://localhost:6006/#profile, and click the “CAPTURE PROFILE” button in the upper left. Enter “localhost:9999” as the profile service URL (this is the address of the profiler server you started in the previous step). Enter the number of milliseconds you’d like to profile for, and click “CAPTURE”.
53+
54+
4. After the capture finishes, TensorBoard should automatically refresh. (Not all of the TensorBoard profiling features are hooked up with JAX, so it may initially look like nothing was captured.) On the left under “Tools”, select `trace_viewer`.

jetstream/core/server_lib.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def run(
9595
threads: int | None = None,
9696
jax_padding: bool = True,
9797
metrics_server_config: config_lib.MetricsServerConfig | None = None,
98+
enable_jax_profiler: bool = False,
99+
jax_profiler_port: int = 9999,
98100
) -> JetStreamServer:
99101
"""Runs a server with a specified config.
100102
@@ -105,6 +107,10 @@ def run(
105107
credentials: Should use grpc credentials by default.
106108
threads: Number of RPC handlers worker threads. This should be at least
107109
equal to the decoding batch size to fully saturate the decoding queue.
110+
jax_padding: The flag to enable JAX padding during tokenization.
111+
metrics_server_config: The config to enable Promethus metric server.
112+
enable_jax_profiler: The flag to enable JAX profiler server.
113+
jax_profiler_port: The port JAX profiler server (default to 9999).
108114
109115
Returns:
110116
JetStreamServer that wraps the grpc server and orchestrator driver.
@@ -148,6 +154,13 @@ def run(
148154
logging.info("Starting server on port %d with %d threads", port, threads)
149155

150156
jetstream_server.start()
157+
158+
# Setup Jax Profiler
159+
if enable_jax_profiler:
160+
logging.info("Starting JAX profiler server on port %s", jax_profiler_port)
161+
jax.profiler.start_server(jax_profiler_port)
162+
else:
163+
logging.info("Not starting JAX profiler server: %s", enable_jax_profiler)
151164
return jetstream_server
152165

153166

jetstream/tests/core/test_server.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,8 @@ async def test_server(
6262
"""Sets up a server and requests token responses."""
6363
######################### Server side ######################################
6464
port = portpicker.pick_unused_port()
65-
metrics_port = portpicker.pick_unused_port()
6665

6766
print("port: " + str(port))
68-
print("metrics port: " + str(metrics_port))
6967
credentials = grpc.local_server_credentials()
7068

7169
server = server_lib.run(
@@ -105,25 +103,43 @@ async def test_server(
105103
counter += 1
106104
server.stop()
107105

108-
# Now test server with prometheus config
109-
server = server_lib.run(
110-
port=port,
111-
config=config,
112-
devices=devices,
113-
credentials=credentials,
114-
metrics_server_config=config_lib.MetricsServerConfig(
115-
port=metrics_port
116-
),
117-
)
118-
# assert prometheus server is running and responding
119-
assert server._driver._metrics_collector is not None # pylint: disable=protected-access
120-
assert (
121-
requests.get(
122-
f"http://localhost:{metrics_port}", timeout=5
123-
).status_code
124-
== requests.status_codes.codes["ok"]
125-
)
126-
server.stop()
106+
def test_prometheus_server(self):
107+
port = portpicker.pick_unused_port()
108+
metrics_port = portpicker.pick_unused_port()
109+
110+
print("port: " + str(port))
111+
print("metrics port: " + str(metrics_port))
112+
credentials = grpc.local_server_credentials()
113+
# Now test server with prometheus config
114+
server = server_lib.run(
115+
port=port,
116+
config=config_lib.InterleavedCPUTestServer,
117+
devices=[None],
118+
credentials=credentials,
119+
metrics_server_config=config_lib.MetricsServerConfig(port=metrics_port),
120+
)
121+
# assert prometheus server is running and responding
122+
assert server._driver._metrics_collector is not None # pylint: disable=protected-access
123+
assert (
124+
requests.get(f"http://localhost:{metrics_port}", timeout=5).status_code
125+
== requests.status_codes.codes["ok"]
126+
)
127+
server.stop()
128+
129+
def test_jax_profiler_server(self):
130+
port = portpicker.pick_unused_port()
131+
print("port: " + str(port))
132+
credentials = grpc.local_server_credentials()
133+
# Now test server with prometheus config
134+
server = server_lib.run(
135+
port=port,
136+
config=config_lib.InterleavedCPUTestServer,
137+
devices=[None],
138+
credentials=credentials,
139+
enable_jax_profiler=True,
140+
)
141+
assert server
142+
server.stop()
127143

128144
def test_get_devices(self):
129145
assert len(server_lib.get_devices()) == 1

0 commit comments

Comments
 (0)