diff --git a/.bazelrc b/.bazelrc index 47cf3e766e88..a800a96bcd6b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -31,6 +31,7 @@ build -c opt build --output_filter=DONT_MATCH_ANYTHING build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. +build --copt=-DNB_DOMAIN=jax # ############################################################################# # Platform Specific configs below. These are automatically picked up by Bazel @@ -97,6 +98,7 @@ build:windows --incompatible_strict_action_env=true # ############################################################################# build:nonccl --define=no_nccl_support=true +build --repo_env USE_PYWRAP_RULES=1 build:posix --copt=-fvisibility=hidden build:posix --copt=-Wno-sign-compare build:posix --cxxopt=-std=c++17 @@ -138,13 +140,13 @@ build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --repo_env TF_NCCL_USE_STUB=1 # "sm" means we emit only cubin, which is forward compatible within a GPU generation. # "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. -build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,sm_90,sm_100,compute_120" build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda # Default hermetic CUDA and CUDNN versions. -build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" -build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.8.0" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.8.0" build:cuda --@local_config_cuda//cuda:include_cuda_libs=true # This config is used for building targets with CUDA libraries from stubs. @@ -262,8 +264,8 @@ build:ci_darwin_arm64 --color=yes # Windows x86 CI configs build:ci_windows_amd64 --config=avx_windows build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=true -build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain" -build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl" +build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win2022/20241118:toolchain" +build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl" build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE build:ci_windows_amd64 --color=yes @@ -331,9 +333,9 @@ common:rbe_windows_amd64 --remote_instance_name=projects/tensorflow-testing/inst build:rbe_windows_amd64 --config=rbe # Set the host, execution, and target platform -build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win:x64_windows-clang-cl" -build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" -build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" +build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" +build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" +build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" build:rbe_windows_amd64 --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe build:rbe_windows_amd64 --enable_runfiles diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 0b0e1cb62497..ff1cf9900ce3 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -47,7 +47,7 @@ jobs: # Explicitly set the shell to bash shell: bash runs-on: ${{ inputs.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest" env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} @@ -79,6 +79,7 @@ jobs: continue-on-error: true run: >- mkdir -p $(pwd)/dist && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ diff --git a/.github/workflows/bazel_optional_cuda.yml b/.github/workflows/bazel_optional_cuda.yml new file mode 100644 index 000000000000..71936aeb9ae8 --- /dev/null +++ b/.github/workflows/bazel_optional_cuda.yml @@ -0,0 +1,65 @@ +name: CI - Bazel Optional CUDA tests +on: + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + pull_request: + branches: + - main + types: [ labeled, synchronize ] + schedule: + - cron: "0 */2 * * *" # Run once every 2 hours +permissions: + contents: read +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + # Don't cancel in-progress jobs for main/release branches. + cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} +jobs: + run_tests: + if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }} + runs-on: ${{ matrix.runner }} + container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest' + strategy: + matrix: + # Optional gpus to run against + runner: ["linux-x86-a4-224-b200-1gpu"] + name: "Bazel single accelerator CUDA tests (${{ matrix.runner }})" +# End Presubmit Naming Check github-cuda-presubmits + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel CUDA Tests + run: | + nvidia-smi + bazel test --config=ci_linux_x86_64_cuda \ + --config=resultstore \ + --config=rbe_cache \ + --repo_env=HERMETIC_CUDA_VERSION="12.8.0" \ + --repo_env=HERMETIC_CUDNN_VERSION="9.8.0" \ + --repo_env=HERMETIC_PYTHON_VERSION="3.13" \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --run_under "$(pwd)/build/parallel_accelerator_execute.sh" \ + --test_output=errors \ + --test_env=JAX_ACCELERATOR_COUNT=1 \ + --test_env=JAX_TESTS_PER_ACCELERATOR=32 \ + --local_test_jobs=32 \ + --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ + --test_tag_filters=-multiaccelerator \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="1" \ + --action_env=NCCL_DEBUG=WARN \ + --color=yes \ + //tests:gpu_tests //tests:backend_independent_tests \ + //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ No newline at end of file diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index c2e7acb91f7a..37a791784506 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -16,8 +16,8 @@ on: default: "linux-x86-n2-16" options: - "linux-x86-n2-16" - - "linux-arm64-c4a-64" - - "windows-x86-n2-64" + - "linux-arm64-t2a-48" + - "windows-x86-n2-16" artifact: description: "Which JAX artifact to build?" type: choice @@ -119,11 +119,11 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Enable RBE if building on Linux x86 - if: contains(inputs.runner, 'linux-x86') + - name: Enable RBE if building on Linux x86 or Windows x86 + if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV - - name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 or Windows x86 - if: contains(inputs.runner, 'linux-arm64') || contains(inputs.runner, 'windows-x86') + - name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 + if: contains(inputs.runner, 'linux-arm64') run: echo "JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=1" >> $GITHUB_ENV # Halt for testing - name: Wait For Connection diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 699d33092b12..5bfbf4c9a705 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -33,11 +33,11 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python 3.11 - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: 3.11 - run: python -m pip install pre-commit - - uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + - uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: ~/.cache/pre-commit key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }} @@ -64,7 +64,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -102,7 +102,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -130,7 +130,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -156,7 +156,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -187,7 +187,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: 3.12 - name: Install JAX diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 099f4ad5c520..b50b07d5cc4a 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -26,7 +26,6 @@ jobs: matrix: jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] tpu: [ - # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} ] diff --git a/.github/workflows/community_release_actions.yml b/.github/workflows/community_release_actions.yml new file mode 100644 index 000000000000..1110cbad9475 --- /dev/null +++ b/.github/workflows/community_release_actions.yml @@ -0,0 +1,34 @@ +name: Release Actions + +on: + release: + types: [published] + +permissions: + contents: read + +jobs: + discord_release: + if: github.repository_owner == 'jax-ml' + runs-on: ubuntu-latest + steps: + - name: Get release URL + id: get-release-url + run: | + URL="https://docs.jax.dev/en/latest/changelog.html" + echo "::set-output name=URL::$URL" + - name: Get content + uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757 # v1.4.1 + id: get-content + with: + stringToTruncate: | + JAX [${{ github.event.release.tag_name }}](<${{ steps.get-release-url.outputs.URL }}>) was just released! + + ${{ github.event.release.body }} + maxLength: 2000 + truncationSymbol: "..." + - name: Discord Webhook Action + uses: tsickert/discord-webhook@b217a69502f52803de774ded2b1ab7c282e99645 # v7.0.0 + with: + webhook-url: ${{ secrets.DISCORD_WEBHOOK_URL }} + content: ${{ steps.get-content.outputs.string }} diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 2b97c5a05c1c..7df4228dd2a3 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -28,11 +28,11 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: '0b89c5268e4e4a352223a487b8f63dbd1023872d' # Latest commit as of 2025-03-04 + ref: 'c48410f96fc58e02eea844e6b7f6cc01680f77ce' # Latest commit as of 2025-04-02 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml new file mode 100644 index 000000000000..1042388fe9c6 --- /dev/null +++ b/.github/workflows/k8s.yaml @@ -0,0 +1,109 @@ +name: Distributed run using K8s Jobset + +on: + push: + branches: + - main + paths: + - 'jax/distributed.py' + - 'jax/_src/distributed.py' + - 'jax/_src/clusters/**' + pull_request: + branches: + - main + paths: + - 'jax/distributed.py' + - 'jax/_src/distributed.py' + - 'jax/_src/clusters/**' + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +defaults: + run: + shell: bash -ex -o pipefail {0} + +jobs: + + distributed-initialize: + runs-on: ubuntu-22.04 + steps: + - name: Checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # ratchet:actions/checkout@v4 + with: + path: jax + + - name: Start Minikube cluster + uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # ratchet:medyagh/setup-minikube@v0.0.19 + + - name: Install K8s Jobset + run: | + kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/v0.6.0/manifests.yaml + + - name: Build image + run: | + cat > Dockerfile <> $GITHUB_ENV + else + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' run: | diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index a105a2feb347..0b56635a8aac 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -54,6 +54,11 @@ on: # - "pypi_latest": Use the latest libtpu wheel from PyPI. # - "oldest_supported_libtpu": Use the oldest supported libtpu wheel. default: "nightly" + download-jax-only-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + required: false + default: '0' + type: string gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" required: true @@ -110,7 +115,11 @@ jobs: run: | mkdir -p $(pwd)/dist gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + echo "JAX only release. Only downloading the jax wheel from the release bucket." + else + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' run: | diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 4efbeee7c5e2..4a0fab2a7703 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -22,7 +22,7 @@ jobs: strategy: matrix: python: ["3.10.13"] - rocm: ["6.2.4", "6.3.3"] + rocm: ["6.3.3"] env: BASE_IMAGE: "ubuntu:22.04" TEST_IMAGE: ubuntu-jax-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }} diff --git a/.github/workflows/tsan-suppressions.txt b/.github/workflows/tsan-suppressions_3.13.txt similarity index 84% rename from .github/workflows/tsan-suppressions.txt rename to .github/workflows/tsan-suppressions_3.13.txt index 7b713b2da194..b600e38276cc 100644 --- a/.github/workflows/tsan-suppressions.txt +++ b/.github/workflows/tsan-suppressions_3.13.txt @@ -2,14 +2,11 @@ # are racing on a call to __register_frame_info(), but that function appears to be correctly locked internally. race:llvm::RuntimeDyldELF::registerEHFrames -# https://github.com/python/cpython/issues/128050 -race:partial_vectorcall_fallback - # https://github.com/openxla/xla/issues/20686 race:dnnl_sgemm -# https://github.com/python/cpython/issues/128130 -race_top:run_eval_code_obj +# https://github.com/python/cpython/issues/128050 +race:partial_vectorcall_fallback # Likely only happens when the process is crashing. race:dump_traceback @@ -18,18 +15,21 @@ race:dump_traceback # Fixed in Python 3.14, but not backported to 3.13. race:immortalize_interned race:_PyUnicode_InternMortal +race:_PyUnicode_InternImmortal # https://github.com/python/cpython/issues/128144 # Fixed in Python 3.14, but not backported to 3.13. race_top:PyMember_GetOne -# https://github.com/python/cpython/issues/129547 -race:type_get_annotations - +# https://github.com/python/cpython/issues/131680 +# Fixed in Python 3.14, but not backported to 3.13. +race_top: new_reference # https://github.com/python/cpython/issues/129748 race:mi_block_set_nextx +# https://github.com/python/cpython/issues/128130 +race_top:run_eval_code_obj # Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. race:heevd_ffi @@ -65,3 +65,11 @@ race:gemm_oncopy # https://github.com/python/cpython/issues/130547 # race:split_keys_entry_added + +# https://github.com/python/cpython/issues/129547 +# Maybe fixed? +# race:type_get_annotations + +# https://github.com/python/cpython/issues/132013 +# Fixed on 3.14 and not backported to 3.13 +race_top:frozenset_hash \ No newline at end of file diff --git a/.github/workflows/tsan-suppressions_3.14.txt b/.github/workflows/tsan-suppressions_3.14.txt new file mode 100644 index 000000000000..9cfc68e1ae36 --- /dev/null +++ b/.github/workflows/tsan-suppressions_3.14.txt @@ -0,0 +1,26 @@ +# false-positive caused because we haven't tsan-instrumented libgcc_s. Multiple threads +# are racing on a call to __register_frame_info(), but that function appears to be correctly locked internally. +race:llvm::RuntimeDyldELF::registerEHFrames + +# https://github.com/openxla/xla/issues/20686 +race:dnnl_sgemm + +# https://github.com/python/cpython/issues/128050 +race:partial_vectorcall_fallback + +# Likely only happens when the process is crashing. +race:dump_traceback + +# https://github.com/python/cpython/issues/129748 +race:mi_block_set_nextx + +# https://github.com/python/cpython/issues/128130 +race_top:run_eval_code_obj + +# Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. +race:heevd_ffi +race:gesdd_ffi +race:dscal_k_ +race:scal_k_ +race:gemm_beta +race:gemm_oncopy diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 7d93707e4e92..ef1cf99d6d74 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -13,6 +13,8 @@ on: - main paths: - '**/workflows/tsan.yaml' + - '**/workflows/tsan-suppressions*.txt' + - '**/workflows/requirements_lock_3_13_ft.patch' jobs: tsan: @@ -21,6 +23,16 @@ jobs: image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04 strategy: fail-fast: false + matrix: + include: + - name-prefix: "with 3.13" + python-version: "3.13" + github_branch: "3.13" + requirements_lock_name: "requirements_lock_3_13_ft" + - name-prefix: "with 3.14" + python-version: "3.14" + github_branch: "main" + requirements_lock_name: "requirements_lock_3_14_ft" defaults: run: shell: bash -l {0} @@ -43,22 +55,33 @@ jobs: with: repository: python/cpython path: cpython - ref: "3.13" + ref: ${{ matrix.github_branch }} - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: numpy/numpy path: numpy submodules: true + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + if: ${{ matrix.python-version == '3.14' }} + with: + repository: scipy/scipy + path: scipy + submodules: true - - name: Restore cached TSAN CPython + - name: Get year & week number + id: get-date + run: echo "date=$(/bin/date "+%Y-%U")" >> $GITHUB_OUTPUT + shell: bash -l {0} + + - name: Restore cached TSAN CPython ${{ matrix.python-version }} id: cache-cpython-tsan-restore - uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/restore@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./python-tsan.tgz - key: ${{ runner.os }}-cpython-tsan-${{ hashFiles('cpython/configure.ac') }} + key: ${{ runner.os }}-cpython-tsan-${{ matrix.python-version }}-${{ steps.get-date.outputs.date }} - - name: Build CPython with enabled TSAN + - name: Build TSAN CPython ${{ matrix.python-version }} if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' run: | cd cpython @@ -72,27 +95,22 @@ jobs: # Create archive to be used with bazel as hermetic python: cd ${GITHUB_WORKSPACE} && tar -czpf python-tsan.tgz cpython-tsan - - name: Save TSAN CPython + - name: Save TSAN CPython ${{ matrix.python-version }} id: cache-cpython-tsan-save if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' - uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/save@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./python-tsan.tgz - key: ${{ runner.os }}-cpython-tsan-${{ hashFiles('cpython/configure.ac') }} - - - name: Get year & week number - id: get-date - run: echo "date=$(/bin/date "+%Y-%U")" >> $GITHUB_OUTPUT - shell: bash -l {0} + key: ${{ runner.os }}-cpython-tsan-${{ matrix.python-version }}-${{ steps.get-date.outputs.date }} - name: Restore cached TSAN Numpy id: cache-numpy-tsan-restore - uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/restore@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./wheelhouse - key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + key: ${{ runner.os }}-numpy-tsan-${{ matrix.python-version }}-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} - name: Build TSAN Numpy wheel if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true' @@ -113,7 +131,8 @@ jobs: python3 -m pip install uv~=0.5.30 # Make sure to install a compatible Cython version (master branch is best for now) - python3 -m uv pip install -r requirements/build_requirements.txt -U git+https://github.com/cython/cython + NO_CYTHON_COMPILE=true python3 -m uv pip install -U git+https://github.com/cython/cython + python3 -m uv pip install -r requirements/build_requirements.txt CC=clang-18 CXX=clang++-18 python3 -m pip wheel --wheel-dir dist -v . --no-build-isolation -Csetup-args=-Db_sanitize=thread -Csetup-args=-Dbuildtype=debugoptimized @@ -142,11 +161,87 @@ jobs: - name: Save TSAN Numpy wheel id: cache-numpy-tsan-save if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true' - uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/save@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 + with: + path: | + ./wheelhouse + key: ${{ runner.os }}-numpy-tsan-${{ matrix.python-version }}-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + + - name: Restore cached Scipy + if: ${{ matrix.python-version == '3.14' }} + id: cache-scipy-restore + uses: actions/cache/restore@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./wheelhouse - key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + key: ${{ runner.os }}-scipy-${{ matrix.python-version }}-${{ hashFiles('scipy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + + - name: Build Scipy wheel + if: ${{ steps.cache-scipy-restore.outputs.cache-hit != 'true' && matrix.python-version == '3.14' }} + run: | + # Install scipy dependencies: + apt-get install -y gfortran libopenblas-dev liblapack-dev pkg-config --no-install-recommends + + cd scipy + + # If we restored cpython from cache, we need to get python interpreter from python-tsan.tgz + if [ ! -d ${GITHUB_WORKSPACE}/cpython-tsan/bin/ ]; then + echo "Extract cpython from python-tsan.tgz" + pushd . + ls ${GITHUB_WORKSPACE}/python-tsan.tgz + cd ${GITHUB_WORKSPACE} && tar -xzf python-tsan.tgz + ls ${GITHUB_WORKSPACE}/cpython-tsan/bin/ + popd + fi + + export PATH=${GITHUB_WORKSPACE}/cpython-tsan/bin/:$PATH + + python3 -m pip install uv~=0.5.30 + # Make sure to install a compatible Cython version (master branch is best for now) + NO_CYTHON_COMPILE=true python3 -m uv pip install -U git+https://github.com/cython/cython + python3 -m uv pip install -U --pre numpy --extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/ + python3 -m uv pip install pythran pybind11 meson-python ninja + + python3 -m uv pip list | grep -E "(numpy|pythran|cython|pybind11)" + + export CC=clang-18 + export CXX=clang++-18 + python3 -m pip wheel --wheel-dir dist -vvv . --no-build-isolation --no-deps -Csetup-args=-Dbuildtype=debugoptimized + + python3 -m uv pip list | grep -E "(numpy|pythran|cython|pybind11)" + + # Create simple index and copy the wheel + mkdir -p ${GITHUB_WORKSPACE}/wheelhouse/scipy + + scipy_whl_name=($(cd dist && ls scipy*.whl)) + if [ -z "${scipy_whl_name}" ]; then exit 1; fi + + echo "Built TSAN Scipy wheel: ${scipy_whl_name}" + + cp dist/${scipy_whl_name} ${GITHUB_WORKSPACE}/wheelhouse/scipy + + # Recreate wheelhouse index with Numpy and Scipy + cat << EOF > ${GITHUB_WORKSPACE}/wheelhouse/index.html + + numpy>
+ scipy>
+ + EOF + + cat << EOF > ${GITHUB_WORKSPACE}/wheelhouse/scipy/index.html + + ${scipy_whl_name}
+ + EOF + + - name: Save Scipy wheel + id: cache-scipy-save + if: ${{ steps.cache-scipy-restore.outputs.cache-hit != 'true' && matrix.python-version == '3.14' }} + uses: actions/cache/save@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 + with: + path: | + ./wheelhouse + key: ${{ runner.os }}-scipy-${{ matrix.python-version }}-${{ hashFiles('scipy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} - name: Build Jax and run tests timeout-minutes: 120 @@ -163,7 +258,7 @@ jobs: python3 -VV python3 build/build.py build --configure_only \ - --python_version=3.13-ft \ + --python_version=${{ matrix.python-version }}-ft \ --bazel_options=--repo_env=HERMETIC_PYTHON_URL="file://${GITHUB_WORKSPACE}/python-tsan.tgz" \ --bazel_options=--repo_env=HERMETIC_PYTHON_SHA256=${PYTHON_SHA256} \ --bazel_options=--repo_env=HERMETIC_PYTHON_PREFIX="cpython-tsan/" \ @@ -173,18 +268,32 @@ jobs: --bazel_options=--copt=-g \ --clang_path=/usr/bin/clang-18 - # Patch build/requirements_lock_3_13_ft.txt to use TSAN instrumented NumPy - sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/requirements_lock_3_13_ft.patch - cat .github/workflows/requirements_lock_3_13_ft.patch - git apply .github/workflows/requirements_lock_3_13_ft.patch || exit 1 + if [ "${{ matrix.python-version }}" == "3.13" ]; then + # Patch build/requirements_lock_3_13_ft.txt to use TSAN instrumented NumPy - # Display the content for debugging in logs - cat build/requirements_lock_3_13_ft.txt | head -15 - # Check the patch - cat build/requirements_lock_3_13_ft.txt | head -15 | grep -E "(--pre|.*${GITHUB_WORKSPACE}/wheelhouse/|numpy)" - if [ "$?" == "1" ]; then echo "Could not find the patch in the requirements_lock_3_13_ft.txt"; exit 1; fi - cat build/requirements_lock_3_13_ft.txt | grep -E "(numpy==)" - if [ "$?" == "0" ]; then "Found original numpy dependency in the requirements_lock_3_13_ft.txt"; exit 1; fi + sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/${{ matrix.requirements_lock_name }}.patch + cat .github/workflows/${{ matrix.requirements_lock_name }}.patch + git apply .github/workflows/${{ matrix.requirements_lock_name }}.patch || exit 1 + + # Display the content for debugging in logs + cat build/${{ matrix.requirements_lock_name }}.txt | head -15 + # Check the patch + cat build/${{ matrix.requirements_lock_name }}.txt | head -15 | grep -E "(--pre|.*${GITHUB_WORKSPACE}/wheelhouse/|numpy)" + if [ "$?" == "1" ]; then echo "Could not find the patch in the ${{ matrix.requirements_lock_name }}.txt"; exit 1; fi + cat build/${{ matrix.requirements_lock_name }}.txt | grep -E "(numpy==)" + if [ "$?" == "0" ]; then "Found original numpy dependency in the ${{ matrix.requirements_lock_name }}.txt"; exit 1; fi + + else + # Patch build/requirements_lock_3_14_ft.txt to use TSAN instrumented NumPy and Scipy + + sed -i "s|--extra-index-url.*|--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" build/${{ matrix.requirements_lock_name }}.txt + + # We should install jpeg dev package to be able to build Pillow from source: + apt-get install -y libjpeg-dev --no-install-recommends + + # Install scipy runtime dependencies (in case we restore scipy wheel from cache): + apt-get install -y libopenblas-dev liblapack-dev --no-install-recommends + fi echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" @@ -200,17 +309,22 @@ jobs: # Check numpy version ./bazel cquery @pypi_numpy//:* | grep whl + if [ "${{ matrix.python-version }}" == "3.14" ]; then + # Check scipy version + ./bazel cquery @pypi_scipy//:* | grep whl + fi + # Build JAX and run tests ./bazel test \ --test_env=JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES \ --test_env=JAX_ENABLE_X64=$JAX_ENABLE_X64 \ --test_env=JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS \ --test_env=PYTHON_GIL=0 \ - --test_env=TSAN_OPTIONS=halt_on_error=1,suppressions=$PWD/.github/workflows/tsan-suppressions.txt \ + --test_env=TSAN_OPTIONS=halt_on_error=1,suppressions=$PWD/.github/workflows/tsan-suppressions_${{ matrix.python-version }}.txt \ --test_env=JAX_TEST_NUM_THREADS=8 \ --test_output=errors \ --local_test_jobs=32 \ - --test_timeout=600 \ + --test_timeout=1800 \ --config=resultstore \ --config=rbe_cache \ //tests:cpu_tests diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 371426973436..7ccbe974e5ab 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -33,7 +33,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install JAX test requirements diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index ecdf43b133cc..3739c9267730 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -44,7 +44,7 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: # Runner OS and Python values need to match the matrix stategy in the CPU tests job - runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-64"] + runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-16"] artifact: ["jaxlib"] python: ["3.10"] # Note: For reasons unknown, Github actions groups jobs with the same top-level name in the @@ -111,28 +111,21 @@ jobs: matrix: # Python values need to match the matrix stategy in the artifact build jobs above # See exlusions for what is fully tested - runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu","linux-x86-a4-224-b200-1gpu"] + runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"] python: ["3.10",] - cuda: ["12.1","12.3","12.8"] + cuda: ["12.1", "12.8"] enable-x64: [1, 0] exclude: - # L4 does not run on cuda 12.8 but tests other configs - - runner: "linux-x86-g2-48-l4-4gpu" - cuda: "12.8" - # H100 runs only a single config, CUDA 12.3 Enable x64 1 - - runner: "linux-x86-a3-8g-h100-8gpu" - cuda: "12.8" + # H100 runs only a single config, CUDA 12.8 Enable x64 1 - runner: "linux-x86-a3-8g-h100-8gpu" cuda: "12.1" - runner: "linux-x86-a3-8g-h100-8gpu" enable-x64: "0" # B200 runs only a single config, CUDA 12.8 Enable x64 1 - - runner: "linux-x86-a4-224-b200-1gpu" - enable-x64: "0" - runner: "linux-x86-a4-224-b200-1gpu" cuda: "12.1" - runner: "linux-x86-a4-224-b200-1gpu" - cuda: "12.3" + enable-x64: "0" name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})" with: @@ -148,7 +141,7 @@ jobs: # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we # still want to run the tests for other platforms. if: ${{ !cancelled() }} - needs: [build-jaxlib-artifact, build-cuda-artifacts] + needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts] uses: ./.github/workflows/bazel_cuda_non_rbe.yml strategy: fail-fast: false # don't cancel all jobs on failure diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index adb678be9d9d..f6d2aa9b97c6 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -1,12 +1,14 @@ # CI - Wheel Tests (Nightly/Release) # -# This workflow builds JAX artifacts and runs CPU/CUDA tests. +# This workflow is used to test the JAX wheels that was built by internal CI jobs. # -# It orchestrates the following: -# 1. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the jaxlib wheel that was +# 1. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the JAX wheels that was # built by internal CI jobs and runs CPU tests. -# 2. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA -# artifacts that were built by internal CI jobs and runs the CUDA tests. +# 2. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the JAX wheels that was +# built by internal CI jobs and runs CUDA tests. +# 3. run-pytest-tpu: Calls the `pytest_tpu.yml` workflow which downloads the JAX wheels that was +# built by internal CI jobs and runs TPU tests. +# 4. verify-release-wheels-install: Verifies that JAX's release wheels can be installed. name: CI - Wheel Tests (Nightly/Release) on: @@ -17,6 +19,11 @@ on: required: true default: 'gs://jax-nightly-release-transient/nightly/latest' type: string + download-jax-only-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + required: true + default: '0' + type: string concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} @@ -41,6 +48,7 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} + download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} gcs_download_uri: ${{inputs.gcs_download_uri}} run-pytest-cuda: @@ -52,7 +60,7 @@ jobs: # that build the wheels. runner: ["linux-x86-g2-48-l4-4gpu"] python: ["3.10","3.11", "3.12", "3.13", "3.13-nogil"] - cuda: ["12.3", "12.1"] + cuda: ["12.1", "12.8"] enable-x64: [0] name: "Pytest CUDA (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" with: @@ -60,6 +68,7 @@ jobs: python: ${{ matrix.python }} cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} + download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} gcs_download_uri: ${{inputs.gcs_download_uri}} run-pytest-tpu: @@ -79,7 +88,7 @@ jobs: exclude: - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }} - libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }} - # Run a single Python version for v4-8. + # Run a single Python version for v4-8 - tpu-specs: type: "v4-8" python: "3.10" @@ -98,4 +107,89 @@ jobs: python: ${{ matrix.python }} run-full-tpu-test-suite: "1" libtpu-version-type: ${{ matrix.libtpu-version-type }} - gcs_download_uri: ${{inputs.gcs_download_uri}} \ No newline at end of file + download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} + gcs_download_uri: ${{inputs.gcs_download_uri}} + + verify-release-wheels-install: + if: ${{ startsWith(github.ref_name, 'release/')}} + defaults: + run: + # Set the shell to bash as GitHub actions runs with /bin/sh by default + shell: bash + runs-on: linux-x86-n2-16 + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + python: ["3.10", "3.13", "3.13-nogil"] + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + + # Verifies that JAX's release wheels can be installed + name: "Verify release wheels install (Python ${{ matrix.python }})" + + env: + PYTHON: "python${{ matrix.python }}" + + steps: + - name: Download release wheels from GCS + run: | + mkdir -p $(pwd)/dist + final_gcs_download_uri=${{ inputs.gcs_download_uri }} + + # Get the major and minor version of Python. + # E.g if python=3.10, then python_major_minor=310 + # E.g if python=3.13-nogil, then python_major_minor=313t + python_major_minor=${{ matrix.python }} + python_major_minor=$(echo "${python_major_minor//-nogil/t}" | tr -d '.') + python_major_minor="cp${python_major_minor%t}-cp${python_major_minor}-" + + gsutil -m cp -r "${final_gcs_download_uri}"/jax*py3*none*any.whl $(pwd)/dist/ + + jax_wheel=$(ls dist/jax*py3*none*any.whl 2>/dev/null) + echo "JAX_WHEEL=$jax_wheel" >> $GITHUB_ENV + + if [[ "${{ inputs.download-jax-only-from-gcs }}" != "1" ]]; then + gsutil -m cp -r "${final_gcs_download_uri}/jaxlib*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/ + gsutil -m cp -r "${final_gcs_download_uri}/jax*cuda*plugin*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/ + gsutil -m cp -r "${final_gcs_download_uri}/jax*cuda*pjrt*linux*x86_64*.whl" $(pwd)/dist/ + + jaxlib_wheel=$(ls dist/jaxlib*${python_major_minor}*linux*x86_64*.whl 2>/dev/null) + jax_cuda_plugin_wheel=$(ls dist/jax*cuda*plugin*${python_major_minor}*linux*x86_64*.whl 2>/dev/null) + jax_cuda_pjrt_wheel=$(ls dist/jax*cuda*pjrt*linux*x86_64*.whl 2>/dev/null) + + echo "JAXLIB_WHEEL=$jaxlib_wheel" >> $GITHUB_ENV + echo "JAX_CUDA_PLUGIN_WHEEL=$jax_cuda_plugin_wheel" >> $GITHUB_ENV + echo "JAX_CUDA_PJRT_WHEEL=$jax_cuda_pjrt_wheel" >> $GITHUB_ENV + fi + - name: Verify JAX CPU packages can be installed + run: | + $PYTHON -m uv venv ~/test_cpu && source ~/test_cpu/bin/activate + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL + else + uv pip install $JAX_WHEEL $JAXLIB_WHEEL + fi + - name: Verify JAX TPU packages can be installed + run: | + $PYTHON -m uv venv ~/test_tpu && source ~/test_tpu/bin/activate + + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[tpu] + else + uv pip install $JAX_WHEEL[tpu] $JAXLIB_WHEEL + fi + - name: Verify JAX CUDA packages can be installed (Nvidia Pip Packages) + run: | + $PYTHON -m uv venv ~/test_cuda_pip && source ~/test_cuda_pip/bin/activate + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[cuda] + else + uv pip install $JAX_WHEEL[cuda] $JAXLIB_WHEEL $JAX_CUDA_PJRT_WHEEL $JAX_CUDA_PLUGIN_WHEEL[with-cuda] + fi + - name: Verify JAX CUDA packages can be installed (CUDA local) + run: | + $PYTHON -m uv venv ~/test_cuda_local && source ~/test_cuda_local/bin/activate + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[cuda12-local] + else + uv pip install $JAX_WHEEL $JAXLIB_WHEEL $JAX_CUDA_PJRT_WHEEL $JAX_CUDA_PLUGIN_WHEEL + fi \ No newline at end of file diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 444bc83f2889..a2b3aeddc24a 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -27,7 +27,7 @@ jobs: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.pyver }} cache: 'pip' @@ -38,7 +38,7 @@ jobs: JAXLIB_RELEASE: true run: | python -m pip install uv~=0.5.30 - python -m uv pip install -r build/test-requirements.txt \ + python -m uv pip install -r build/test-requirements.txt ` --upgrade numpy==2.0.0 scipy==1.13.1 "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH python.exe build\build.py build --wheels=jaxlib ` @@ -58,7 +58,7 @@ jobs: JAX_SKIP_SLOW_TESTS: true PY_COLORS: 1 run: | - python -m uv pip install --find-links ${{ github.workspace }}\dist jaxlib \ + python -m uv pip install --find-links ${{ github.workspace }}\dist jaxlib ` -e ${{ github.workspace }} echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" pytest -n auto --tb=short tests examples diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index fc2b63396f56..5a435023ffda 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -35,7 +35,7 @@ jobs: with: path: jax - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.pyver }} cache: 'pip' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 27ccc6d831f3..89ce80d9a815 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,6 +15,7 @@ repos: - id: check-merge-conflict - id: check-toml - id: check-yaml + exclude: examples/k8s/svc-acct.yaml - id: end-of-file-fixer # only include python files files: \.py$ diff --git a/BUILD.bazel b/BUILD.bazel index 33cbefd29f0b..8dbf2bed0902 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +load( + "@xla//third_party/py:py_import.bzl", + "py_import", +) load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") load( "//jaxlib:jax.bzl", + "jax_source_package", "jax_wheel", + "py_deps", + "pytype_test", ) collect_data_files( @@ -41,6 +48,7 @@ transitive_py_deps( "//jax:sparse_test_util", "//jax:test_util", "//jax/_src/lib", + "//jax/_src/pallas/fuser", "//jax/_src/pallas/mosaic_gpu", "//jax/experimental/array_serialization:serialization", "//jax/experimental/jax2tf", @@ -65,20 +73,98 @@ py_binary( ], ) +WHEEL_SOURCE_FILES = [ + ":transitive_py_data", + ":transitive_py_deps", + "//jax:py.typed", + "AUTHORS", + "LICENSE", + "README.md", + "pyproject.toml", + "setup.py", +] + jax_wheel( name = "jax_wheel", - build_wheel_only = False, platform_independent = True, - source_files = [ - ":transitive_py_data", - ":transitive_py_deps", - "//jax:py.typed", - "AUTHORS", - "LICENSE", - "README.md", - "pyproject.toml", - "setup.py", - ], + source_files = WHEEL_SOURCE_FILES, + wheel_binary = ":build_wheel", + wheel_name = "jax", +) + +jax_wheel( + name = "jax_wheel_editable", + editable = True, + platform_independent = True, + source_files = WHEEL_SOURCE_FILES, wheel_binary = ":build_wheel", wheel_name = "jax", ) + +jax_source_package( + name = "jax_source_package", + source_files = WHEEL_SOURCE_FILES, + source_package_binary = ":build_wheel", + source_package_name = "jax", +) + +genrule( + name = "wheel_additives", + srcs = [ + "//jax:internal_export_back_compat_test_util", + "//jax:internal_test_harnesses", + "//jax:internal_test_util", + "//jax:internal_export_back_compat_test_data", + "//jax:experimental/pallas/ops/tpu/random/philox.py", + "//jax:experimental/pallas/ops/tpu/random/prng_utils.py", + "//jax:experimental/pallas/ops/tpu/random/threefry.py", + "//jax/experimental/mosaic/gpu/examples:flash_attention.py", + "//jax/experimental/mosaic/gpu/examples:matmul.py", + ], + outs = ["wheel_additives.zip"], + cmd = "$(location @bazel_tools//tools/zip:zipper) c $@ $(SRCS)", + tools = ["@bazel_tools//tools/zip:zipper"], +) + +COMMON_DEPS = py_deps([ + "absl/testing", + "numpy", + "ml_dtypes", + "scipy", + "opt_einsum", + "hypothesis", + "cloudpickle", + "flatbuffers", +]) + +py_import( + name = "jax_py_import", + wheel = ":jax_wheel", + wheel_deps = [":wheel_additives"], + deps = COMMON_DEPS, +) + +# This target is used to add more sources to the jax wheel. +# This is needed for the tests that depend on jax and use modules that are not part of +# the jax wheel, but share the same package paths as the modules in the jax wheel. +py_import( + name = "jax_wheel_with_internal_test_util", + wheel = "@pypi_jax//:whl", + wheel_deps = [":wheel_additives"], + deps = COMMON_DEPS, +) + +pytype_test( + name = "jax_wheel_size_test", + srcs = ["//jaxlib/tools:wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jax_wheel)", + "--max-size-mib=5", + ], + data = [":jax_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) diff --git a/CHANGELOG.md b/CHANGELOG.md index c30877ecae14..7859c723ee60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Change log -Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html). +Best viewed [here](https://docs.jax.dev/en/latest/changelog.html). For the changes specific to the experimental Pallas APIs, see {ref}`pallas-changelog`. @@ -16,6 +16,54 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased +* Breaking changes + + * {func}`jax.numpy.array` no longer accepts `None`. This behavior was + deprecated since November 2023 and is now removed. + * Removed the `config.jax_data_dependent_tracing_fallback` config option, + which was added temporarily in v0.4.36 to allow users to opt out of the + new "stackless" tracing machinery. + * Removed the `config.jax_eager_pmap` config option. + +* Changes + * The minimum CuDNN version is v9.8. + * JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain + supported. + +* Deprecations + + * {func}`jax.tree_util.build_tree` is deprecated. Use {func}`jax.tree.unflatten` + instead. + * Implemented host callback handlers for CPU and GPU devices using XLA's FFI + and removed existing CPU/GPU handlers using XLA's custom call. + * All APIs in `jax.lib.xla_extension` are now deprecated. + * `jax.interpreters.mlir.hlo` and `jax.interpreters.mlir.func_dialect`, + which were accidental exports, have been removed. If needed, they are + available from `jax.extend.mlir`. + * `jax.interpreters.mlir.custom_call` is deprecated. The APIs provided by + {mod}`jax.ffi` should be used instead. + * The deprecated use of {func}`jax.ffi.ffi_call` with inline arguments is no + longer supported. {func}`~jax.ffi.ffi_call` now unconditionally returns a + callable. + * Several previously-deprecated APIs have been removed, including: + * From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`, + and `shape_from_pyval`. + * From `jax.lib.xla_extension`: `ArrayImpl`, `XlaRuntimeError`. + * From `jax`: `jax.treedef_is_leaf`, `jax.tree_flatten`, `jax.tree_map`, + `jax.tree_leaves`, `jax.tree_structure`, `jax.tree_transpose`, and + `jax.tree_unflatten`. Replacements can be found in {mod}`jax.tree` or + {mod}`jax.tree_util`. + * From `jax.core`: `AxisSize`, `ClosedJaxpr`, `EvalTrace`, `InDBIdx`, `InputType`, + `Jaxpr`, `JaxprEqn`, `Literal`, `MapPrimitive`, `OpaqueTraceState`, `OutDBIdx`, + `Primitive`, `Token`, `TRACER_LEAK_DEBUGGER_WARNING`, `Var`, `concrete_aval`, + `dedup_referents`, `escaped_tracer_error`, `extend_axis_env_nd`, `full_lower`, `get_referent`, `jaxpr_as_fun`, `join_effects`, `lattice_join`, + `leaked_tracer_error`, `maybe_find_leaked_tracers`, `raise_to_shaped`, + `raise_to_shaped_mappings`, `reset_trace_state`, `str_eqn_compact`, + `substitute_vars_in_output_ty`, `typecompat`, and `used_axis_names_jaxpr`. Most + have no public replacement, though a few are available at {mod}`jax.extend.core`. + +## jax 0.5.3 (Mar 19, 2025) + * New Features * Added a `allow_negative_indices` option to {func}`jax.lax.dynamic_slice`, @@ -81,7 +129,7 @@ Patch release of 0.5.1 ## jax 0.5.0 (Jan 17, 2025) As of this release, JAX now uses -[effort-based versioning](https://jax.readthedocs.io/en/latest/jep/25516-effver.html). +[effort-based versioning](https://docs.jax.dev/en/latest/jep/25516-effver.html). Since this release makes a breaking change to PRNG key semantics that may require users to update their code, we are bumping the "meso" version of JAX to signify this. @@ -172,7 +220,7 @@ to signify this. * New Features * {func}`jax.export.export` can be used for device-polymorphic export with shardings constructed with {func}`jax.sharding.AbstractMesh`. - See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export). + See the [jax.export documentation](https://docs.jax.dev/en/latest/export/export.html#device-polymorphic-export). * Added {func}`jax.lax.split`. This is a primitive version of {func}`jax.numpy.split`, added because it yields a more compact transpose during automatic differentiation. @@ -214,7 +262,7 @@ This is a patch release of jax 0.4.36. Only "jax" was released at this version. after being deprecated in JAX v0.4.31. Instead use `xb = jax.lib.xla_bridge`, `xc = jax.lib.xla_client`, and `xe = jax.lib.xla_extension`. * The deprecated module `jax.experimental.export` has been removed. It was replaced - by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export) + by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://docs.jax.dev/en/latest/export/export.html#migration-guide-from-jax-experimental-export) for information on migrating to the new API. * The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax` has been removed, after being deprecated in v0.4.27. @@ -252,7 +300,7 @@ This is a patch release of jax 0.4.36. Only "jax" was released at this version. call that we guarantee export stability. This is because this custom call relies on Triton IR, which is not guaranteed to be stable. If you need to export code that uses this custom call, you can use the `disabled_checks` - parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls). + parameter. See more details in the [documentation](https://docs.jax.dev/en/latest/export/export.html#compatibility-guarantees-for-custom-calls). * New Features * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for @@ -532,7 +580,7 @@ See the 0.4.33 release notes for more details. * Added an API for exporting and serializing JAX functions. This used to exist in `jax.experimental.export` (which is being deprecated), and will now live in `jax.export`. - See the [documentation](https://jax.readthedocs.io/en/latest/export/index.html). + See the [documentation](https://docs.jax.dev/en/latest/export/index.html). * Deprecations * Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed @@ -541,7 +589,7 @@ See the 0.4.33 release notes for more details. release. This previously was the case, but there was an inadvertent regression in the last several JAX releases. * `jax.experimental.export` is deprecated. Use {mod}`jax.export` instead. - See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export). + See the [migration guide](https://docs.jax.dev/en/latest/export/export.html#migration-guide-from-jax-experimental-export). * Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays `x` and `y`, `x.astype(y)` will raise a warning. To silence it use `x.astype(y.dtype)`. * `jax.xla_computation` is deprecated and will be removed in a future release. @@ -753,7 +801,7 @@ See the 0.4.33 release notes for more details. deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the `spmd_axis_name` argument for expressing SPMD device-parallel computations. * The `jax.experimental.host_callback` module is deprecated. - Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html). + Use instead the [new JAX external callbacks](https://docs.jax.dev/en/latest/notebooks/external_callbacks.html). Added `JAX_HOST_CALLBACK_LEGACY` flag to assist in the transition to the new callbacks. See {jax-issue}`#20385` for a discussion. * Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv` @@ -1225,9 +1273,9 @@ See the 0.4.33 release notes for more details. * Deprecations * Python 3.8 support has been dropped as per - https://jax.readthedocs.io/en/latest/deprecation.html + https://docs.jax.dev/en/latest/deprecation.html * JAX now requires NumPy 1.22 or newer as per - https://jax.readthedocs.io/en/latest/deprecation.html + https://docs.jax.dev/en/latest/deprecation.html * Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is no longer supported, after being deprecated in JAX version 0.4.7. For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)` @@ -1272,7 +1320,7 @@ See the 0.4.33 release notes for more details. * Deprecations * Python 3.8 support has been dropped as per - https://jax.readthedocs.io/en/latest/deprecation.html + https://docs.jax.dev/en/latest/deprecation.html ## jax 0.4.13 (June 22, 2023) @@ -1451,7 +1499,7 @@ See the 0.4.33 release notes for more details. ## jax 0.4.7 (March 27, 2023) * Changes - * As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration + * As per https://docs.jax.dev/en/latest/jax_array_migration.html#jax-array-migration `jax.config.jax_array` cannot be disabled anymore. * `jax.config.jax_jit_pjit_api_merge` cannot be disabled anymore. * {func}`jax.experimental.jax2tf.convert` now supports the `native_serialization` @@ -1535,7 +1583,7 @@ Changes: on top of each other. With the `jit`-`pjit` implementation merge, `jit` becomes an initial style primitive which means that we trace to jaxpr as early as possible. For more information see - [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing). + [this section in autodidax](https://docs.jax.dev/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing). Moving to initial style should simplify JAX's internals and make development of features like dynamic shapes, etc easier. You can disable it only via the environment variable i.e. @@ -1620,9 +1668,9 @@ Changes: simplifies and unifies JAX internals, and allows us to unify `jit` and `pjit`. `jax.Array` has been enabled by default in JAX 0.4 and makes some breaking change to the `pjit` API. The [jax.Array migration - guide](https://jax.readthedocs.io/en/latest/jax_array_migration.html) can + guide](https://docs.jax.dev/en/latest/jax_array_migration.html) can help you migrate your codebase to `jax.Array`. You can also look at the - [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) + [Distributed arrays and automatic parallelization](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial to understand the new concepts. * `PartitionSpec` and `Mesh` are now out of experimental. The new API endpoints are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`. @@ -1651,7 +1699,7 @@ Changes: * The behavior of `XLA_PYTHON_CLIENT_MEM_FRACTION=.XX` has been changed to allocate XX% of the total GPU memory instead of the previous behavior of using currently available GPU memory to calculate preallocation. Please refer to - [GPU memory allocation](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html) for + [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html) for more details. * The deprecated method `.block_host_until_ready()` has been removed. Use `.block_until_ready()` instead. @@ -1765,7 +1813,7 @@ Changes: * Changes * Ahead-of-time lowering and compilation functionality (tracked in {jax-issue}`#7733`) is stable and public. See [the - overview](https://jax.readthedocs.io/en/latest/aot.html) and the API docs + overview](https://docs.jax.dev/en/latest/aot.html) and the API docs for {mod}`jax.stages`. * Introduced {class}`jax.Array`, intended to be used for both `isinstance` checks and type annotations for array types in JAX. Notice that this included some subtle @@ -1786,7 +1834,7 @@ Changes: * Breaking changes * {func}`jax.checkpoint`, also known as {func}`jax.remat`, no longer supports the `concrete` option, following the previous version's deprecation; see - [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). + [JEP 11830](https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html). * Changes * Added {func}`jax.pure_callback` that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with `jax.jit` or `jax.pmap`). * Deprecations: @@ -1798,7 +1846,7 @@ Changes: * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.15...main). * Breaking changes * Support for NumPy 1.19 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to NumPy 1.20 or newer. * Changes * Added {mod}`jax.debug` that includes utilities for runtime value debugging such at {func}`jax.debug.print` and {func}`jax.debug.breakpoint`. @@ -1816,7 +1864,7 @@ Changes: {mod}`jax.example_libraries.optimizers`. * {func}`jax.checkpoint`, also known as {func}`jax.remat`, has a new implementation switched on by default, meaning the old implementation is - deprecated; see [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). + deprecated; see [JEP 11830](https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html). ## jax 0.3.15 (July 22, 2022) * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.14...jax-v0.3.15). @@ -1948,7 +1996,7 @@ Changes: * {func}`jax.numpy.linalg.matrix_rank` on TPUs now accepts complex input. * {func}`jax.scipy.cluster.vq.vq` has been added. * `jax.experimental.maps.mesh` has been deleted. - Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh + Please use `jax.experimental.maps.Mesh`. Please see https://docs.jax.dev/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information. * {func}`jax.scipy.linalg.qr` now returns a length-1 tuple rather than the raw array when `mode='r'`, in order to match the behavior of `scipy.linalg.qr` ({jax-issue}`#10452`) @@ -2064,7 +2112,7 @@ Changes: * Changes: * The functions `jax.ops.index_update`, `jax.ops.index_add`, which were deprecated in 0.2.22, have been removed. Please use - [the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) + [the `.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) instead, e.g., `x.at[idx].set(y)`. * Moved `jax.experimental.ann.approx_*_k` into `jax.lax`. These functions are optimized alternatives to `jax.lax.top_k`. @@ -2110,13 +2158,13 @@ Changes: commits](https://github.com/jax-ml/jax/compare/jax-v0.2.28...jax-v0.3.0). * Changes - * jax version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html) + * jax version has been bumped to 0.3.0. Please see the [design doc](https://docs.jax.dev/en/latest/design_notes/jax_versioning.html) for the explanation. ## jaxlib 0.3.0 (Feb 10, 2022) * Changes * Bazel 5.0.0 is now required to build jaxlib. - * jaxlib version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html) + * jaxlib version has been bumped to 0.3.0. Please see the [design doc](https://docs.jax.dev/en/latest/design_notes/jax_versioning.html) for the explanation. ## jax 0.2.28 (Feb 1, 2022) @@ -2138,7 +2186,7 @@ Changes: by default. * Breaking changes * Support for NumPy 1.18 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * Bug fixes * Fixed a bug where apparently identical pytreedef objects constructed by different routes @@ -2150,7 +2198,7 @@ Changes: * Breaking changes: * Support for NumPy 1.18 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The host_callback primitives have been simplified to drop the special autodiff handling for hcb.id_tap and id_print. @@ -2277,7 +2325,7 @@ Changes: * Deprecations * The functions `jax.ops.index_update`, `jax.ops.index_add` etc. are deprecated and will be removed in a future JAX release. Please use - [the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) + [the `.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) instead, e.g., `x.at[idx].set(y)`. For now, these functions produce a `DeprecationWarning`. * New features: @@ -2341,7 +2389,7 @@ Changes: commits](https://github.com/jax-ml/jax/compare/jax-v0.2.18...jax-v0.2.19). * Breaking changes: * Support for NumPy 1.17 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The `jit` decorator has been added around the implementation of a number of operators on JAX arrays. This speeds up dispatch times for common @@ -2362,10 +2410,10 @@ Changes: ## jaxlib 0.1.70 (Aug 9, 2021) * Breaking changes: * Support for Python 3.6 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported Python version. * Support for NumPy 1.17 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The host_callback mechanism now uses one thread per local device for @@ -2379,7 +2427,7 @@ Changes: * Breaking changes: * Support for Python 3.6 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported Python version. * The minimum jaxlib version is now 0.1.69. * The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been @@ -2428,7 +2476,7 @@ Changes: * Breaking changes: * Support for NumPy 1.16 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). * Bug fixes: * Fixed bug that prevented round-tripping from JAX to TF and back: @@ -2968,7 +3016,7 @@ Changes: * Support for reduction over subsets of a pmapped axis using `axis_index_groups` {jax-issue}`#2382`. * Experimental support for printing and calling host-side Python function from - compiled code. See [id_print and id_tap](https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html) + compiled code. See [id_print and id_tap](https://docs.jax.dev/en/latest/jax.experimental.host_callback.html) ({jax-issue}`#3006`). * Notable changes: * The visibility of names exported from {mod}`jax.numpy` has been @@ -3040,7 +3088,7 @@ Changes: ## jax 0.1.63 (April 12, 2020) * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.62...jax-v0.1.63). -* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works). +* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works). * Add `scipy.sparse.linalg.cg` {jax-issue}`#2566`. * Changed how Tracers are printed to show more useful information for debugging {jax-issue}`#2591`. * Made `jax.numpy.isclose` handle `nan` and `inf` correctly {jax-issue}`#2501`. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 314d4387a044..046d3df3195c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,4 +1,4 @@ # Contributing to JAX For information on how to contribute to JAX, see -[Contributing to JAX](https://jax.readthedocs.io/en/latest/contributing.html) +[Contributing to JAX](https://docs.jax.dev/en/latest/contributing.html) diff --git a/README.md b/README.md index 0aca7cf58e6e..0057440f5a55 100644 --- a/README.md +++ b/README.md @@ -11,8 +11,8 @@ | [**Transformations**](#transformations) | [**Install guide**](#installation) | [**Neural net libraries**](#neural-network-libraries) -| [**Change logs**](https://jax.readthedocs.io/en/latest/changelog.html) -| [**Reference docs**](https://jax.readthedocs.io/en/latest/) +| [**Change logs**](https://docs.jax.dev/en/latest/changelog.html) +| [**Reference docs**](https://docs.jax.dev/en/latest/) ## What is JAX? @@ -48,7 +48,7 @@ are instances of such transformations. Others are parallel programming of multiple accelerators, with more to come. This is a research project, not an official Google product. Expect -[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +[sharp edges](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). Please help by trying it out, [reporting bugs](https://github.com/jax-ml/jax/issues), and letting us know what you think! @@ -83,15 +83,15 @@ perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example gra ## Quickstart: Colab in the Cloud Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks: -- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html) +- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://docs.jax.dev/en/latest/quickstart.html) - [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) **JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU Colabs](https://github.com/jax-ml/jax/tree/main/cloud_tpu_colabs). For a deeper dive into JAX: -- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) -- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) +- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html) +- [Common gotchas and sharp edges](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html) - See the [full list of notebooks](https://github.com/jax-ml/jax/tree/main/docs/notebooks). @@ -105,7 +105,7 @@ Here are four transformations of primary interest: `grad`, `jit`, `vmap`, and JAX has roughly the same API as [Autograd](https://github.com/hips/autograd). The most popular function is -[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad) +[`grad`](https://docs.jax.dev/en/latest/jax.html#jax.grad) for reverse-mode gradients: ```python @@ -129,13 +129,13 @@ print(grad(grad(grad(tanh)))(1.0)) ``` For more advanced autodiff, you can use -[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for +[`jax.vjp`](https://docs.jax.dev/en/latest/jax.html#jax.vjp) for reverse-mode vector-Jacobian products and -[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) for +[`jax.jvp`](https://docs.jax.dev/en/latest/jax.html#jax.jvp) for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here's one way to compose those to make a function that efficiently computes [full Hessian -matrices](https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html#jax.hessian): +matrices](https://docs.jax.dev/en/latest/_autosummary/jax.hessian.html#jax.hessian): ```python from jax import jit, jacfwd, jacrev @@ -160,15 +160,15 @@ print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated) ``` See the [reference docs on automatic -differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) +differentiation](https://docs.jax.dev/en/latest/jax.html#automatic-differentiation) and the [JAX Autodiff -Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) +Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html) for more. ### Compilation with `jit` You can use XLA to compile your functions end-to-end with -[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), +[`jit`](https://docs.jax.dev/en/latest/jax.html#just-in-time-compilation-jit), used either as an `@jit` decorator or as a higher-order function. ```python @@ -189,12 +189,12 @@ You can mix `jit` and `grad` and any other JAX transformation however you like. Using `jit` puts constraints on the kind of Python control flow the function can use; see -the tutorial on [Control Flow and Logical Operators with JIT](https://jax.readthedocs.io/en/latest/control-flow.html) +the tutorial on [Control Flow and Logical Operators with JIT](https://docs.jax.dev/en/latest/control-flow.html) for more. ### Auto-vectorization with `vmap` -[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is +[`vmap`](https://docs.jax.dev/en/latest/jax.html#vectorization-vmap) is the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a @@ -259,7 +259,7 @@ differentiation for fast Jacobian and Hessian matrix calculations in ### SPMD programming with `pmap` For parallel programming of multiple accelerators, like multiple GPUs, use -[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap). +[`pmap`](https://docs.jax.dev/en/latest/jax.html#parallelization-pmap). With `pmap` you write single-program multiple-data (SPMD) programs, including fast parallel collective communication operations. Applying `pmap` will mean that the function you write is compiled by XLA (similarly to `jit`), then @@ -284,7 +284,7 @@ print(pmap(jnp.mean)(result)) ``` In addition to expressing pure maps, you can use fast [collective communication -operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) +operations](https://docs.jax.dev/en/latest/jax.lax.html#parallel-operators) between devices: ```python @@ -341,20 +341,20 @@ for more. For a more thorough survey of current gotchas, with examples and explanations, we highly recommend reading the [Gotchas -Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +Notebook](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). Some standouts: 1. JAX transformations only work on [pure functions](https://en.wikipedia.org/wiki/Pure_function), which don't have side-effects and respect [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...` or `Exception: Different traces at same level`. 1. [In-place mutating updates of - arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. + arrays](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://docs.jax.dev/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. 1. [Random numbers are - different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). + different](https://docs.jax.dev/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). 1. If you're looking for [convolution - operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html), + operators](https://docs.jax.dev/en/latest/notebooks/convolutions.html), they're in the `jax.lax` package. 1. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and [to enable - double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) + double-precision](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) (64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at startup (or set the environment variable `JAX_ENABLE_X64=True`). On TPU, JAX uses 32-bit values by default for everything _except_ internal @@ -368,14 +368,14 @@ Some standouts: and NumPy types aren't preserved, namely `np.add(1, np.array([2], np.float32)).dtype` is `float64` rather than `float32`. 1. Some transformations, like `jit`, [constrain how you can use Python control - flow](https://jax.readthedocs.io/en/latest/control-flow.html). + flow](https://docs.jax.dev/en/latest/control-flow.html). You'll always get loud errors if something goes wrong. You might have to use [`jit`'s `static_argnums` - parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), + parameter](https://docs.jax.dev/en/latest/jax.html#just-in-time-compilation-jit), [structured control flow - primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators) + primitives](https://docs.jax.dev/en/latest/jax.lax.html#control-flow-operators) like - [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan), + [`lax.scan`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan), or just use `jit` on smaller subfunctions. ## Installation @@ -403,7 +403,7 @@ Some standouts: | Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). | | Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). | -See [the documentation](https://jax.readthedocs.io/en/latest/installation.html) +See [the documentation](https://docs.jax.dev/en/latest/installation.html) for information on alternative installation strategies. These include compiling from source, installing with Docker, using other versions of CUDA, a community-supported conda build, and answers to some frequently-asked questions. @@ -417,7 +417,7 @@ for training neural networks in JAX. If you want a fully featured library for ne training with examples and how-to guides, try [Flax](https://github.com/google/flax) and its [documentation site](https://flax.readthedocs.io/en/latest/nnx/index.html). -Check out the [JAX Ecosystem section](https://jax.readthedocs.io/en/latest/#ecosystem) +Check out the [JAX Ecosystem section](https://docs.jax.dev/en/latest/#ecosystem) on the JAX documentation site for a list of JAX-based network libraries, which includes [Optax](https://github.com/deepmind/optax) for gradient processing and optimization, [chex](https://github.com/deepmind/chex) for reliable code and testing, and @@ -452,7 +452,8 @@ paper. ## Reference documentation For details about the JAX API, see the -[reference documentation](https://jax.readthedocs.io/). +[reference documentation](https://docs.jax.dev/). For getting started as a JAX developer, see the -[developer documentation](https://jax.readthedocs.io/en/latest/developer.html). +[developer documentation](https://docs.jax.dev/en/latest/developer.html). + diff --git a/WORKSPACE b/WORKSPACE index 129488281ea9..5c093ec2228f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -14,8 +14,10 @@ python_init_repositories( "3.12": "//build:requirements_lock_3_12.txt", "3.13": "//build:requirements_lock_3_13.txt", "3.13-ft": "//build:requirements_lock_3_13_ft.txt", + "3.14-ft": "//build:requirements_lock_3_14_ft.txt", }, local_wheel_inclusion_list = [ + "jax-*", "jaxlib*", "jax_cuda*", "jax-cuda*", diff --git a/benchmarks/tracing_benchmark.py b/benchmarks/tracing_benchmark.py new file mode 100644 index 000000000000..e06ad538d476 --- /dev/null +++ b/benchmarks/tracing_benchmark.py @@ -0,0 +1,76 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmarks for Jax tracing.""" + +import google_benchmark +import jax +from jax import random +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib +import numpy as np + + +def make_mqa_splash_attention_fn_and_args(): + seed = 0 + key = random.key(seed) + k1, k2, k3 = random.split(key, 3) + + q_seq_len = 1024 + kv_seq_len = 1024 + num_q_heads = 2 + head_dim_qk = 128 + head_dim_v = 128 + dtype = np.dtype("float32") + + q = random.uniform(k1, (num_q_heads, q_seq_len, head_dim_qk), dtype=dtype) + k = random.uniform(k2, (kv_seq_len, head_dim_qk), dtype=dtype) + v = random.uniform(k3, (kv_seq_len, head_dim_v), dtype=dtype) + + mask = mask_lib.NumpyMask( + mask_lib.make_random_mask((q_seq_len, kv_seq_len), sparsity=0.5, seed=0) + ) + mask = mask_lib.MultiHeadMask(tuple(mask for _ in range(num_q_heads))) + block_sizes = splash.BlockSizes.get_default() + + return ( + jax.jit( + splash.make_splash_mqa_single_device(mask, block_sizes=block_sizes) + ) + ), (q, k, v) + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_pallas_mqa_splash_attention_trace(state): + attn, (q, k, v) = make_mqa_splash_attention_fn_and_args() + + while state: + _ = attn.trace(q, k, v) + jax.clear_caches() + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_pallas_mqa_splash_attention_lower(state): + attn, (q, k, v) = make_mqa_splash_attention_fn_and_args() + traced = attn.trace(q, k, v) + + while state: + _ = traced.lower(lowering_platforms=("tpu",)) + jax.clear_caches() + + +if __name__ == "__main__": + google_benchmark.main() diff --git a/build/build.py b/build/build.py index 5bcbdd862446..ad526faa0344 100755 --- a/build/build.py +++ b/build/build.py @@ -68,13 +68,20 @@ # rule as the default. WHEEL_BUILD_TARGET_DICT_NEW = { "jax": "//:jax_wheel", + "jax_editable": "//:jax_wheel_editable", + "jax_source_package": "//:jax_source_package", "jaxlib": "//jaxlib/tools:jaxlib_wheel", + "jaxlib_editable": "//jaxlib/tools:jaxlib_wheel_editable", "jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel", + "jax-cuda-plugin_editable": "//jaxlib/tools:jax_cuda_plugin_wheel_editable", "jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel", + "jax-cuda-pjrt_editable": "//jaxlib/tools:jax_cuda_pjrt_wheel_editable", "jax-rocm-plugin": "//jaxlib/tools:jax_rocm_plugin_wheel", "jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel", } +_JAX_CUDA_VERSION = "12" + def add_global_arguments(parser: argparse.ArgumentParser): """Adds all the global arguments that applies to all the CLI subcommands.""" parser.add_argument( @@ -382,6 +389,11 @@ async def main(): arch = platform.machine() os_name = platform.system().lower() + custom_wheel_version_suffix = "" + wheel_build_date = "" + wheel_git_hash = "" + wheel_type = "snapshot" + args = parser.parse_args() logger.info("%s", BANNER) @@ -407,7 +419,7 @@ async def main(): for option in args.bazel_startup_options: bazel_command_base.append(option) - if not args.use_new_wheel_build_rule or args.command == "requirements_update": + if args.command == "requirements_update" or not args.use_new_wheel_build_rule: bazel_command_base.append("run") else: bazel_command_base.append("build") @@ -489,6 +501,7 @@ async def main(): if args.use_clang: clang_path = args.clang_path or utils.get_clang_path_or_exit() clang_major_version = utils.get_clang_major_version(clang_path) + clangpp_path = utils.get_clangpp_path(clang_path) logging.debug( "Using Clang as the compiler, clang path: %s, clang version: %s", clang_path, @@ -498,6 +511,7 @@ async def main(): # Use double quotes around clang path to avoid path issues on Windows. wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") wheel_build_command_base.append(f"--repo_env=CC=\"{clang_path}\"") + wheel_build_command_base.append(f"--repo_env=CXX=\"{clangpp_path}\"") wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") if clang_major_version >= 16: @@ -612,6 +626,17 @@ async def main(): ) for option in args.bazel_options: wheel_build_command_base.append(option) + + # Parse the build options for the wheel version suffix. + if "ML_WHEEL_TYPE" in option: + wheel_type = option.split("=")[-1] + if "ML_WHEEL_VERSION_SUFFIX" in option: + custom_wheel_version_suffix = option.split("=")[-1].replace("-", "") + if "ML_WHEEL_BUILD_DATE" in option: + wheel_build_date = option.split("=")[-1].replace("-", "") + if "ML_WHEEL_GIT_HASH" in option: + wheel_git_hash = option.split("=")[-1][:9] + if "cuda" in args.wheels: wheel_build_command_base.append("--config=cuda_libraries_from_stubs") @@ -659,8 +684,13 @@ async def main(): ) # Append the build target to the Bazel command. - build_target = wheel_build_targets[wheel] + if args.use_new_wheel_build_rule and args.editable: + build_target = wheel_build_targets[wheel + "_editable"] + else: + build_target = wheel_build_targets[wheel] wheel_build_command.append(build_target) + if args.use_new_wheel_build_rule and wheel == "jax" and not args.editable: + wheel_build_command.append(wheel_build_targets["jax_source_package"]) if not args.use_new_wheel_build_rule: wheel_build_command.append("--") @@ -692,6 +722,54 @@ async def main(): if result.return_code != 0: raise RuntimeError(f"Command failed with return code {result.return_code}") + if args.use_new_wheel_build_rule: + output_path = args.output_path + jax_bazel_dir = os.path.join("bazel-bin", "dist") + jaxlib_and_plugins_bazel_dir = os.path.join( + "bazel-bin", "jaxlib", "tools", "dist" + ) + for wheel in args.wheels.split(","): + if wheel == "jax": + bazel_dir = jax_bazel_dir + else: + bazel_dir = jaxlib_and_plugins_bazel_dir + if "cuda" in wheel: + wheel_dir = wheel.replace("cuda", f"cuda{_JAX_CUDA_VERSION}").replace( + "-", "_" + ) + else: + wheel_dir = wheel + + if args.editable: + src_dir = os.path.join(bazel_dir, wheel_dir) + dst_dir = os.path.join(output_path, wheel_dir) + utils.copy_dir_recursively(src_dir, dst_dir) + else: + wheel_version_suffix = "dev0+selfbuilt" + if wheel_type == "release": + wheel_version_suffix = custom_wheel_version_suffix + elif wheel_type in ["nightly", "custom"]: + wheel_version_suffix = f".dev{wheel_build_date}" + if wheel_type == "custom": + wheel_version_suffix += ( + f"+{wheel_git_hash}{custom_wheel_version_suffix}" + ) + if wheel in ["jax", "jax-cuda-pjrt"]: + python_tag = "py" + else: + python_tag = "cp" + utils.copy_individual_files( + bazel_dir, + output_path, + f"{wheel_dir}*{wheel_version_suffix}-{python_tag}*.whl", + ) + if wheel == "jax": + utils.copy_individual_files( + bazel_dir, + output_path, + f"{wheel_dir}*{wheel_version_suffix}.tar.gz", + ) + # Exit with success if all wheels in the list were built successfully. sys.exit(0) diff --git a/build/gpu-test-requirements.txt b/build/gpu-test-requirements.txt index ff43f91ba90f..d0dda5cf526c 100644 --- a/build/gpu-test-requirements.txt +++ b/build/gpu-test-requirements.txt @@ -5,7 +5,7 @@ nvidia-cublas-cu12>=12.1.3.1 ; sys_platform == "linux" nvidia-cuda-cupti-cu12>=12.1.105 ; sys_platform == "linux" nvidia-cuda-nvcc-cu12>=12.6.85 ; sys_platform == "linux" nvidia-cuda-runtime-cu12>=12.1.105 ; sys_platform == "linux" -nvidia-cudnn-cu12>=9.1,<10.0 ; sys_platform == "linux" +nvidia-cudnn-cu12>=9.8,<10.0 ; sys_platform == "linux" nvidia-cufft-cu12>=11.0.2.54 ; sys_platform == "linux" nvidia-cusolver-cu12>=11.4.5.107 ; sys_platform == "linux" nvidia-cusparse-cu12>=12.1.0.106 ; sys_platform == "linux" diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index 6ed6b59aa584..8bf5293bd948 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -410,10 +410,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 8446e8361505..487346ab6d12 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -405,10 +405,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 0436ab6dd486..e2f76cab8abc 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -405,10 +405,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index e74d40b798f4..403d0ad8a061 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -460,10 +460,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index e7a2968e981e..a96a3e6e489b 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -413,10 +413,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ @@ -658,7 +658,7 @@ zipp==3.21.0 \ --hash=sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4 \ --hash=sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931 # via etils -# python 3.13t can compile 0.23.0 +# python 3.13t can't compile 0.23.0 # due to https://github.com/indygreg/python-zstandard/issues/231 # zstandard==0.23.0 \ # --hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \ diff --git a/build/requirements_lock_3_14_ft.txt b/build/requirements_lock_3_14_ft.txt new file mode 100644 index 000000000000..e50305f4fa48 --- /dev/null +++ b/build/requirements_lock_3_14_ft.txt @@ -0,0 +1,21 @@ +--pre +--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple +numpy + +--pre +--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple +scipy + +absl-py==2.1.0 + +attrs==24.3.0 + +hypothesis==6.123.9 + +sortedcontainers==2.4.0 + +flatbuffers==24.12.23 + +ml-dtypes==0.5.1 + +opt-einsum==3.4.0 diff --git a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm index 14bf6fd60f84..72a5ce41aa0c 100644 --- a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm +++ b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm @@ -9,7 +9,7 @@ ARG ROCM_BUILD_NUM # manylinux base image. However, adding this does fix an issue where Bazel isn't able # to find them. RUN --mount=type=cache,target=/var/cache/dnf \ - dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64 numactl-devel + dnf install -y numactl-devel RUN --mount=type=cache,target=/var/cache/dnf \ --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ @@ -25,5 +25,11 @@ RUN mkdir /tmp/llvm-project && wget -qO - https://github.com/llvm/llvm-project/a mkdir /tmp/llvm-project/build && cd /tmp/llvm-project/build && cmake -DLLVM_ENABLE_PROJECTS='clang;lld' -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/lib/llvm-18/ ../llvm && \ make -j$(nproc) && make -j$(nproc) install && rm -rf /tmp/llvm-project +# Set some clang config +COPY ./build/rocm/build_wheels/clang.cfg /usr/lib/llvm-18/bin/clang++.cfg +COPY ./build/rocm/build_wheels/clang.cfg /usr/lib/llvm-18/bin/clang.cfg +COPY ./build/rocm/build_wheels/clang.cfg /opt/rocm/llvm/bin/clang++.cfg +COPY ./build/rocm/build_wheels/clang.cfg /opt/rocm/llvm/bin/clang.cfg + # Stop git from erroring out when we don't own the repo RUN git config --global --add safe.directory '*' diff --git a/build/rocm/build_wheels/clang.cfg b/build/rocm/build_wheels/clang.cfg new file mode 100644 index 000000000000..767c04c03ae7 --- /dev/null +++ b/build/rocm/build_wheels/clang.cfg @@ -0,0 +1,3 @@ +# Tell clang where it can find gcc so that it can use gcc's standard libraries +--gcc-toolchain=/opt/rh/gcc-toolset-14/root/usr/ + diff --git a/build/rocm/ci_build b/build/rocm/ci_build index ee2c8698d346..b8faa485afd7 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -98,7 +98,10 @@ def dist_wheels( bw_cmd.append("/jax") - cmd = ["docker", "run"] + cmd = [ + "docker", + "run", + ] mounts = [ "-v", diff --git a/build/rocm/tools/fixwheel.py b/build/rocm/tools/fixwheel.py index ea77162728d5..7d8c1fcce055 100644 --- a/build/rocm/tools/fixwheel.py +++ b/build/rocm/tools/fixwheel.py @@ -87,7 +87,7 @@ def fix_wheel(path): exclude = list(ext_libs.keys()) # call auditwheel repair with excludes - cmd = ["auditwheel", "repair", "--plat", plat, "--only-plat"] + cmd = ["auditwheel", "-v", "repair", "--plat", plat, "--only-plat"] for ex in exclude: cmd.append("--exclude") diff --git a/build/tools/utils.py b/build/tools/utils.py index 7e375169827b..6170cdfabb50 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -14,6 +14,7 @@ # ============================================================================== # Helper script for tools/utilities used by the JAX build CLI. import collections +import glob import hashlib import logging import os @@ -201,6 +202,20 @@ def get_clang_major_version(clang_path): return major_version +def get_clangpp_path(clang_path): + clang_path = pathlib.Path(clang_path) + clang_exec_name = clang_path.name + clangpp_exec_name = clang_exec_name + if "clang++" not in clang_exec_name: + clangpp_exec_name = re.sub("clang(-[0-9.]*)?", "clang++", clangpp_exec_name) + clangpp_path = clang_path.parent / clangpp_exec_name + if not clangpp_path.exists(): + raise FileNotFoundError( + f"Failed to get clang++ path from clang path: '{clang_path!s}'. " + f"Tried the path: '{clangpp_path!s}'." + ) + return str(clangpp_path) + def get_gcc_major_version(gcc_path: str): gcc_version_proc = subprocess.run( [gcc_path, "-dumpversion"], @@ -256,3 +271,28 @@ def _parse_string_as_bool(s): return False else: raise ValueError(f"Expected either 'true' or 'false'; got {s}") + + +def copy_dir_recursively(src, dst): + if os.path.exists(dst): + shutil.rmtree(dst) + os.makedirs(dst, exist_ok=True) + for root, dirs, files in os.walk(src): + relative_path = os.path.relpath(root, src) + dst_dir = os.path.join(dst, relative_path) + os.makedirs(dst_dir, exist_ok=True) + for f in files: + src_file = os.path.join(root, f) + dst_file = os.path.join(dst_dir, f) + shutil.copy2(src_file, dst_file) + logging.info("Editable wheel path: %s" % dst) + + +def copy_individual_files(src, dst, regex): + os.makedirs(dst, exist_ok=True) + for f in glob.glob(os.path.join(src, regex)): + dst_file = os.path.join(dst, os.path.basename(f)) + if os.path.exists(dst_file): + os.remove(dst_file) + shutil.copy2(f, dst_file) + logging.info("Distribution path: %s" % dst_file) diff --git a/build_wheel.py b/build_wheel.py index f8e1595d3c3a..793523e8e3b2 100644 --- a/build_wheel.py +++ b/build_wheel.py @@ -47,6 +47,25 @@ parser.add_argument( "--srcs", help="source files for the wheel", action="append" ) +parser.add_argument( + "--build-wheel-only", + default=False, + help=( + "Whether to build the wheel only. Optional." + ), +) +parser.add_argument( + "--build-source-package-only", + default=False, + help=( + "Whether to build the source package only. Optional." + ), +) +parser.add_argument( + "--editable", + action="store_true", + help="Create an 'editable' jax build instead of a wheel.", +) args = parser.parse_args() @@ -76,7 +95,11 @@ def prepare_srcs(deps: list[str], srcs_dir: str) -> None: """ for file in deps: - if not (file.startswith("bazel-out") or file.startswith("external")): + if not ( + file.startswith("bazel-out") + or file.startswith("external") + or file.startswith("jaxlib") + ): copy_file(file, srcs_dir) @@ -89,13 +112,18 @@ def prepare_srcs(deps: list[str], srcs_dir: str) -> None: try: os.makedirs(args.output_path, exist_ok=True) prepare_srcs(args.srcs, pathlib.Path(sources_path)) - build_utils.build_wheel( - sources_path, - args.output_path, - package_name="jax", - git_hash=args.jaxlib_git_hash, - build_wheel_only=False, - ) + package_name = "jax" + if args.editable: + build_utils.build_editable(sources_path, args.output_path, package_name) + else: + build_utils.build_wheel( + sources_path, + args.output_path, + package_name, + git_hash=args.jaxlib_git_hash, + build_wheel_only=args.build_wheel_only, + build_source_package_only=args.build_source_package_only, + ) finally: if tmpdir: tmpdir.cleanup() diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index 84b8d35a2a50..d7ffe82eb699 100755 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -96,6 +96,7 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + --output_path="$JAXCI_OUTPUT_DIR" \ $artifact_tag_flags # If building release artifacts, we also build a release candidate ("rc") @@ -105,18 +106,10 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + --output_path="$JAXCI_OUTPUT_DIR" \ $artifact_tag_flags --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="$JAXCI_WHEEL_RC_VERSION" fi - # Move the built artifacts from the Bazel cache directory to the output - # directory. - if [[ "$artifact" == "jax" ]]; then - mv bazel-bin/dist/*.whl "$JAXCI_OUTPUT_DIR" - mv bazel-bin/dist/*.tar.gz "$JAXCI_OUTPUT_DIR" - else - mv bazel-bin/jaxlib/tools/dist/*.whl "$JAXCI_OUTPUT_DIR" - fi - # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we # run `auditwheel show` to verify manylinux compliance. if [[ "$os" == "linux" ]] && [[ "$artifact" != "jax" ]]; then diff --git a/ci/envs/docker.env b/ci/envs/docker.env index 82a76d33350c..a0f558520d45 100644 --- a/ci/envs/docker.env +++ b/ci/envs/docker.env @@ -41,5 +41,5 @@ fi # Windows image for building JAX artifacts, running Pytests CPU tests, and Bazel # tests if [[ $os =~ "msys_nt" ]]; then - export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/tf-test-windows@sha256:6e2b299f12418d70ea522646b3dd618042a102f2ac2e4f8b1e423638549ea801" + export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/tf-test-windows:latest" fi \ No newline at end of file diff --git a/ci/run_bazel_test_cpu_rbe.sh b/ci/run_bazel_test_cpu_rbe.sh index 248111e0247a..d8cb190079e0 100755 --- a/ci/run_bazel_test_cpu_rbe.sh +++ b/ci/run_bazel_test_cpu_rbe.sh @@ -64,5 +64,7 @@ else --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --test_output=errors \ --color=yes \ - //tests:cpu_tests //tests:backend_independent_tests + //tests:cpu_tests //tests:backend_independent_tests \ + //jaxlib/tools:jaxlib_wheel_size_test \ + //:jax_wheel_size_test fi \ No newline at end of file diff --git a/ci/run_bazel_test_cuda_rbe.sh b/ci/run_bazel_test_cuda_rbe.sh index 17bd8d9db4f8..94c6a89fdb8c 100755 --- a/ci/run_bazel_test_cuda_rbe.sh +++ b/ci/run_bazel_test_cuda_rbe.sh @@ -48,4 +48,10 @@ bazel test --config=rbe_linux_x86_64_cuda \ --test_env=JAX_SKIP_SLOW_TESTS=true \ --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --color=yes \ - //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ No newline at end of file + --@local_config_cuda//cuda:override_include_cuda_libs=true \ + //tests:gpu_tests //tests:backend_independent_tests \ + //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ + //jaxlib/tools:jax_cuda_plugin_wheel_size_test \ + //jaxlib/tools:jax_cuda_pjrt_wheel_size_test \ + //jaxlib/tools:jaxlib_wheel_size_test \ + //:jax_wheel_size_test \ No newline at end of file diff --git a/ci/run_pytest_cpu.sh b/ci/run_pytest_cpu.sh index 43581ef2c96c..9de29691f753 100755 --- a/ci/run_pytest_cpu.sh +++ b/ci/run_pytest_cpu.sh @@ -26,13 +26,13 @@ set -exu -o history -o allexport # Source default JAXCI environment variables. source ci/envs/default.env +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + # Install jaxlib wheel inside the $JAXCI_OUTPUT_DIR directory on the system. echo "Installing wheels locally..." source ./ci/utilities/install_wheels_locally.sh -# Set up the build environment. -source "ci/utilities/setup_build_environment.sh" - # Print all the installed packages echo "Installed packages:" "$JAXCI_PYTHON" -m uv pip list diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index f98f7658ad18..53f070d1e0e6 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -26,27 +26,34 @@ for i in "${!WHEELS[@]}"; do # Append [tpu] to the jax wheel name to download the latest libtpu wheel # from PyPI. WHEELS[$i]="${WHEELS[$i]}[tpu]" + elif [[ "$JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI" == "jax_cuda_pypi" ]]; then + # Append [cuda12-local] to the jax wheel name to download the latest + # release of JAX's CUDA plugin and PJRT packages from PyPI. This is used + # when running CUDA tests for a "jax" only release. + WHEELS[$i]="${WHEELS[$i]}[cuda12-local]" fi fi done -if [[ -z "${WHEELS[@]}" ]]; then - echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" - exit 1 -fi +if [[ -n "${WHEELS[@]}" ]]; then + echo "Installing the following wheels:" + echo "${WHEELS[@]}" -echo "Installing the following wheels:" -echo "${WHEELS[@]}" - -# Install `uv` if it's not already installed. `uv` is much faster than pip for -# installing Python packages. -if ! command -v uv >/dev/null 2>&1; then - pip install uv~=0.5.30 -fi + # Install `uv` if it's not already installed. `uv` is much faster than pip for + # installing Python packages. + if ! command -v uv >/dev/null 2>&1; then + pip install uv~=0.5.30 + fi -# On Windows, convert MSYS Linux-like paths to Windows paths. -if [[ $(uname -s) =~ "MSYS_NT" ]]; then - "$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}") + # On Windows, convert MSYS Linux-like paths to Windows paths. + if [[ $(uname -s) =~ "MSYS_NT" ]]; then + "$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}") + else + "$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}" + fi else - "$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}" + # Note that we don't exit here because the wheels may have been installed + # earlier in a different step in the CI job. + echo "INFO: No wheels found under $JAXCI_OUTPUT_DIR" + echo "INFO: Skipping local wheel installation." fi \ No newline at end of file diff --git a/ci/utilities/run_auditwheel.sh b/ci/utilities/run_auditwheel.sh index 30b6a3b51865..b8f80c3e6778 100755 --- a/ci/utilities/run_auditwheel.sh +++ b/ci/utilities/run_auditwheel.sh @@ -26,6 +26,10 @@ if [[ -z "$WHEELS" ]]; then fi for wheel in $WHEELS; do + # Skip checking manylinux compliance for jax wheel. + if [[ "$wheel" =~ 'jax-' ]]; then + continue + fi printf "\nRunning auditwheel on the following wheel:" ls $wheel OUTPUT_FULL=$(python -m auditwheel show $wheel) diff --git a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb index edaa71b93e85..5bc045d0f606 100644 --- a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb +++ b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb @@ -225,7 +225,7 @@ "* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n", "\n", "\n", - "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { diff --git a/cloud_tpu_colabs/JAX_demo.ipynb b/cloud_tpu_colabs/JAX_demo.ipynb index d7ba5ed334f4..b69246c57e0b 100644 --- a/cloud_tpu_colabs/JAX_demo.ipynb +++ b/cloud_tpu_colabs/JAX_demo.ipynb @@ -315,7 +315,7 @@ "* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n", "\n", "\n", - "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { diff --git a/cloud_tpu_colabs/Pmap_Cookbook.ipynb b/cloud_tpu_colabs/Pmap_Cookbook.ipynb index ea126ac4f1e7..8b16cd7694eb 100644 --- a/cloud_tpu_colabs/Pmap_Cookbook.ipynb +++ b/cloud_tpu_colabs/Pmap_Cookbook.ipynb @@ -59,7 +59,7 @@ "id": "2e_06-OAJNyi" }, "source": [ - "A basic starting point is expressing parallel maps with [`pmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap):" + "A basic starting point is expressing parallel maps with [`pmap`](https://docs.jax.dev/en/latest/jax.html#jax.pmap):" ] }, { @@ -407,7 +407,7 @@ "source": [ "When writing nested `pmap` functions in the decorator style, axis names are resolved according to lexical scoping.\n", "\n", - "Check [the JAX reference documentation](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n", + "Check [the JAX reference documentation](https://docs.jax.dev/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n", "\n", "Here's how to use `lax.ppermute` to implement a simple halo exchange for a [Rule 30](https://en.wikipedia.org/wiki/Rule_30) simulation:" ] diff --git a/cloud_tpu_colabs/README.md b/cloud_tpu_colabs/README.md index db3dc5f30814..6e5501584da0 100644 --- a/cloud_tpu_colabs/README.md +++ b/cloud_tpu_colabs/README.md @@ -4,7 +4,7 @@ The same JAX code that runs on CPU and GPU can also be run on TPU. Cloud TPUs have the advantage of quickly giving you access to multiple TPU accelerators, including in [Colab](https://research.google.com/colaboratory/). All of the example notebooks here use -[`jax.pmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap) to run JAX +[`jax.pmap`](https://docs.jax.dev/en/latest/jax.html#jax.pmap) to run JAX computation across multiple TPU cores from Colab. You can also run the same code directly on a [Cloud TPU VM](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm). diff --git a/docs/README.md b/docs/README.md index 12e00425592f..54b8a67477b0 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,2 +1,2 @@ To rebuild the documentation, -see [Update Documentation](https://jax.readthedocs.io/en/latest/developer.html#update-documentation). +see [Update Documentation](https://docs.jax.dev/en/latest/developer.html#update-documentation). diff --git a/docs/_static/pallas/gpu/nvidia_sm.svg b/docs/_static/pallas/gpu/nvidia_sm.svg new file mode 100644 index 000000000000..76b4edb2afad --- /dev/null +++ b/docs/_static/pallas/gpu/nvidia_sm.svg @@ -0,0 +1,99 @@ + + + + + Streaming Multiprocessor + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + Shared Memory / L1 Cache + + + diff --git a/docs/about.md b/docs/about.md index 58e1703842b9..baeed941c8c3 100644 --- a/docs/about.md +++ b/docs/about.md @@ -19,7 +19,7 @@ technology stack](#components). First, we design the `jax` module to be [composable](https://github.com/jax-ml/jax?tab=readme-ov-file#transformations) and -[extensible](https://jax.readthedocs.io/en/latest/jax.extend.html), so +[extensible](https://docs.jax.dev/en/latest/jax.extend.html), so that a wide variety of domain-specific libraries can thrive outside of it in a decentralized manner. Second, we lean heavily on a modular backend stack (compiler and runtime) to target different @@ -42,10 +42,10 @@ scale. JAX's day-to-day development takes place in the open on GitHub, using pull requests, the issue tracker, discussions, and [JAX Enhancement Proposals -(JEPs)](https://jax.readthedocs.io/en/latest/jep/index.html). Reading +(JEPs)](https://docs.jax.dev/en/latest/jep/index.html). Reading and participating in these is a good way to get involved. We also maintain [developer -notes](https://jax.readthedocs.io/en/latest/contributor_guide.html) +notes](https://docs.jax.dev/en/latest/contributor_guide.html) that cover JAX's internal design. The JAX core team determines whether to accept changes and @@ -56,7 +56,7 @@ intricate decision structure over time (e.g. with designated area owners) if/when it becomes useful to do so. For more see [contributing to -JAX](https://jax.readthedocs.io/en/latest/contributing.html). +JAX](https://docs.jax.dev/en/latest/contributing.html). (components)= ## A modular stack @@ -71,7 +71,7 @@ and (b) an advancing hardware landscape, we lean heavily on While the JAX core library focuses on the fundamentals, we want to encourage domain-specific libraries and tools to be built on top of JAX. Indeed, [many -libraries](https://jax.readthedocs.io/en/latest/#ecosystem) have +libraries](https://docs.jax.dev/en/latest/#ecosystem) have emerged around JAX to offer higher-level features and extensions. How do we encourage such decentralized development? We guide it with @@ -80,11 +80,11 @@ building blocks (e.g. numerical primitives, NumPy operations, arrays, and transformations), encouraging auxiliary libraries to develop utilities as needed for their domain. In addition, JAX exposes a handful of more advanced APIs for -[customization](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) +[customization](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) and -[extensibility](https://jax.readthedocs.io/en/latest/jax.extend.html). Libraries +[extensibility](https://docs.jax.dev/en/latest/jax.extend.html). Libraries can [lean on these -APIs](https://jax.readthedocs.io/en/latest/building_on_jax.html) in +APIs](https://docs.jax.dev/en/latest/building_on_jax.html) in order to use JAX as an internal means of implementation, to integrate more with its transformations like autodiff, and more. diff --git a/docs/advanced-autodiff.md b/docs/advanced-autodiff.md index eaa3bc7317c8..bef2fd088a3a 100644 --- a/docs/advanced-autodiff.md +++ b/docs/advanced-autodiff.md @@ -876,7 +876,7 @@ There are two ways to define differentiation rules in JAX: 1. Using {func}`jax.custom_jvp` and {func}`jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and 2. Defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems. -This notebook is about #1. To read instead about #2, refer to the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html). +This notebook is about #1. To read instead about #2, refer to the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html). ### TL;DR: Custom JVPs with {func}`jax.custom_jvp` @@ -1608,7 +1608,7 @@ Array(-0.91113025, dtype=float32) #### Working with `list` / `tuple` / `dict` containers (and other pytrees) -You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. +You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. Here's a contrived example with {func}`jax.custom_jvp`: diff --git a/docs/aot.md b/docs/aot.md index 1fcf11ab945d..8f68c2758148 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -26,7 +26,7 @@ are arrays, JAX does the following in order: carries out this specialization by a process that we call _tracing_. During tracing, JAX stages the specialization of `F` to a jaxpr, which is a function in the [Jaxpr intermediate - language](https://jax.readthedocs.io/en/latest/jaxpr.html). + language](https://docs.jax.dev/en/latest/jaxpr.html). 2. **Lower** this specialized, staged-out computation to the XLA compiler's input language, StableHLO. diff --git a/docs/api_compatibility.md b/docs/api_compatibility.md index 749c5907bc6b..9dca1fc08f50 100644 --- a/docs/api_compatibility.md +++ b/docs/api_compatibility.md @@ -91,7 +91,7 @@ guarantees of the main JAX package. If you have code that uses `jax.extend`, we would strongly recommend CI tests against JAX's nightly releases, so as to catch potential changes before they are released. -For details on `jax.extend`, see the [`jax.extend` module docuementation](https://jax.readthedocs.io/en/latest/jax.extend.html), or the design document, {ref}`jax-extend-jep`. +For details on `jax.extend`, see the [`jax.extend` module docuementation](https://docs.jax.dev/en/latest/jax.extend.html), or the design document, {ref}`jax-extend-jep`. ## Numerics and randomness diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 7ec91affa05d..b6f12b624f8b 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -72,7 +72,7 @@ "outputs, we want to override primitive application and let different values\n", "flow through our program. For example, we might want to replace the\n", "application of every primitive with an application of [its JVP\n", - "rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),\n", + "rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html),\n", "and let primal-tangent pairs flow through our program. Moreover, we want to be\n", "able to compose multiple transformations, leading to stacks of interpreters." ] @@ -3620,7 +3620,7 @@ "source": [ "Notice that we're not currently supporting the case where the predicate value\n", "itself is batched. In mainline JAX, we handle this case by transforming the\n", - "conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html).\n", + "conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html).\n", "That transformation is semantically correct so long as `true_fun` and\n", "`false_fun` do not involve any side-effecting primitives.\n", "\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 2d4d6cd528af..1c375e21227c 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -72,7 +72,7 @@ where we apply primitive operations to numerical inputs to produce numerical outputs, we want to override primitive application and let different values flow through our program. For example, we might want to replace the application of every primitive with an application of [its JVP -rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html), +rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html), and let primal-tangent pairs flow through our program. Moreover, we want to be able to compose multiple transformations, leading to stacks of interpreters. @@ -2843,7 +2843,7 @@ print(out) Notice that we're not currently supporting the case where the predicate value itself is batched. In mainline JAX, we handle this case by transforming the -conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html). +conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html). That transformation is semantically correct so long as `true_fun` and `false_fun` do not involve any side-effecting primitives. diff --git a/docs/autodidax.py b/docs/autodidax.py index f8c6372fe30d..6329234224cb 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -62,7 +62,7 @@ # outputs, we want to override primitive application and let different values # flow through our program. For example, we might want to replace the # application of every primitive with an application of [its JVP -# rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html), +# rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html), # and let primal-tangent pairs flow through our program. Moreover, we want to be # able to compose multiple transformations, leading to stacks of interpreters. @@ -2837,7 +2837,7 @@ def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr): # Notice that we're not currently supporting the case where the predicate value # itself is batched. In mainline JAX, we handle this case by transforming the -# conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html). +# conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html). # That transformation is semantically correct so long as `true_fun` and # `false_fun` do not involve any side-effecting primitives. # diff --git a/docs/building_on_jax.md b/docs/building_on_jax.md index 9416b16cde10..6d13f517f50b 100644 --- a/docs/building_on_jax.md +++ b/docs/building_on_jax.md @@ -45,8 +45,8 @@ Here are more specific examples of each pattern. ### Direct usage Jax can be directly imported and utilized to build models “from scratch” as shown across this website, -for example in [JAX Tutorials](https://jax.readthedocs.io/en/latest/tutorials.html) -or [Neural Network with JAX](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html). +for example in [JAX Tutorials](https://docs.jax.dev/en/latest/tutorials.html) +or [Neural Network with JAX](https://docs.jax.dev/en/latest/notebooks/neural_network_with_tfds_data.html). This may be the best option if you are unable to find prebuilt code for your particular challenge, or if you're looking to reduce the number of dependencies in your codebase. diff --git a/docs/contributing.md b/docs/contributing.md index 99d78453c436..53a863fdcd8c 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -6,7 +6,7 @@ Everyone can contribute to JAX, and we value everyone's contributions. There are ways to contribute, including: - Answering questions on JAX's [discussions page](https://github.com/jax-ml/jax/discussions) -- Improving or expanding JAX's [documentation](http://jax.readthedocs.io/) +- Improving or expanding JAX's [documentation](http://docs.jax.dev/) - Contributing to JAX's [code-base](http://github.com/jax-ml/jax/) - Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/jax-ml/jax#neural-network-libraries) diff --git a/docs/control-flow.md b/docs/control-flow.md index 7cb959f3e434..8f59bd92add7 100644 --- a/docs/control-flow.md +++ b/docs/control-flow.md @@ -244,19 +244,19 @@ lax.cond(False, lambda x: x+1, lambda x: x-1, operand) `jax.lax` provides two other functions that allow branching on dynamic predicates: -- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is +- [`lax.select`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html) is like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays rather than as functions. -- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is +- [`lax.switch`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.switch.html) is like `lax.cond`, but allows switching between any number of callable choices. In addition, `jax.numpy` provides several numpy-style interfaces to these functions: -- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with +- [`jnp.where`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.where.html) with three arguments is the numpy-style wrapper of `lax.select`. -- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html) +- [`jnp.piecewise`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.piecewise.html) is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index. -- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has +- [`jnp.select`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.select.html) has an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather than as functions. It is implemented in terms of multiple calls to `lax.select`. diff --git a/docs/default_dtypes.md b/docs/default_dtypes.md new file mode 100644 index 000000000000..629f7fb5c314 --- /dev/null +++ b/docs/default_dtypes.md @@ -0,0 +1,82 @@ +(default-dtypes)= +# Default dtypes and the X64 flag +JAX strives to meet the needs of a range of numerical computing practitioners, who +sometimes have conflicting preferences. When it comes to default dtypes, there are +two different camps: + +- Classic scientific computing practitioners (i.e. users of tools like {mod}`numpy` or + {mod}`scipy`) tend to value accuracy of computations foremost: such users would + prefer that computations default to the **widest available representation**: e.g. + floating point values should default to `float64`, integers to `int64`, etc. +- AI researchers (i.e. folks implementing and training neural networks) tend to value + speed over accuracy, to the point where they have developed special data types like + [bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) and others + which deliberately discard the least significant bits in order to speed up computation. + For these users, the mere presence of a float64 value in their computation can lead + to programs that are slow at best, and incompatible with their hardware at worst! + These users would prefer that computations default to `float32` or `int32`. + +The main mechanism JAX offers for this is the `jax_enable_x64` flag, which controls +whether 64-bit values can be created at all. By default this flag is set to `False` +(serving the needs of AI researchers and practitioners), but can be set to `True` +by users who value accuracy over computational speed. + +## Default setting: 32-bits everywhere +By default `jax_enable_x64` is set to False, and so {mod}`jax.numpy` array creation +functions will default to returning 32-bit values. + +For example: +```python +>>> import jax.numpy as jnp + +>>> jnp.arange(5) +Array([0, 1, 2, 3, 4], dtype=int32) + +>>> jnp.zeros(5) +Array([0., 0., 0., 0., 0.], dtype=float32) + +>>> jnp.ones(5, dtype=int) +Array([1, 1, 1, 1, 1], dtype=int32) + +``` + +Beyond defaults, because 64-bit values can be so poisonous to AI workflows, having +this flag set to False prevents you from creating 64-bit arrays at all! For example: +``` +>>> jnp.arange(5, dtype='float64') # doctest: +SKIP +UserWarning: Explicitly requested dtype float64 requested in arange is not available, and will be +truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the +JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. +Array([0., 1., 2., 3., 4.], dtype=float32) +``` + +## The X64 flag: enabling 64-bit values +To work in the "other mode" where functions default to producing 64-bit values, you can set the +`jax_enable_x64` flag to `True`: +```python +import jax +import jax.numpy as jnp + +jax.config.update('jax_enable_x64', True) + +print(repr(jnp.arange(5))) +print(repr(jnp.zeros(5))) +print(repr(jnp.ones(5, dtype=int))) +``` +``` +Array([0, 1, 2, 3, 4], dtype=int64) +Array([0., 0., 0., 0., 0.], dtype=float64) +Array([1, 1, 1, 1, 1], dtype=int64) +``` + +The X64 configuration can also be set via the `JAX_ENABLE_X64` shell environment variable, +for example: +```bash +$ JAX_ENABLE_X64=1 python main.py +``` +The X64 flag is intended as a **global setting** that should have one value for your whole +program, set at the top of your main file. A common feature request is for the flag to +be contextually configurable (e.g. enabling X64 just for one section of a long program): +this turns out to be difficult to implement within JAX's programming model, where code +execution may happen in a different context than code compilation. There is ongoing work +exploring the feasibility of relaxing this constraint, so stay tuned! diff --git a/docs/developer.md b/docs/developer.md index 0affbba9ed36..9edeaeac83f8 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -1,7 +1,7 @@ (building-from-source)= # Building from source - + First, obtain the JAX source code: @@ -526,23 +526,27 @@ bazel test //tests:cpu_tests //tests:backend_independent_tests `//tests:gpu_tests` and `//tests:tpu_tests` are also available, if you have the necessary hardware. -To use a preinstalled `jaxlib` instead of building it you first need to -make it available in the hermetic Python. To install a specific version of -`jaxlib` within hermetic Python run (using `jaxlib >= 0.4.26` as an example): +To use the preinstalled `jax` and `jaxlib` instead of building them you first +need to make them available in the hermetic Python. To install the specific +versions of `jax` and `jaxlib` within hermetic Python run (using `jax >= 0.4.26` +and `jaxlib >= 0.4.26` as an example): ``` +echo -e "\njax >= 0.4.26" >> build/requirements.in echo -e "\njaxlib >= 0.4.26" >> build/requirements.in python build/build.py requirements_update ``` -Alternatively, to install `jaxlib` from a local wheel (assuming Python 3.12): +Alternatively, to install `jax` and `jaxlib` from the local wheels +(assuming Python 3.12): ``` +echo -e "\n$(realpath jax-0.4.26-py3-none-any.whl)" >> build/requirements.in echo -e "\n$(realpath jaxlib-0.4.26-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in python build/build.py requirements_update --python_version=3.12 ``` -Once you have `jaxlib` installed hermetically, run: +Once you have `jax` and `jaxlib` installed hermetically, run: ``` bazel test --//jax:build_jaxlib=false //tests:cpu_tests //tests:backend_independent_tests @@ -785,7 +789,7 @@ desired formats, and which the `jupytext --sync` command recognizes when invoked #### Notebooks within the Sphinx build Some of the notebooks are built automatically as part of the pre-submit checks and -as part of the [Read the docs](https://jax.readthedocs.io/en/latest) build. +as part of the [Read the docs](https://docs.jax.dev/en/latest) build. The build will fail if cells raise errors. If the errors are intentional, you can either catch them, or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/jax-ml/jax/pull/2402/files)). You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else @@ -796,7 +800,7 @@ See `exclude_patterns` in [conf.py](https://github.com/jax-ml/jax/blob/main/docs ### Documentation building on `readthedocs.io` -JAX's auto-generated documentation is at . +JAX's auto-generated documentation is at . The documentation building is controlled for the entire project by the [readthedocs JAX settings](https://readthedocs.org/dashboard/jax). The current settings @@ -809,7 +813,7 @@ For each automated documentation build you can see the If you want to test the documentation generation on Readthedocs, you can push code to the `test-docs` branch. That branch is also built automatically, and you can -see the generated documentation [here](https://jax.readthedocs.io/en/test-docs/). If the documentation build +see the generated documentation [here](https://docs.jax.dev/en/test-docs/). If the documentation build fails you may want to [wipe the build environment for test-docs](https://docs.readthedocs.io/en/stable/guides/wipe-environment.html). For a local test, I was able to do it in a fresh directory by replaying the commands diff --git a/docs/export/export.md b/docs/export/export.md index 18cdcc6c51d0..63c0db14f905 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -161,7 +161,7 @@ e.g., the inference system.) What **matters is when the exporting and consuming components were built**, not the time when the exporting and the compilation happen. For external JAX users, it is -[possible to run JAX and jaxlib at different versions](https://jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned); +[possible to run JAX and jaxlib at different versions](https://docs.jax.dev/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned); what matters is when the jaxlib release was built. To reduce chances of incompatibility, internal JAX users should: diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index 9254030a4e1c..6b63a536ab48 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -86,7 +86,7 @@ matching the structure of the arguments passed to it. The polymorphic shapes specification can be a pytree prefix in cases where one specification should apply to multiple arguments, as in the above example. -See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). +See [how optional parameters are matched to arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). A few examples of shape specifications: @@ -609,7 +609,7 @@ Division had remainder 1 when computing the value of 'd'. Using the following polymorphic shapes specifications: args[0].shape = (b, b, 2*d). Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3), . -Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details. +Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details. ``` diff --git a/docs/faq.rst b/docs/faq.rst index 44267f6f5f7d..f5d43d25afb6 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -4,7 +4,7 @@ Frequently asked questions (FAQ) .. comment RST primer for Sphinx: https://thomas-cokelaer.info/tutorials/sphinx/rest_syntax.html .. comment Some links referenced here. Use `JAX - The Sharp Bits`_ (underscore at the end) to reference -.. _JAX - The Sharp Bits: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html +.. _JAX - The Sharp Bits: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html We are collecting answers to frequently asked questions here. Contributions welcome! @@ -116,7 +116,7 @@ code in JAX's internal representation, typically because it makes heavy use of Python control flow such as ``for`` loops. For a handful of loop iterations, Python is OK, but if you need *many* loop iterations, you should rewrite your code to make use of JAX's -`structured control flow primitives `_ +`structured control flow primitives `_ (such as :func:`lax.scan`) or avoid wrapping the loop with ``jit`` (you can still use ``jit`` decorated functions *inside* the loop). @@ -454,8 +454,8 @@ performing matrix-matrix multiplication) to amortize the increased overhead of JAX/accelerators vs NumPy/CPU. For example, if we switch this example to use 10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs). -.. _To JIT or not to JIT: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit -.. _Double (64 bit) precision: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision +.. _To JIT or not to JIT: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit +.. _Double (64 bit) precision: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision .. _`%time and %timeit magics`: https://ipython.readthedocs.io/en/stable/interactive/magics.html#magic-time .. _Colab: https://colab.research.google.com/ @@ -841,12 +841,12 @@ reducing :code:`XLA_PYTHON_CLIENT_MEM_FRACTION` from the default of :code:`.75`, or setting :code:`XLA_PYTHON_CLIENT_PREALLOCATE=false`. For more details, please see the page on `JAX GPU memory allocation`_. -.. _JIT mechanics: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables -.. _External callbacks in JAX: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html -.. _Pure callback example: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp -.. _IO callback example: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#exploring-jax-experimental-io-callback +.. _JIT mechanics: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables +.. _External callbacks in JAX: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html +.. _Pure callback example: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp +.. _IO callback example: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html#exploring-jax-experimental-io-callback .. _Heaviside Step Function: https://en.wikipedia.org/wiki/Heaviside_step_function .. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function .. _algebraic_simplifier.cc: https://github.com/openxla/xla/blob/33f815e190982dac4f20d1f35adb98497a382377/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc#L4851 -.. _JAX GPU memory allocation: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html +.. _JAX GPU memory allocation: https://docs.jax.dev/en/latest/gpu_memory_allocation.html .. _dynamic linker search pattern: https://man7.org/linux/man-pages/man8/ld.so.8.html diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index b622fba9d5bc..f74ae9d58a78 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -439,7 +439,7 @@ "As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.\n", "Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n", "\n", - "More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n", + "More details about custom derivative rules can be found in the [custom derivatives tutorial](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n", "In this case, we actually define two new FFI calls:\n", "\n", "1. `rms_norm_fwd` returns two outputs: (a) the \"primal\" result, and (b) the \"residuals\" which are used in the backwards pass.\n", @@ -785,7 +785,7 @@ "{func}`~jax.experimental.custom_partitioning.custom_partitioning` works by adding Python callbacks into the XLA compiler's partitioning pass, which allows very flexible logic, but also comes with some rough edges.\n", "We won't go into too much detail on the caveats here, but the main issues that you should be aware of are:\n", "\n", - "1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either.\n", + "1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://docs.jax.dev/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either.\n", "2. Debugging `custom_partitioning` logic can be tedious because Python errors don't always get propagated, instead causing your Python process to exit. That being said, any exceptions will show up in the process logs, so you should be able to track them down there.\n", "\n", "All that being said, here's how we can wrap our FFI implementation of `rms_norm` using {func}`~jax.experimental.custom_partitioning.custom_partitioning`:" diff --git a/docs/ffi.md b/docs/ffi.md index 4aa03c217855..97648c78e118 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -353,7 +353,7 @@ Unlike with batching, {func}`~jax.ffi.ffi_call` doesn't provide any default supp As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated. Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule. -More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function. +More details about custom derivative rules can be found in the [custom derivatives tutorial](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function. In this case, we actually define two new FFI calls: 1. `rms_norm_fwd` returns two outputs: (a) the "primal" result, and (b) the "residuals" which are used in the backwards pass. @@ -591,7 +591,7 @@ If you can't use {func}`~jax.experimental.shard_map.shard_map`, an alternative a {func}`~jax.experimental.custom_partitioning.custom_partitioning` works by adding Python callbacks into the XLA compiler's partitioning pass, which allows very flexible logic, but also comes with some rough edges. We won't go into too much detail on the caveats here, but the main issues that you should be aware of are: -1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either. +1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://docs.jax.dev/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either. 2. Debugging `custom_partitioning` logic can be tedious because Python errors don't always get propagated, instead causing your Python process to exit. That being said, any exceptions will show up in the process logs, so you should be able to track them down there. All that being said, here's how we can wrap our FFI implementation of `rms_norm` using {func}`~jax.experimental.custom_partitioning.custom_partitioning`: diff --git a/docs/gpu_memory_allocation.rst b/docs/gpu_memory_allocation.rst index 6667589e7b72..be40dfc8004c 100644 --- a/docs/gpu_memory_allocation.rst +++ b/docs/gpu_memory_allocation.rst @@ -69,7 +69,7 @@ Common causes of OOM failures disabling the automatic remat pass produces different trade-offs between compute and memory. Note however, that the algorithm is basic and you can often get better trade-off between compute and memory by disabling the automatic remat pass and doing - it manually with `the jax.remat API `_ + it manually with `the jax.remat API `_ Experimental features diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index bf032dccff88..b3643cb8e292 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -1,6 +1,6 @@ # GPU performance tips - + This document focuses on performance tips for neural network workloads @@ -58,7 +58,173 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta * **--xla_gpu_triton_gemm_any** Use the Triton-based GEMM (matmul) emitter for any GEMM that it supports. The default value is False. -### Communication flags +## Communication tips + +### Auto and manual PGLE + +The Profile Guided Latency Estimator (PGLE) workflow measures the actual running time +of compute and collectives, the the profile information is fed back into XLA compiler +for a better scheduling decision. + +The Profile Guided Latency Estimator can be used manually or automatically. In the auto mode +JAX will collect profile information and recompile a module in a single run. While +in manual mode you need to run a task twice, the first time to collect and save profiles +and the second to compile and run with provided data. + +**Important**: the JAX profiler, which is used by both of the PGLE workflows documented +below, cannot co-exist with the NVIDIA Nsight Systems profiler. This limitation can be +avoided by using the JAX compilation cache, as described below. + +### Auto PGLE +The auto PGLE can be turned on by setting the following environment variables: + +Mandatory: +```bash +export JAX_ENABLE_PGLE=true + +# For JAX version <= 0.5.0 make sure to include: +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true" +``` + +Optional: +```bash +export JAX_PGLE_PROFILING_RUNS=3 +export JAX_PGLE_AGGREGATION_PERCENTILE=85 + +# Right now the auto PGLE profile collection doesn't work with command buffer. +# If the command buffer is enabled, Auto PGLE will disable it during profile +# colletion and enable it back after the recompilation. If you need to have a +# consistent command buffer logic with and with PGLE profile you can disable it +# manually: +export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_command_buffer=''" +``` + +Or in the JAX this can be set as the following: + +``` +import jax +from jax._src import config + +with config.enable_pgle(True), config.pgle_profiling_runs(1): + # Run with the profiler collecting performance information. + train_step() + # Automatically re-compile with PGLE profile results + train_step() + ... +``` + +You can control amount of reruns used to collect profile data by changing `JAX_PGLE_PROFILING_RUNS`. +Increasing this parameter would lead to better profile information, but it will also increase the +amount of non-optimized training steps. + +Decreasing the `JAX_PGLE_AGGREGATION_PERCENTILE` parameter might help in case when performance between steps is too noisy to filter out a non-relevant measures. + +**Attention:** Auto PGLE doesn't work for pre-compiled modules. Since JAX need to recompile the module during execution the auto PGLE will not work neither for AoT nor for the following case: + +``` +import jax +from jax._src import config + +train_step_compiled = train_step().lower().compile() + +with config.enable_pgle(True), config.pgle_profiling_runs(1): + train_step_compiled() + # No effect since module was pre-compiled. + train_step_compiled() +``` + +#### Collecting NVIDIA Nsight Systems profiles when using AutoPGLE +[jax#24910](https://github.com/jax-ml/jax/pull/24910) (JAX v0.5.1 and newer) added a +new JAX configuration option, `JAX_COMPILATION_CACHE_EXPECT_PGLE`, which tells JAX to +attempt to load PGLE-optimized compiled functions from the persistent compilation +cache. + +This allows a two-step process, where the first step writes a PGLE-optimized function +to the cache: +```bash +export JAX_ENABLE_COMPILATION_CACHE=yes # not strictly needed, on by default +export JAX_COMPILATION_CACHE_DIR=/root/jax_cache +JAX_ENABLE_PGLE=yes python my-model.py +``` +And the second step uses Nsight Systems and loads the PGLE-optimized function from the +cache: +```bash +JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python my-model.py +``` +See also [this page]( +https://docs.jax.dev/en/latest/persistent_compilation_cache.html#pitfalls) for more +information about the persistent compilation cache and possible pitfalls. + +### Manual PGLE + +If you still want to use a manual Profile Guided Latency Estimator the workflow in XLA/GPU is: + +- 1. Run your workload once, with async collectives and latency hiding scheduler enabled. + +You could do so by setting: + +```bash +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true" +``` + +- 2. Collect and post process a profile by using JAX profiler, saving the extracted instruction latencies into a binary protobuf file. + +```python +import os +from etils import epath +import jax +from jax.experimental import profiler as exp_profiler + +# Define your profile directory +profile_dir = 'gs://my_bucket/profile' +jax.profiler.start_trace(profile_dir) + +# run your workflow +# for i in range(10): +# train_step() + +# Stop trace +jax.profiler.stop_trace() +profile_dir = epath.Path(profile_dir) +directories = profile_dir.glob('plugins/profile/*/') +directories = [d for d in directories if d.is_dir()] +rundir = directories[-1] +logging.info('rundir: %s', rundir) + +# Post process the profile +fdo_profile = exp_profiler.get_profiled_instructions_proto(os.fspath(rundir)) + +# Save the profile proto to a file. +dump_dir = rundir / 'profile.pb' +dump_dir.parent.mkdir(parents=True, exist_ok=True) +dump_dir.write_bytes(fdo_profile) + +``` + +After this step, you will get a `profile.pb` file under the `rundir` printed in the code. + +- 3. Run the workload again feeding that file into the compilation. + +You need to pass the `profile.pb` file to the `--xla_gpu_pgle_profile_file_or_directory_path` flag. + +```bash + export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_pgle_profile_file_or_directory_path=/path/to/profile/profile.pb" +``` + +To enable logging in the XLA and check if the profile is good, set the logging level to include `INFO`: + +```bash +export TF_CPP_MIN_LOG_LEVEL=0 +``` + +Run the real workflow, if you found these loggings in the running log, it means the profiler is used in the latency hiding scheduler: + +``` +2023-07-21 16:09:43.551600: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:478] Using PGLE profile from /tmp/profile/plugins/profile/2023_07_20_18_29_30/profile.pb +2023-07-21 16:09:43.551741: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:573] Found profile, using profile guided latency estimator +``` + +#### Flags * **--xla_gpu_enable_latency_hiding_scheduler** This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. diff --git a/docs/gradient-checkpointing.md b/docs/gradient-checkpointing.md index 0938a5da944f..e4e842df49f0 100644 --- a/docs/gradient-checkpointing.md +++ b/docs/gradient-checkpointing.md @@ -341,7 +341,7 @@ def predict(params, x): return x ``` -By itself, {func}`jax.ad_checkpoint import.checkpoint_name` is just an identity function. But because some policy functions know to look for them, you can use the names to control whether certain values output by {func}`jax.ad_checkpoint import.checkpoint_name` are considered saveable: +By itself, {func}`jax.ad_checkpoint.checkpoint_name` is just an identity function. But because some policy functions know to look for them, you can use the names to control whether certain values output by {func}`jax.ad_checkpoint.checkpoint_name` are considered saveable: ```{code-cell} print_saved_residuals(loss, params, x, y) diff --git a/docs/installation.md b/docs/installation.md index ee675dd1e586..34274d7596aa 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -229,7 +229,7 @@ refer to JAX has experimental ROCm support. There are two ways to install JAX: * Use [AMD's Docker container](https://hub.docker.com/r/rocm/jax-community/tags); or -* Build from source. Refer to the section [Additional notes for building a ROCm jaxlib for AMD GPUs](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). +* Build from source. Refer to the section [Additional notes for building a ROCm jaxlib for AMD GPUs](https://docs.jax.dev/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). (install-intel-gpu)= ## Intel GPU diff --git a/docs/jax-primitives.md b/docs/jax-primitives.md index abdc8be6d0a8..819d0418e894 100644 --- a/docs/jax-primitives.md +++ b/docs/jax-primitives.md @@ -21,7 +21,7 @@ kernelspec: A JAX primitive is the basic computational unit of a JAX program. This document explains the interface that a JAX primitive must support to allow JAX to perform all its transformations (this is not a how-to guide). -For example, the multiply-add operation can be implemented in terms of the low-level `jax.lax.*` primitives (which are like XLA operator wrappers) or `jax.core.Primitive("multiply_add")`, as demonstrated further below. +For example, the multiply-add operation can be implemented in terms of the low-level `jax.lax.*` primitives (which are like XLA operator wrappers) or `jax.extend.core.Primitive("multiply_add")`, as demonstrated further below. And JAX is able to take sequences of such primitive operations, and transform them via its composable transformations of Python functions, such as {func}`jax.jit`, {func}`jax.grad` and {func}`jax.vmap`. JAX implements these transforms in a *JAX-traceable* way. This means that when a Python function is executed, the only operations it applies to the data are either: @@ -171,7 +171,7 @@ The JAX traceability property is satisfied as long as the function is written in The right way to add support for multiply-add is in terms of existing JAX primitives, as shown above. However, to demonstrate how JAX primitives work, pretend that you want to add a new primitive to JAX for the multiply-add functionality. ```{code-cell} -from jax import core +from jax.extend import core multiply_add_p = core.Primitive("multiply_add") # Create the primitive @@ -300,7 +300,7 @@ def multiply_add_lowering(ctx, xc, yc, zc): return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result] # Now, register the lowering rule with JAX. -# For GPU, refer to the https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html +# For GPU, refer to the https://docs.jax.dev/en/latest/Custom_Operation_for_GPUs.html from jax.interpreters import mlir mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu') diff --git a/docs/jax.nn.rst b/docs/jax.nn.rst index adb13f89903d..339f07f4cdcc 100644 --- a/docs/jax.nn.rst +++ b/docs/jax.nn.rst @@ -40,6 +40,7 @@ Activation functions glu squareplus mish + identity Other functions --------------- @@ -53,3 +54,6 @@ Other functions standardize one_hot dot_product_attention + scaled_matmul + get_scaled_dot_general_config + scaled_dot_general diff --git a/docs/jax.rst b/docs/jax.rst index 98cd464cda15..2f16df613e5a 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -57,6 +57,7 @@ Configuration enable_custom_prng enable_custom_vjp_by_custom_transpose log_compiles + no_tracing numpy_rank_promotion transfer_guard diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index dcbb673997ad..3c436697e1be 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -69,6 +69,7 @@ jax.scipy.linalg lu lu_factor lu_solve + pascal polar qr rsf2csf diff --git a/docs/jax_array_migration.md b/docs/jax_array_migration.md index 95d4a632a295..3cc1629b2068 100644 --- a/docs/jax_array_migration.md +++ b/docs/jax_array_migration.md @@ -1,3 +1,6 @@ +--- +orphan: true +--- (jax-array-migration)= # jax.Array migration @@ -24,7 +27,7 @@ the unified jax.Array After the migration is complete `jax.Array` will be the only type of array in JAX. -This doc explains how to migrate existing codebases to `jax.Array`. For more information on using `jax.Array` and JAX parallelism APIs, see the [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial. +This doc explains how to migrate existing codebases to `jax.Array`. For more information on using `jax.Array` and JAX parallelism APIs, see the [Distributed arrays and automatic parallelization](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial. ### How to enable jax.Array? diff --git a/docs/jep/10657-sequencing-effects.md b/docs/jep/10657-sequencing-effects.md index 5f7eb0da4c04..ac3024519101 100644 --- a/docs/jep/10657-sequencing-effects.md +++ b/docs/jep/10657-sequencing-effects.md @@ -47,7 +47,7 @@ g() In many cases, JAX will execute `f` and `g` *in parallel*, dispatching the computations onto different threads -- `g` might actually be executed before `f`. Parallel execution is a nice performance optimization, especially if copying -to and from a device is expensive (see the [asynchronous dispatch note](https://jax.readthedocs.io/en/latest/async_dispatch.html) for more details). +to and from a device is expensive (see the [asynchronous dispatch note](https://docs.jax.dev/en/latest/async_dispatch.html) for more details). In practice, however, we often don't need to think about asynchronous dispatch because we're writing pure functions and only care about the inputs and outputs of functions -- we'll naturally block on future diff --git a/docs/jep/12049-type-annotations.md b/docs/jep/12049-type-annotations.md index 7a20958c5cab..bf6123b2bc7f 100644 --- a/docs/jep/12049-type-annotations.md +++ b/docs/jep/12049-type-annotations.md @@ -35,7 +35,7 @@ def slice(operand: Array, start_indices: Sequence[int], For the purposes of static type checking, this use of `Array = Any` for array type annotations puts no constraint on the argument values (`Any` is equivalent to no annotation at all), but it does serve as a form of useful in-code documentation for the developer. -For the sake of generated documentation, the name of the alias gets lost (the [HTML docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.slice.html) for `jax.lax.slice` report operand as type `Any`), so the documentation benefit does not go beyond the source code (though we could enable some `sphinx-autodoc` options to improve this: See [autodoc_type_aliases](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_type_aliases)). +For the sake of generated documentation, the name of the alias gets lost (the [HTML docs](https://docs.jax.dev/en/latest/_autosummary/jax.lax.slice.html) for `jax.lax.slice` report operand as type `Any`), so the documentation benefit does not go beyond the source code (though we could enable some `sphinx-autodoc` options to improve this: See [autodoc_type_aliases](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_type_aliases)). A benefit of this level of type annotation is that it is never wrong to annotate a value with `Any`, so it will provide a concrete benefit to developers and users in the form of documentation, without added complexity of satisfying the stricter needs of any particular static type checker. @@ -122,7 +122,7 @@ All told, the array-type-granularity challenge is less of an issue than the othe ### Challenge 5: imprecise APIs inherited from NumPy A large part of JAX’s user-facing API is inherited from NumPy within the {mod}`jax.numpy` submodule. -NumPy’s API was developed years before static type checking became part of the Python language, and follows Python’s historic recommendations to use a [duck-typing](https://docs.python.org/3/glossary.html#term-duck-typing)/[EAFP](https://docs.python.org/3/glossary.html#term-eafp) coding style, in which strict type-checking at runtime is discouraged. As a concrete example of this, consider the {func}`numpy.tile` function, which is defined like this: +NumPy’s API was developed years before static type checking became part of the Python language, and follows Python’s historic recommendations to use a [duck-typing](https://docs.python.org/3/glossary.html#term-duck-typing)/[EAFP](https://docs.python.org/3/glossary.html#term-EAFP) coding style, in which strict type-checking at runtime is discouraged. As a concrete example of this, consider the {func}`numpy.tile` function, which is defined like this: ```python def tile(A, reps): diff --git a/docs/jep/14273-shard-map.md b/docs/jep/14273-shard-map.md index 63742bc852c6..fa6681551d17 100644 --- a/docs/jep/14273-shard-map.md +++ b/docs/jep/14273-shard-map.md @@ -4,7 +4,7 @@ *January 2023* **This was the design doc proposing `shard_map`. You may instead want -[the up-to-date user docs](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html).** +[the up-to-date user docs](https://docs.jax.dev/en/latest/notebooks/shard_map.html).** ## Motivation @@ -18,7 +18,7 @@ We need great APIs for both, and rather than being mutually exclusive alternatives, they need to compose with each other. With `pjit` (now just `jit`) we have [a next-gen -API](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) +API](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) for the first school. But we haven't quite leveled-up the second school. `pmap` follows the second school, but over time we found it has [fatal flaws](#why-dont-pmap-or-xmap-already-solve-this). `xmap` solved those flaws, diff --git a/docs/jep/15856-jex.md b/docs/jep/15856-jex.md index a5625abf8930..a821405c399e 100644 --- a/docs/jep/15856-jex.md +++ b/docs/jep/15856-jex.md @@ -14,13 +14,13 @@ import jax.extend as jex Several projects depend on JAX's codebase internals, often to use its core machinery (e.g. to write a -[transformation over its IR](https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html)) +[transformation over its IR](https://docs.jax.dev/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html)) or to extend it (e.g. to [define new primitives](https://github.com/dfm/extending-jax)). Two challenges for these dependencies are (a) that our internals aren't all solidly designed for external use, and (b) that circumventing JAX's public API is -[unsupported](https://jax.readthedocs.io/en/latest/api_compatibility.html). +[unsupported](https://docs.jax.dev/en/latest/api_compatibility.html). In other words, our internals are often used like a library, but are neither structured nor updated like one. @@ -50,12 +50,12 @@ removed altogether. To keep development overhead low, `jax.extend` would not follow the public -[API compatibility](https://jax.readthedocs.io/en/latest/api_compatibility.html) +[API compatibility](https://docs.jax.dev/en/latest/api_compatibility.html) policy. It would promise no deprecation windows nor backwards compatibility between releases. Every release may break existing callers without simple recourse (e.g. without a flag reintroducing prior behavior). We would rely on the -[changelog](https://jax.readthedocs.io/en/latest/changelog.html) +[changelog](https://docs.jax.dev/en/latest/changelog.html) to call out such changes. Callers of `jax.extend` that need to upgrade their code regularly @@ -108,7 +108,7 @@ to process the Jaxpr IR (the output of At initialization, this module will contain many more symbols than what's needed to define primitives and rules, including various names used in setting up -["final-style transformations"](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing), +["final-style transformations"](https://docs.jax.dev/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing), such as the current `jax._src.core.Trace` and `Tracer` classes. We can revisit whether `jex.core` should also support final-style extensions alongside initial style approaches, and whether it can do so by a more @@ -137,7 +137,7 @@ tracer types from `jex`. This module plus `jex.core` ought to suffice for replicating today's custom primitive tutorials (e.g. -[ours](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html) +[ours](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html) and [dfm's](https://github.com/dfm/extending-jax)). For instance, defining a primitive and its behavior under `jax.jit` @@ -184,6 +184,6 @@ arrays. We have only one item in mind for now. The XLA compiler's array sharding format is more expressive than [those provided by -JAX](https://jax.readthedocs.io/en/latest/jax.sharding.html). We could +JAX](https://docs.jax.dev/en/latest/jax.sharding.html). We could provide this as `jex.sharding.XlaOpShardingProto`, corresponding to today's `jax._src.lib.xla_client.OpSharding` internally. diff --git a/docs/jep/17111-shmap-transpose.md b/docs/jep/17111-shmap-transpose.md index 2fdf5f822835..00d8a3f383fd 100644 --- a/docs/jep/17111-shmap-transpose.md +++ b/docs/jep/17111-shmap-transpose.md @@ -497,7 +497,7 @@ of every function instance along which the outputs are mapped, whereas for mesh axes over which the output is unmapped only one copy of the value is used. See [the `shmap` -JEP](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html) for examples +JEP](https://docs.jax.dev/en/latest/jep/14273-shard-map.html) for examples of unmapped inputs and outputs. For comparison, in `vmap` unmapped inputs/outputs are indicated by using `in_axes` / `out_axes` of `None` (rather than an `int`). diff --git a/docs/jep/2026-custom-derivatives.md b/docs/jep/2026-custom-derivatives.md index ce149fa6fb35..b09926425667 100644 --- a/docs/jep/2026-custom-derivatives.md +++ b/docs/jep/2026-custom-derivatives.md @@ -2,7 +2,7 @@ This is a design document, explaining some of the thinking behind the design and implementation of `jax.custom_jvp` and `jax.custom_vjp`. For user-oriented -documentation, see [the tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). +documentation, see [the tutorial notebook](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). There are two ways to define differentiation rules in JAX: 1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation diff --git a/docs/jep/4008-custom-vjp-update.md b/docs/jep/4008-custom-vjp-update.md index 1e2270e052a6..c3f2be151ef7 100644 --- a/docs/jep/4008-custom-vjp-update.md +++ b/docs/jep/4008-custom-vjp-update.md @@ -4,7 +4,7 @@ _Oct 14 2020_ This doc assumes familiarity with `jax.custom_vjp`, as described in the [Custom derivative rules for JAX-transformable Python -functions](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) +functions](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) notebook. ## What to update diff --git a/docs/jep/4410-omnistaging.md b/docs/jep/4410-omnistaging.md index f95c15f404b6..5b4536864ac2 100644 --- a/docs/jep/4410-omnistaging.md +++ b/docs/jep/4410-omnistaging.md @@ -266,7 +266,7 @@ While tracing the function ex1 at ex1.py:4, this value became a tracer due to JA You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions. -See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information. +See https://docs.jax.dev/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information. Encountered tracer value: Tracedwith ``` diff --git a/docs/jep/9407-type-promotion.ipynb b/docs/jep/9407-type-promotion.ipynb index a1ede3177a3a..5f12877c97a9 100644 --- a/docs/jep/9407-type-promotion.ipynb +++ b/docs/jep/9407-type-promotion.ipynb @@ -12,7 +12,7 @@ "\n", "*Jake VanderPlas, December 2021*\n", "\n", - "One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html)." + "One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html)." ] }, { @@ -1335,7 +1335,7 @@ "However, these advantages comes with a few tradeoffs:\n", "\n", "- mixed float/integer promotion is very prone to precision loss: for example, `int64` (with a maximum value of $9.2 \\times 10^{18}$) can be promoted to `float16` (with a maximum value of $6.5 \\times 10^4$), meaning most representable values will become `inf`.\n", - "- as mentioned above, `f*` can no longer be thought of as a \"scalar type\", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values.\n", + "- as mentioned above, `f*` can no longer be thought of as a \"scalar type\", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://docs.jax.dev/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values.\n", "\n", "Note that also, this approach still leaves the `uint64` promotion question unanswered, although it is perhaps reasonable to close the lattice by connecting `u64` to `f*`." ] @@ -1413,7 +1413,7 @@ "id": "o0-E2KWjYEXO" }, "source": [ - "The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch.\n", + "The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch.\n", "\n", "For those interested, the appendix below prints the full promotion tables used by NumPy, Tensorflow, PyTorch, and JAX." ] @@ -2883,7 +2883,7 @@ "source": [ "### JAX Type Promotion: `jax.numpy`\n", "\n", - "`jax.numpy` follows type promotion rules laid out at https://jax.readthedocs.io/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays." + "`jax.numpy` follows type promotion rules laid out at https://docs.jax.dev/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays." ] }, { diff --git a/docs/jep/9407-type-promotion.md b/docs/jep/9407-type-promotion.md index ff67a8c21399..c047d76c1b18 100644 --- a/docs/jep/9407-type-promotion.md +++ b/docs/jep/9407-type-promotion.md @@ -20,7 +20,7 @@ kernelspec: *Jake VanderPlas, December 2021* -One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). +One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). +++ {"id": "Rod6OOyUVbQ8"} @@ -680,7 +680,7 @@ This is important because `f16` and `bf16` are not comparable because they utili However, these advantages comes with a few tradeoffs: - mixed float/integer promotion is very prone to precision loss: for example, `int64` (with a maximum value of $9.2 \times 10^{18}$) can be promoted to `float16` (with a maximum value of $6.5 \times 10^4$), meaning most representable values will become `inf`. -- as mentioned above, `f*` can no longer be thought of as a "scalar type", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values. +- as mentioned above, `f*` can no longer be thought of as a "scalar type", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://docs.jax.dev/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values. Note that also, this approach still leaves the `uint64` promotion question unanswered, although it is perhaps reasonable to close the lattice by connecting `u64` to `f*`. @@ -730,7 +730,7 @@ nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos +++ {"id": "o0-E2KWjYEXO"} -The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch. +The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch. For those interested, the appendix below prints the full promotion tables used by NumPy, Tensorflow, PyTorch, and JAX. @@ -900,7 +900,7 @@ display.HTML(table.to_html()) ### JAX Type Promotion: `jax.numpy` -`jax.numpy` follows type promotion rules laid out at https://jax.readthedocs.io/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays. +`jax.numpy` follows type promotion rules laid out at https://docs.jax.dev/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays. ```{code-cell} :cellView: form diff --git a/docs/jep/9419-jax-versioning.md b/docs/jep/9419-jax-versioning.md index b964aa2af45d..85b95257ebae 100644 --- a/docs/jep/9419-jax-versioning.md +++ b/docs/jep/9419-jax-versioning.md @@ -167,16 +167,16 @@ We maintain an additional version number (`_version`) in [`xla_client.py` in the XLA repository](https://github.com/openxla/xla/blob/main/xla/python/xla_client.py). The idea is that this version number, is defined in `xla/python` together with the C++ parts of JAX, is also accessible to JAX Python as -`jax._src.lib.xla_extension_version`, and must +`jax._src.lib.jaxlib_extension_version`, and must be incremented every time that a change is made to the XLA/Python code that has backwards compatibility implications for `jax`. The JAX Python code can then use this version number to maintain backwards compatibility, e.g.: ``` -from jax._src.lib import xla_extension_version +from jax._src.lib import jaxlib_extension_version # 123 is the new version number for _version in xla_client.py -if xla_extension_version >= 123: +if jaxlib_extension_version >= 123: # Use new code path ... else: diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md index 5e5be308068a..093f5ec4ab72 100644 --- a/docs/jit-compilation.md +++ b/docs/jit-compilation.md @@ -55,7 +55,7 @@ The {ref}`jax-internals-jaxpr` section of the documentation provides more inform Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code. -If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). +If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https:docs.jax.devio/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). Impure functions are dangerous because under JAX transformations they are likely not to behave as intended; they might fail silently, or produce surprising downstream errors like leaked Tracers. Moreover, JAX often can't detect when side effects are present. diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index a1435c4e557e..de6da98b7d62 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -346,7 +346,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html\n" + "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html\n" ] } ], @@ -365,7 +365,7 @@ "source": [ "Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions.\n", "\n", - "Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." + "Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." ] }, { @@ -521,7 +521,7 @@ "id": "sTjJ3WuaDyqU" }, "source": [ - "For more details on indexed array updates, see the [documentation for the `.at` property](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." + "For more details on indexed array updates, see the [documentation for the `.at` property](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." ] }, { @@ -604,7 +604,7 @@ "id": "NAcXJNAcDi_v" }, "source": [ - "If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example:" + "If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example:" ] }, { @@ -971,7 +971,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[5])\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" + "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[5])\n\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" ] } ], @@ -1296,7 +1296,7 @@ "While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ.\n", "Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge.\n", "\n", - "- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html) for more details.\n", + "- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html) for more details.\n", "- When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.tensorflow.org/xla/operation_semantics#convertelementtype).\n", "\n", " Here is an example of an unsafe cast with differing results between NumPy and JAX:\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 80ab69be1ed8..9fbc26a46c8f 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -201,7 +201,7 @@ jax_array[1, :] = 1.0 Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions. -Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +++ {"id": "hfloZ1QXCS_J"} @@ -261,7 +261,7 @@ print(new_jax_array) +++ {"id": "sTjJ3WuaDyqU"} -For more details on indexed array updates, see the [documentation for the `.at` property](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +For more details on indexed array updates, see the [documentation for the `.at` property](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +++ {"id": "oZ_jE2WAypdL"} @@ -292,7 +292,7 @@ jnp.arange(10)[11] +++ {"id": "NAcXJNAcDi_v"} -If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example: +If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example: ```{code-cell} ipython3 :id: -0-MaFddO-xy @@ -664,7 +664,7 @@ x.dtype # --> dtype('float64') While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ. Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge. -- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html) for more details. +- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html) for more details. - When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.tensorflow.org/xla/operation_semantics#convertelementtype). Here is an example of an unsafe cast with differing results between NumPy and JAX: diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index e550cbf36da3..e80c7ae94687 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -17,9 +17,9 @@ "1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and\n", "2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.\n", "\n", - "This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html).\n", + "This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html).\n", "\n", - "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) and [jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs." + "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/jax.html#jax.jvp) and [jax.grad](https://docs.jax.dev/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs." ] }, { @@ -2035,7 +2035,7 @@ "source": [ "### Working with `list` / `tuple` / `dict` containers (and other pytrees)\n", "\n", - "You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. \n", + "You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. \n", "\n", "Here's a contrived example with `jax.custom_jvp`:" ] diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 8a63f142693e..82b97e195bd9 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -24,9 +24,9 @@ There are two ways to define differentiation rules in JAX: 1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and 2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems. -This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html). +This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html). -For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) and [jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs. +For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/jax.html#jax.jvp) and [jax.grad](https://docs.jax.dev/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs. +++ {"id": "9Fg3NFNY-2RY"} @@ -1048,7 +1048,7 @@ Array(-0.91113025, dtype=float32) ### Working with `list` / `tuple` / `dict` containers (and other pytrees) -You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. +You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. Here's a contrived example with `jax.custom_jvp`: diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index 8abee469d552..90d92c4ea241 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -1276,7 +1276,7 @@ "id": "3qfPjJdhgerc" }, "source": [ - "So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices)." + "So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices)." ] }, { @@ -1382,7 +1382,7 @@ "id": "6ZYcK8eXrn0p" }, "source": [ - "We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information.\n", + "We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information.\n", "\n", "When arrays are _not_ explicitly placed or sharded with `jax.device_put`, they are placed _uncommitted_ on the default device.\n", "Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices.\n", @@ -2339,7 +2339,7 @@ "source": [ "### Generating random numbers\n", "\n", - "JAX comes with a functional, deterministic [random number generator](https://jax.readthedocs.io/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://jax.readthedocs.io/en/latest/jax.random.html), such as `jax.random.uniform`.\n", + "JAX comes with a functional, deterministic [random number generator](https://docs.jax.dev/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://docs.jax.dev/en/latest/jax.random.html), such as `jax.random.uniform`.\n", "\n", "JAX's random numbers are produced by a counter-based PRNG, so in principle, random number generation should be a pure map over counter values. A pure map is a trivially partitionable operation in principle. It should require no cross-device communication, nor any redundant computation across devices.\n", "\n", diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index c207f0ae4a00..79990fefb95d 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -427,7 +427,7 @@ jax.debug.visualize_array_sharding(w_copy) +++ {"id": "3qfPjJdhgerc"} -So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices). +So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices). +++ {"id": "QRB95LaWuT80"} @@ -484,7 +484,7 @@ except ValueError as e: print_exception(e) +++ {"id": "6ZYcK8eXrn0p"} -We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information. +We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information. When arrays are _not_ explicitly placed or sharded with `jax.device_put`, they are placed _uncommitted_ on the default device. Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices. @@ -854,7 +854,7 @@ outputId: 479c4d81-cb0b-40a5-89ba-394c10dc3297 ### Generating random numbers -JAX comes with a functional, deterministic [random number generator](https://jax.readthedocs.io/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://jax.readthedocs.io/en/latest/jax.random.html), such as `jax.random.uniform`. +JAX comes with a functional, deterministic [random number generator](https://docs.jax.dev/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://docs.jax.dev/en/latest/jax.random.html), such as `jax.random.uniform`. JAX's random numbers are produced by a counter-based PRNG, so in principle, random number generation should be a pure map over counter values. A pure map is a trivially partitionable operation in principle. It should require no cross-device communication, nor any redundant computation across devices. diff --git a/docs/notebooks/README.md b/docs/notebooks/README.md index 07be4441ade8..c945c197ad19 100644 --- a/docs/notebooks/README.md +++ b/docs/notebooks/README.md @@ -1,2 +1,2 @@ For instructions on how to change and test notebooks, see -[Update Documentation](https://jax.readthedocs.io/en/latest/developer.html#update-documentation). +[Update Documentation](https://docs.jax.dev/en/latest/developer.html#update-documentation). diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb index 00ba9186eeec..d22457c5d718 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -24,7 +24,7 @@ "\n", "Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And we'll get composability with all the other transformations for free.\n", "\n", - "**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://jax.readthedocs.io/en/latest/jax.html) should be assumed internal.**" + "**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://docs.jax.dev/en/latest/jax.html) should be assumed internal.**" ] }, { @@ -215,8 +215,8 @@ "# Importing Jax functions useful for tracing/interpreting.\n", "from functools import wraps\n", "\n", - "from jax import core\n", "from jax import lax\n", + "from jax.extend import core\n", "from jax._src.util import safe_map" ] }, diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index 10c4e7cb6e3b..ad707a9746fc 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -27,7 +27,7 @@ etc.) that enable writing concise, accelerated code. Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And we'll get composability with all the other transformations for free. -**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://jax.readthedocs.io/en/latest/jax.html) should be assumed internal.** +**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://docs.jax.dev/en/latest/jax.html) should be assumed internal.** ```{code-cell} ipython3 :id: s27RDKvKXFL8 @@ -147,8 +147,8 @@ Let's use `make_jaxpr` to trace a function into a Jaxpr. # Importing Jax functions useful for tracing/interpreting. from functools import wraps -from jax import core from jax import lax +from jax.extend import core from jax._src.util import safe_map ``` diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index feb906546341..d8a74e4b15fd 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -348,7 +348,7 @@ "source": [ "### Let's think step by step\n", "\n", - "You might want to first (re)read [the Autodiff Cookbook Part 1](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "You might want to first (re)read [the Autodiff Cookbook Part 1](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index 8ba87dcfee18..12564bd91f30 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -156,7 +156,7 @@ print_fwd_bwd(f3, W1, W2, W3, x) ### Let's think step by step -You might want to first (re)read [the Autodiff Cookbook Part 1](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). +You might want to first (re)read [the Autodiff Cookbook Part 1](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). +++ {"id": "VMfwm_yinvoZ"} diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb index c31a99746866..a909d9329e24 100644 --- a/docs/notebooks/neural_network_with_tfds_data.ipynb +++ b/docs/notebooks/neural_network_with_tfds_data.ipynb @@ -46,7 +46,7 @@ "\n", "![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n", "\n", - "Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", + "Let's combine everything we showed in the [quickstart](https://docs.jax.dev/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", "\n", "Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model." ] diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index 53b7d47358c2..9c153d704763 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -44,7 +44,7 @@ _Forked from_ `neural_network_and_data_loading.ipynb` ![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png) -Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). +Let's combine everything we showed in the [quickstart](https://docs.jax.dev/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model. diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index d73b0d4c0f3e..ecfa199c6b52 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -13,9 +13,9 @@ "\n", "`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.\n", "\n", - "`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", + "`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", "\n", - "If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this))\n", + "If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://docs.jax.dev/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this))\n", "\n", "By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies.\n", "\n", @@ -499,7 +499,7 @@ "* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`;\n", "* `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;\n", "* `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually;\n", - "* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html)).\n", + "* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)).\n", "\n", "The shapes of the arguments passed to `f` have the same ranks as the arguments\n", "passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed\n", @@ -1520,7 +1520,7 @@ "source": [ "Compare these examples with the purely [automatic partitioning examples in the\n", "\"Distributed arrays and automatic partitioning\"\n", - "doc](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", + "doc](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", "While in those automatic partitioning examples we don't need to edit the model\n", "functions to use different parallelization strategies, with `shard_map` we\n", "often do.\n", @@ -1626,7 +1626,7 @@ "parameters from the forward pass for use on the backward pass. Instead, we want\n", "to gather them again on the backward pass. We can express that by using\n", "`jax.remat` with a [custom\n", - "policy](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable)\n", + "policy](https://docs.jax.dev/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable)\n", "(or a `custom_vjp`), though XLA typically does that rematerialization\n", "automatically.\n", "\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index c52cf0e6d22b..095f37d0dde1 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -22,9 +22,9 @@ kernelspec: `shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations. -`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. +`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. -If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this)) +If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://docs.jax.dev/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this)) By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies. @@ -346,7 +346,7 @@ where: * `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`; * `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively; * `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually; -* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html)). +* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)). The shapes of the arguments passed to `f` have the same ranks as the arguments passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed @@ -1061,7 +1061,7 @@ params, batch = init(jax.random.key(0), layer_sizes, batch_size) Compare these examples with the purely [automatic partitioning examples in the "Distributed arrays and automatic partitioning" -doc](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). +doc](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). While in those automatic partitioning examples we don't need to edit the model functions to use different parallelization strategies, with `shard_map` we often do. @@ -1137,7 +1137,7 @@ There's one other ingredient we need: we don't want to store the fully gathered parameters from the forward pass for use on the backward pass. Instead, we want to gather them again on the backward pass. We can express that by using `jax.remat` with a [custom -policy](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable) +policy](https://docs.jax.dev/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable) (or a `custom_vjp`), though XLA typically does that rematerialization automatically. diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index 5ddcdd32e2b4..d6cbf6e02198 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -248,7 +248,7 @@ "id": "yRYF0YgO3F4H" }, "source": [ - "For updating individual elements, JAX provides an [indexed update syntax](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy:" + "For updating individual elements, JAX provides an [indexed update syntax](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy:" ] }, { @@ -423,7 +423,7 @@ "id": "0GPqgT7S0q8r" }, "source": [ - "Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html):" + "Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv_general_dilated.html):" ] }, { @@ -461,7 +461,7 @@ "id": "7mdo6ycczlbd" }, "source": [ - "This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions).\n", + "This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://docs.jax.dev/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions).\n", "\n", "At their heart, all `jax.lax` operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by [XLA:ConvWithGeneralPadding](https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution).\n", "Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation." @@ -562,7 +562,7 @@ "id": "3GvisB-CA9M8" }, "source": [ - "But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)):" + "But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://docs.jax.dev/en/latest/async_dispatch.html)):" ] }, { @@ -650,7 +650,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[10])\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" + "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[10])\n\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" ] } ], @@ -835,7 +835,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mTracerBoolConversionError\u001b[0m\u001b[0;31m:\u001b[0m Attempted boolean conversion of traced array with shape bool[]..\nThe error occurred while tracing the function f at :1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError\n" + "\u001b[0;31mTracerBoolConversionError\u001b[0m\u001b[0;31m:\u001b[0m Attempted boolean conversion of traced array with shape bool[]..\nThe error occurred while tracing the function f at :1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n" ] } ], diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index 0693f6ba8579..7b0bb0d9b8ce 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -117,7 +117,7 @@ x[0] = 10 +++ {"id": "yRYF0YgO3F4H"} -For updating individual elements, JAX provides an [indexed update syntax](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy: +For updating individual elements, JAX provides an [indexed update syntax](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy: ```{code-cell} ipython3 :id: 8zqPEAeP3UK5 @@ -189,7 +189,7 @@ jnp.convolve(x, y) +++ {"id": "0GPqgT7S0q8r"} -Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html): +Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv_general_dilated.html): ```{code-cell} ipython3 :id: pi4f6ikjzc3l @@ -206,7 +206,7 @@ result[0, 0] +++ {"id": "7mdo6ycczlbd"} -This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions). +This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://docs.jax.dev/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions). At their heart, all `jax.lax` operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by [XLA:ConvWithGeneralPadding](https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution). Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation. @@ -261,7 +261,7 @@ np.allclose(norm(X), norm_compiled(X), atol=1E-6) +++ {"id": "3GvisB-CA9M8"} -But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)): +But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://docs.jax.dev/en/latest/async_dispatch.html)): ```{code-cell} ipython3 :id: 6mUB6VdDAEIY diff --git a/docs/notes.rst b/docs/notes.rst index 08265638000e..502385142b16 100644 --- a/docs/notes.rst +++ b/docs/notes.rst @@ -9,9 +9,6 @@ Dependencies and version compatibility: - :doc:`api_compatibility` outlines JAX's policies with regard to API compatibility across releases. - :doc:`deprecation` outlines JAX's policies with regard to compatibility with Python and NumPy. -Migrations and deprecations: - - :doc:`jax_array_migration` summarizes the changes to the default array type in jax v 0.4.1 - Memory and computation usage: - :doc:`async_dispatch` describes JAX's asynchronous dispatch model. - :doc:`concurrency` describes how JAX interacts with other Python concurrency. @@ -20,6 +17,10 @@ Memory and computation usage: Programmer guardrails: - :doc:`rank_promotion_warning` describes how to configure :mod:`jax.numpy` to avoid implicit rank promotion. +Arrays and data types: + - :doc:`type_promotion` describes JAX's implicit type promotion for functions of two or more values. + - :doc:`default_dtypes` describes how JAX determines the default dtype for array creation functions. + .. toctree:: :hidden: @@ -27,8 +28,9 @@ Programmer guardrails: api_compatibility deprecation - jax_array_migration async_dispatch concurrency gpu_memory_allocation - rank_promotion_warning \ No newline at end of file + rank_promotion_warning + type_promotion + default_dtypes diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index 2b1cad7c9a66..7533e6eda053 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -5,7 +5,7 @@ This is the list of changes specific to {class}`jax.experimental.pallas`. -For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/changelog.html). +For the overall JAX change log see [here](https://docs.jax.dev/en/latest/changelog.html).