Skip to content

Commit 2245876

Browse files
authored
Add model ckpt conversion and AQT scripts for JetStream MaxText Serving (#23)
* add model ckpt conversion scripts for JetStream MaxText Serving * Add description and device requirements * address nit * verification
1 parent 81beb11 commit 2245876

8 files changed

Lines changed: 221 additions & 2 deletions

File tree

README.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ JetStream is a fast library for LLM inference and serving on TPUs.
66

77
## Getting Started
88

9+
### Setup
10+
```
11+
pip install -r requirements.txt
12+
```
13+
914
### Run local server & Testing
1015

1116
Use the following commands to run a server locally:
@@ -14,10 +19,10 @@ Use the following commands to run a server locally:
1419
python -m jetstream.core.implementations.mock.server
1520
1621
# Test local mock server
17-
python -m jetstream.core.tools.requester
22+
python -m jetstream.tools.requester
1823
1924
# Load test local mock server
20-
python -m jetstream.core.tools.load_tester
25+
python -m jetstream.tools.load_tester
2126
2227
```
2328

benchmarks/eval_accuracy.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
import argparse
216
import nltk
317
import evaluate

jetstream/core/utils/async_multifuture.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
import asyncio
216
from concurrent import futures
317
import threading

jetstream/tests/engine/test_token_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
import os
216
import unittest
317
from typing import List
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#!/bin/bash
2+
# Copyright 2024 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
# This script will do the following:
18+
# - Create GCS buckets to store model artifacts for the JetStream Maxtext Inference demo.
19+
# - Convert the downloaded checkpoints to MaxText compatible checkpoints.
20+
# - Convert the MaxText compatible checkpoints to unscanned checkpoints for inference.
21+
# Device requirements:
22+
# - Both checkpoints conversion only requires CPU (with JAX CPU mode).
23+
set -ex
24+
25+
idx=$(date +%Y-%m-%d-%H-%M)
26+
# Modify the `MODEL` and `MODEL_VARIATION` based on the model you use.
27+
export MODEL=$1
28+
export MODEL_VARIATION=$2
29+
export MODEL_NAME=${MODEL}-${MODEL_VARIATION}
30+
31+
# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \
32+
# Please use seperate GCS paths for uploading open source model weights ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET).
33+
# Point these variables to a GCS bucket that you created.
34+
export CHKPT_BUCKET=gs://${USER}-maxtext/chkpt/${MODEL}/${MODEL_VARIATION}
35+
export MODEL_BUCKET=gs://${USER}-maxtext
36+
37+
# Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run.
38+
export BASE_OUTPUT_DIRECTORY=gs://${USER}-runner-maxtext-logs
39+
40+
# Point `DATASET_PATH` to the GCS bucket where you have your training data
41+
export DATASET_PATH=gs://${USER}-maxtext-dataset
42+
43+
export BUCKET_LOCATION=US
44+
45+
# Create three GCS buckets for the demo.
46+
gcloud storage buckets create ${MODEL_BUCKET} --location=${BUCKET_LOCATION} || true
47+
gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY} --location=${BUCKET_LOCATION} || true
48+
gcloud storage buckets create ${DATASET_PATH} --location=${BUCKET_LOCATION} || true
49+
50+
# Copy the downloaded checkpoints to `CHKPT_BUCKET`.
51+
# Gemma example: gsutil -m cp -r 7b ${CHKPT_BUCKET}
52+
# Llama2 example: gsutil -m cp -r llama-2-7b ${CHKPT_BUCKET}
53+
sudo gsutil -m cp -r $3 ${CHKPT_BUCKET}
54+
55+
# Covert model checkpoints to MaxText compatible checkpoints.
56+
if [ "$MODEL" == "gemma" ]; then
57+
CONVERT_CKPT_SCRIPT="convert_gemma_chkpt.py"
58+
JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \
59+
--base_model_path ${CHKPT_BUCKET} \
60+
--maxtext_model_path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \
61+
--model_size ${MODEL_VARIATION}
62+
else
63+
# We install torch CPU because the checkpoint conversion script MaxText/llama_or_mistral_ckpt.py does not need a TPU/GPU
64+
pip install torch --index-url https://download.pytorch.org/whl/cpu
65+
CONVERT_CKPT_SCRIPT="llama_or_mistral_ckpt.py"
66+
JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \
67+
--base-model-path ${CHKPT_BUCKET} \
68+
--maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \
69+
--model-size ${MODEL_VARIATION}
70+
fi
71+
echo "Written MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}"
72+
73+
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory.
74+
export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items
75+
76+
# Covert MaxText compatible checkpoints to unscanned checkpoints.
77+
# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
78+
export RUN_NAME=${MODEL_NAME}_unscanned_chkpt_${idx}
79+
80+
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \
81+
MaxText/configs/base.yml \
82+
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
83+
load_parameters_path=${CONVERTED_CHECKPOINT} \
84+
run_name=${RUN_NAME} \
85+
model_name=${MODEL_NAME} \
86+
force_unroll=true
87+
echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints"
88+
89+
# We will use the unscanned checkpoints by passing `UNSCANNED_CKPT_PATH` into `LOAD_PARAMETERS_PATH` in the following sections.
90+
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#!/bin/bash
2+
# Copyright 2024 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
# This script will do the following:
18+
# - Finetuning the MaxText compatible checkpoint (converted from original checkpoints) with AQT
19+
# - Convert the AQT-finetuned checkpoints to unscanned checkpoints for inference
20+
# TPU device requirements:
21+
# - For llama2-7b, it requires at least a v5e-8 TPU VM.
22+
# - For llama2-13B/70b, it requires a v4-128 TPU VM.
23+
set -ex
24+
25+
idx=$(date +%Y-%m-%d-%H-%M)
26+
# Modify the `MODEL` and `MODEL_VARIATION` based on the model you use.
27+
export MODEL=$1
28+
export MODEL_VARIATION=$2
29+
export MODEL_NAME=${MODEL}-${MODEL_VARIATION}
30+
31+
# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \
32+
# Please use seperate GCS paths for uploading open source model weights ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET).
33+
# Point these variables to a GCS bucket that you created.
34+
export CHKPT_BUCKET=gs://${USER}-maxtext/chkpt/${MODEL}/${MODEL_VARIATION}
35+
export MODEL_BUCKET=gs://${USER}-maxtext
36+
37+
# Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run.
38+
export BASE_OUTPUT_DIRECTORY=gs://${USER}-runner-maxtext-logs
39+
40+
# Point `DATASET_PATH` to the GCS bucket where you have your training data
41+
export DATASET_PATH=gs://${USER}-maxtext-dataset
42+
43+
# Prepare C4 dataset for fine tuning: https://github.com/allenai/allennlp/discussions/5056
44+
sudo gsutil -u $3 -m cp 'gs://allennlp-tensorflow-datasets/c4/en/3.0.1/*' ${DATASET_PATH}/c4/en/3.0.1/
45+
46+
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory.
47+
export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items
48+
49+
# Fine tune the converted model checkpoints with AQT.
50+
export RUN_NAME=finetune_aqt_${idx}
51+
52+
python3 MaxText/train.py \
53+
MaxText/configs/base.yml \
54+
run_name=${RUN_NAME} \
55+
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
56+
dataset_path=${DATASET_PATH} \
57+
steps=501 \
58+
enable_checkpointing=True \
59+
load_parameters_path=${CONVERTED_CHECKPOINT} \
60+
model_name=${MODEL_NAME} \
61+
per_device_batch_size=1 \
62+
quantization=int8 \
63+
checkpoint_period=100
64+
65+
# We will convert the `AQT_CKPT` to unscanned checkpoint in the next step.
66+
export AQT_CKPT=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/100/items
67+
68+
# Covert MaxText compatible AQT-fine-tuned checkpoints to unscanned checkpoints.
69+
# Note that the `AQT_CKPT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
70+
export RUN_NAME=${MODEL_NAME}_unscanned_chkpt_${idx}
71+
72+
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \
73+
MaxText/configs/base.yml \
74+
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
75+
load_parameters_path=${AQT_CKPT} \
76+
run_name=${RUN_NAME} \
77+
model_name=${MODEL_NAME} \
78+
force_unroll=true
79+
echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints"
80+
81+
# We will use the unscanned checkpoints by passing `UNSCANNED_CKPT_PATH` into `LOAD_PARAMETERS_PATH` in the following sections.
82+
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items

0 commit comments

Comments
 (0)