Skip to content

Commit 0b4f346

Browse files
committed
Add gpu image creation pipeline
Signed-off-by: Kunjan <kunjanp@google.com>
1 parent c9254be commit 0b4f346

3 files changed

Lines changed: 19 additions & 4 deletions

File tree

.github/workflows/UploadDockerImages.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,16 @@ jobs:
3535
- name: build maxdiffusion jax nightly image
3636
run: |
3737
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly
38+
39+
build-gpu-image:
40+
runs-on: ["self-hosted", "e2", "cpu"]
41+
steps:
42+
- uses: actions/checkout@v3
43+
- name: Cleanup old docker images
44+
run: docker system prune --all --force
45+
- name: build maxdiffusion jax stable stack gpu image
46+
run: |
47+
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu MODE=stable_stack PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:latest DEVICE=gpu
48+
- name: build maxdiffusion jax nightly image
49+
run: |
50+
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly_gpu MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly DEVICE=gpu

.github/workflows/build_and_upload_images.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@ for ARGUMENT in "$@"; do
3434
echo "$KEY"="$VALUE"
3535
done
3636

37+
export DEVICE="${DEVICE:-tpu}"
38+
3739
if [[ ! -v CLOUD_IMAGE_NAME ]] || [[ ! -v PROJECT ]] || [[ ! -v MODE ]] ; then
3840
echo "You must set CLOUD_IMAGE_NAME, PROJECT and MODE"
3941
exit 1
4042
fi
4143

4244
gcloud auth configure-docker us-docker.pkg.dev --quiet
43-
bash docker_build_dependency_image.sh LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE=$MODE
45+
bash docker_build_dependency_image.sh LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE=$MODE DEVICE=$DEVICE
4446
image_date=$(date +%Y-%m-%d)
4547

4648
# Upload only dependencies image

maxdiffusion_gpu_dependencies.Dockerfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ RUN apt-get update && apt-get install -y google-cloud-sdk
2222
# Set environment variables for Google Cloud SDK
2323
ENV PATH="/usr/local/google-cloud-sdk/bin:${PATH}"
2424

25-
# Upgrade libcusprase to work with Jax
26-
RUN apt-get update && apt-get install -y libcusparse-12-3
25+
2726

2827
ARG MODE
2928
ENV ENV_MODE=$MODE
@@ -45,6 +44,7 @@ RUN ls .
4544

4645
RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}"
4746
RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE}
48-
47+
RUN pip install -r requirements.txt
48+
RUN pip install -U "jax[cuda12]"
4949

5050
WORKDIR /deps

0 commit comments

Comments
 (0)