|
| 1 | +#!/bin/bash |
| 2 | +# Copyright 2024 Google LLC |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | + |
| 17 | +# This script will do the following: |
| 18 | +# - Create GCS buckets to store model artifacts for the JetStream Maxtext Inference demo. |
| 19 | +# - Convert the downloaded checkpoints to MaxText compatible checkpoints. |
| 20 | +# - Convert the MaxText compatible checkpoints to unscanned checkpoints for inference. |
| 21 | +# Device requirements: |
| 22 | +# - Both checkpoints conversion only requires CPU (with JAX CPU mode). |
| 23 | +set -ex |
| 24 | + |
| 25 | +idx=$(date +%Y-%m-%d-%H-%M) |
| 26 | +# Modify the `MODEL` and `MODEL_VARIATION` based on the model you use. |
| 27 | +export MODEL=$1 |
| 28 | +export MODEL_VARIATION=$2 |
| 29 | +export MODEL_NAME=${MODEL}-${MODEL_VARIATION} |
| 30 | + |
| 31 | +# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ |
| 32 | +# Please use seperate GCS paths for uploading open source model weights ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). |
| 33 | +# Point these variables to a GCS bucket that you created. |
| 34 | +export CHKPT_BUCKET=gs://${USER}-maxtext/chkpt/${MODEL}/${MODEL_VARIATION} |
| 35 | +export MODEL_BUCKET=gs://${USER}-maxtext |
| 36 | + |
| 37 | +# Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run. |
| 38 | +export BASE_OUTPUT_DIRECTORY=gs://${USER}-runner-maxtext-logs |
| 39 | + |
| 40 | +# Point `DATASET_PATH` to the GCS bucket where you have your training data |
| 41 | +export DATASET_PATH=gs://${USER}-maxtext-dataset |
| 42 | + |
| 43 | +export BUCKET_LOCATION=US |
| 44 | + |
| 45 | +# Create three GCS buckets for the demo. |
| 46 | +gcloud storage buckets create ${MODEL_BUCKET} --location=${BUCKET_LOCATION} || true |
| 47 | +gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY} --location=${BUCKET_LOCATION} || true |
| 48 | +gcloud storage buckets create ${DATASET_PATH} --location=${BUCKET_LOCATION} || true |
| 49 | + |
| 50 | +# Copy the downloaded checkpoints to `CHKPT_BUCKET`. |
| 51 | +# Gemma example: gsutil -m cp -r 7b ${CHKPT_BUCKET} |
| 52 | +# Llama2 example: gsutil -m cp -r llama-2-7b ${CHKPT_BUCKET} |
| 53 | +sudo gsutil -m cp -r $3 ${CHKPT_BUCKET} |
| 54 | + |
| 55 | +# Covert model checkpoints to MaxText compatible checkpoints. |
| 56 | +if [ "$MODEL" == "gemma" ]; then |
| 57 | + CONVERT_CKPT_SCRIPT="convert_gemma_chkpt.py" |
| 58 | + JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \ |
| 59 | + --base_model_path ${CHKPT_BUCKET} \ |
| 60 | + --maxtext_model_path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \ |
| 61 | + --model_size ${MODEL_VARIATION} |
| 62 | +else |
| 63 | + # We install torch CPU because the checkpoint conversion script MaxText/llama_or_mistral_ckpt.py does not need a TPU/GPU |
| 64 | + pip install torch --index-url https://download.pytorch.org/whl/cpu |
| 65 | + CONVERT_CKPT_SCRIPT="llama_or_mistral_ckpt.py" |
| 66 | + JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \ |
| 67 | + --base-model-path ${CHKPT_BUCKET} \ |
| 68 | + --maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \ |
| 69 | + --model-size ${MODEL_VARIATION} |
| 70 | +fi |
| 71 | +echo "Written MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}" |
| 72 | + |
| 73 | +# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. |
| 74 | +export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items |
| 75 | + |
| 76 | +# Covert MaxText compatible checkpoints to unscanned checkpoints. |
| 77 | +# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. |
| 78 | +export RUN_NAME=${MODEL_NAME}_unscanned_chkpt_${idx} |
| 79 | + |
| 80 | +JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \ |
| 81 | +MaxText/configs/base.yml \ |
| 82 | +base_output_directory=${BASE_OUTPUT_DIRECTORY} \ |
| 83 | +load_parameters_path=${CONVERTED_CHECKPOINT} \ |
| 84 | +run_name=${RUN_NAME} \ |
| 85 | +model_name=${MODEL_NAME} \ |
| 86 | +force_unroll=true |
| 87 | +echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints" |
| 88 | + |
| 89 | +# We will use the unscanned checkpoints by passing `UNSCANNED_CKPT_PATH` into `LOAD_PARAMETERS_PATH` in the following sections. |
| 90 | +export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items |
0 commit comments