Skip to content

Commit 45f39ac

Browse files
committed
Move jax[cuda] installation after requirements.txt to avoid clobbering of it by jax installation in requirements.txt
Signed-off-by: Kunjan <kunjanp@google.com>
1 parent 5febb77 commit 45f39ac

2 files changed

Lines changed: 4 additions & 6 deletions

File tree

maxdiffusion_gpu_dependencies.Dockerfile

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,5 @@ RUN ls .
4444

4545
RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}"
4646
RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE}
47-
RUN pip install -r requirements.txt
48-
RUN pip install -U "jax[cuda12]"
4947

5048
WORKDIR /deps

setup.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE) ]]; then
5555
exit 1
5656
fi
5757

58+
# Install dependencies from requirements.txt first
59+
pip3 install -U -r requirements.txt || echo "Failed to install dependencies in the requirements" >&2
60+
5861
# Install JAX and JAXlib based on the specified mode
5962
if [[ "$MODE" == "stable" || ! -v MODE ]]; then
6063
# Stable mode
@@ -78,7 +81,7 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then
7881
pip3 install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
7982
fi
8083
export NVTE_FRAMEWORK=jax
81-
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
84+
pip3 install transformer_engine[jax]==2.1.0
8285
fi
8386

8487
elif [[ $MODE == "nightly" ]]; then
@@ -106,8 +109,5 @@ else
106109
exit 1
107110
fi
108111

109-
# Install dependencies from requirements.txt
110-
pip3 install -U -r requirements.txt || echo "Failed to install dependencies in the requirements" >&2
111-
112112
# Install maxdiffusion
113113
pip3 install -U . || echo "Failed to install maxdiffusion" >&2

0 commit comments

Comments
 (0)