Before you start, complete the following steps:
- Make sure that you have XPK and its prerequisites installed by following instructions found here.
- Ensure you have a Google Cloud project with billing enabled.
- Get access to TPU7x. For more information, contact your account team.
- Ensure the account you're using with XPK has the roles listed in the XPK GitHub repository.
-
Set the following environment variables:
NOTE: For multi-host provisioning use an ACCELERATOR_TYPE with any topology that results to more than 8 chips, e.g.
tpu7x-2x2x2ortpu7x-16. For single-host provisioning use an ACCELERATOR_TYPE with any topology that results to 8 or less chips, e.g.tpu7x-2x2x1ortpu7x-8.export PROJECT_ID=<project_id> # Your GCP project name export ZONE=<zone> # Example: us-central1-c export CLUSTER_NAME=<cluster_name> # Your cluster name # For a list of supported topologies, see: https://docs.cloud.google.com/tpu/docs/tpu/docs/tpu7x#configurations export ACCELERATOR_TYPE=<tpu_type> # Example:tpu7x-2x2x2
-
Set up your network configuration.
export NETWORK_NAME=<network_name> # Your network name export SUBNET_NAME=<subnet_name> # Your subnet name export NETWORK_FW_NAME=${NETWORK_NAME}-privatefirewall # Your firewall name export IP_RANGE=<ip_range> # Your IP range in CIDR notation, e.g. 10.0.0.0/24 export REGION=${ZONE%-*} gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \ --subnet-mode=custom --bgp-routing-mode=regional gcloud compute networks subnets create ${SUBNET_NAME} --project=${PROJECT_ID} \ --network=${NETWORK_NAME} --region=${REGION} --range=${IP_RANGE} gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \ --allow tcp,icmp,udp --project=${PROJECT_ID}
-
Populate the
${CLUSTER_ARGUMENTS}variable, which you'll use in thexpk cluster createcommand:export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${SUBNET_NAME}"
-
Create your GKE cluster with TPU7x node pools using the
xpk cluster createcommand:xpk cluster create \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --cluster ${CLUSTER_NAME} \ --cluster-cpu-machine-type=n1-standard-8 \ --tpu-type=${ACCELERATOR_TYPE} \ --flex \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}"Setting the
--cluster-cpu-machine-typeflag ton1-standard-8(or larger) ensures that the default node pool has sufficient CPU for system pods, for example JobSet webhook, preventing errors. By default, XPK usese2-standard-16. Some zones only support specific CPU types, so you might need to change betweenn1,n2, ande2types. Otherwise, you might encounter quota errors. -
Add a maintenance exclusion to prevent upgrades for the cluster:
# Your selected start time for the maintenance exclusion in # `YYYY-MM-DDTHH:MM:SSZ` format, e.g. "2025-11-24T00:00:00Z" export EXCLUSION_START_TIME=<exclusion_start_time> # Your selected end time for the maintenance exclusion in # `YYYY-MM-DDTHH:MM:SSZ` format, e.g. "2025-12-24T00:00:00Z" export EXCLUSION_END_TIME=<exclusion_end_time>
gcloud container clusters update ${CLUSTER_NAME} \ --region=${REGION} \ --project=${PROJECT_ID} \ --add-maintenance-exclusion-name="no-upgrade-next-month" \ --add-maintenance-exclusion-start="${EXCLUSION_START_TIME}" \ --add-maintenance-exclusion-end="${EXCLUSION_END_TIME}" \ --add-maintenance-exclusion-scope="no_upgrades"
Option A: Mock training workload
-
Download a fake training training script
curl -o fake_training.py https://raw.githubusercontent.com/AI-Hypercomputer/xpk/refs/heads/main/examples/fake_training.py
-
Run a mock training workload on the cluster.
xpk workload create \ --cluster ${CLUSTER_NAME} \ --workload tf-mock-$(date +%H%M) \ --tpu-type=${ACCELERATOR_TYPE} \ --zone ${ZONE} \ --project ${PROJECT_ID} \ --command "python3 fake_training.py"
Option B: Training a generic model with MaxText
-
Create a Filestore storage by running the commands below:
export STORAGE_NAME=<storage_name> # Your storage name xpk storage create ${STORAGE_NAME} --type=gcpfilestore \ --auto-mount=false --mount-point=/data-fs --readonly=false \ --size=1024 --tier=BASIC_HDD --vol=default \ --project=${PROJECT_ID} --cluster=${CLUSTER_NAME} --zone=${ZONE}
-
Attach the Filestore storage to your cluster by running the commands below:
export BASE_OUTPUT_DIR="/data-fs" xpk storage attach ${STORAGE_NAME} --cluster=${CLUSTER_NAME} --zone=${ZONE} \ --project=${PROJECT_ID} --type=gcpfilestore --auto-mount=true \ --vol=default --mount-point=/data-fs --readonly=false
-
Build or upload the MaxText Docker image. Note: MaxText supports Python 3.12 only. Build your virtual environment with 3.12 to install the correct dependencies.
You can either build a Docker image locally using scripts provided by MaxText or use a prebuilt image. The following commands copy your local directory into the container:
# Make sure you're running on a virtual environment with python3.12. If nothing is printed, you have the correct version. [[ "$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")' 2>/dev/null)" == "3.12" ]] || { >&2 echo "Error: Python version must be 3.12."; false; }
# Clone MaxText git clone https://github.com/AI-Hypercomputer/maxtext.git cd maxtext git checkout maxtext-tutorial-v1.0.0
# Custom Jax and LibTPU wheels pip install flax==0.12.0 pip download libtpu==0.0.28.dev20251104+nightly -f "https://storage.googleapis.com/jax-releases/libtpu_releases.html" pip download --pre jax==0.8.1.dev20251104 jaxlib==0.8.1.dev20251104 --index https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
# Build the Docker image bash docker_build_dependency_image.sh MODE=custom_wheelsAfter the successful execution of the commands, you should see an image named
maxtext_base_imagecreated locally. You can use your local image directly in the xpk workload command. -
Run a MaxText workload on the cluster.
export MAXTEXT_COMMAND="JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ python3 src/MaxText/train.py src/MaxText/configs/base.yml \ base_output_directory=$BASE_OUTPUT_DIR \ dataset_type=synthetic \ per_device_batch_size=2 \ enable_checkpointing=false \ gcs_metrics=true \ run_name=maxtext_xpk \ steps=30" xpk workload create \ --cluster ${CLUSTER_NAME} \ --base-docker-image maxtext_base_image \ --workload maxtext-1b-$(date +%H%M) \ --tpu-type=${ACCELERATOR_TYPE} \ --zone ${ZONE} \ --project ${PROJECT_ID} \ --command "${MAXTEXT_COMMAND}"
Option C: Training a Llama3.1 model with MaxText
NOTE: For Llama3.1-70b it is recommended that you use at least a 4x4x4 topology (i.e. 64 chips). If the cluster you created uses less chips, recreate the cluster with a larger topology before running the steps below.
-
Create a Filestore storage by running the commands below:
export STORAGE_NAME=<storage_name> # Your storage name xpk storage create ${STORAGE_NAME} --type=gcpfilestore \ --auto-mount=false --mount-point=/data-fs --readonly=false \ --size=1024 --tier=BASIC_HDD --vol=default \ --project=${PROJECT_ID} --cluster=${CLUSTER_NAME} --zone=${ZONE}
-
Attach the Filestore storage to your cluster by running the commands below:
export BASE_OUTPUT_DIR="/data-fs" xpk storage attach ${STORAGE_NAME} --cluster=${CLUSTER_NAME} --zone=${ZONE} \ --project=${PROJECT_ID} --type=gcpfilestore --auto-mount=true \ --vol=default --mount-point=/data-fs --readonly=false
-
Build the Docker Image
export CONTAINER_REGISTRY=<registry_name> # Initialize with your registry e.g. gcr.io export CLOUD_IMAGE_NAME="llama-maxtext-runner" export WORKLOAD_IMAGE="${CONTAINER_REGISTRY}/${PROJECT_ID}/${CLOUD_IMAGE_NAME}"
# Make sure you're running on a Virtual Environment with python 3.12 if [[ "$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")' 2>/dev/null)" == "3.12" ]]; then { echo You have the correct Python version 3.12; } else { >&2 echo Error: Python version must be 3.12; } fi
# Clone MaxText Repository and Checkout Recipe Branch git clone https://github.com/AI-Hypercomputer/maxtext.git cd maxtext git checkout maxtext-tutorial-v1.3.0
# Custom Jax and LibTPU wheels pip download libtpu==0.0.31.dev20251119+nightly -f"https://storage.googleapis.com/jax-releases/libtpu_releases.html" pip download --pre jax==0.8.1 jaxlib==0.8.1 --index https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
# Build and upload the docker image bash dependencies/scripts/docker_build_dependency_image.sh MODE=custom_wheels bash dependencies/scripts/docker_upload_runner.sh CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME}
-
Run the Llama 3.1 MaxText workload on the cluster.
export WORKLOAD_NAME="$(printf "%.26s" "llama3-1-70b-8192-fp8-4x4x4")-$(date +%Y%m%d-%H%M)" export XLA_FLAGS=" \ --xla_tpu_scoped_vmem_limit_kib=65536 \ --xla_tpu_bf16_emission_mode=NATIVE_EMISSION \ --xla_tpu_enable_sparse_core_reduce_scatter_v2=true \ --xla_tpu_enable_sparse_core_collective_offload_all_gather=true \ --xla_tpu_enable_sparse_core_collective_offload_2d_all_gather=true \ --xla_tpu_enable_all_gather_offload_tracing=true \ --xla_tpu_use_tc_device_shape_on_sc=True \ --xla_sc_disable_megacore_partitioning=True \ --xla_tpu_enable_async_collective_fusion_fuse_all_gather=false \ --xla_enable_async_all_gather=true \ --xla_tpu_prefer_async_allgather_to_allreduce=true \ --xla_tpu_enable_sparse_core_collective_offload_all_reduce=true \ --xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true \ --xla_tpu_enable_sparse_core_collective_offload_3d_all_gather=true \ --xla_tpu_use_single_sparse_core_for_all_gather_offload=true " export MAXTEXT_ARGS="\ model_name=llama3.1-70b \ skip_jax_distributed_system=True \ dtype=bfloat16 \ per_device_batch_size=2 \ profile_periodically_period=10000 \ async_checkpointing=False \ enable_checkpointing=False \ use_iota_embed=True \ remat_policy=custom \ decoder_layer_input=device \ context=device \ query_proj=device \ key_proj=device \ value_proj=device \ ici_fsdp_parallelism=-1 \ dataset_type=synthetic \ opt_type=adamw \ mu_dtype=bfloat16 \ sa_block_q=2048 \ sa_block_kv=1024 \ sa_block_kv_compute=512 \ sa_block_q_dkv=2048 \ sa_block_kv_dkv=2048 \ sa_block_kv_dkv_compute=256 \ sa_q_layout=SEQ_MINOR \ sa_k_layout=SEQ_MINOR \ sa_v_layout=HEAD_DIM_MINOR \ sa_use_fused_bwd_kernel=True \ use_tokamax_splash=True \ max_target_length=8192 \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ profiler_steps=2 \ attention=flash \ quantization=fp8_full \ use_qwix_quantization=True \ steps=30 \ base_output_directory=${BASE_OUTPUT_DIR} \ run_name=${WORKLOAD_NAME}" xpk workload create \ --cluster=${CLUSTER_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --priority=very-high \ --max-restarts=0 \ --device-type=${ACCELERATOR_TYPE} \ --num-slices=1 \ --docker-image="${WORKLOAD_IMAGE}" \ --enable-debug-logs \ --workload="${WORKLOAD_NAME}" \ --command="set -e && export ENABLE_PATHWAYS_PERSISTENCE='1' && \ export LIBTPU_INIT_ARGS='${XLA_FLAGS}' && \ export JAX_PLATFORMS='tpu,cpu' && export ENABLE_PJRT_COMPATIBILITY='true' && \ python3 -m MaxText.train MaxText/configs/base.yml ${MAXTEXT_ARGS}"