@@ -18,75 +18,126 @@ ARG MODE
1818ENV MODE=$MODE
1919
2020RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}"
21+ RUN pip uninstall -y jax jaxlib libtpu
22+
23+ RUN pip install aiohttp==3.12.15
24+
25+ # Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
26+ RUN pip install keyring keyrings.google-artifactregistry-auth
27+
28+ RUN pip install numba==0.61.2
2129
22- # Uninstall existing jax to avoid conflicts
23- # RUN pip uninstall -y jax jaxlib libtpu
24-
25- # --- STAGE 1: Install Static Dependencies ---
26- # Install any packages *not* defined in your project dependency files
27- RUN --mount=type=cache,target=/root/.cache/pip pip install \
28- aiohttp==3.12.15\
29- keyring \
30- keyrings.google-artifactregistry-auth
31-
32- RUN --mount=type=cache,target=/root/.cache/pip pip install \
33- numba==0.61.2
34-
35- # RUN VLLM_TARGET_DEVICE="tpu" pip install vllm
36- # --- STAGE 2: Install Project Dependencies (The Main Cached Layer) ---
37-
38- # Copy *only* the dependency definition files.
39- # This assumes vllm and tpu-inference are in the build context, copied from the parent directory.
40- COPY vllm/requirements/tpu.txt /tmp/
41- COPY vllm/requirements/build.txt /tmp/
42- COPY vllm/requirements/common.txt /tmp/
43- COPY tpu-inference/requirements.txt /tmp/
44-
45- # Run the full dependency installation.
46- # This entire layer is cached and will *only* be rebuilt if
47- # these .txt files change.
48- RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
49- # Set the target device so pip installs the right JAX/libtpu
50- # Install tpu-inference dependencies
51- export VLLM_TARGET_DEVICE="tpu" && \
52- pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \
53- --extra-index-url https://pypi.org/simple/ \
54- --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
55- --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
56- --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
57- --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
58- --find-links https://storage.googleapis.com/libtpu-releases/index.html \
59- --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
60- --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
61-
62- # Install tpu-inference dependencies
63- RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
64- pip install -r /tmp/requirements.txt --no-cache-dir --pre \
65- --extra-index-url https://pypi.org/simple/ \
66- --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
67- --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
68- --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
69- --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
70- --find-links https://storage.googleapis.com/libtpu-releases/index.html \
71- --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
72- --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
73-
74- # --- STAGE 3: Install Project Source Code ---
75-
76- # Now, copy the full source code. This invalidates cache frequently,
77- # but the next step is fast.
78- COPY vllm /vllm/
79- COPY tpu-inference /tpu-inference/
8030COPY tunix /tunix
31+ RUN pip install -e /tunix --no-cache-dir
32+
33+
34+ COPY vllm /vllm
35+ RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \
36+ --extra-index-url https://pypi.org/simple/ \
37+ --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
38+ --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
39+ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
40+ --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
41+ --find-links https://storage.googleapis.com/libtpu-releases/index.html \
42+ --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
43+ --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
44+
45+
46+ COPY tpu-inference /tpu-inference
47+ RUN pip install -e /tpu-inference --no-cache-dir --pre \
48+ --extra-index-url https://pypi.org/simple/ \
49+ --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
50+ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
51+
52+ # # Install vLLM for Jax and TPUs from the artifact registry
53+ # RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \
54+ # --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
55+ # --extra-index-url https://pypi.org/simple/ \
56+ # --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
57+ # --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
58+ # --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
59+ # --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
60+ # --find-links https://storage.googleapis.com/libtpu-releases/index.html \
61+ # --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
62+ # --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
63+ # vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu
64+
65+ # # Install tpu-commons from the artifact registry
66+ # RUN pip install --no-cache-dir --pre \
67+ # --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
68+ # --extra-index-url https://pypi.org/simple/ \
69+ # --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
70+ # --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
71+ # tpu-commons==0.1.2
72+
73+ # # Uninstall existing jax to avoid conflicts
74+ # # RUN pip uninstall -y jax jaxlib libtpu
75+
76+ # # --- STAGE 1: Install Static Dependencies ---
77+ # # Install any packages *not* defined in your project dependency files
78+ # RUN --mount=type=cache,target=/root/.cache/pip pip install \
79+ # aiohttp==3.12.15\
80+ # keyring \
81+ # keyrings.google-artifactregistry-auth
82+
83+ # RUN --mount=type=cache,target=/root/.cache/pip pip install \
84+ # numba==0.61.2
85+
86+ # # RUN VLLM_TARGET_DEVICE="tpu" pip install vllm
87+ # # --- STAGE 2: Install Project Dependencies (The Main Cached Layer) ---
88+
89+ # # Copy *only* the dependency definition files.
90+ # # This assumes vllm and tpu-inference are in the build context, copied from the parent directory.
91+ # COPY vllm/requirements/tpu.txt /tmp/
92+ # COPY vllm/requirements/build.txt /tmp/
93+ # COPY vllm/requirements/common.txt /tmp/
94+ # COPY tpu-inference/requirements.txt /tmp/
95+
96+ # # Run the full dependency installation.
97+ # # This entire layer is cached and will *only* be rebuilt if
98+ # # these .txt files change.
99+ # RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
100+ # # Set the target device so pip installs the right JAX/libtpu
101+ # # Install tpu-inference dependencies
102+ # export VLLM_TARGET_DEVICE="tpu" && \
103+ # pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \
104+ # --extra-index-url https://pypi.org/simple/ \
105+ # --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
106+ # --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
107+ # --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
108+ # --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
109+ # --find-links https://storage.googleapis.com/libtpu-releases/index.html \
110+ # --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
111+ # --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
112+
113+ # # Install tpu-inference dependencies
114+ # RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
115+ # pip install -r /tmp/requirements.txt --no-cache-dir --pre \
116+ # --extra-index-url https://pypi.org/simple/ \
117+ # --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
118+ # --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
119+ # --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
120+ # --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
121+ # --find-links https://storage.googleapis.com/libtpu-releases/index.html \
122+ # --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
123+ # --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
124+
125+ # # --- STAGE 3: Install Project Source Code ---
126+
127+ # # Now, copy the full source code. This invalidates cache frequently,
128+ # # but the next step is fast.
129+ # COPY vllm /vllm/
130+ # COPY tpu-inference /tpu-inference/
131+ # COPY tunix /tunix
81132
82133
83- # Install in editable mode. This is lightning-fast because all
84- # dependencies were installed and cached in STAGE 2.
85- RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/
86- RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/
134+ # # Install in editable mode. This is lightning-fast because all
135+ # # dependencies were installed and cached in STAGE 2.
136+ # RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/
137+ # RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/
87138
88- RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/
89- # RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/
139+ # RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/
140+ # # RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/
90141
91142RUN if [ "$MODE" = "grpo-experimental" ]; then \
92143 echo "MODE=grpo-experimental: Re-installing JAX/libtpu" ; \
0 commit comments