diff --git a/Dockerfile-cuda b/Dockerfile-cuda index 537ad59f..bba1c1b5 100644 --- a/Dockerfile-cuda +++ b/Dockerfile-cuda @@ -1,4 +1,4 @@ -FROM nvidia/cuda:12.2.0-devel-ubuntu22.04 AS base-builder +FROM nvidia/cuda:12.9.0-devel-ubuntu22.04 AS base-builder ENV SCCACHE=0.10.0 ENV RUSTC_WRAPPER=/usr/local/bin/sccache @@ -58,6 +58,9 @@ RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ elif [ ${CUDA_COMPUTE_CAP} -eq 90 ]; \ then \ nvprune --generate-code code=sm_90 /usr/local/cuda/lib64/libcublas_static.a -o /usr/local/cuda/lib64/libcublas_static.a; \ + elif [ ${CUDA_COMPUTE_CAP} -eq 120 ]; \ + then \ + nvprune --generate-code code=sm_120 /usr/local/cuda/lib64/libcublas_static.a -o /usr/local/cuda/lib64/libcublas_static.a; \ else \ echo "cuda compute cap ${CUDA_COMPUTE_CAP} is not supported"; exit 1; \ fi; diff --git a/Dockerfile-cuda-all b/Dockerfile-cuda-all index 5dca432a..d3908112 100644 --- a/Dockerfile-cuda-all +++ b/Dockerfile-cuda-all @@ -1,4 +1,4 @@ -FROM nvidia/cuda:12.2.0-devel-ubuntu22.04 AS base-builder +FROM nvidia/cuda:12.9.0-devel-ubuntu22.04 AS base-builder ENV SCCACHE=0.10.0 ENV RUSTC_WRAPPER=/usr/local/bin/sccache @@ -85,6 +85,15 @@ RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ CUDA_COMPUTE_CAP=90 cargo chef cook --release --features candle-cuda --recipe-path recipe.json && sccache -s; \ fi; +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + if [ $VERTEX = "true" ]; \ + then \ + CUDA_COMPUTE_CAP=120 cargo chef cook --release --features google --features candle-cuda --recipe-path recipe.json && sccache -s; \ + else \ + CUDA_COMPUTE_CAP=120 cargo chef cook --release --features candle-cuda --recipe-path recipe.json && sccache -s; \ + fi; + COPY backends backends COPY core core COPY router router @@ -122,9 +131,18 @@ RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ CUDA_COMPUTE_CAP=90 cargo build --release --bin text-embeddings-router -F candle-cuda && sccache -s; \ fi; +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + if [ $VERTEX = "true" ]; \ + then \ + CUDA_COMPUTE_CAP=120 cargo build --release --bin text-embeddings-router -F candle-cuda -F google && sccache -s; \ + else \ + CUDA_COMPUTE_CAP=120 cargo build --release --bin text-embeddings-router -F candle-cuda && sccache -s; \ + fi; + RUN mv /usr/src/target/release/text-embeddings-router /usr/src/target/release/text-embeddings-router-90 -FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04 AS base +FROM nvidia/cuda:12.9.0-runtime-ubuntu22.04 AS base ARG DEFAULT_USE_FLASH_ATTENTION=True diff --git a/README.md b/README.md index e9e98d0b..71c49156 100644 --- a/README.md +++ b/README.md @@ -581,6 +581,9 @@ runtime_compute_cap=89 # Example for H100 runtime_compute_cap=90 +# Example for Blackwell (RTX 5000 series, ...) +runtime_compute_cap=120 + docker build . -f Dockerfile-cuda --build-arg CUDA_COMPUTE_CAP=$runtime_compute_cap ``` diff --git a/backends/candle/src/compute_cap.rs b/backends/candle/src/compute_cap.rs index ac79fcf1..56978f5d 100644 --- a/backends/candle/src/compute_cap.rs +++ b/backends/candle/src/compute_cap.rs @@ -30,6 +30,7 @@ fn compute_cap_matching(runtime_compute_cap: usize, compile_compute_cap: usize) (86..=89, 80..=86) => true, (89, 89) => true, (90, 90) => true, + (120, 120) => true, (_, _) => false, } } @@ -54,6 +55,7 @@ mod tests { assert!(compute_cap_matching(86, 86)); assert!(compute_cap_matching(89, 89)); assert!(compute_cap_matching(90, 90)); + assert!(compute_cap_matching(120, 120)); assert!(compute_cap_matching(86, 80)); assert!(compute_cap_matching(89, 80)); diff --git a/backends/candle/src/flash_attn.rs b/backends/candle/src/flash_attn.rs index 8dbe58cf..f1b69c72 100644 --- a/backends/candle/src/flash_attn.rs +++ b/backends/candle/src/flash_attn.rs @@ -61,7 +61,7 @@ pub(crate) fn flash_attn_varlen( } #[cfg(not(feature = "flash-attn-v1"))] candle::bail!("Flash attention v1 is not installed. Use `flash-attn-v1` feature.") - } else if (80..90).contains(&runtime_compute_cap) || runtime_compute_cap == 90 { + } else if (80..90).contains(&runtime_compute_cap) || runtime_compute_cap == 90 || runtime_compute_cap == 120 { #[cfg(feature = "flash-attn")] { use candle_flash_attn::{flash_attn_varlen_alibi_windowed, flash_attn_varlen_windowed}; diff --git a/docs/source/en/custom_container.md b/docs/source/en/custom_container.md index c670026c..8b85262e 100644 --- a/docs/source/en/custom_container.md +++ b/docs/source/en/custom_container.md @@ -32,6 +32,7 @@ the examples of runtime compute capabilities for various GPU types: - A10 - `runtime_compute_cap=86` - Ada Lovelace (RTX 4000 series, ...) - `runtime_compute_cap=89` - H100 - `runtime_compute_cap=90` +- Blackwell (RTX 5000 series, ...) - `runtime_compute_cap=120` Once you have determined the compute capability is determined, set it as the `runtime_compute_cap` variable and build the container as shown in the example below: