@@ -17,7 +17,7 @@ FROM ${BASEIMAGE}
1717ARG MODE
1818ENV MODE=$MODE
1919
20- RUN echo "Installing GRPO dependencies (vLLM, tpu-common, tunix ) with MODE=${MODE}"
20+ RUN echo "Installing GRPO dependencies (vLLM, tpu-inference ) with MODE=${MODE}"
2121
2222# Uninstall existing jax to avoid conflicts
2323# RUN pip uninstall -y jax jaxlib libtpu
@@ -27,23 +27,27 @@ RUN echo "Installing GRPO dependencies (vLLM, tpu-common, tunix) with MODE=${MOD
2727RUN --mount=type=cache,target=/root/.cache/pip pip install \
2828 aiohttp==3.12.15\
2929 keyring \
30- keyrings.google-artifactregistry-auth \
30+ keyrings.google-artifactregistry-auth
31+
32+ RUN --mount=type=cache,target=/root/.cache/pip pip install \
3133 numba==0.61.2
3234
35+ # RUN VLLM_TARGET_DEVICE="tpu" pip install vllm
3336# --- STAGE 2: Install Project Dependencies (The Main Cached Layer) ---
3437
3538# Copy *only* the dependency definition files.
36- # This assumes vllm and tpu_commons are in the build context, copied from the parent directory.
39+ # This assumes vllm and tpu-inference are in the build context, copied from the parent directory.
3740COPY vllm/requirements/tpu.txt /tmp/
3841COPY vllm/requirements/build.txt /tmp/
3942COPY vllm/requirements/common.txt /tmp/
40- COPY tpu_commons /requirements.txt /tmp/
43+ COPY tpu-inference /requirements.txt /tmp/
4144
4245# Run the full dependency installation.
4346# This entire layer is cached and will *only* be rebuilt if
4447# these .txt files change.
45- RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
48+ RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
4649 # Set the target device so pip installs the right JAX/libtpu
50+ # Install tpu-inference dependencies
4751 export VLLM_TARGET_DEVICE="tpu" && \
4852 pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \
4953 --extra-index-url https://pypi.org/simple/ \
@@ -55,16 +59,34 @@ RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
5559 --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
5660 --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
5761
62+ # Install tpu-inference dependencies
63+ RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
64+ pip install -r /tmp/requirements.txt --no-cache-dir --pre \
65+ --extra-index-url https://pypi.org/simple/ \
66+ --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
67+ --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
68+ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
69+ --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
70+ --find-links https://storage.googleapis.com/libtpu-releases/index.html \
71+ --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
72+ --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
73+
5874# --- STAGE 3: Install Project Source Code ---
5975
6076# Now, copy the full source code. This invalidates cache frequently,
6177# but the next step is fast.
6278COPY vllm /vllm/
63- COPY tpu_commons /tpu_commons/
79+ COPY tpu-inference /tpu-inference/
80+ COPY tunix /tunix
81+
6482
6583# Install in editable mode. This is lightning-fast because all
6684# dependencies were installed and cached in STAGE 2.
67- RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/ -e /tpu_commons/
85+ RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/
86+ RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/
87+
88+ RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/
89+ # RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/
6890
6991RUN if [ "$MODE" = "grpo-experimental" ]; then \
7092 echo "MODE=grpo-experimental: Re-installing JAX/libtpu" ; \
0 commit comments