Skip to content

Latest commit

 

History

History
556 lines (467 loc) · 20.2 KB

File metadata and controls

556 lines (467 loc) · 20.2 KB

Run training workload with Ironwood and flex-start using Lustre storage

Create a cluster with flex-start provisioning

Before you begin

Before you start, complete the following steps:

  • Make sure that you have XPK and its prerequisites installed by following instructions found here.
  • Ensure you have a Google Cloud project with billing enabled.
  • Get access to TPU7x. For more information, contact your account team.
  • Ensure the account you're using with XPK has the roles listed in the XPK GitHub repository.

Create a single-NIC, single slice cluster

  1. Set the following environment variables:

    NOTE: For multi-host provisioning use an ACCELERATOR_TYPE with any topology that results to more than 8 chips, e.g. tpu7x-2x2x2 or tpu7x-16. For single-host provisioning use an ACCELERATOR_TYPE with any topology that results to 8 or less chips, e.g. tpu7x-2x2x1 or tpu7x-8.

    export PROJECT_ID=<project_id> # Your GCP project name
    export ZONE=<zone> # Example: us-central1-c
    export CLUSTER_NAME=<cluster_name> # Your cluster name
    # For a list of supported topologies, see: https://docs.cloud.google.com/tpu/docs/tpu/docs/tpu7x#configurations
    export ACCELERATOR_TYPE=<tpu_type> # Example:tpu7x-2x2x2
  2. Set up your network configuration.

    export NETWORK_NAME=<network_name> # Your network name
    export SUBNET_NAME=<subnet_name> # Your subnet name
    export NETWORK_FW_NAME=${NETWORK_NAME}-privatefirewall # Your firewall name
    export IP_RANGE=<ip_range> # Your IP range in CIDR notation, e.g. 10.0.0.0/24
    export REGION=${ZONE%-*}
    gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
        --subnet-mode=custom --bgp-routing-mode=regional
    gcloud compute networks subnets create ${SUBNET_NAME} --project=${PROJECT_ID} \
        --network=${NETWORK_NAME} --region=${REGION} --range=${IP_RANGE}
    gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
        --allow tcp,icmp,udp --project=${PROJECT_ID}
  3. Populate the ${CLUSTER_ARGUMENTS} variable, which you'll use in the xpk cluster create command:

    export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${SUBNET_NAME}"
  4. Create your GKE cluster with TPU7x node pools using the xpk cluster create command:

    xpk cluster create \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --cluster ${CLUSTER_NAME} \
        --cluster-cpu-machine-type=n1-standard-8 \
        --tpu-type=${ACCELERATOR_TYPE} \
        --flex \
        --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
        --enable-lustre-csi-driver

    Setting the --cluster-cpu-machine-type flag to n1-standard-8 (or larger) ensures that the default node pool has sufficient CPU for system pods, for example JobSet webhook, preventing errors. By default, XPK uses e2-standard-16. Some zones only support specific CPU types, so you might need to change between n1, n2, and e2 types. Otherwise, you might encounter quota errors.

  5. Add a maintenance exclusion to prevent upgrades for the cluster:

    # Your selected start time for the maintenance exclusion in
    # `YYYY-MM-DDTHH:MM:SSZ` format, e.g. "2025-11-24T00:00:00Z"
    export EXCLUSION_START_TIME=<exclusion_start_time>
    # Your selected end time for the maintenance exclusion in
    # `YYYY-MM-DDTHH:MM:SSZ` format, e.g. "2025-12-24T00:00:00Z"
    export EXCLUSION_END_TIME=<exclusion_end_time>
    gcloud container clusters update ${CLUSTER_NAME} \
        --region=${REGION} \
        --project=${PROJECT_ID} \
        --add-maintenance-exclusion-name="no-upgrade-next-month" \
        --add-maintenance-exclusion-start="${EXCLUSION_START_TIME}" \
        --add-maintenance-exclusion-end="${EXCLUSION_END_TIME}" \
        --add-maintenance-exclusion-scope="no_upgrades"

Run a workload

Option A: Mock training workload
  1. Download a fake training training script

    curl -o fake_training.py https://raw.githubusercontent.com/AI-Hypercomputer/xpk/refs/heads/main/examples/fake_training.py
  2. Run a mock training workload on the cluster.

    xpk workload create \
        --cluster ${CLUSTER_NAME} \
        --workload tf-mock-$(date +%H%M) \
        --tpu-type=${ACCELERATOR_TYPE} \
        --zone ${ZONE} \
        --project ${PROJECT_ID} \
        --command "python3 fake_training.py"
Option B: Training a generic model with MaxText
  1. Set up the networking needed for a Lustre storage by running the commands below:

    export IP_RANGE_NAME=<ip_range_name> # Your IP range name
    export FIREWALL_RULE_NAME=<fw_rule_name> # Your firewall rule name
    
    # a. enable service networking
    gcloud services enable servicenetworking.googleapis.com \
      --project=${PROJECT_ID}
    
    # b. Create an IP address range
    gcloud compute addresses create ${IP_RANGE_NAME} \
      --global \
      --purpose=VPC_PEERING \
      --prefix-length=20 \
      --description="Managed Lustre VPC Peering" \
      --network=${NETWORK_NAME} \
      --project=${PROJECT_ID}
    
    # c. Get the CIDR range of the IP address range
    CIDR_RANGE=$(
      gcloud compute addresses describe ${IP_RANGE_NAME} \
          --global  \
          --format="value[separator=/](address, prefixLength)" \
          --project=${PROJECT_ID}
    )
    
    # d. Create a firewall rule to allow TCP traffic from the IP address range
    gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
        --allow=tcp:988,tcp:6988 \
        --network=${NETWORK_NAME} \
        --source-ranges=${CIDR_RANGE} \
        --project=${PROJECT_ID}
    
    # e. Connect the peering (required IAM role: compute.networkAdmin or servicenetworking.networksAdmin role)
    gcloud services vpc-peerings connect \
        --network=${NETWORK_NAME} \
        --project=${PROJECT_ID} \
        --ranges=${IP_RANGE_NAME} \
        --service=servicenetworking.googleapis.com
  2. Create the Lustre storage by running the commands below:

    export STORAGE_NAME=<storage_name> # Your storage name
    export STORAGE_THROUGHPUT=1000
    export STORAGE_CAPACITY=18000
    export STORAGE_FS=lfs
    export LOCATION=${ZONE}
    
    gcloud lustre instances create ${STORAGE_NAME} \
      --per-unit-storage-throughput=${STORAGE_THROUGHPUT} \
      --capacity-gib=${STORAGE_CAPACITY} \
      --filesystem=${STORAGE_FS} \
      --location=${LOCATION} \
      --network=projects/${PROJECT_ID}/global/networks/${NETWORK_NAME} \
      --project=${PROJECT_ID}
  3. Get Lustre properties. Note the mountPoint property.

    gcloud lustre instances describe konradkaim-lustre --location=us-central1-c
  4. Prepare the Lustre manifest file, use the IP address part of the mountPoint from the command above.

    export VOLUME_IP=<volume_ip> # Should be equal to the mount point value from the previous command
    export VOLUME_HANDLE="${PROJECT_ID}/${ZONE}/${STORAGE_NAME}" # Your volume handle
    
    echo "apiVersion: v1
          kind: PersistentVolume
          metadata:
            name: xpk-lustre-pv
          spec:
            storageClassName: ""
            capacity:
              storage: 18000Gi
            accessModes:
              - ReadWriteMany
            persistentVolumeReclaimPolicy: Retain
            volumeMode: Filesystem
            claimRef:
              namespace: default
              name: xpk-lustre-pvc
            csi:
              driver: lustre.csi.storage.gke.io
              volumeHandle: ${VOLUME_HANDLE}
              volumeAttributes:
                ip: ${VOLUME_IP}
                filesystem: lfs
          ---
            kind: PersistentVolumeClaim
            apiVersion: v1
            metadata:
              name: xpk-lustre-pvc
            spec:
              accessModes:
                - ReadWriteMany
              storageClassName: ""
              volumeName: xpk-lustre-pv
              resources:
                requests:
                  storage: 18000Gi" > lustre-manifest-attach.yaml
  5. Attach the Lustre storage to your cluster by running the commands below:

    export BASE_OUTPUT_DIR="/lustre-data"
    xpk storage attach ${STORAGE_NAME} \
      --cluster=${CLUSTER_NAME} --project=${PROJECT_ID} --zone=${LOCATION} \
      --type=lustre \
      --mount-point=$BASE_OUTPUT_DIR \
      --readonly=false \
      --auto-mount=true \
      --manifest='./lustre-manifest-attach.yaml'
  6. Build or upload the MaxText Docker image. Note: MaxText supports Python 3.12 only. Build your virtual environment with 3.12 to install the correct dependencies.

    You can either build a Docker image locally using scripts provided by MaxText or use a prebuilt image. The following commands copy your local directory into the container:

    # Make sure you're running on a virtual environment with python3.12. If nothing is printed, you have the correct version.
    [[ "$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")' 2>/dev/null)" == "3.12" ]] || { >&2 echo "Error: Python version must be 3.12."; false; }
    # Clone MaxText
    git clone https://github.com/AI-Hypercomputer/maxtext.git
    cd maxtext
    git checkout maxtext-tutorial-v1.0.0
    # Custom Jax and LibTPU wheels
    pip install flax==0.12.0
    pip download libtpu==0.0.28.dev20251104+nightly -f "https://storage.googleapis.com/jax-releases/libtpu_releases.html"
    pip download --pre jax==0.8.1.dev20251104 jaxlib==0.8.1.dev20251104 --index https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
    # Build the Docker image
    bash docker_build_dependency_image.sh MODE=custom_wheels

    After the successful execution of the commands, you should see an image named maxtext_base_image created locally. You can use your local image directly in the xpk workload command.

  7. Run a MaxText workload on the cluster.

    export MAXTEXT_COMMAND="JAX_PLATFORMS=tpu,cpu \
      ENABLE_PJRT_COMPATIBILITY=true \
      python3 src/MaxText/train.py src/MaxText/configs/base.yml \
          base_output_directory=$BASE_OUTPUT_DIR \
          dataset_type=synthetic \
          per_device_batch_size=2 \
          enable_checkpointing=false \
          run_name=maxtext_xpk \
          steps=30"
    
    xpk workload create \
        --cluster ${CLUSTER_NAME} \
        --base-docker-image maxtext_base_image \
        --workload maxtext-1b-$(date +%H%M) \
        --tpu-type=${ACCELERATOR_TYPE} \
        --zone ${ZONE} \
        --project ${PROJECT_ID} \
        --command "${MAXTEXT_COMMAND}"
Option C: Training a Llama3.1 model with MaxText

NOTE: For Llama3.1-70b it is recommended that you use at least a 4x4x4 topology (i.e. 64 chips). If the cluster you created uses less chips, recreate the cluster with a larger topology before running the steps below.

  1. Set up the networking needed for a Lustre storage by running the commands below:

    export IP_RANGE_NAME=<ip_range_name> # Your IP range name
    export FIREWALL_RULE_NAME=<fw_rule_name> # Your firewall rule name
    
    # a. enable service networking
    gcloud services enable servicenetworking.googleapis.com \
      --project=${PROJECT_ID}
    
    # b. Create an IP address range
    gcloud compute addresses create ${IP_RANGE_NAME} \
      --global \
      --purpose=VPC_PEERING \
      --prefix-length=20 \
      --description="Managed Lustre VPC Peering" \
      --network=${NETWORK_NAME} \
      --project=${PROJECT_ID}
    
    # c. Get the CIDR range of the IP address range
    CIDR_RANGE=$(
      gcloud compute addresses describe ${IP_RANGE_NAME} \
          --global  \
          --format="value[separator=/](address, prefixLength)" \
          --project=${PROJECT_ID}
    )
    
    # d. Create a firewall rule to allow TCP traffic from the IP address range
    gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
        --allow=tcp:988,tcp:6988 \
        --network=${NETWORK_NAME} \
        --source-ranges=${CIDR_RANGE} \
        --project=${PROJECT_ID}
    
    # e. Connect the peering (required IAM role: compute.networkAdmin or servicenetworking.networksAdmin role)
    gcloud services vpc-peerings connect \
        --network=${NETWORK_NAME} \
        --project=${PROJECT_ID} \
        --ranges=${IP_RANGE_NAME} \
        --service=servicenetworking.googleapis.com
  2. Create the Lustre storage by running the commands below:

    export STORAGE_NAME=<storage_name> # Your storage name
    export STORAGE_THROUGHPUT=1000
    export STORAGE_CAPACITY=18000
    export STORAGE_FS=lfs
    export LOCATION=${ZONE}
    
    gcloud lustre instances create ${STORAGE_NAME} \
      --per-unit-storage-throughput=${STORAGE_THROUGHPUT} \
      --capacity-gib=${STORAGE_CAPACITY} \
      --filesystem=${STORAGE_FS} \
      --location=${LOCATION} \
      --network=projects/${PROJECT_ID}/global/networks/${NETWORK_NAME} \
      --project=${PROJECT_ID}
  3. Get Lustre properties. Note the mountPoint property.

    gcloud lustre instances describe konradkaim-lustre --location=us-central1-c
  4. Prepare the Lustre manifest file, use the IP address part of the mountPoint from the command above.

    export VOLUME_IP=<volume_ip> # Should be equal to the mount point value from the previous command
    export VOLUME_HANDLE="${PROJECT_ID}/${ZONE}/${STORAGE_NAME}" # Your volume handle
    
    echo "apiVersion: v1
          kind: PersistentVolume
          metadata:
            name: xpk-lustre-pv
          spec:
            storageClassName: ""
            capacity:
              storage: 18000Gi
            accessModes:
              - ReadWriteMany
            persistentVolumeReclaimPolicy: Retain
            volumeMode: Filesystem
            claimRef:
              namespace: default
              name: xpk-lustre-pvc
            csi:
              driver: lustre.csi.storage.gke.io
              volumeHandle: ${VOLUME_HANDLE}
              volumeAttributes:
                ip: ${VOLUME_IP}
                filesystem: lfs
          ---
            kind: PersistentVolumeClaim
            apiVersion: v1
            metadata:
              name: xpk-lustre-pvc
            spec:
              accessModes:
                - ReadWriteMany
              storageClassName: ""
              volumeName: xpk-lustre-pv
              resources:
                requests:
                  storage: 18000Gi" > lustre-manifest-attach.yaml
  5. Attach the Lustre storage to your cluster by running the commands below:

    export BASE_OUTPUT_DIR="/lustre-data"
    xpk storage attach ${STORAGE_NAME} \
      --cluster=${CLUSTER_NAME} --project=${PROJECT_ID} --zone=${LOCATION} \
      --type=lustre \
      --mount-point=$BASE_OUTPUT_DIR \
      --readonly=false \
      --auto-mount=true \
      --manifest='./lustre-manifest-attach.yaml'
  6. Build the Docker Image

    export CONTAINER_REGISTRY=<registry_name> # Initialize with your registry e.g. gcr.io
    export CLOUD_IMAGE_NAME="llama-maxtext-runner"
    export WORKLOAD_IMAGE="${CONTAINER_REGISTRY}/${PROJECT_ID}/${CLOUD_IMAGE_NAME}"
    # Make sure you're running on a Virtual Environment with python 3.12
    if [[ "$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")' 2>/dev/null)" == "3.12" ]]; then { echo You have the correct Python version 3.12; } else { >&2 echo Error: Python version must be 3.12; } fi
    # Clone MaxText Repository and Checkout Recipe Branch
    git clone https://github.com/AI-Hypercomputer/maxtext.git
    cd maxtext
    git checkout maxtext-tutorial-v1.3.0
    # Custom Jax and LibTPU wheels
    pip download libtpu==0.0.31.dev20251119+nightly -f"https://storage.googleapis.com/jax-releases/libtpu_releases.html"
    pip download --pre jax==0.8.1 jaxlib==0.8.1 --index https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
    # Build and upload the docker image
    bash dependencies/scripts/docker_build_dependency_image.sh MODE=custom_wheels
    bash dependencies/scripts/docker_upload_runner.sh CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME}
  7. Run the Llama 3.1 MaxText workload on the cluster.

    export WORKLOAD_NAME="$(printf "%.26s" "llama3-1-70b-8192-fp8-4x4x4")-$(date +%Y%m%d-%H%M)"
    export XLA_FLAGS=" \
      --xla_tpu_scoped_vmem_limit_kib=65536 \
      --xla_tpu_bf16_emission_mode=NATIVE_EMISSION \
      --xla_tpu_enable_sparse_core_reduce_scatter_v2=true \
      --xla_tpu_enable_sparse_core_collective_offload_all_gather=true \
      --xla_tpu_enable_sparse_core_collective_offload_2d_all_gather=true \
      --xla_tpu_enable_all_gather_offload_tracing=true \
      --xla_tpu_use_tc_device_shape_on_sc=True \
      --xla_sc_disable_megacore_partitioning=True \
      --xla_tpu_enable_async_collective_fusion_fuse_all_gather=false \
      --xla_enable_async_all_gather=true \
      --xla_tpu_prefer_async_allgather_to_allreduce=true \
      --xla_tpu_enable_sparse_core_collective_offload_all_reduce=true \
      --xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true \
      --xla_tpu_enable_sparse_core_collective_offload_3d_all_gather=true \
      --xla_tpu_use_single_sparse_core_for_all_gather_offload=true "
    
    export MAXTEXT_ARGS="\
      model_name=llama3.1-70b \
      skip_jax_distributed_system=True \
      dtype=bfloat16 \
      per_device_batch_size=2 \
      profile_periodically_period=10000 \
      async_checkpointing=False \
      enable_checkpointing=False \
      use_iota_embed=True \
      remat_policy=custom \
      decoder_layer_input=device \
      context=device \
      query_proj=device \
      key_proj=device \
      value_proj=device \
      ici_fsdp_parallelism=-1 \
      dataset_type=synthetic \
      opt_type=adamw \
      mu_dtype=bfloat16 \
      sa_block_q=2048 \
      sa_block_kv=1024 \
      sa_block_kv_compute=512 \
      sa_block_q_dkv=2048 \
      sa_block_kv_dkv=2048 \
      sa_block_kv_dkv_compute=256 \
      sa_q_layout=SEQ_MINOR \
      sa_k_layout=SEQ_MINOR \
      sa_v_layout=HEAD_DIM_MINOR \
      sa_use_fused_bwd_kernel=True \
      use_tokamax_splash=True \
      max_target_length=8192 \
      profiler=xplane \
      skip_first_n_steps_for_profiler=5 \
      profiler_steps=2 \
      attention=flash \
      quantization=fp8_full \
      use_qwix_quantization=True \
      steps=30 \
      base_output_directory=${BASE_OUTPUT_DIR} \
      run_name=${WORKLOAD_NAME}"
    
    xpk workload create \
      --cluster=${CLUSTER_NAME} \
      --project=${PROJECT_ID} \
      --zone=${ZONE} \
      --priority=very-high \
      --max-restarts=0 \
      --device-type=${ACCELERATOR_TYPE} \
      --num-slices=1 \
      --docker-image="${WORKLOAD_IMAGE}" \
      --enable-debug-logs \
      --workload="${WORKLOAD_NAME}" \
      --command="set -e && export ENABLE_PATHWAYS_PERSISTENCE='1' && \
    export LIBTPU_INIT_ARGS='${XLA_FLAGS}' && \
    export JAX_PLATFORMS='tpu,cpu' && export ENABLE_PJRT_COMPATIBILITY='true' && \
    python3 -m MaxText.train MaxText/configs/base.yml ${MAXTEXT_ARGS}"