Skip to content

Commit 7de3241

Browse files
committed
optionally separate out the dpendencies of vllm
1 parent 3f819cc commit 7de3241

File tree

7 files changed

+222
-71
lines changed

7 files changed

+222
-71
lines changed

base_requirements/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,5 @@ tiktoken
3939
tokamax
4040
transformers
4141
qwix
42-
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
42+
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip
4343
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip

maxtext_grpo_dependencies.Dockerfile

Lines changed: 115 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -18,75 +18,126 @@ ARG MODE
1818
ENV MODE=$MODE
1919

2020
RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}"
21+
RUN pip uninstall -y jax jaxlib libtpu
22+
23+
RUN pip install aiohttp==3.12.15
24+
25+
# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
26+
RUN pip install keyring keyrings.google-artifactregistry-auth
27+
28+
RUN pip install numba==0.61.2
2129

22-
# Uninstall existing jax to avoid conflicts
23-
# RUN pip uninstall -y jax jaxlib libtpu
24-
25-
# --- STAGE 1: Install Static Dependencies ---
26-
# Install any packages *not* defined in your project dependency files
27-
RUN --mount=type=cache,target=/root/.cache/pip pip install \
28-
aiohttp==3.12.15\
29-
keyring \
30-
keyrings.google-artifactregistry-auth
31-
32-
RUN --mount=type=cache,target=/root/.cache/pip pip install \
33-
numba==0.61.2
34-
35-
# RUN VLLM_TARGET_DEVICE="tpu" pip install vllm
36-
# --- STAGE 2: Install Project Dependencies (The Main Cached Layer) ---
37-
38-
# Copy *only* the dependency definition files.
39-
# This assumes vllm and tpu-inference are in the build context, copied from the parent directory.
40-
COPY vllm/requirements/tpu.txt /tmp/
41-
COPY vllm/requirements/build.txt /tmp/
42-
COPY vllm/requirements/common.txt /tmp/
43-
COPY tpu-inference/requirements.txt /tmp/
44-
45-
# Run the full dependency installation.
46-
# This entire layer is cached and will *only* be rebuilt if
47-
# these .txt files change.
48-
RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
49-
# Set the target device so pip installs the right JAX/libtpu
50-
# Install tpu-inference dependencies
51-
export VLLM_TARGET_DEVICE="tpu" && \
52-
pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \
53-
--extra-index-url https://pypi.org/simple/ \
54-
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
55-
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
56-
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
57-
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
58-
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
59-
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
60-
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
61-
62-
# Install tpu-inference dependencies
63-
RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
64-
pip install -r /tmp/requirements.txt --no-cache-dir --pre \
65-
--extra-index-url https://pypi.org/simple/ \
66-
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
67-
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
68-
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
69-
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
70-
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
71-
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
72-
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
73-
74-
# --- STAGE 3: Install Project Source Code ---
75-
76-
# Now, copy the full source code. This invalidates cache frequently,
77-
# but the next step is fast.
78-
COPY vllm /vllm/
79-
COPY tpu-inference /tpu-inference/
8030
COPY tunix /tunix
31+
RUN pip install -e /tunix --no-cache-dir
32+
33+
34+
COPY vllm /vllm
35+
RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \
36+
--extra-index-url https://pypi.org/simple/ \
37+
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
38+
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
39+
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
40+
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
41+
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
42+
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
43+
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
44+
45+
46+
COPY tpu-inference /tpu-inference
47+
RUN pip install -e /tpu-inference --no-cache-dir --pre \
48+
--extra-index-url https://pypi.org/simple/ \
49+
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
50+
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
51+
52+
# # Install vLLM for Jax and TPUs from the artifact registry
53+
# RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \
54+
# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
55+
# --extra-index-url https://pypi.org/simple/ \
56+
# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
57+
# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
58+
# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
59+
# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
60+
# --find-links https://storage.googleapis.com/libtpu-releases/index.html \
61+
# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
62+
# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
63+
# vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu
64+
65+
# # Install tpu-commons from the artifact registry
66+
# RUN pip install --no-cache-dir --pre \
67+
# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
68+
# --extra-index-url https://pypi.org/simple/ \
69+
# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
70+
# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
71+
# tpu-commons==0.1.2
72+
73+
# # Uninstall existing jax to avoid conflicts
74+
# # RUN pip uninstall -y jax jaxlib libtpu
75+
76+
# # --- STAGE 1: Install Static Dependencies ---
77+
# # Install any packages *not* defined in your project dependency files
78+
# RUN --mount=type=cache,target=/root/.cache/pip pip install \
79+
# aiohttp==3.12.15\
80+
# keyring \
81+
# keyrings.google-artifactregistry-auth
82+
83+
# RUN --mount=type=cache,target=/root/.cache/pip pip install \
84+
# numba==0.61.2
85+
86+
# # RUN VLLM_TARGET_DEVICE="tpu" pip install vllm
87+
# # --- STAGE 2: Install Project Dependencies (The Main Cached Layer) ---
88+
89+
# # Copy *only* the dependency definition files.
90+
# # This assumes vllm and tpu-inference are in the build context, copied from the parent directory.
91+
# COPY vllm/requirements/tpu.txt /tmp/
92+
# COPY vllm/requirements/build.txt /tmp/
93+
# COPY vllm/requirements/common.txt /tmp/
94+
# COPY tpu-inference/requirements.txt /tmp/
95+
96+
# # Run the full dependency installation.
97+
# # This entire layer is cached and will *only* be rebuilt if
98+
# # these .txt files change.
99+
# RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
100+
# # Set the target device so pip installs the right JAX/libtpu
101+
# # Install tpu-inference dependencies
102+
# export VLLM_TARGET_DEVICE="tpu" && \
103+
# pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \
104+
# --extra-index-url https://pypi.org/simple/ \
105+
# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
106+
# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
107+
# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
108+
# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
109+
# --find-links https://storage.googleapis.com/libtpu-releases/index.html \
110+
# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
111+
# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
112+
113+
# # Install tpu-inference dependencies
114+
# RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
115+
# pip install -r /tmp/requirements.txt --no-cache-dir --pre \
116+
# --extra-index-url https://pypi.org/simple/ \
117+
# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
118+
# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
119+
# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
120+
# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
121+
# --find-links https://storage.googleapis.com/libtpu-releases/index.html \
122+
# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
123+
# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
124+
125+
# # --- STAGE 3: Install Project Source Code ---
126+
127+
# # Now, copy the full source code. This invalidates cache frequently,
128+
# # but the next step is fast.
129+
# COPY vllm /vllm/
130+
# COPY tpu-inference /tpu-inference/
131+
# COPY tunix /tunix
81132

82133

83-
# Install in editable mode. This is lightning-fast because all
84-
# dependencies were installed and cached in STAGE 2.
85-
RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/
86-
RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/
134+
# # Install in editable mode. This is lightning-fast because all
135+
# # dependencies were installed and cached in STAGE 2.
136+
# RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/
137+
# RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/
87138

88-
RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/
89-
# RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/
139+
# RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/
140+
# # RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/
90141

91142
RUN if [ "$MODE" = "grpo-experimental" ]; then \
92143
echo "MODE=grpo-experimental: Re-installing JAX/libtpu"; \
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
ARG BASEIMAGE
16+
FROM ${BASEIMAGE}
17+
ARG MODE
18+
ENV MODE=$MODE
19+
20+
RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}"
21+
22+
# Uninstall existing jax to avoid conflicts
23+
# RUN pip uninstall -y jax jaxlib libtpu
24+
25+
# --- STAGE 1: Install Static Dependencies ---
26+
# Install any packages *not* defined in your project dependency files
27+
RUN --mount=type=cache,target=/root/.cache/pip pip install \
28+
aiohttp==3.12.15\
29+
keyring \
30+
keyrings.google-artifactregistry-auth
31+
32+
RUN --mount=type=cache,target=/root/.cache/pip pip install \
33+
numba==0.61.2
34+
35+
# RUN VLLM_TARGET_DEVICE="tpu" pip install vllm
36+
# --- STAGE 2: Install Project Dependencies (The Main Cached Layer) ---
37+
38+
# Copy *only* the dependency definition files.
39+
# This assumes vllm and tpu-inference are in the build context, copied from the parent directory.
40+
COPY vllm/requirements/tpu.txt /tmp/
41+
COPY vllm/requirements/build.txt /tmp/
42+
COPY vllm/requirements/common.txt /tmp/
43+
COPY tpu-inference/requirements.txt /tmp/
44+
45+
# Run the full dependency installation.
46+
# This entire layer is cached and will *only* be rebuilt if
47+
# these .txt files change.
48+
RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
49+
# Set the target device so pip installs the right JAX/libtpu
50+
# Install tpu-inference dependencies
51+
export VLLM_TARGET_DEVICE="tpu" && \
52+
pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \
53+
--extra-index-url https://pypi.org/simple/ \
54+
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
55+
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
56+
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
57+
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
58+
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
59+
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
60+
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
61+
62+
# Install tpu-inference dependencies
63+
RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
64+
pip install -r /tmp/requirements.txt --no-cache-dir --pre \
65+
--extra-index-url https://pypi.org/simple/ \
66+
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
67+
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
68+
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
69+
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
70+
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
71+
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
72+
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
73+
74+
# --- STAGE 3: Install Project Source Code ---
75+
76+
# Now, copy the full source code. This invalidates cache frequently,
77+
# but the next step is fast.
78+
COPY vllm /vllm/
79+
COPY tpu-inference /tpu-inference/
80+
COPY tunix /tunix
81+
82+
83+
# Install in editable mode. This is lightning-fast because all
84+
# dependencies were installed and cached in STAGE 2.
85+
RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/
86+
RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/
87+
88+
RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/
89+
# RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/
90+
91+
RUN if [ "$MODE" = "grpo-experimental" ]; then \
92+
echo "MODE=grpo-experimental: Re-installing JAX/libtpu"; \
93+
pip uninstall -y jax jaxlib libtpu && \
94+
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
95+
pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
96+
fi

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,5 @@ tensorflow-text
3838
tensorflow
3939
tiktoken
4040
transformers
41-
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
41+
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip
4242
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip

requirements_with_jax_ai_image.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
datasets @ https://github.com/huggingface/datasets/archive/6790e138c00b87a1ddc72184f89e7814cf784360.zip
44
flax>=0.11.0
55
google-api-python-client
6-
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
6+
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip
77
grain[parquet]>=0.2.12
88
jaxtyping
99
jsonlines

src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,20 +142,24 @@
142142

143143

144144
# ====== Input Checkpoint directory =====
145-
MODEL_CHECKPOINT_PATH = "gs://maxtext-model-checkpoints/llama3.1_70b_instruct/2025-10-15/pathways/scanned/0/items"
145+
MODEL_CHECKPOINT_PATH = "gs://mazumdera-test-bucket-europe-west4/llama3.1_70b_instruct/2025-10-15/pathways/scanned/0/items"
146+
# MODEL_CHECKPOINT_PATH = "gs://maxtext-model-checkpoints/llama3.1_70b_instruct/2025-10-15/pathways/scanned/0/items"
146147

147148
# ====== Checkpoint directory =====
148-
LOG_DIR = f"{HOME}/content/tensorboard/grpo/logs_llama3/"
149+
LOG_DIR = f"gs://mazumdera-test-bucket-europe-west4/rl-tuning/grpo/anisha-{run_id}/tensorboard/grpo/logs_llama3/"
150+
# LOG_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/rl-tuning/grpo/anisha-{run_id}/tensorboard/grpo/logs_llama3/"
149151
if not os.path.exists(LOG_DIR):
150-
os.makedirs(LOG_DIR)
152+
epath.Path(LOG_DIR).mkdir(parents=True)
151153

152154
# ===== Profiling =====
153155
PROFILE_DIR = f"gs://mazumdera-test-bucket-europe-west4/rl-tuning/grpo/anisha-{run_id}/profiles_llama3/"
156+
# PROFILE_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/rl-tuning/grpo/anisha-{run_id}/profiles_llama3/"
154157
if not epath.Path(PROFILE_DIR).exists():
155158
epath.Path(PROFILE_DIR).mkdir(parents=True)
156159

157160
# ====== Checkpoint saving ======
158161
CKPT_DIR = f"gs://mazumdera-test-bucket-europe-west4/rl-tuning/grpo/anisha-{run_id}/ckpts_llama3/"
162+
# CKPT_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/rl-tuning/grpo/anisha-{run_id}/ckpts_llama3/"
159163

160164
if not epath.Path(CKPT_DIR).exists():
161165
epath.Path(CKPT_DIR).mkdir(parents=True)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
1+
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip
22
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip

0 commit comments

Comments
 (0)