From 94312d3643239db9d4acfff1827c747ef5f4cbbf Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Thu, 4 Apr 2024 16:28:55 -0700 Subject: [PATCH 1/3] Update front-page readme with link to XLA flag doc --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index e8d598626..dd76d978d 100644 --- a/README.md +++ b/README.md @@ -271,7 +271,7 @@ We will update this table as new models become available, so stay tuned. ## Environment Variables -The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is embedded with the following flags and environment variables for performance tuning: +The [JAX images](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) are embedded with the following flags and environment variables for performance tuning: | XLA Flags | Value | Explanation | | --------- | ----- | ----------- | @@ -280,6 +280,8 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb | `--xla_gpu_enable_async_reduce_scatter` | `true` | allows XLA to run NCCL [ReduceScatter](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/operations.html#reducescatter) kernels on a separate CUDA stream to allow overlap with compute kernels | | `--xla_gpu_enable_triton_gemm` | `false` | use cuBLAS instead of Trition GeMM kernels | +See [GPU performance](./rosetta/docs/GPU_performance.md) for details about these, and other XLA flags, that enable high-performance for LLMs on NVIDIA GPUs. + | Environment Variable | Value | Explanation | | -------------------- | ----- | ----------- | | `CUDA_DEVICE_MAX_CONNECTIONS` | `1` | use a single queue for GPU work to lower latency of stream operations; OK since XLA already orders launches | From 45e1cc00ab98d43238450ef120c0f1bb9e0a1d97 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Fri, 5 Apr 2024 09:51:48 -0700 Subject: [PATCH 2/3] Remove default list of flags on front page readme and instead provide instructions for how to inspect them remotely --- README.md | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index dd76d978d..2503ce21a 100644 --- a/README.md +++ b/README.md @@ -270,17 +270,7 @@ We currently enable training and evaluation for the following models: We will update this table as new models become available, so stay tuned. ## Environment Variables - -The [JAX images](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) are embedded with the following flags and environment variables for performance tuning: - -| XLA Flags | Value | Explanation | -| --------- | ----- | ----------- | -| `--xla_gpu_enable_latency_hiding_scheduler` | `true` | allows XLA to move communication collectives to increase overlap with compute kernels | -| `--xla_gpu_enable_async_all_gather` | `true` | allows XLA to run NCCL [AllGather](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/operations.html#allgather) kernels on a separate CUDA stream to allow overlap with compute kernels | -| `--xla_gpu_enable_async_reduce_scatter` | `true` | allows XLA to run NCCL [ReduceScatter](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/operations.html#reducescatter) kernels on a separate CUDA stream to allow overlap with compute kernels | -| `--xla_gpu_enable_triton_gemm` | `false` | use cuBLAS instead of Trition GeMM kernels | - -See [GPU performance](./rosetta/docs/GPU_performance.md) for details about these, and other XLA flags, that enable high-performance for LLMs on NVIDIA GPUs. +The [JAX images](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) are embedded with the following environment variables and XLA flags for performance tuning: | Environment Variable | Value | Explanation | | -------------------- | ----- | ----------- | @@ -288,6 +278,17 @@ See [GPU performance](./rosetta/docs/GPU_performance.md) for details about these | `NCCL_NVLS_ENABLE` | `0` | Disables NVLink SHARP ([1](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nvls-enable)). Future releases will re-enable this feature. | | `CUDA_MODULE_LOADING` | `EAGER` | Disables lazy-loading ([1](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#cuda-environment-variables)) which uses slightly more GPU memory. | +XLA flags that tune performance are also set by default in the [JAX images](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax). To view the +the flags currently set, you can inspect the container's environment variables: +```sh +# Update IMAGE to inspect a container of your choosing +IMAGE=ghcr.io/nvidia/jax:jax + +docker run --rm quay.io/skopeo/stable inspect docker://$IMAGE | jq -r '.Env[]' | grep '^XLA_FLAGS=' +``` + +See [GPU performance](./rosetta/docs/GPU_performance.md) for details about these, and other XLA flags, that enable high-performance for LLMs on NVIDIA GPUs. + ## Profiling JAX programs on GPU See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU. From 0ffcebb1009e7e466394f4eb79ade3e5380f7a0a Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Wed, 10 Apr 2024 10:17:13 -0700 Subject: [PATCH 3/3] Adds an example output of the XLA flags present in container --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 2503ce21a..a0db1f404 100644 --- a/README.md +++ b/README.md @@ -285,6 +285,10 @@ the flags currently set, you can inspect the container's environment variables: IMAGE=ghcr.io/nvidia/jax:jax docker run --rm quay.io/skopeo/stable inspect docker://$IMAGE | jq -r '.Env[]' | grep '^XLA_FLAGS=' + +# which returns + +XLA_FLAGS= --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false ``` See [GPU performance](./rosetta/docs/GPU_performance.md) for details about these, and other XLA flags, that enable high-performance for LLMs on NVIDIA GPUs.