File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -44,7 +44,5 @@ RUN ls .
4444
4545RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}"
4646RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE}
47- RUN pip install -r requirements.txt
48- RUN pip install -U "jax[cuda12]"
4947
5048WORKDIR /deps
Original file line number Diff line number Diff line change @@ -55,6 +55,9 @@ if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE) ]]; then
5555 exit 1
5656fi
5757
58+ # Install dependencies from requirements.txt first
59+ pip3 install -U -r requirements.txt || echo " Failed to install dependencies in the requirements" >&2
60+
5861# Install JAX and JAXlib based on the specified mode
5962if [[ " $MODE " == " stable" || ! -v MODE ]]; then
6063 # Stable mode
@@ -78,7 +81,7 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then
7881 pip3 install " jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
7982 fi
8083 export NVTE_FRAMEWORK=jax
81- pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
84+ pip3 install transformer_engine[jax]==2.1.0
8285 fi
8386
8487elif [[ $MODE == " nightly" ]]; then
106109 exit 1
107110fi
108111
109- # Install dependencies from requirements.txt
110- pip3 install -U -r requirements.txt || echo " Failed to install dependencies in the requirements" >&2
111-
112112# Install maxdiffusion
113113pip3 install -U . || echo " Failed to install maxdiffusion" >&2
You can’t perform that action at this time.
0 commit comments